1use candle_core::{D, DType, Device, IndexOp, Result, Tensor};
2use candle_nn::encoding::one_hot;
3use strum::{Display, EnumIter, EnumString};
4
5#[rustfmt::skip]
6pub fn aa3to1(aa: &str) -> char {
8 match aa {
9 "ALA" => 'A', "CYS" => 'C', "ASP" => 'D',
10 "GLU" => 'E', "PHE" => 'F', "GLY" => 'G',
11 "HIS" => 'H', "ILE" => 'I', "LYS" => 'K',
12 "LEU" => 'L', "MET" => 'M', "ASN" => 'N',
13 "PRO" => 'P', "GLN" => 'Q', "ARG" => 'R',
14 "SER" => 'S', "THR" => 'T', "VAL" => 'V',
15 "TRP" => 'W', "TYR" => 'Y', _ => 'X',
16 }
17}
18
19#[rustfmt::skip]
20pub fn aa1to_int(aa: char) -> u32 {
22 match aa {
23 'A' => 0, 'C' => 1, 'D' => 2,
24 'E' => 3, 'F' => 4, 'G' => 5,
25 'H' => 6, 'I' => 7, 'K' => 8,
26 'L' => 9, 'M' => 10, 'N' => 11,
27 'P' => 12, 'Q' => 13, 'R' => 14,
28 'S' => 15, 'T' => 16, 'V' => 17,
29 'W' => 18, 'Y' => 19, _ => 20,
30 }
31}
32
33#[rustfmt::skip]
34pub fn int_to_aa1(aa_int: u32) -> char {
35 match aa_int {
36 0 => 'A', 1 => 'C', 2 => 'D',
37 3 => 'E', 4 => 'F', 5 => 'G',
38 6 => 'H', 7 => 'I', 8 => 'K',
39 9 => 'L', 10 => 'M', 11 => 'N',
40 12 => 'P', 13 => 'Q', 14 => 'R',
41 15 => 'S', 16 => 'T', 17 => 'V',
42 18 => 'W', 19 => 'Y', 20 => 'X',
43 _ => 'X'
44
45 }
46}
47
48const ALPHABET: [char; 21] = [
49 'A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W',
50 'Y', 'X',
51];
52
53const ELEMENT_LIST: [&str; 118] = [
54 "H", "He", "Li", "Be", "B", "C", "N", "O", "F", "Ne", "Na", "Mg", "Al", "Si", "P", "S", "Cl",
55 "Ar", "K", "Ca", "Sc", "Ti", "V", "Cr", "Mn", "Fe", "Co", "Ni", "Cu", "Zn", "Ga", "Ge", "As",
56 "Se", "Br", "Kr", "Rb", "Sr", "Y", "Zr", "Nb", "Mo", "Tc", "Ru", "Rh", "Pd", "Ag", "Cd", "In",
57 "Sn", "Sb", "Te", "I", "Xe", "Cs", "Ba", "La", "Ce", "Pr", "Nd", "Pm", "Sm", "Eu", "Gd", "Tb",
58 "Dy", "Ho", "Er", "Tm", "Yb", "Lu", "Hf", "Ta", "W", "Re", "Os", "Ir", "Pt", "Au", "Hg", "Tl",
59 "Pb", "Bi", "Po", "At", "Rn", "Fr", "Ra", "Ac", "Th", "Pa", "U", "Np", "Pu", "Am", "Cm", "Bk",
60 "Cf", "Es", "Fm", "Md", "No", "Lr", "Rf", "Db", "Sg", "Bh", "Hs", "Mt", "Ds", "Rg", "Cn", "Nh",
61 "Fl", "Mc", "Lv", "Ts", "Og",
62];
63
64#[rustfmt::skip]
65#[derive(Debug, Clone, Copy, PartialEq, Display, EnumString, EnumIter)]
66pub enum AAAtom {
67 N = 0, CA = 1, C = 2, CB = 3, O = 4,
68 CG = 5, CG1 = 6, CG2 = 7, OG = 8, OG1 = 9,
69 SG = 10, CD = 11, CD1 = 12, CD2 = 13, ND1 = 14,
70 ND2 = 15, OD1 = 16, OD2 = 17, SD = 18, CE = 19,
71 CE1 = 20, CE2 = 21, CE3 = 22, NE = 23, NE1 = 24,
72 NE2 = 25, OE1 = 26, OE2 = 27, CH2 = 28, NH1 = 29,
73 NH2 = 30, OH = 31, CZ = 32, CZ2 = 33, CZ3 = 34,
74 NZ = 35, OXT = 36,
75 Unknown = -1,
76}
77impl AAAtom {
78 pub fn to_index(&self) -> usize {
80 *self as usize
81 }
82}
83
84macro_rules! define_residues {
85 ($($name:ident: $code3:expr_2021, $code1:expr_2021, $idx:expr_2021, $features:expr_2021, $atoms14:expr_2021),* $(,)?) => {
86 #[derive(Debug, Copy, Clone)]
87 pub enum Residue {
88 $($name),*
89 }
90
91 impl Residue {
92 pub const fn code3(&self) -> &'static str {
93 match self {
94 $(Self::$name => $code3),*
95 }
96 }
97 pub const fn code1(&self) -> char {
98 match self {
99 $(Self::$name => $code1),*
100 }
101 }
102 pub const fn atoms14(&self) -> [AAAtom; 14] {
103 match self {
104 $(Self::$name => $atoms14),*
105 }
106 }
107 pub fn from_int(value: i32) -> Self {
108 match value {
109 $($idx => Self::$name,)*
110 _ => Self::UNK
111 }
112 }
113 pub fn to_int(&self) -> i32 {
114 match self {
115 $(Self::$name => $idx),*
116 }
117 }
118 }
119 }
120}
121
122define_residues! {
123 ALA: "ALA", 'A', 0, [1.0, 0.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::CB, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown],
124 CYS: "CYS", 'C', 1, [0.0, 0.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::CB, AAAtom::SG, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown],
125 ASP: "ASP", 'D', 2, [0.0, 0.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::CB, AAAtom::CG, AAAtom::OD1, AAAtom::OD2, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown],
126 GLU: "GLU", 'E', 3, [0.0, 0.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::CB, AAAtom::CG, AAAtom::CD, AAAtom::OE1, AAAtom::OE2, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown],
127 PHE: "PHE", 'F', 4, [0.0, 0.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::CB, AAAtom::CG, AAAtom::CD1, AAAtom::CD2, AAAtom::CE1, AAAtom::CE2, AAAtom::CZ, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown],
128 GLY: "GLY", 'G', 5, [0.0, 0.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown],
129 HIS: "HIS", 'H', 6, [0.0, 0.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::CB, AAAtom::CG, AAAtom::ND1, AAAtom::CD2, AAAtom::CE1, AAAtom::NE2, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown],
130 ILE: "ILE", 'I', 7, [0.0, 0.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::CB, AAAtom::CG1, AAAtom::CG2, AAAtom::CD1, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown],
131 LYS: "LYS", 'K', 8, [0.0, 0.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::CB, AAAtom::CG, AAAtom::CD, AAAtom::CE, AAAtom::NZ, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown],
132 LEU: "LEU", 'L', 9, [0.0, 0.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::CB, AAAtom::CG, AAAtom::CD1, AAAtom::CD2, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown],
133 MET: "MET", 'M', 10, [0.0, 0.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::CB, AAAtom::CG, AAAtom::SD, AAAtom::CE, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown],
134 ASN: "ASN", 'N', 11, [0.0, 0.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::CB, AAAtom::CG, AAAtom::OD1, AAAtom::ND2, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown],
135 PRO: "PRO", 'P', 12, [0.0, 0.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::CB, AAAtom::CG, AAAtom::CD, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown],
136 GLN: "GLN", 'Q', 13, [0.0, 0.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::CB, AAAtom::CG, AAAtom::CD, AAAtom::OE1, AAAtom::NE2, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown],
137 ARG: "ARG", 'R', 14, [0.0, 1.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::CB, AAAtom::CG, AAAtom::CD, AAAtom::NE, AAAtom::CZ, AAAtom::NH1, AAAtom::NH2, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown],
138 SER: "SER", 'S', 15, [0.0, 0.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::CB, AAAtom::OG, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown],
139 THR: "THR", 'T', 16, [0.0, 0.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::CB, AAAtom::OG1, AAAtom::CG2, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown],
140 VAL: "VAL", 'V', 17, [0.0, 0.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::CB, AAAtom::CG1, AAAtom::CG2, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown],
141 TRP: "TRP", 'W', 18, [0.0, 0.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::CB, AAAtom::CG, AAAtom::CD1, AAAtom::CD2, AAAtom::CE2, AAAtom::CE3, AAAtom::NE1, AAAtom::CZ2, AAAtom::CZ3, AAAtom::CH2],
142 TYR: "TYR", 'Y', 19, [0.0, 0.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::CB, AAAtom::CG, AAAtom::CD1, AAAtom::CD2, AAAtom::CE1, AAAtom::CE2, AAAtom::CZ, AAAtom::OH, AAAtom::Unknown, AAAtom::Unknown],
143 UNK: "UNK", 'X', 20, [0.0, 0.0], [AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown],
144}
145
146pub fn get_nearest_neighbours(
147 cb: &Tensor,
148 mask: &Tensor,
149 y: &Tensor,
150 y_t: &Tensor,
151 y_m: &Tensor,
152 number_of_ligand_atoms: i64,
153) -> Result<(Tensor, Tensor, Tensor, Tensor)> {
154 let cb = cb.squeeze(0)?;
156 let mask = mask.squeeze(0)?;
157 let y_m = if y_m.dims().len() > 1 {
158 y_m.sum_keepdim(1)?.squeeze(1)? } else {
160 y_m.clone()
161 };
162 let num_residues = cb.dim(0)?;
163 let mask_cby = mask.unsqueeze(1)?.matmul(&y_m.unsqueeze(0)?)?;
164 let cb_flat = cb.reshape((cb.dim(0)?, 1, 3))?; let y_flat = y.reshape((1, y.dim(0)?, 3))?; let cb_broadcast = cb_flat.broadcast_as((cb.dim(0)?, y.dim(0)?, 3))?; let y_broadcast = y_flat.broadcast_as((cb.dim(0)?, y.dim(0)?, 3))?; let diff = cb_broadcast.sub(&y_broadcast)?;
170 let l2_ab = diff.powf(2.0)?.sum(D::Minus1)?;
171 let complement_mask = (mask_cby.neg()? + 1.0)?;
172 let padding_value = Tensor::full(1000.0_f32, mask_cby.dims(), cb.device())?;
173 let masked_distances = l2_ab.mul(&mask_cby)?;
174 let padding_contribution = complement_mask.mul(&padding_value)?;
175 let l2_ab = masked_distances.add(&padding_contribution)?;
176
177 let nn_idx = l2_ab
179 .arg_sort_last_dim(false)?
180 .narrow(1, 0, number_of_ligand_atoms as usize)?
181 .contiguous()?;
182 let l2_ab_nn = l2_ab.contiguous()?.gather(&nn_idx, 1)?;
183 let d_ab_closest = l2_ab_nn.i((.., 0))?.sqrt()?;
184 let y_new = y
185 .unsqueeze(0)?
186 .expand((num_residues, y.dim(0)?, 3))?
187 .contiguous()?
188 .gather(
189 &nn_idx
190 .unsqueeze(2)?
191 .expand((num_residues, number_of_ligand_atoms as usize, 3))?
192 .contiguous()?,
193 1,
194 )?;
195
196 let y_t_new = y_t
197 .unsqueeze(0)?
198 .expand((num_residues, y_t.dim(0)?))?
199 .contiguous()?
200 .gather(&nn_idx, 1)?;
201
202 let y_m_new = y_m
203 .unsqueeze(0)?
204 .expand((num_residues, y_m.dim(0)?))?
205 .contiguous()?
206 .gather(&nn_idx, 1)?;
207
208 Ok((y_new, y_t_new, y_m_new, d_ab_closest))
209}
210
211pub fn cat_neighbors_nodes(
212 h_nodes: &Tensor,
213 h_neighbors: &Tensor,
214 e_idx: &Tensor,
215) -> Result<Tensor> {
216 let h_nodes_gathered = gather_nodes(h_nodes, e_idx)?;
217 let h_neighbors = h_neighbors.expand((
218 h_neighbors.dim(0)?, h_nodes.dim(1)?, h_neighbors.dim(2)?, h_neighbors.dim(3)?, ))?;
223
224 Tensor::cat(
225 &[h_neighbors, h_nodes_gathered.to_dtype(DType::F32)?],
226 D::Minus1,
227 )
228}
229
230pub fn compute_nearest_neighbors(
233 coords: &Tensor,
234 mask: &Tensor,
235 k: usize,
236 eps: f32,
237) -> Result<(Tensor, Tensor)> {
238 let (_batch_size, seq_len, _) = coords.dims3()?;
239 let mask_2d = mask
243 .unsqueeze(2)?
244 .broadcast_matmul(&mask.unsqueeze(1)?)?
245 .to_dtype(DType::F32)?;
246 let distances = (coords
249 .unsqueeze(2)?
250 .broadcast_sub(&coords.unsqueeze(1)?)?
251 .powf(2.)?
252 .sum(D::Minus1)?
253 + eps as f64)? .sqrt()?
255 .to_dtype(DType::F32)?;
256
257 let masked_distances = (&distances * &mask_2d.to_dtype(DType::F32)?)?;
260 let d_max = masked_distances.max_keepdim(D::Minus1)?;
262 let mask_term = ((&mask_2d.to_dtype(DType::F32)? * -1.0)? + 1.0)?;
263 let d_adjust = (&masked_distances + mask_term.broadcast_mul(&d_max)?)?;
264 let d_adjust = d_adjust.to_dtype(DType::F32)?;
265 Ok(topk_last_dim(&d_adjust, k.min(seq_len))?)
266}
267
268pub fn topk_last_dim(xs: &Tensor, topk: usize) -> Result<(Tensor, Tensor)> {
270 let sorted_indices = xs.arg_sort_last_dim(false)?.to_dtype(DType::U32)?;
271 let topk_indices = sorted_indices.narrow(D::Minus1, 0, topk)?.contiguous()?;
272 let gathered = xs.gather(&topk_indices, D::Minus1)?;
273 Ok((gathered, topk_indices))
274}
275
276pub fn create_backbone_mask_37(xyz_37: &Tensor) -> Result<Tensor> {
281 let (b, l, rescount, _) = xyz_37.dims4()?;
282 let mut values = vec![0f32; rescount];
284 for &idx in &[0, 1, 2, 4] {
285 if idx < rescount {
286 values[idx] = 1.0;
287 }
288 }
289
290 let base_mask = Tensor::new(values.as_slice(), xyz_37.device())?.to_dtype(DType::F32)?;
292
293 let mask = base_mask.unsqueeze(0)?.unsqueeze(0)?; let mask = mask.broadcast_as((b, l, rescount))?;
296
297 Ok(mask)
298}
299pub fn calculate_cb(xyz_37: &Tensor) -> Result<Tensor> {
301 let (_, dim37, dim3) = xyz_37.dims3()?;
303 assert_eq!(dim37, 37);
304 assert_eq!(dim3, 3);
305
306 let a_coeff = -0.58273431f64;
308 let b_coeff = 0.56802827f64;
309 let c_coeff = -0.54067466f64;
310
311 let n = xyz_37.i((.., 0, ..))?; let ca = xyz_37.i((.., 1, ..))?; let c = xyz_37.i((.., 2, ..))?; let b = (&ca - &n)?; let c = (&c - &ca)?; let b_x = b.i((.., 0))?;
325 let b_y = b.i((.., 1))?;
326 let b_z = b.i((.., 2))?;
327 let c_x = c.i((.., 0))?;
328 let c_y = c.i((.., 1))?;
329 let c_z = c.i((.., 2))?;
330
331 let a_x = ((&b_y * &c_z)? - (&b_z * &c_y)?)?;
332 let a_y = ((&b_z * &c_x)? - (&b_x * &c_z)?)?;
333 let a_z = ((&b_x * &c_y)? - (&b_y * &c_x)?)?;
334
335 let a = Tensor::stack(&[&a_x, &a_y, &a_z], 1)?;
337
338 let cb = (&a * a_coeff)? + (&b * b_coeff)? + (&c * c_coeff)? + &ca;
340
341 Ok(cb?)
342}
343
344pub fn cross_product(a: &Tensor, b: &Tensor) -> Result<Tensor> {
346 let last_dim = a.dims().len() - 1;
347
348 let a0 = a.narrow(last_dim, 0, 1)?;
350 let a1 = a.narrow(last_dim, 1, 1)?;
351 let a2 = a.narrow(last_dim, 2, 1)?;
352
353 let b0 = b.narrow(last_dim, 0, 1)?;
354 let b1 = b.narrow(last_dim, 1, 1)?;
355 let b2 = b.narrow(last_dim, 2, 1)?;
356
357 let c0 = ((&a1 * &b2)? - (&a2 * &b1)?)?;
359 let c1 = ((&a2 * &b0)? - (&a0 * &b2)?)?;
360 let c2 = ((&a0 * &b1)? - (&a1 * &b0)?)?;
361
362 Tensor::cat(&[&c0, &c1, &c2], last_dim)
364}
365
366pub fn gather_edges(edges: &Tensor, neighbor_idx: &Tensor) -> Result<Tensor> {
369 let (d1, d2, d3) = neighbor_idx.dims3()?;
370 let neighbors =
371 neighbor_idx
372 .unsqueeze(D::Minus1)?
373 .expand((d1, d2, d3, edges.dim(D::Minus1)?))?;
374 let edge_gather = edges.gather(&neighbors, 2)?;
375 Ok(edge_gather)
376}
377
378pub fn gather_nodes(nodes: &Tensor, neighbor_idx: &Tensor) -> Result<Tensor> {
383 let (batch_size, n_nodes, n_features) = nodes.dims3()?;
384 let (_, _, k_neighbors) = neighbor_idx.dims3()?;
385 let neighbors_flat = neighbor_idx.reshape((batch_size, n_nodes * k_neighbors))?;
387 let neighbors_flat = neighbors_flat
389 .unsqueeze(2)? .expand((batch_size, n_nodes * k_neighbors, n_features))?; let neighbors_flat = neighbors_flat.contiguous()?;
393 let neighbor_features = nodes.gather(&neighbors_flat, 1)?;
395 neighbor_features.reshape((batch_size, n_nodes, k_neighbors, n_features))
397}
398
399pub fn gather_nodes_t(nodes: &Tensor, neighbor_idx: &Tensor) -> Result<Tensor> {
400 let (d1, d2, d3) = nodes.dims3()?;
402 let idx_flat = neighbor_idx.unsqueeze(D::Minus1)?.expand((d1, d2, d3))?;
403 nodes.gather(&idx_flat, 1)
404}
405
406fn get_seq_rec(s: &Tensor, s_pred: &Tensor, mask: &Tensor) -> Result<Tensor> {
407 let match_tensor = s.eq(s_pred)?;
414 let match_f32 = match_tensor.to_dtype(DType::F32)?;
415 let numerator = (match_f32 * mask)?.sum_keepdim(1)?;
416 let denominator = mask.sum_keepdim(1)?;
417 let average = numerator.broadcast_div(&denominator)?;
418 average.squeeze(1)
420}
421
422fn get_score(s: &Tensor, log_probs: &Tensor, mask: &Tensor) -> Result<(Tensor, Tensor)> {
423 let s_one_hot = one_hot(s.clone(), 21, 1., 0.)?;
448 let loss_per_residue = s_one_hot.mul(&log_probs.neg()?)?.sum(D::Minus1)?;
449 let average_loss = loss_per_residue
450 .mul(&mask)?
451 .sum_keepdim(D::Minus1)?
452 .div(&(mask.sum_keepdim(D::Minus1)? + 1e-8f64)?)?
453 .squeeze(D::Minus1)?;
454
455 Ok((average_loss, loss_per_residue))
456}
457
458pub fn linspace(
459 start: f64,
460 stop: f64,
461 steps: usize,
462 device: &Device,
463 return_type: DType,
464) -> Result<Tensor> {
465 if steps == 0 {
466 Tensor::from_vec(Vec::<f64>::new(), steps, device)
467 } else if steps == 1 {
468 Tensor::from_vec(vec![start], steps, device)
469 } else {
470 let delta = (stop - start) / (steps - 1) as f64;
471 let vs = (0..steps)
472 .map(|step| start + step as f64 * delta)
473 .collect::<Vec<_>>();
474 Tensor::from_vec(vs, steps, device)?.to_dtype(return_type)
475 }
476}
477
478pub fn linspace_f32(start: f32, stop: f32, steps: usize, device: &Device) -> Result<Tensor> {
479 match steps {
480 0 => Tensor::from_vec(Vec::<f32>::new(), steps, device),
481 1 => Tensor::from_vec(vec![start], steps, device),
482 _ => {
483 let delta = (stop - start) / (steps - 1) as f32;
484 let vs = (0..steps)
485 .map(|step| start + step as f32 * delta)
486 .collect::<Vec<_>>();
487 Tensor::from_vec(vs, steps, device)
488 }
489 }
490}
491
492#[cfg(test)]
493mod tests {
494 use super::*;
495 use crate::StructureFeatures;
496 use anyhow::Result;
497 use ferritin_core::info::elements::Element;
498 use ferritin_core::load_structure;
499 use ferritin_test_data::TestFile;
500
501 #[test]
502 fn test_residue_codes() {
503 let ala = Residue::ALA;
504 assert_eq!(ala.code3(), "ALA");
505 assert_eq!(ala.code1(), 'A');
506 assert_eq!(ala.to_int(), 0);
507 }
508
509 #[test]
510 fn test_residue_from_int() {
511 assert!(matches!(Residue::from_int(0), Residue::ALA));
512 assert!(matches!(Residue::from_int(1), Residue::CYS));
513 assert!(matches!(Residue::from_int(999), Residue::UNK));
514 }
515
516 #[test]
517 fn test_residue_atoms() {
518 let trp = Residue::TRP;
519 let atoms = trp.atoms14();
520 assert_eq!(atoms[0], AAAtom::N);
521 assert_eq!(atoms[13], AAAtom::CH2);
522
523 let gly = Residue::GLY;
524 let atoms = gly.atoms14();
525 assert_eq!(atoms[4], AAAtom::Unknown);
526 }
527
528 #[test]
529 fn test_atom_backbone_tensor() -> Result<()> {
530 let device = Device::Cpu;
531 let (pdb_file, _temp) = TestFile::protein_01().create_temp()?;
532 let ac = load_structure(pdb_file)?;
533 let ac_backbone_tensor: Tensor = ac.to_numeric_backbone_atoms(&device).expect("REASON");
534 assert_eq!(ac_backbone_tensor.dims(), &[1, 154, 4, 3]);
536
537 let backbone_coords = [
543 ("N", (0, 0, 0, ..), vec![24.277, 8.374, -9.854]),
545 ("CA", (0, 0, 1, ..), vec![24.404, 9.859, -9.939]),
546 ("C", (0, 0, 2, ..), vec![25.814, 10.249, -10.359]),
547 ("O", (0, 0, 3, ..), vec![26.748, 9.469, -10.197]),
548 ("N", (0, 1, 0, ..), vec![25.964, 11.453, -10.903]),
550 ("CA", (0, 1, 1, ..), vec![27.263, 11.924, -11.359]),
551 ("C", (0, 1, 2, ..), vec![27.392, 13.428, -11.115]),
552 ("O", (0, 1, 3, ..), vec![26.443, 14.184, -11.327]),
553 ("N", (0, 153, 0, ..), vec![23.474, -3.227, 5.994]),
555 ("CA", (0, 153, 1, ..), vec![22.818, -2.798, 7.211]),
556 ("C", (0, 153, 2, ..), vec![22.695, -1.282, 7.219]),
557 ("O", (0, 153, 3, ..), vec![21.870, -0.745, 7.992]),
558 ];
559
560 for (atom_name, (b, i, j, k), expected) in backbone_coords {
561 let actual: Vec<f32> = ac_backbone_tensor.i((b, i, j, k))?.to_vec1()?;
563 assert_eq!(actual, expected, "Mismatch for atom {}", atom_name);
564 }
565 Ok(())
566 }
567
568 #[test]
569 fn test_all_atom37_tensor() -> Result<()> {
570 let device = Device::Cpu;
571 let (pdb_file, _temp) = TestFile::protein_01().create_temp()?;
572 let ac = load_structure(pdb_file)?;
573 let ac_backbone_tensor: Tensor = ac.to_numeric_atom37(&device).expect("REASON");
574 assert_eq!(ac_backbone_tensor.dims(), &[1, 154, 37, 3]);
575
576 let allatom_coords = [
598 ("N", (0, 0, 0, ..), vec![24.277, 8.374, -9.854]),
601 ("CA", (0, 0, 1, ..), vec![24.404, 9.859, -9.939]),
602 ("C", (0, 0, 2, ..), vec![25.814, 10.249, -10.359]),
603 ("CB", (0, 0, 3, ..), vec![24.070, 10.495, -8.596]),
604 ("O", (0, 0, 4, ..), vec![26.748, 9.469, -10.197]),
605 ("CG", (0, 0, 5, ..), vec![24.880, 9.939, -7.442]),
606 ("CG1", (0, 0, 6, ..), vec![0.0, 0.0, 0.0]),
607 ("CG2", (0, 0, 7, ..), vec![0.0, 0.0, 0.0]),
608 ("OG", (0, 0, 8, ..), vec![0.0, 0.0, 0.0]),
609 ("OG1", (0, 0, 9, ..), vec![0.0, 0.0, 0.0]),
610 ("SG", (0, 0, 10, ..), vec![0.0, 0.0, 0.0]),
611 ("CD", (0, 0, 11, ..), vec![0.0, 0.0, 0.0]),
612 ("CD1", (0, 0, 12, ..), vec![0.0, 0.0, 0.0]),
613 ("CD2", (0, 0, 13, ..), vec![0.0, 0.0, 0.0]),
614 ("ND1", (0, 0, 14, ..), vec![0.0, 0.0, 0.0]),
615 ("ND2", (0, 0, 15, ..), vec![0.0, 0.0, 0.0]),
616 ("OD1", (0, 0, 16, ..), vec![0.0, 0.0, 0.0]),
617 ("OD2", (0, 0, 17, ..), vec![0.0, 0.0, 0.0]),
618 ("SD", (0, 0, 18, ..), vec![24.262, 10.555, -5.873]),
619 ("CE", (0, 0, 19, ..), vec![24.822, 12.266, -5.967]),
620 ("CE1", (0, 0, 20, ..), vec![0.0, 0.0, 0.0]),
621 ("CE2", (0, 0, 21, ..), vec![0.0, 0.0, 0.0]),
622 ("CE3", (0, 0, 22, ..), vec![0.0, 0.0, 0.0]),
623 ("NE", (0, 0, 23, ..), vec![0.0, 0.0, 0.0]),
624 ("NE1", (0, 0, 24, ..), vec![0.0, 0.0, 0.0]),
625 ("NE2", (0, 0, 25, ..), vec![0.0, 0.0, 0.0]),
626 ("OE1", (0, 0, 26, ..), vec![0.0, 0.0, 0.0]),
627 ("OE2", (0, 0, 27, ..), vec![0.0, 0.0, 0.0]),
628 ("CH2", (0, 0, 28, ..), vec![0.0, 0.0, 0.0]),
629 ("NH1", (0, 0, 29, ..), vec![0.0, 0.0, 0.0]),
630 ("NH2", (0, 0, 30, ..), vec![0.0, 0.0, 0.0]),
631 ("OH", (0, 0, 31, ..), vec![0.0, 0.0, 0.0]),
632 ("CZ", (0, 0, 32, ..), vec![0.0, 0.0, 0.0]),
633 ("CZ2", (0, 0, 33, ..), vec![0.0, 0.0, 0.0]),
634 ("CZ3", (0, 0, 34, ..), vec![0.0, 0.0, 0.0]),
635 ("NZ", (0, 0, 35, ..), vec![0.0, 0.0, 0.0]),
636 ("OXT", (0, 0, 36, ..), vec![0.0, 0.0, 0.0]),
637 ];
638 for (atom_name, (b, i, j, k), expected) in allatom_coords {
639 let actual: Vec<f32> = ac_backbone_tensor
640 .i((b, i, j, k))?
641 .to_vec1()?;
642 assert_eq!(actual, expected, "Mismatch for atom {}", atom_name);
643 }
644 Ok(())
645 }
646
647 #[test]
648 fn test_ligand_tensor() -> Result<()> {
649 let device = Device::Cpu;
650 let (pdb_file, _temp) = TestFile::protein_01().create_temp()?;
651 let ac = load_structure(pdb_file)?;
652 let (ligand_coords, ligand_elements, _) =
653 ac.to_numeric_ligand_atoms(&device).expect("REASON");
654 assert_eq!(ligand_coords.dims(), &[1, 154, 54, 3]);
656 let allatom_coords = [
664 ("S", (0, 0, 0, ..), vec![30.746, 18.706, 28.896]),
665 ("O1", (0, 0, 1, ..), vec![30.697, 20.077, 28.620]),
666 ("O2", (0, 0, 2, ..), vec![31.104, 18.021, 27.725]),
667 ("O3", (0, 0, 3, ..), vec![29.468, 18.179, 29.331]),
668 ("O4", (0, 0, 4, ..), vec![31.722, 18.578, 29.881]),
669 ];
670 for (atom_name, (b, l, i, j), expected) in allatom_coords {
671 let actual: Vec<f32> = ligand_coords.i((b, l, i, ..))?.to_vec1()?;
672 assert_eq!(actual, expected, "Mismatch for atom {}", atom_name);
673 }
674
675 let elements: Vec<&str> = ligand_elements
677 .i((0, 0, ..))?
678 .to_vec1::<i64>()?
679 .into_iter()
680 .map(|elem| Element::new(elem as usize).unwrap().symbol())
681 .collect();
682
683 assert_eq!(elements[0], "S");
684 assert_eq!(elements[1], "O");
685 assert_eq!(elements[2], "O");
686 assert_eq!(elements[3], "O");
687
688 Ok(())
689 }
690
691 #[test]
692 fn test_backbone_tensor() {
693 let device = Device::Cpu;
694 let (pdb_file, _temp) = TestFile::protein_01().create_temp().unwrap();
695 let ac = load_structure(pdb_file).unwrap();
696 let xyz_37 = ac
697 .to_numeric_atom37(&device)
698 .expect("XYZ creation for all-atoms");
699 assert_eq!(xyz_37.dims(), [1, 154, 37, 3]);
700
701 let xyz_m = create_backbone_mask_37(&xyz_37).expect("masking procedure should work");
703 assert_eq!(xyz_m.dims(), &[1, 154, 37]);
704 }
705 #[test]
706 fn test_compute_nearest_neighbors() {
707 let device = Device::Cpu;
708 let test_dtype = DType::F32;
709
710 let coords = Tensor::new(
712 &[
713 [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [2.0, 0.0, 0.0]], [[0.0, 1.0, 0.0], [1.0, 1.0, 0.0], [2.0, 1.0, 0.0]], ],
716 &device,
717 )
718 .unwrap()
719 .to_dtype(test_dtype)
720 .unwrap();
721
722 let mask = Tensor::ones((2, 3), test_dtype, &device).unwrap();
724
725 let (distances, indices) = compute_nearest_neighbors(&coords, &mask, 2, 1e-6).unwrap();
727
728 assert_eq!(distances.dims(), &[2, 3, 2]); assert_eq!(indices.dims(), &[2, 3, 2]); let point_neighbors: Vec<u32> = indices.i((0, 1, ..)).unwrap().to_vec1().unwrap();
734 assert_eq!(point_neighbors, vec![0, 2]);
735
736 let point_distances: Vec<f32> = distances.i((0, 1, ..)).unwrap().to_vec1().unwrap();
738 assert!((point_distances[0] - 1.0).abs() < 1e-5);
739 assert!((point_distances[1] - 1.0).abs() < 1e-5);
740 }
741}