1use super::utilities::{AAAtom, aa1to_int, aa3to1, get_nearest_neighbours};
3use crate::ligandmpnn::proteinfeatures::ProteinFeatures;
4use candle_core::{D, DType, Device, IndexOp, Result, Tensor};
5use ferritin_core::AtomCollection;
6use ferritin_core::info::elements::Element;
7use itertools::MultiUnzip;
8use std::collections::HashSet;
9use strum::IntoEnumIterator;
10
11fn is_heavy_atom(element: &Element) -> bool {
13 !matches!(element, Element::H | Element::He)
14}
15
16pub trait StructureFeatures {
18 fn decode_amino_acids(&self, device: &Device) -> Result<Tensor>;
20
21 fn encode_amino_acids(&self, device: &Device) -> Result<Tensor>;
23
24 fn create_cb(&self, device: &Device) -> Result<Tensor>;
26
27 fn featurize_lmpnn(&self, device: &Device) -> Result<ProteinFeatures>; fn get_res_index(&self) -> Vec<u32>;
32
33 fn to_numeric_backbone_atoms(&self, device: &Device) -> Result<Tensor>;
35
36 fn to_numeric_atom37(&self, device: &Device) -> Result<Tensor>;
38
39 fn to_numeric_ligand_atoms(&self, device: &Device) -> Result<(Tensor, Tensor, Tensor)>;
41}
42
43impl StructureFeatures for AtomCollection {
44 fn decode_amino_acids(&self, _device: &Device) -> Result<Tensor> {
46 todo!()
47 }
48
49 fn encode_amino_acids(&self, device: &Device) -> Result<Tensor> {
51 let n = self.iter_residues_aminoacid().count();
52 let s = self
53 .iter_residues_aminoacid()
54 .map(|res| res.residue_name().to_string())
55 .map(|res| aa3to1(&res))
56 .map(|res| aa1to_int(res));
57
58 Ok(Tensor::from_iter(s, device)?.reshape((1, n))?)
59 }
60
61 fn create_cb(&self, device: &Device) -> Result<Tensor> {
63 let backbone = self.to_numeric_backbone_atoms(device)?;
72 let backbone = backbone.squeeze(0)?; let n = backbone.i((.., 0, ..))?;
76 let ca = backbone.i((.., 1, ..))?;
77 let c = backbone.i((.., 2, ..))?;
78
79 let a_coeff = -0.58273431_f64;
81 let b_coeff = 0.56802827_f64;
82 let c_coeff = -0.54067466_f64;
83
84 let b = (&ca - &n)?; let c = (&c - &ca)?; let b_x = b.i((.., 0))?;
93 let b_y = b.i((.., 1))?;
94 let b_z = b.i((.., 2))?;
95 let c_x = c.i((.., 0))?;
96 let c_y = c.i((.., 1))?;
97 let c_z = c.i((.., 2))?;
98
99 let a_x = ((&b_y * &c_z)? - (&b_z * &c_y)?)?;
100 let a_y = ((&b_z * &c_x)? - (&b_x * &c_z)?)?;
101 let a_z = ((&b_x * &c_y)? - (&b_y * &c_x)?)?;
102 let a = Tensor::stack(&[&a_x, &a_y, &a_z], D::Minus1)?;
103
104 let cb = ((&a * a_coeff)? + (&b * b_coeff)? + (&c * c_coeff)? + &ca)?;
106 let cb = cb.unsqueeze(0)?;
107 Ok(cb)
108 }
109
110 fn featurize_lmpnn(&self, device: &Device) -> Result<ProteinFeatures> {
112 let x_37 = self.to_numeric_atom37(device)?;
113 let x_37_m = Tensor::zeros((x_37.dim(0)?, x_37.dim(1)?), DType::F32, device)?;
114 let (y, y_t, y_m) = self.to_numeric_ligand_atoms(device)?;
115 let cb = self.create_cb(device);
116 let chain_labels = self.get_resids(); let residue_ids = self.get_res_index();
118 let residue_length = residue_ids.len();
119 let r_idx = Tensor::from_iter(residue_ids, device)?.reshape((1, residue_length))?;
120 let chain_letters: Vec<String> = self
121 .iter_residues_aminoacid()
122 .map(|res| res.chain_id().to_string())
123 .collect();
124 let chain_list: Vec<String> = chain_letters
125 .clone()
126 .into_iter()
127 .collect::<HashSet<_>>()
128 .into_iter()
129 .collect();
130 let chain_labels: Option<Vec<f64>> = None; let s = self.encode_amino_acids(device)?;
133 let indices = Tensor::from_slice(
135 &[0i64, 1i64, 2i64, 4i64], (4,),
137 &device,
138 )?;
139 let x = x_37.index_select(&indices, 2)?;
140 Ok(ProteinFeatures {
141 s,
142 x,
143 x_mask: Some(x_37_m),
144 y,
145 y_t,
146 y_m: Some(y_m),
147 r_idx,
148 chain_labels,
149 chain_letters,
150 mask_c: None,
151 chain_list,
152 })
153 }
154 fn get_res_index(&self) -> Vec<u32> {
156 self.iter_residues_aminoacid()
157 .map(|res| res.residue_id() as u32)
158 .collect()
159 }
160
161 fn to_numeric_backbone_atoms(&self, device: &Device) -> Result<Tensor> {
163 let res_count = self.iter_residues_aminoacid().count();
164 let mut backbone_data = vec![0f32; res_count * 4 * 3];
165 for residue in self.iter_residues_aminoacid() {
166 let resid = residue.residue_id() as usize;
167 let backbone_atoms = [
168 residue.find_atom_by_name("N"),
169 residue.find_atom_by_name("CA"),
170 residue.find_atom_by_name("C"),
171 residue.find_atom_by_name("O"),
172 ];
173 for (atom_idx, maybe_atom) in backbone_atoms.iter().enumerate() {
174 if let Some(atom) = maybe_atom {
175 let [x, y, z] = atom.coords();
176 let base_idx = (resid * 4 + atom_idx) * 3;
177 backbone_data[base_idx] = *x;
178 backbone_data[base_idx + 1] = *y;
179 backbone_data[base_idx + 2] = *z;
180 }
181 }
182 }
183 Tensor::from_vec(backbone_data, (1, res_count, 4, 3), &device)
185 }
186
187 fn to_numeric_atom37(&self, device: &Device) -> Result<Tensor> {
189 let res_count = self.iter_residues_aminoacid().count();
190 let mut atom37_data = vec![0f32; res_count * 37 * 3];
191 for (idx, residue) in self.iter_residues_aminoacid().enumerate() {
192 for atom_type in AAAtom::iter().filter(|&a| a != AAAtom::Unknown) {
193 if let Some(atom) = residue.find_atom_by_name(&atom_type.to_string()) {
194 let [x, y, z] = atom.coords();
195 let base_idx = (idx * 37 + atom_type as usize) * 3;
196 atom37_data[base_idx] = *x;
197 atom37_data[base_idx + 1] = *y;
198 atom37_data[base_idx + 2] = *z;
199 }
200 }
201 }
202 Tensor::from_vec(atom37_data, (1, res_count, 37, 3), &device)
204 }
205
206 fn to_numeric_ligand_atoms(&self, device: &Device) -> Result<(Tensor, Tensor, Tensor)> {
215 let (coords, elements): (Vec<[f32; 3]>, Vec<Element>) = self
216 .iter_residues()
217 .filter(|residue| {
218 let res_name = &residue.residue_name();
219 !residue.is_amino_acid() && *res_name != "HOH" && *res_name != "WAT"
220 })
221 .flat_map(|residue| {
222 residue
223 .iter_atoms()
224 .filter(|atom| is_heavy_atom(atom.element()))
225 .map(|atom| (*atom.coords(), *atom.element()))
226 .collect::<Vec<_>>()
227 })
228 .multiunzip();
229 let y = Tensor::from_slice(&coords.concat(), (coords.len(), 3), device)?;
230 let y_t = Tensor::from_slice(
231 &elements
232 .iter()
233 .map(|e| e.atomic_number() as f32)
234 .collect::<Vec<_>>(),
235 (elements.len(),),
236 device,
237 )?;
238 let y_m = Tensor::ones_like(&y)?;
239 Ok((y, y_t, y_m))
240 }
241}