ferritin_plms/featurize/
structure_features.rs

1//!  Protein->Tensor utilities useful for Machine Learning
2use super::utilities::{AAAtom, aa1to_int, aa3to1, get_nearest_neighbours};
3use crate::ligandmpnn::proteinfeatures::ProteinFeatures;
4use candle_core::{D, DType, Device, IndexOp, Result, Tensor};
5use ferritin_core::AtomCollection;
6use ferritin_core::info::elements::Element;
7use itertools::MultiUnzip;
8use std::collections::HashSet;
9use strum::IntoEnumIterator;
10
11// Helper Fns --------------------------------------
12fn is_heavy_atom(element: &Element) -> bool {
13    !matches!(element, Element::H | Element::He)
14}
15
16///. Trait defining Protein->Tensor utilities useful for Machine Learning
17pub trait StructureFeatures {
18    /// Convert amino acid sequence to numeric representation
19    fn decode_amino_acids(&self, device: &Device) -> Result<Tensor>;
20
21    /// Convert amino acid sequence to numeric representation
22    fn encode_amino_acids(&self, device: &Device) -> Result<Tensor>;
23
24    /// Convert amino acid sequence to numeric representation
25    fn create_cb(&self, device: &Device) -> Result<Tensor>;
26
27    /// Prepare for ProteinMPNN
28    fn featurize_lmpnn(&self, device: &Device) -> Result<ProteinFeatures>; // need more control over this featurization process
29
30    /// Get residue indices
31    fn get_res_index(&self) -> Vec<u32>;
32
33    /// Extract backbone atom coordinates (N, CA, C, O)
34    fn to_numeric_backbone_atoms(&self, device: &Device) -> Result<Tensor>;
35
36    /// Extract all atom coordinates in standard ordering
37    fn to_numeric_atom37(&self, device: &Device) -> Result<Tensor>;
38
39    /// Extract ligand atom coordinates and properties
40    fn to_numeric_ligand_atoms(&self, device: &Device) -> Result<(Tensor, Tensor, Tensor)>;
41}
42
43impl StructureFeatures for AtomCollection {
44    /// Convert amino acid sequence to numeric representation
45    fn decode_amino_acids(&self, _device: &Device) -> Result<Tensor> {
46        todo!()
47    }
48
49    /// Convert amino acid sequence to numeric representation
50    fn encode_amino_acids(&self, device: &Device) -> Result<Tensor> {
51        let n = self.iter_residues_aminoacid().count();
52        let s = self
53            .iter_residues_aminoacid()
54            .map(|res| res.residue_name().to_string())
55            .map(|res| aa3to1(&res))
56            .map(|res| aa1to_int(res));
57
58        Ok(Tensor::from_iter(s, device)?.reshape((1, n))?)
59    }
60
61    /// Calculate CB for each residue
62    fn create_cb(&self, device: &Device) -> Result<Tensor> {
63        // N = input_dict["X"][:, 0, :]
64        //         CA = input_dict["X"][:, 1, :]
65        //         C = input_dict["X"][:, 2, :]
66        //         b = CA - N
67        //         c = C - CA
68        //         a = torch.cross(b, c, axis=-1)
69        //         CB = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + CA
70        //
71        let backbone = self.to_numeric_backbone_atoms(device)?;
72        let backbone = backbone.squeeze(0)?; // remove batch dim for calc
73
74        // Extract N, CA, C coordinates
75        let n = backbone.i((.., 0, ..))?;
76        let ca = backbone.i((.., 1, ..))?;
77        let c = backbone.i((.., 2, ..))?;
78
79        // Constants for CB calculation
80        let a_coeff = -0.58273431_f64;
81        let b_coeff = 0.56802827_f64;
82        let c_coeff = -0.54067466_f64;
83
84        // Calculate vectors
85        let b = (&ca - &n)?; // CA - N
86        let c = (&c - &ca)?; // C - CA
87
88        // Manual cross product components
89        // a_x = b_y * c_z - b_z * c_y
90        // a_y = b_z * c_x - b_x * c_z
91        // a_z = b_x * c_y - b_y * c_x
92        let b_x = b.i((.., 0))?;
93        let b_y = b.i((.., 1))?;
94        let b_z = b.i((.., 2))?;
95        let c_x = c.i((.., 0))?;
96        let c_y = c.i((.., 1))?;
97        let c_z = c.i((.., 2))?;
98
99        let a_x = ((&b_y * &c_z)? - (&b_z * &c_y)?)?;
100        let a_y = ((&b_z * &c_x)? - (&b_x * &c_z)?)?;
101        let a_z = ((&b_x * &c_y)? - (&b_y * &c_x)?)?;
102        let a = Tensor::stack(&[&a_x, &a_y, &a_z], D::Minus1)?;
103
104        // Final CB calculation: -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + CA
105        let cb = ((&a * a_coeff)? + (&b * b_coeff)? + (&c * c_coeff)? + &ca)?;
106        let cb = cb.unsqueeze(0)?;
107        Ok(cb)
108    }
109
110    // Convert AtomCollection to ProteinFeatures
111    fn featurize_lmpnn(&self, device: &Device) -> Result<ProteinFeatures> {
112        let x_37 = self.to_numeric_atom37(device)?;
113        let x_37_m = Tensor::zeros((x_37.dim(0)?, x_37.dim(1)?), DType::F32, device)?;
114        let (y, y_t, y_m) = self.to_numeric_ligand_atoms(device)?;
115        let cb = self.create_cb(device);
116        let chain_labels = self.get_resids(); //  <-- need to double-check shape. I think this is all-atom
117        let residue_ids = self.get_res_index();
118        let residue_length = residue_ids.len();
119        let r_idx = Tensor::from_iter(residue_ids, device)?.reshape((1, residue_length))?;
120        let chain_letters: Vec<String> = self
121            .iter_residues_aminoacid()
122            .map(|res| res.chain_id().to_string())
123            .collect();
124        let chain_list: Vec<String> = chain_letters
125            .clone()
126            .into_iter()
127            .collect::<HashSet<_>>()
128            .into_iter()
129            .collect();
130        // Numeric chain labels (optional)
131        let chain_labels: Option<Vec<f64>> = None; // Could populate if needed
132        let s = self.encode_amino_acids(device)?;
133        // coordinates of the backbone atoms
134        let indices = Tensor::from_slice(
135            &[0i64, 1i64, 2i64, 4i64], // index of N/CA/C/O as integers
136            (4,),
137            &device,
138        )?;
139        let x = x_37.index_select(&indices, 2)?;
140        Ok(ProteinFeatures {
141            s,
142            x,
143            x_mask: Some(x_37_m),
144            y,
145            y_t,
146            y_m: Some(y_m),
147            r_idx,
148            chain_labels,
149            chain_letters,
150            mask_c: None,
151            chain_list,
152        })
153    }
154    /// Get residue indices
155    fn get_res_index(&self) -> Vec<u32> {
156        self.iter_residues_aminoacid()
157            .map(|res| res.residue_id() as u32)
158            .collect()
159    }
160
161    /// create numeric Tensor of shape [1, <sequence-length>, 4, 3] where the 4 is N/CA/C/O
162    fn to_numeric_backbone_atoms(&self, device: &Device) -> Result<Tensor> {
163        let res_count = self.iter_residues_aminoacid().count();
164        let mut backbone_data = vec![0f32; res_count * 4 * 3];
165        for residue in self.iter_residues_aminoacid() {
166            let resid = residue.residue_id() as usize;
167            let backbone_atoms = [
168                residue.find_atom_by_name("N"),
169                residue.find_atom_by_name("CA"),
170                residue.find_atom_by_name("C"),
171                residue.find_atom_by_name("O"),
172            ];
173            for (atom_idx, maybe_atom) in backbone_atoms.iter().enumerate() {
174                if let Some(atom) = maybe_atom {
175                    let [x, y, z] = atom.coords();
176                    let base_idx = (resid * 4 + atom_idx) * 3;
177                    backbone_data[base_idx] = *x;
178                    backbone_data[base_idx + 1] = *y;
179                    backbone_data[base_idx + 2] = *z;
180                }
181            }
182        }
183        // Create tensor with shape [1,residues, 4, 3]
184        Tensor::from_vec(backbone_data, (1, res_count, 4, 3), &device)
185    }
186
187    /// create numeric Tensor of shape [1, <sequence-length>, 37, 3]
188    fn to_numeric_atom37(&self, device: &Device) -> Result<Tensor> {
189        let res_count = self.iter_residues_aminoacid().count();
190        let mut atom37_data = vec![0f32; res_count * 37 * 3];
191        for (idx, residue) in self.iter_residues_aminoacid().enumerate() {
192            for atom_type in AAAtom::iter().filter(|&a| a != AAAtom::Unknown) {
193                if let Some(atom) = residue.find_atom_by_name(&atom_type.to_string()) {
194                    let [x, y, z] = atom.coords();
195                    let base_idx = (idx * 37 + atom_type as usize) * 3;
196                    atom37_data[base_idx] = *x;
197                    atom37_data[base_idx + 1] = *y;
198                    atom37_data[base_idx + 2] = *z;
199                }
200            }
201        }
202        // Create tensor with shape [batch, residues, 37, 3]
203        Tensor::from_vec(atom37_data, (1, res_count, 37, 3), &device)
204    }
205
206    // The purpose of this function it to create 3 output tensors that relate
207    // key information about a protein sequence and ligands it interacts with.
208    //
209    // The outputs are:
210    //  - y: 4D tensor of dimensions (<batch=1>, <num_residues>, <number_of_ligand_atoms>, <coords=3>)
211    //  - y_t: 1D tensor of dimension = <num_residues>
212    //  - y_m: 3D tensor of dimensions: (<batch=1>, <num_residues>, <number_of_ligand_atoms>))
213    //
214    fn to_numeric_ligand_atoms(&self, device: &Device) -> Result<(Tensor, Tensor, Tensor)> {
215        let (coords, elements): (Vec<[f32; 3]>, Vec<Element>) = self
216            .iter_residues()
217            .filter(|residue| {
218                let res_name = &residue.residue_name();
219                !residue.is_amino_acid() && *res_name != "HOH" && *res_name != "WAT"
220            })
221            .flat_map(|residue| {
222                residue
223                    .iter_atoms()
224                    .filter(|atom| is_heavy_atom(atom.element()))
225                    .map(|atom| (*atom.coords(), *atom.element()))
226                    .collect::<Vec<_>>()
227            })
228            .multiunzip();
229        let y = Tensor::from_slice(&coords.concat(), (coords.len(), 3), device)?;
230        let y_t = Tensor::from_slice(
231            &elements
232                .iter()
233                .map(|e| e.atomic_number() as f32)
234                .collect::<Vec<_>>(),
235            (elements.len(),),
236            device,
237        )?;
238        let y_m = Tensor::ones_like(&y)?;
239        Ok((y, y_t, y_m))
240    }
241}