Skip to main content

ferritin_plms/ligandmpnn/
proteinfeatures.rs

1//! Protein Featurizer for ProteinMPNN/LignadMPNN
2//!
3//! Extract protein features for ligandmpnn
4//!
5//! Returns a set of features calculated from protein structure
6//! including:
7//! - Residue-level features like amino acid type, secondary structure
8//! - Geometric features like distances, angles
9//! - Chemical features like hydrophobicity, charge
10//! - Evolutionary features from MSA profiles
11use crate::featurize::utilities::aa1to_int;
12use candle_core::{Device, Result, Tensor};
13use std::collections::{HashMap, HashSet};
14
15#[allow(dead_code)]
16pub struct ProteinFeatures {
17    /// protein amino acids sequences as 1D Tensor of u32
18    pub(crate) s: Tensor,
19    /// protein co-oords by residue [batch, seqlength, 37, 3]
20    pub(crate) x: Tensor,
21    /// protein mask by residue
22    pub(crate) x_mask: Option<Tensor>,
23    /// ligand coords
24    pub(crate) y: Tensor,
25    /// encoded ligand atom names
26    pub(crate) y_t: Tensor,
27    /// ligand mask
28    pub(crate) y_m: Option<Tensor>,
29    /// R_idx:         Tensor dimensions: torch.Size([93])          # protein residue indices shape=[length]
30    pub(crate) r_idx: Tensor,
31    /// chain_labels:  Tensor dimensions: torch.Size([93])          # protein chain letters shape=[length]
32    pub(crate) chain_labels: Option<Vec<f64>>,
33    /// chain_letters: NumPy array dimensions: (93,)
34    pub(crate) chain_letters: Vec<String>,
35    /// mask_c:        Tensor dimensions: torch.Size([93])
36    pub(crate) mask_c: Option<Tensor>,
37    pub(crate) chain_list: Vec<String>,
38}
39impl ProteinFeatures {
40    pub fn get_coords(&self) -> &Tensor {
41        &self.x
42    }
43    pub fn get_sequence(&self) -> &Tensor {
44        &self.s
45    }
46    pub fn get_sequence_mask(&self) -> Option<&Tensor> {
47        self.x_mask.as_ref()
48    }
49    pub fn get_residue_index(&self) -> &Tensor {
50        &self.r_idx
51    }
52    pub fn get_encoded(
53        &self,
54    ) -> Result<(Vec<String>, HashMap<String, usize>, HashMap<usize, String>)> {
55        let r_idx_list = &self.r_idx.flatten_all()?.to_vec1::<u32>()?;
56        let chain_letters_list = &self.chain_letters;
57        let encoded_residues: Vec<String> = r_idx_list
58            .iter()
59            .enumerate()
60            .map(|(i, r_idx)| format!("{}{}", chain_letters_list[i], r_idx))
61            .collect();
62        let encoded_residue_dict: HashMap<String, usize> = encoded_residues
63            .iter()
64            .enumerate()
65            .map(|(i, s)| (s.clone(), i))
66            .collect();
67        let encoded_residue_dict_rev: HashMap<usize, String> = encoded_residues
68            .iter()
69            .enumerate()
70            .map(|(i, s)| (i, s.clone()))
71            .collect();
72        Ok((
73            encoded_residues,
74            encoded_residue_dict,
75            encoded_residue_dict_rev,
76        ))
77    }
78    // Fixed Residue List --> Tensor of 1/0
79    // Inputs: `"C1 C2 C3 C4 C5 C6 C7 C8 C9 C10`
80    pub fn get_encoded_tensor(&self, fixed_residues: String, device: &Device) -> Result<Tensor> {
81        let res_set: HashSet<String> = fixed_residues.split(' ').map(String::from).collect();
82        let (encoded_res, _, _) = &self.get_encoded()?;
83        Tensor::from_iter(
84            encoded_res
85                .iter()
86                .map(|item| u32::from(!res_set.contains(item))),
87            device,
88        )
89    }
90    pub fn get_chain_mask_tensor(
91        &self,
92        chains_to_design: Vec<String>,
93        device: &Device,
94    ) -> Result<Tensor> {
95        let mask_values: Vec<u32> = self
96            .chain_letters
97            .iter()
98            .map(|chain| u32::from(chains_to_design.contains(chain)))
99            .collect();
100        Tensor::from_iter(mask_values, device)
101    }
102    pub fn update_mask(&mut self, tensor: Tensor) -> Result<()> {
103        self.x_mask = match self.x_mask.as_ref() {
104            Some(mask) => Some(mask.mul(&tensor)?),
105            None => Some(tensor),
106        };
107        Ok(())
108    }
109    // Fixed Residue List --> Tensor of length 21
110    // Inputs: `A:10.0"`
111    pub fn create_bias_tensor(&self, bias_aa: Option<String>) -> Result<Tensor> {
112        let device = self.s.device();
113        let dtype = self.s.dtype();
114        match bias_aa {
115            None => Tensor::zeros(21, dtype, device),
116            Some(bias_aa) => {
117                let mut bias_values = vec![0.0f32; 21];
118                for pair in bias_aa.split(',') {
119                    if let Some((aa, value_str)) = pair.split_once(':') {
120                        if let Ok(value) = value_str.parse::<f32>() {
121                            if let Some(aa_char) = aa.chars().next() {
122                                let idx = aa1to_int(aa_char) as usize;
123                                bias_values[idx] = value;
124                            }
125                        }
126                    }
127                }
128                Tensor::from_slice(&bias_values, 21, device)
129            }
130        }
131    }
132    pub fn save_to_safetensor(&self, path: &str) -> Result<()> {
133        let mut tensors: HashMap<String, Tensor> = HashMap::new();
134        // this is only one field. need to do the rest of the fields
135        tensors.insert("protein_atom_sequence".to_string(), self.s.clone());
136        tensors.insert("protein_atom_positions".to_string(), self.x.clone());
137        tensors.insert("ligand_atom_positions".to_string(), self.y.clone());
138        tensors.insert("ligand_atom_name".to_string(), self.y_t.clone());
139        candle_core::safetensors::save(&tensors, path)?;
140        Ok(())
141    }
142}