ferritin_plms/ligandmpnn/
proteinfeatures.rs

1//! Protein Featurizer for ProteinMPNN/LignadMPNN
2//!
3//! Extract protein features for ligandmpnn
4//!
5//! Returns a set of features calculated from protein structure
6//! including:
7//! - Residue-level features like amino acid type, secondary structure
8//! - Geometric features like distances, angles
9//! - Chemical features like hydrophobicity, charge
10//! - Evolutionary features from MSA profiles
11use crate::featurize::utilities::aa1to_int;
12use candle_core::{Device, Result, Tensor};
13use std::collections::{HashMap, HashSet};
14
15pub struct ProteinFeatures {
16    /// protein amino acids sequences as 1D Tensor of u32
17    pub(crate) s: Tensor,
18    /// protein co-oords by residue [batch, seqlength, 37, 3]
19    pub(crate) x: Tensor,
20    /// protein mask by residue
21    pub(crate) x_mask: Option<Tensor>,
22    /// ligand coords
23    pub(crate) y: Tensor,
24    /// encoded ligand atom names
25    pub(crate) y_t: Tensor,
26    /// ligand mask
27    pub(crate) y_m: Option<Tensor>,
28    /// R_idx:         Tensor dimensions: torch.Size([93])          # protein residue indices shape=[length]
29    pub(crate) r_idx: Tensor,
30    /// chain_labels:  Tensor dimensions: torch.Size([93])          # protein chain letters shape=[length]
31    pub(crate) chain_labels: Option<Vec<f64>>,
32    /// chain_letters: NumPy array dimensions: (93,)
33    pub(crate) chain_letters: Vec<String>,
34    /// mask_c:        Tensor dimensions: torch.Size([93])
35    pub(crate) mask_c: Option<Tensor>,
36    pub(crate) chain_list: Vec<String>,
37}
38impl ProteinFeatures {
39    pub fn get_coords(&self) -> &Tensor {
40        &self.x
41    }
42    pub fn get_sequence(&self) -> &Tensor {
43        &self.s
44    }
45    pub fn get_sequence_mask(&self) -> Option<&Tensor> {
46        self.x_mask.as_ref()
47    }
48    pub fn get_residue_index(&self) -> &Tensor {
49        &self.r_idx
50    }
51    pub fn get_encoded(
52        &self,
53    ) -> Result<(Vec<String>, HashMap<String, usize>, HashMap<usize, String>)> {
54        let r_idx_list = &self.r_idx.flatten_all()?.to_vec1::<u32>()?;
55        let chain_letters_list = &self.chain_letters;
56        let encoded_residues: Vec<String> = r_idx_list
57            .iter()
58            .enumerate()
59            .map(|(i, r_idx)| format!("{}{}", chain_letters_list[i], r_idx))
60            .collect();
61        let encoded_residue_dict: HashMap<String, usize> = encoded_residues
62            .iter()
63            .enumerate()
64            .map(|(i, s)| (s.clone(), i))
65            .collect();
66        let encoded_residue_dict_rev: HashMap<usize, String> = encoded_residues
67            .iter()
68            .enumerate()
69            .map(|(i, s)| (i, s.clone()))
70            .collect();
71        Ok((
72            encoded_residues,
73            encoded_residue_dict,
74            encoded_residue_dict_rev,
75        ))
76    }
77    // Fixed Residue List --> Tensor of 1/0
78    // Inputs: `"C1 C2 C3 C4 C5 C6 C7 C8 C9 C10`
79    pub fn get_encoded_tensor(&self, fixed_residues: String, device: &Device) -> Result<Tensor> {
80        let res_set: HashSet<String> = fixed_residues.split(' ').map(String::from).collect();
81        let (encoded_res, _, _) = &self.get_encoded()?;
82        Tensor::from_iter(
83            encoded_res
84                .iter()
85                .map(|item| u32::from(!res_set.contains(item))),
86            device,
87        )
88    }
89    pub fn get_chain_mask_tensor(
90        &self,
91        chains_to_design: Vec<String>,
92        device: &Device,
93    ) -> Result<Tensor> {
94        let mask_values: Vec<u32> = self
95            .chain_letters
96            .iter()
97            .map(|chain| u32::from(chains_to_design.contains(chain)))
98            .collect();
99        Tensor::from_iter(mask_values, device)
100    }
101    pub fn update_mask(&mut self, tensor: Tensor) -> Result<()> {
102        self.x_mask = match self.x_mask.as_ref() {
103            Some(mask) => Some(mask.mul(&tensor)?),
104            None => Some(tensor),
105        };
106        Ok(())
107    }
108    // Fixed Residue List --> Tensor of length 21
109    // Inputs: `A:10.0"`
110    pub fn create_bias_tensor(&self, bias_aa: Option<String>) -> Result<Tensor> {
111        let device = self.s.device();
112        let dtype = self.s.dtype();
113        match bias_aa {
114            None => Tensor::zeros(21, dtype, device),
115            Some(bias_aa) => {
116                let mut bias_values = vec![0.0f32; 21];
117                for pair in bias_aa.split(',') {
118                    if let Some((aa, value_str)) = pair.split_once(':') {
119                        if let Ok(value) = value_str.parse::<f32>() {
120                            if let Some(aa_char) = aa.chars().next() {
121                                let idx = aa1to_int(aa_char) as usize;
122                                bias_values[idx] = value;
123                            }
124                        }
125                    }
126                }
127                Tensor::from_slice(&bias_values, 21, device)
128            }
129        }
130    }
131    pub fn save_to_safetensor(&self, path: &str) -> Result<()> {
132        let mut tensors: HashMap<String, Tensor> = HashMap::new();
133        // this is only one field. need to do the rest of the fields
134        tensors.insert("protein_atom_sequence".to_string(), self.s.clone());
135        tensors.insert("protein_atom_positions".to_string(), self.x.clone());
136        tensors.insert("ligand_atom_positions".to_string(), self.y.clone());
137        tensors.insert("ligand_atom_name".to_string(), self.y_t.clone());
138        candle_core::safetensors::save(&tensors, path)?;
139        Ok(())
140    }
141}