ferritin_plms/ligandmpnn/
proteinfeatures.rs1use crate::featurize::utilities::aa1to_int;
12use candle_core::{Device, Result, Tensor};
13use std::collections::{HashMap, HashSet};
14
15pub struct ProteinFeatures {
16 pub(crate) s: Tensor,
18 pub(crate) x: Tensor,
20 pub(crate) x_mask: Option<Tensor>,
22 pub(crate) y: Tensor,
24 pub(crate) y_t: Tensor,
26 pub(crate) y_m: Option<Tensor>,
28 pub(crate) r_idx: Tensor,
30 pub(crate) chain_labels: Option<Vec<f64>>,
32 pub(crate) chain_letters: Vec<String>,
34 pub(crate) mask_c: Option<Tensor>,
36 pub(crate) chain_list: Vec<String>,
37}
38impl ProteinFeatures {
39 pub fn get_coords(&self) -> &Tensor {
40 &self.x
41 }
42 pub fn get_sequence(&self) -> &Tensor {
43 &self.s
44 }
45 pub fn get_sequence_mask(&self) -> Option<&Tensor> {
46 self.x_mask.as_ref()
47 }
48 pub fn get_residue_index(&self) -> &Tensor {
49 &self.r_idx
50 }
51 pub fn get_encoded(
52 &self,
53 ) -> Result<(Vec<String>, HashMap<String, usize>, HashMap<usize, String>)> {
54 let r_idx_list = &self.r_idx.flatten_all()?.to_vec1::<u32>()?;
55 let chain_letters_list = &self.chain_letters;
56 let encoded_residues: Vec<String> = r_idx_list
57 .iter()
58 .enumerate()
59 .map(|(i, r_idx)| format!("{}{}", chain_letters_list[i], r_idx))
60 .collect();
61 let encoded_residue_dict: HashMap<String, usize> = encoded_residues
62 .iter()
63 .enumerate()
64 .map(|(i, s)| (s.clone(), i))
65 .collect();
66 let encoded_residue_dict_rev: HashMap<usize, String> = encoded_residues
67 .iter()
68 .enumerate()
69 .map(|(i, s)| (i, s.clone()))
70 .collect();
71 Ok((
72 encoded_residues,
73 encoded_residue_dict,
74 encoded_residue_dict_rev,
75 ))
76 }
77 pub fn get_encoded_tensor(&self, fixed_residues: String, device: &Device) -> Result<Tensor> {
80 let res_set: HashSet<String> = fixed_residues.split(' ').map(String::from).collect();
81 let (encoded_res, _, _) = &self.get_encoded()?;
82 Tensor::from_iter(
83 encoded_res
84 .iter()
85 .map(|item| u32::from(!res_set.contains(item))),
86 device,
87 )
88 }
89 pub fn get_chain_mask_tensor(
90 &self,
91 chains_to_design: Vec<String>,
92 device: &Device,
93 ) -> Result<Tensor> {
94 let mask_values: Vec<u32> = self
95 .chain_letters
96 .iter()
97 .map(|chain| u32::from(chains_to_design.contains(chain)))
98 .collect();
99 Tensor::from_iter(mask_values, device)
100 }
101 pub fn update_mask(&mut self, tensor: Tensor) -> Result<()> {
102 self.x_mask = match self.x_mask.as_ref() {
103 Some(mask) => Some(mask.mul(&tensor)?),
104 None => Some(tensor),
105 };
106 Ok(())
107 }
108 pub fn create_bias_tensor(&self, bias_aa: Option<String>) -> Result<Tensor> {
111 let device = self.s.device();
112 let dtype = self.s.dtype();
113 match bias_aa {
114 None => Tensor::zeros(21, dtype, device),
115 Some(bias_aa) => {
116 let mut bias_values = vec![0.0f32; 21];
117 for pair in bias_aa.split(',') {
118 if let Some((aa, value_str)) = pair.split_once(':') {
119 if let Ok(value) = value_str.parse::<f32>() {
120 if let Some(aa_char) = aa.chars().next() {
121 let idx = aa1to_int(aa_char) as usize;
122 bias_values[idx] = value;
123 }
124 }
125 }
126 }
127 Tensor::from_slice(&bias_values, 21, device)
128 }
129 }
130 }
131 pub fn save_to_safetensor(&self, path: &str) -> Result<()> {
132 let mut tensors: HashMap<String, Tensor> = HashMap::new();
133 tensors.insert("protein_atom_sequence".to_string(), self.s.clone());
135 tensors.insert("protein_atom_positions".to_string(), self.x.clone());
136 tensors.insert("ligand_atom_positions".to_string(), self.y.clone());
137 tensors.insert("ligand_atom_name".to_string(), self.y_t.clone());
138 candle_core::safetensors::save(&tensors, path)?;
139 Ok(())
140 }
141}