ferritin_plms/ligandmpnn/
proteinfeatures.rs1use crate::featurize::utilities::aa1to_int;
12use candle_core::{Device, Result, Tensor};
13use std::collections::{HashMap, HashSet};
14
15#[allow(dead_code)]
16pub struct ProteinFeatures {
17 pub(crate) s: Tensor,
19 pub(crate) x: Tensor,
21 pub(crate) x_mask: Option<Tensor>,
23 pub(crate) y: Tensor,
25 pub(crate) y_t: Tensor,
27 pub(crate) y_m: Option<Tensor>,
29 pub(crate) r_idx: Tensor,
31 pub(crate) chain_labels: Option<Vec<f64>>,
33 pub(crate) chain_letters: Vec<String>,
35 pub(crate) mask_c: Option<Tensor>,
37 pub(crate) chain_list: Vec<String>,
38}
39impl ProteinFeatures {
40 pub fn get_coords(&self) -> &Tensor {
41 &self.x
42 }
43 pub fn get_sequence(&self) -> &Tensor {
44 &self.s
45 }
46 pub fn get_sequence_mask(&self) -> Option<&Tensor> {
47 self.x_mask.as_ref()
48 }
49 pub fn get_residue_index(&self) -> &Tensor {
50 &self.r_idx
51 }
52 pub fn get_encoded(
53 &self,
54 ) -> Result<(Vec<String>, HashMap<String, usize>, HashMap<usize, String>)> {
55 let r_idx_list = &self.r_idx.flatten_all()?.to_vec1::<u32>()?;
56 let chain_letters_list = &self.chain_letters;
57 let encoded_residues: Vec<String> = r_idx_list
58 .iter()
59 .enumerate()
60 .map(|(i, r_idx)| format!("{}{}", chain_letters_list[i], r_idx))
61 .collect();
62 let encoded_residue_dict: HashMap<String, usize> = encoded_residues
63 .iter()
64 .enumerate()
65 .map(|(i, s)| (s.clone(), i))
66 .collect();
67 let encoded_residue_dict_rev: HashMap<usize, String> = encoded_residues
68 .iter()
69 .enumerate()
70 .map(|(i, s)| (i, s.clone()))
71 .collect();
72 Ok((
73 encoded_residues,
74 encoded_residue_dict,
75 encoded_residue_dict_rev,
76 ))
77 }
78 pub fn get_encoded_tensor(&self, fixed_residues: String, device: &Device) -> Result<Tensor> {
81 let res_set: HashSet<String> = fixed_residues.split(' ').map(String::from).collect();
82 let (encoded_res, _, _) = &self.get_encoded()?;
83 Tensor::from_iter(
84 encoded_res
85 .iter()
86 .map(|item| u32::from(!res_set.contains(item))),
87 device,
88 )
89 }
90 pub fn get_chain_mask_tensor(
91 &self,
92 chains_to_design: Vec<String>,
93 device: &Device,
94 ) -> Result<Tensor> {
95 let mask_values: Vec<u32> = self
96 .chain_letters
97 .iter()
98 .map(|chain| u32::from(chains_to_design.contains(chain)))
99 .collect();
100 Tensor::from_iter(mask_values, device)
101 }
102 pub fn update_mask(&mut self, tensor: Tensor) -> Result<()> {
103 self.x_mask = match self.x_mask.as_ref() {
104 Some(mask) => Some(mask.mul(&tensor)?),
105 None => Some(tensor),
106 };
107 Ok(())
108 }
109 pub fn create_bias_tensor(&self, bias_aa: Option<String>) -> Result<Tensor> {
112 let device = self.s.device();
113 let dtype = self.s.dtype();
114 match bias_aa {
115 None => Tensor::zeros(21, dtype, device),
116 Some(bias_aa) => {
117 let mut bias_values = vec![0.0f32; 21];
118 for pair in bias_aa.split(',') {
119 if let Some((aa, value_str)) = pair.split_once(':') {
120 if let Ok(value) = value_str.parse::<f32>() {
121 if let Some(aa_char) = aa.chars().next() {
122 let idx = aa1to_int(aa_char) as usize;
123 bias_values[idx] = value;
124 }
125 }
126 }
127 }
128 Tensor::from_slice(&bias_values, 21, device)
129 }
130 }
131 }
132 pub fn save_to_safetensor(&self, path: &str) -> Result<()> {
133 let mut tensors: HashMap<String, Tensor> = HashMap::new();
134 tensors.insert("protein_atom_sequence".to_string(), self.s.clone());
136 tensors.insert("protein_atom_positions".to_string(), self.x.clone());
137 tensors.insert("ligand_atom_positions".to_string(), self.y.clone());
138 tensors.insert("ligand_atom_name".to_string(), self.y_t.clone());
139 candle_core::safetensors::save(&tensors, path)?;
140 Ok(())
141 }
142}