1use super::utilities::{AAAtom, aa1to_int, aa3to1, get_nearest_neighbours};
3use crate::ligandmpnn::proteinfeatures::ProteinFeatures;
4use candle_core::{D, DType, Device, IndexOp, Result, Tensor};
5use ferritin_core::AtomCollection;
6use ferritin_core::info::elements::Element;
7use std::collections::HashSet;
8use strum::IntoEnumIterator;
9
10const LIGAND_CUTOFF_SCORE: f32 = 5.;
11
12fn is_heavy_atom(element: &Element) -> bool {
14 !matches!(element, Element::H | Element::He)
15}
16
17pub trait StructureFeatures {
19 fn decode_amino_acids(&self, device: &Device) -> Result<Tensor>;
21
22 fn encode_amino_acids(&self, device: &Device) -> Result<Tensor>;
24
25 fn create_cb(&self, device: &Device) -> Result<Tensor>;
27
28 fn featurize_lmpnn(&self, device: &Device) -> Result<ProteinFeatures>; fn get_res_index(&self) -> Vec<u32>;
33
34 fn to_numeric_backbone_atoms(&self, device: &Device) -> Result<Tensor>;
36
37 fn to_numeric_atom37(&self, device: &Device) -> Result<Tensor>;
39
40 fn to_numeric_ligand_atoms(&self, device: &Device) -> Result<(Tensor, Tensor, Tensor)>;
42}
43
44impl StructureFeatures for AtomCollection {
45 fn decode_amino_acids(&self, _device: &Device) -> Result<Tensor> {
47 todo!()
48 }
49
50 fn encode_amino_acids(&self, device: &Device) -> Result<Tensor> {
52 let n = self.iter_residues_aminoacid().count();
53 let s = self
54 .iter_residues_aminoacid()
55 .map(|res| res.residue_name().to_string())
56 .map(|res| aa3to1(&res))
57 .map(|res| aa1to_int(res));
58
59 Ok(Tensor::from_iter(s, device)?.reshape((1, n))?)
60 }
61
62 fn create_cb(&self, device: &Device) -> Result<Tensor> {
64 let backbone = self.to_numeric_backbone_atoms(device)?.squeeze(0)?;
65
66 let n = backbone.i((.., 0, ..))?;
68 let ca = backbone.i((.., 1, ..))?;
69 let c = backbone.i((.., 2, ..))?;
70
71 let a_coeff = -0.58273431_f64;
73 let b_coeff = 0.56802827_f64;
74 let c_coeff = -0.54067466_f64;
75
76 let b = (&ca - &n)?;
78 let c = (&c - &ca)?;
79
80 let b_x = b.i((.., 0))?;
85 let b_y = b.i((.., 1))?;
86 let b_z = b.i((.., 2))?;
87 let c_x = c.i((.., 0))?;
88 let c_y = c.i((.., 1))?;
89 let c_z = c.i((.., 2))?;
90
91 let a_x = ((&b_y * &c_z)? - (&b_z * &c_y)?)?;
92 let a_y = ((&b_z * &c_x)? - (&b_x * &c_z)?)?;
93 let a_z = ((&b_x * &c_y)? - (&b_y * &c_x)?)?;
94 let a = Tensor::stack(&[&a_x, &a_y, &a_z], D::Minus1)?;
95
96 let cb = ((&a * a_coeff)? + (&b * b_coeff)? + (&c * c_coeff)? + &ca)?;
98 let cb = cb.unsqueeze(0)?;
99 Ok(cb)
100 }
101
102 fn featurize_lmpnn(&self, device: &Device) -> Result<ProteinFeatures> {
104 let x_37 = self.to_numeric_atom37(device)?;
105 let x_37_m = Tensor::zeros((x_37.dim(0)?, x_37.dim(1)?), DType::F32, device)?;
106 let (y, y_t, y_m) = self.to_numeric_ligand_atoms(device)?;
107 let cb = self.create_cb(device);
108 let chain_labels = self.get_resids(); let residue_ids = self.get_res_index();
110 let residue_length = residue_ids.len();
111 let r_idx = Tensor::from_iter(residue_ids, device)?.reshape((1, residue_length))?;
112 let chain_letters: Vec<String> = self
113 .iter_residues_aminoacid()
114 .map(|res| res.chain_id().to_string())
115 .collect();
116 let chain_list: Vec<String> = self
117 .iter_residues_aminoacid()
118 .map(|res| res.chain_id().to_string())
119 .collect::<HashSet<_>>()
120 .into_iter()
121 .collect();
122 let chain_labels: Option<Vec<f64>> = None; let s = self.encode_amino_acids(device)?;
125 let indices = Tensor::from_slice(
127 &[0i64, 1i64, 2i64, 4i64], (4,),
129 &device,
130 )?;
131 let x = x_37.index_select(&indices, 2)?;
132 Ok(ProteinFeatures {
133 s,
134 x,
135 x_mask: Some(x_37_m),
136 y,
137 y_t,
138 y_m: Some(y_m),
139 r_idx,
140 chain_labels,
141 chain_letters,
142 mask_c: None,
143 chain_list,
144 })
145 }
146 fn get_res_index(&self) -> Vec<u32> {
148 self.iter_residues_aminoacid()
149 .map(|res| res.residue_id() as u32)
150 .collect()
151 }
152
153 fn to_numeric_backbone_atoms(&self, device: &Device) -> Result<Tensor> {
155 let res_count = self.iter_residues_aminoacid().count();
156 let mut backbone_data = Vec::with_capacity(res_count * 4 * 3);
157
158 for residue in self.iter_residues_aminoacid() {
159 for atom_name in ["N", "CA", "C", "O"] {
160 if let Some(atom) = residue.find_atom_by_name(atom_name) {
161 let [x, y, z] = atom.coords();
162 backbone_data.extend_from_slice(&[*x, *y, *z]);
163 } else {
164 backbone_data.extend_from_slice(&[0.0, 0.0, 0.0]);
165 }
166 }
167 }
168 Tensor::from_vec(backbone_data, (1, res_count, 4, 3), &device)
169 }
170
171 fn to_numeric_atom37(&self, device: &Device) -> Result<Tensor> {
173 let res_count = self.iter_residues_aminoacid().count();
174 let mut atom37_data = vec![0.0; res_count * 37 * 3];
175 for (res_idx, residue) in self.iter_residues_aminoacid().enumerate() {
176 for atom_type in AAAtom::iter().filter(|&a| a != AAAtom::Unknown) {
177 if let Some(atom) = residue.find_atom_by_name(&atom_type.to_string()) {
178 let [x, y, z] = atom.coords();
179 let base_idx = (res_idx * 37 + atom_type as usize) * 3;
180 atom37_data[base_idx..base_idx + 3].copy_from_slice(&[*x, *y, *z]);
181 }
182 }
183 }
184 Tensor::from_vec(atom37_data, (1, res_count, 37, 3), &device)
185 }
186
187 fn to_numeric_ligand_atoms(&self, device: &Device) -> Result<(Tensor, Tensor, Tensor)> {
196 let mut coords = Vec::new();
197 let mut elements = Vec::new();
198 for residue in self.iter_residues() {
199 let res_name = residue.residue_name();
200 if residue.is_amino_acid() || res_name == "HOH" || res_name == "WAT" {
201 continue;
202 }
203 let atoms: Vec<_> = residue
204 .iter_atoms()
205 .filter(|atom| is_heavy_atom(atom.element()))
206 .collect();
207 for atom in atoms {
208 coords.push(*atom.coords());
209 elements.push(*atom.element());
210 }
211 }
212
213 let y = Tensor::from_slice(&coords.concat(), (coords.len(), 3), device)?;
215 let y_m = Tensor::ones_like(&y)?;
216 let y_t = Tensor::from_slice(
217 &elements
218 .iter()
219 .map(|e| e.atomic_number() as f32)
220 .collect::<Vec<_>>(),
221 (elements.len(),),
222 device,
223 )?;
224 let cb = self.create_cb(device)?;
225 let (batch, res_num, _coords) = cb.dims3()?;
226 let (number_of_ligand_atoms, _coords) = y.dims2()?;
227 let mask = Tensor::zeros((batch, res_num), DType::F32, device)?;
228 let (y, y_t, y_m, d_xy) =
229 get_nearest_neighbours(&cb, &mask, &y, &y_t, &y_m, number_of_ligand_atoms as i64)?;
230 let distance_mask = d_xy.lt(LIGAND_CUTOFF_SCORE)?.to_dtype(DType::F32)?;
231 let y_m_first = y_m.i((.., 0))?;
232 let mask = mask.squeeze(0)?;
233 let _mask_xy = distance_mask.mul(&mask)?.mul(&y_m_first)?;
234 let y = y.unsqueeze(0)?;
235 let y_t = y_t.to_dtype(DType::I64)?.unsqueeze(0)?;
236 let y_m = y_m.unsqueeze(0)?;
237 Ok((y, y_t, y_m))
238 }
239}