ferritin_plms/featurize/
structure_features.rs

1//!  Protein->Tensor utilities useful for Machine Learning
2use super::utilities::{AAAtom, aa1to_int, aa3to1, get_nearest_neighbours};
3use crate::ligandmpnn::proteinfeatures::ProteinFeatures;
4use candle_core::{D, DType, Device, IndexOp, Result, Tensor};
5use ferritin_core::AtomCollection;
6use ferritin_core::info::elements::Element;
7use std::collections::HashSet;
8use strum::IntoEnumIterator;
9
10const LIGAND_CUTOFF_SCORE: f32 = 5.;
11
12// Helper Fns --------------------------------------
13fn is_heavy_atom(element: &Element) -> bool {
14    !matches!(element, Element::H | Element::He)
15}
16
17///. Trait defining Protein->Tensor utilities useful for Machine Learning
18pub trait StructureFeatures {
19    /// Convert amino acid sequence to numeric representation
20    fn decode_amino_acids(&self, device: &Device) -> Result<Tensor>;
21
22    /// Convert amino acid sequence to numeric representation
23    fn encode_amino_acids(&self, device: &Device) -> Result<Tensor>;
24
25    /// Convert amino acid sequence to numeric representation
26    fn create_cb(&self, device: &Device) -> Result<Tensor>;
27
28    /// Prepare for ProteinMPNN
29    fn featurize_lmpnn(&self, device: &Device) -> Result<ProteinFeatures>; // need more control over this featurization process
30
31    /// Get residue indices
32    fn get_res_index(&self) -> Vec<u32>;
33
34    /// Extract backbone atom coordinates (N, CA, C, O)
35    fn to_numeric_backbone_atoms(&self, device: &Device) -> Result<Tensor>;
36
37    /// Extract all atom coordinates in standard ordering
38    fn to_numeric_atom37(&self, device: &Device) -> Result<Tensor>;
39
40    /// Extract ligand atom coordinates and properties
41    fn to_numeric_ligand_atoms(&self, device: &Device) -> Result<(Tensor, Tensor, Tensor)>;
42}
43
44impl StructureFeatures for AtomCollection {
45    /// Convert amino acid sequence to numeric representation
46    fn decode_amino_acids(&self, _device: &Device) -> Result<Tensor> {
47        todo!()
48    }
49
50    /// Convert amino acid sequence to numeric representation
51    fn encode_amino_acids(&self, device: &Device) -> Result<Tensor> {
52        let n = self.iter_residues_aminoacid().count();
53        let s = self
54            .iter_residues_aminoacid()
55            .map(|res| res.residue_name().to_string())
56            .map(|res| aa3to1(&res))
57            .map(|res| aa1to_int(res));
58
59        Ok(Tensor::from_iter(s, device)?.reshape((1, n))?)
60    }
61
62    /// Calculate CB for each residue
63    fn create_cb(&self, device: &Device) -> Result<Tensor> {
64        let backbone = self.to_numeric_backbone_atoms(device)?.squeeze(0)?;
65
66        // Extract N, CA, C coordinates
67        let n = backbone.i((.., 0, ..))?;
68        let ca = backbone.i((.., 1, ..))?;
69        let c = backbone.i((.., 2, ..))?;
70
71        // Constants for CB calculation
72        let a_coeff = -0.58273431_f64;
73        let b_coeff = 0.56802827_f64;
74        let c_coeff = -0.54067466_f64;
75
76        // Calculate vectors
77        let b = (&ca - &n)?;
78        let c = (&c - &ca)?;
79
80        // Manual cross product components
81        // a_x = b_y * c_z - b_z * c_y
82        // a_y = b_z * c_x - b_x * c_z
83        // a_z = b_x * c_y - b_y * c_x
84        let b_x = b.i((.., 0))?;
85        let b_y = b.i((.., 1))?;
86        let b_z = b.i((.., 2))?;
87        let c_x = c.i((.., 0))?;
88        let c_y = c.i((.., 1))?;
89        let c_z = c.i((.., 2))?;
90
91        let a_x = ((&b_y * &c_z)? - (&b_z * &c_y)?)?;
92        let a_y = ((&b_z * &c_x)? - (&b_x * &c_z)?)?;
93        let a_z = ((&b_x * &c_y)? - (&b_y * &c_x)?)?;
94        let a = Tensor::stack(&[&a_x, &a_y, &a_z], D::Minus1)?;
95
96        // Final CB calculation: -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + CA
97        let cb = ((&a * a_coeff)? + (&b * b_coeff)? + (&c * c_coeff)? + &ca)?;
98        let cb = cb.unsqueeze(0)?;
99        Ok(cb)
100    }
101
102    // Convert AtomCollection to ProteinFeatures
103    fn featurize_lmpnn(&self, device: &Device) -> Result<ProteinFeatures> {
104        let x_37 = self.to_numeric_atom37(device)?;
105        let x_37_m = Tensor::zeros((x_37.dim(0)?, x_37.dim(1)?), DType::F32, device)?;
106        let (y, y_t, y_m) = self.to_numeric_ligand_atoms(device)?;
107        let cb = self.create_cb(device);
108        let chain_labels = self.get_resids(); //  <-- need to double-check shape. I think this is all-atom
109        let residue_ids = self.get_res_index();
110        let residue_length = residue_ids.len();
111        let r_idx = Tensor::from_iter(residue_ids, device)?.reshape((1, residue_length))?;
112        let chain_letters: Vec<String> = self
113            .iter_residues_aminoacid()
114            .map(|res| res.chain_id().to_string())
115            .collect();
116        let chain_list: Vec<String> = self
117            .iter_residues_aminoacid()
118            .map(|res| res.chain_id().to_string())
119            .collect::<HashSet<_>>()
120            .into_iter()
121            .collect();
122        // Numeric chain labels (optional)
123        let chain_labels: Option<Vec<f64>> = None; // Could populate if needed
124        let s = self.encode_amino_acids(device)?;
125        // coordinates of the backbone atoms
126        let indices = Tensor::from_slice(
127            &[0i64, 1i64, 2i64, 4i64], // index of N/CA/C/O as integers
128            (4,),
129            &device,
130        )?;
131        let x = x_37.index_select(&indices, 2)?;
132        Ok(ProteinFeatures {
133            s,
134            x,
135            x_mask: Some(x_37_m),
136            y,
137            y_t,
138            y_m: Some(y_m),
139            r_idx,
140            chain_labels,
141            chain_letters,
142            mask_c: None,
143            chain_list,
144        })
145    }
146    /// Get residue indices
147    fn get_res_index(&self) -> Vec<u32> {
148        self.iter_residues_aminoacid()
149            .map(|res| res.residue_id() as u32)
150            .collect()
151    }
152
153    /// create numeric Tensor of shape [1, <sequence-length>, 4, 3] where the 4 is N/CA/C/O
154    fn to_numeric_backbone_atoms(&self, device: &Device) -> Result<Tensor> {
155        let res_count = self.iter_residues_aminoacid().count();
156        let mut backbone_data = Vec::with_capacity(res_count * 4 * 3);
157
158        for residue in self.iter_residues_aminoacid() {
159            for atom_name in ["N", "CA", "C", "O"] {
160                if let Some(atom) = residue.find_atom_by_name(atom_name) {
161                    let [x, y, z] = atom.coords();
162                    backbone_data.extend_from_slice(&[*x, *y, *z]);
163                } else {
164                    backbone_data.extend_from_slice(&[0.0, 0.0, 0.0]);
165                }
166            }
167        }
168        Tensor::from_vec(backbone_data, (1, res_count, 4, 3), &device)
169    }
170
171    /// create numeric Tensor of shape [1, <sequence-length>, 37, 3]
172    fn to_numeric_atom37(&self, device: &Device) -> Result<Tensor> {
173        let res_count = self.iter_residues_aminoacid().count();
174        let mut atom37_data = vec![0.0; res_count * 37 * 3];
175        for (res_idx, residue) in self.iter_residues_aminoacid().enumerate() {
176            for atom_type in AAAtom::iter().filter(|&a| a != AAAtom::Unknown) {
177                if let Some(atom) = residue.find_atom_by_name(&atom_type.to_string()) {
178                    let [x, y, z] = atom.coords();
179                    let base_idx = (res_idx * 37 + atom_type as usize) * 3;
180                    atom37_data[base_idx..base_idx + 3].copy_from_slice(&[*x, *y, *z]);
181                }
182            }
183        }
184        Tensor::from_vec(atom37_data, (1, res_count, 37, 3), &device)
185    }
186
187    // The purpose of this function it to create 3 output tensors that relate
188    // key information about a protein sequence and ligands it interacts with.
189    //
190    // The outputs are:
191    //  - y: 4D tensor of dimensions (<batch=1>, <num_residues>, <number_of_ligand_atoms>, <coords=3>)
192    //  - y_t: 1D tensor of dimension = <num_residues>
193    //  - y_m: 3D tensor of dimensions: (<batch=1>, <num_residues>, <number_of_ligand_atoms>))
194    //
195    fn to_numeric_ligand_atoms(&self, device: &Device) -> Result<(Tensor, Tensor, Tensor)> {
196        let mut coords = Vec::new();
197        let mut elements = Vec::new();
198        for residue in self.iter_residues() {
199            let res_name = residue.residue_name();
200            if residue.is_amino_acid() || res_name == "HOH" || res_name == "WAT" {
201                continue;
202            }
203            let atoms: Vec<_> = residue
204                .iter_atoms()
205                .filter(|atom| is_heavy_atom(atom.element()))
206                .collect();
207            for atom in atoms {
208                coords.push(*atom.coords());
209                elements.push(*atom.element());
210            }
211        }
212
213        // raw starting tensors
214        let y = Tensor::from_slice(&coords.concat(), (coords.len(), 3), device)?;
215        let y_m = Tensor::ones_like(&y)?;
216        let y_t = Tensor::from_slice(
217            &elements
218                .iter()
219                .map(|e| e.atomic_number() as f32)
220                .collect::<Vec<_>>(),
221            (elements.len(),),
222            device,
223        )?;
224        let cb = self.create_cb(device)?;
225        let (batch, res_num, _coords) = cb.dims3()?;
226        let (number_of_ligand_atoms, _coords) = y.dims2()?;
227        let mask = Tensor::zeros((batch, res_num), DType::F32, device)?;
228        let (y, y_t, y_m, d_xy) =
229            get_nearest_neighbours(&cb, &mask, &y, &y_t, &y_m, number_of_ligand_atoms as i64)?;
230        let distance_mask = d_xy.lt(LIGAND_CUTOFF_SCORE)?.to_dtype(DType::F32)?;
231        let y_m_first = y_m.i((.., 0))?;
232        let mask = mask.squeeze(0)?;
233        let _mask_xy = distance_mask.mul(&mask)?.mul(&y_m_first)?;
234        let y = y.unsqueeze(0)?;
235        let y_t = y_t.to_dtype(DType::I64)?.unsqueeze(0)?;
236        let y_m = y_m.unsqueeze(0)?;
237        Ok((y, y_t, y_m))
238    }
239}