1use candle_core::{D, DType, Device, IndexOp, Result, Tensor};
2use candle_nn::encoding::one_hot;
3use strum::{Display, EnumIter, EnumString};
4
5#[rustfmt::skip]
6pub fn aa3to1(aa: &str) -> char {
8 match aa {
9 "ALA" => 'A', "CYS" => 'C', "ASP" => 'D',
10 "GLU" => 'E', "PHE" => 'F', "GLY" => 'G',
11 "HIS" => 'H', "ILE" => 'I', "LYS" => 'K',
12 "LEU" => 'L', "MET" => 'M', "ASN" => 'N',
13 "PRO" => 'P', "GLN" => 'Q', "ARG" => 'R',
14 "SER" => 'S', "THR" => 'T', "VAL" => 'V',
15 "TRP" => 'W', "TYR" => 'Y', _ => 'X',
16 }
17}
18
19#[rustfmt::skip]
20pub fn aa1to_int(aa: char) -> u32 {
22 match aa {
23 'A' => 0, 'C' => 1, 'D' => 2,
24 'E' => 3, 'F' => 4, 'G' => 5,
25 'H' => 6, 'I' => 7, 'K' => 8,
26 'L' => 9, 'M' => 10, 'N' => 11,
27 'P' => 12, 'Q' => 13, 'R' => 14,
28 'S' => 15, 'T' => 16, 'V' => 17,
29 'W' => 18, 'Y' => 19, _ => 20,
30 }
31}
32
33#[rustfmt::skip]
34pub fn int_to_aa1(aa_int: u32) -> char {
35 match aa_int {
36 0 => 'A', 1 => 'C', 2 => 'D',
37 3 => 'E', 4 => 'F', 5 => 'G',
38 6 => 'H', 7 => 'I', 8 => 'K',
39 9 => 'L', 10 => 'M', 11 => 'N',
40 12 => 'P', 13 => 'Q', 14 => 'R',
41 15 => 'S', 16 => 'T', 17 => 'V',
42 18 => 'W', 19 => 'Y', 20 => 'X',
43 _ => 'X'
44
45 }
46}
47
48#[allow(dead_code)]
49const ALPHABET: [char; 21] = [
50 'A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W',
51 'Y', 'X',
52];
53
54#[allow(dead_code)]
55const ELEMENT_LIST: [&str; 118] = [
56 "H", "He", "Li", "Be", "B", "C", "N", "O", "F", "Ne", "Na", "Mg", "Al", "Si", "P", "S", "Cl",
57 "Ar", "K", "Ca", "Sc", "Ti", "V", "Cr", "Mn", "Fe", "Co", "Ni", "Cu", "Zn", "Ga", "Ge", "As",
58 "Se", "Br", "Kr", "Rb", "Sr", "Y", "Zr", "Nb", "Mo", "Tc", "Ru", "Rh", "Pd", "Ag", "Cd", "In",
59 "Sn", "Sb", "Te", "I", "Xe", "Cs", "Ba", "La", "Ce", "Pr", "Nd", "Pm", "Sm", "Eu", "Gd", "Tb",
60 "Dy", "Ho", "Er", "Tm", "Yb", "Lu", "Hf", "Ta", "W", "Re", "Os", "Ir", "Pt", "Au", "Hg", "Tl",
61 "Pb", "Bi", "Po", "At", "Rn", "Fr", "Ra", "Ac", "Th", "Pa", "U", "Np", "Pu", "Am", "Cm", "Bk",
62 "Cf", "Es", "Fm", "Md", "No", "Lr", "Rf", "Db", "Sg", "Bh", "Hs", "Mt", "Ds", "Rg", "Cn", "Nh",
63 "Fl", "Mc", "Lv", "Ts", "Og",
64];
65
66#[rustfmt::skip]
67#[derive(Debug, Clone, Copy, PartialEq, Display, EnumString, EnumIter)]
68pub enum AAAtom {
69 N = 0, CA = 1, C = 2, CB = 3, O = 4,
70 CG = 5, CG1 = 6, CG2 = 7, OG = 8, OG1 = 9,
71 SG = 10, CD = 11, CD1 = 12, CD2 = 13, ND1 = 14,
72 ND2 = 15, OD1 = 16, OD2 = 17, SD = 18, CE = 19,
73 CE1 = 20, CE2 = 21, CE3 = 22, NE = 23, NE1 = 24,
74 NE2 = 25, OE1 = 26, OE2 = 27, CH2 = 28, NH1 = 29,
75 NH2 = 30, OH = 31, CZ = 32, CZ2 = 33, CZ3 = 34,
76 NZ = 35, OXT = 36,
77 Unknown = -1,
78}
79impl AAAtom {
80 pub fn to_index(&self) -> usize {
82 *self as usize
83 }
84}
85
86macro_rules! define_residues {
87 ($($name:ident: $code3:expr_2021, $code1:expr_2021, $idx:expr_2021, $features:expr_2021, $atoms14:expr_2021),* $(,)?) => {
88 #[derive(Debug, Copy, Clone)]
89 pub enum Residue {
90 $($name),*
91 }
92
93 impl Residue {
94 pub const fn code3(&self) -> &'static str {
95 match self {
96 $(Self::$name => $code3),*
97 }
98 }
99 pub const fn code1(&self) -> char {
100 match self {
101 $(Self::$name => $code1),*
102 }
103 }
104 pub const fn atoms14(&self) -> [AAAtom; 14] {
105 match self {
106 $(Self::$name => $atoms14),*
107 }
108 }
109 pub fn from_int(value: i32) -> Self {
110 match value {
111 $($idx => Self::$name,)*
112 _ => Self::UNK
113 }
114 }
115 pub fn to_int(&self) -> i32 {
116 match self {
117 $(Self::$name => $idx),*
118 }
119 }
120 }
121 }
122}
123
124define_residues! {
125 ALA: "ALA", 'A', 0, [1.0, 0.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::CB, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown],
126 CYS: "CYS", 'C', 1, [0.0, 0.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::CB, AAAtom::SG, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown],
127 ASP: "ASP", 'D', 2, [0.0, 0.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::CB, AAAtom::CG, AAAtom::OD1, AAAtom::OD2, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown],
128 GLU: "GLU", 'E', 3, [0.0, 0.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::CB, AAAtom::CG, AAAtom::CD, AAAtom::OE1, AAAtom::OE2, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown],
129 PHE: "PHE", 'F', 4, [0.0, 0.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::CB, AAAtom::CG, AAAtom::CD1, AAAtom::CD2, AAAtom::CE1, AAAtom::CE2, AAAtom::CZ, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown],
130 GLY: "GLY", 'G', 5, [0.0, 0.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown],
131 HIS: "HIS", 'H', 6, [0.0, 0.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::CB, AAAtom::CG, AAAtom::ND1, AAAtom::CD2, AAAtom::CE1, AAAtom::NE2, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown],
132 ILE: "ILE", 'I', 7, [0.0, 0.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::CB, AAAtom::CG1, AAAtom::CG2, AAAtom::CD1, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown],
133 LYS: "LYS", 'K', 8, [0.0, 0.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::CB, AAAtom::CG, AAAtom::CD, AAAtom::CE, AAAtom::NZ, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown],
134 LEU: "LEU", 'L', 9, [0.0, 0.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::CB, AAAtom::CG, AAAtom::CD1, AAAtom::CD2, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown],
135 MET: "MET", 'M', 10, [0.0, 0.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::CB, AAAtom::CG, AAAtom::SD, AAAtom::CE, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown],
136 ASN: "ASN", 'N', 11, [0.0, 0.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::CB, AAAtom::CG, AAAtom::OD1, AAAtom::ND2, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown],
137 PRO: "PRO", 'P', 12, [0.0, 0.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::CB, AAAtom::CG, AAAtom::CD, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown],
138 GLN: "GLN", 'Q', 13, [0.0, 0.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::CB, AAAtom::CG, AAAtom::CD, AAAtom::OE1, AAAtom::NE2, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown],
139 ARG: "ARG", 'R', 14, [0.0, 1.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::CB, AAAtom::CG, AAAtom::CD, AAAtom::NE, AAAtom::CZ, AAAtom::NH1, AAAtom::NH2, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown],
140 SER: "SER", 'S', 15, [0.0, 0.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::CB, AAAtom::OG, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown],
141 THR: "THR", 'T', 16, [0.0, 0.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::CB, AAAtom::OG1, AAAtom::CG2, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown],
142 VAL: "VAL", 'V', 17, [0.0, 0.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::CB, AAAtom::CG1, AAAtom::CG2, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown],
143 TRP: "TRP", 'W', 18, [0.0, 0.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::CB, AAAtom::CG, AAAtom::CD1, AAAtom::CD2, AAAtom::CE2, AAAtom::CE3, AAAtom::NE1, AAAtom::CZ2, AAAtom::CZ3, AAAtom::CH2],
144 TYR: "TYR", 'Y', 19, [0.0, 0.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::CB, AAAtom::CG, AAAtom::CD1, AAAtom::CD2, AAAtom::CE1, AAAtom::CE2, AAAtom::CZ, AAAtom::OH, AAAtom::Unknown, AAAtom::Unknown],
145 UNK: "UNK", 'X', 20, [0.0, 0.0], [AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown],
146}
147
148pub fn get_nearest_neighbours(
149 cb: &Tensor,
150 mask: &Tensor,
151 y: &Tensor,
152 y_t: &Tensor,
153 y_m: &Tensor,
154 number_of_ligand_atoms: i64,
155) -> Result<(Tensor, Tensor, Tensor, Tensor)> {
156 let cb = cb.squeeze(0)?;
158 let mask = mask.squeeze(0)?;
159 let y_m = if y_m.dims().len() > 1 {
160 y_m.sum_keepdim(1)?.squeeze(1)? } else {
162 y_m.clone()
163 };
164 let num_residues = cb.dim(0)?;
165 let mask_cby = mask.unsqueeze(1)?.matmul(&y_m.unsqueeze(0)?)?;
166 let cb_flat = cb.reshape((cb.dim(0)?, 1, 3))?; let y_flat = y.reshape((1, y.dim(0)?, 3))?; let cb_broadcast = cb_flat.broadcast_as((cb.dim(0)?, y.dim(0)?, 3))?; let y_broadcast = y_flat.broadcast_as((cb.dim(0)?, y.dim(0)?, 3))?; let diff = cb_broadcast.sub(&y_broadcast)?;
172 let l2_ab = diff.powf(2.0)?.sum(D::Minus1)?;
173 let complement_mask = (mask_cby.neg()? + 1.0)?;
174 let padding_value = Tensor::full(1000.0_f32, mask_cby.dims(), cb.device())?;
175 let masked_distances = l2_ab.mul(&mask_cby)?;
176 let padding_contribution = complement_mask.mul(&padding_value)?;
177 let l2_ab = masked_distances.add(&padding_contribution)?;
178
179 let nn_idx = l2_ab
181 .arg_sort_last_dim(false)?
182 .narrow(1, 0, number_of_ligand_atoms as usize)?
183 .contiguous()?;
184 let l2_ab_nn = l2_ab.contiguous()?.gather(&nn_idx, 1)?;
185 let d_ab_closest = l2_ab_nn.i((.., 0))?.sqrt()?;
186 let y_new = y
187 .unsqueeze(0)?
188 .expand((num_residues, y.dim(0)?, 3))?
189 .contiguous()?
190 .gather(
191 &nn_idx
192 .unsqueeze(2)?
193 .expand((num_residues, number_of_ligand_atoms as usize, 3))?
194 .contiguous()?,
195 1,
196 )?;
197
198 let y_t_new = y_t
199 .unsqueeze(0)?
200 .expand((num_residues, y_t.dim(0)?))?
201 .contiguous()?
202 .gather(&nn_idx, 1)?;
203
204 let y_m_new = y_m
205 .unsqueeze(0)?
206 .expand((num_residues, y_m.dim(0)?))?
207 .contiguous()?
208 .gather(&nn_idx, 1)?;
209
210 Ok((y_new, y_t_new, y_m_new, d_ab_closest))
211}
212
213pub fn cat_neighbors_nodes(
214 h_nodes: &Tensor,
215 h_neighbors: &Tensor,
216 e_idx: &Tensor,
217) -> Result<Tensor> {
218 let h_nodes_gathered = gather_nodes(h_nodes, e_idx)?;
219 let h_neighbors = h_neighbors.expand((
220 h_neighbors.dim(0)?, h_nodes.dim(1)?, h_neighbors.dim(2)?, h_neighbors.dim(3)?, ))?;
225
226 Tensor::cat(
227 &[h_neighbors, h_nodes_gathered.to_dtype(DType::F32)?],
228 D::Minus1,
229 )
230}
231
232pub fn compute_nearest_neighbors(
235 coords: &Tensor,
236 mask: &Tensor,
237 k: usize,
238 eps: f32,
239) -> Result<(Tensor, Tensor)> {
240 let (_batch_size, seq_len, _) = coords.dims3()?;
241 let mask_2d = mask
245 .unsqueeze(2)?
246 .broadcast_matmul(&mask.unsqueeze(1)?)?
247 .to_dtype(DType::F32)?;
248 let distances = (coords
251 .unsqueeze(2)?
252 .broadcast_sub(&coords.unsqueeze(1)?)?
253 .powf(2.)?
254 .sum(D::Minus1)?
255 + eps as f64)? .sqrt()?
257 .to_dtype(DType::F32)?;
258
259 let masked_distances = (&distances * &mask_2d.to_dtype(DType::F32)?)?;
262 let d_max = masked_distances.max_keepdim(D::Minus1)?;
264 let mask_term = ((&mask_2d.to_dtype(DType::F32)? * -1.0)? + 1.0)?;
265 let d_adjust = (&masked_distances + mask_term.broadcast_mul(&d_max)?)?;
266 let d_adjust = d_adjust.to_dtype(DType::F32)?;
267 Ok(topk_last_dim(&d_adjust, k.min(seq_len))?)
268}
269
270pub fn topk_last_dim(xs: &Tensor, topk: usize) -> Result<(Tensor, Tensor)> {
272 let sorted_indices = xs.arg_sort_last_dim(false)?.to_dtype(DType::U32)?;
273 let topk_indices = sorted_indices.narrow(D::Minus1, 0, topk)?.contiguous()?;
274 let gathered = xs.gather(&topk_indices, D::Minus1)?;
275 Ok((gathered, topk_indices))
276}
277
278pub fn create_backbone_mask_37(xyz_37: &Tensor) -> Result<Tensor> {
283 let (b, l, rescount, _) = xyz_37.dims4()?;
284 let mut values = vec![0f32; rescount];
286 for &idx in &[0, 1, 2, 4] {
287 if idx < rescount {
288 values[idx] = 1.0;
289 }
290 }
291
292 let base_mask = Tensor::new(values.as_slice(), xyz_37.device())?.to_dtype(DType::F32)?;
294
295 let mask = base_mask.unsqueeze(0)?.unsqueeze(0)?; let mask = mask.broadcast_as((b, l, rescount))?;
298
299 Ok(mask)
300}
301pub fn calculate_cb(xyz_37: &Tensor) -> Result<Tensor> {
303 let (_, dim37, dim3) = xyz_37.dims3()?;
305 assert_eq!(dim37, 37);
306 assert_eq!(dim3, 3);
307
308 let a_coeff = -0.58273431f64;
310 let b_coeff = 0.56802827f64;
311 let c_coeff = -0.54067466f64;
312
313 let n = xyz_37.i((.., 0, ..))?; let ca = xyz_37.i((.., 1, ..))?; let c = xyz_37.i((.., 2, ..))?; let b = (&ca - &n)?; let c = (&c - &ca)?; let b_x = b.i((.., 0))?;
327 let b_y = b.i((.., 1))?;
328 let b_z = b.i((.., 2))?;
329 let c_x = c.i((.., 0))?;
330 let c_y = c.i((.., 1))?;
331 let c_z = c.i((.., 2))?;
332
333 let a_x = ((&b_y * &c_z)? - (&b_z * &c_y)?)?;
334 let a_y = ((&b_z * &c_x)? - (&b_x * &c_z)?)?;
335 let a_z = ((&b_x * &c_y)? - (&b_y * &c_x)?)?;
336
337 let a = Tensor::stack(&[&a_x, &a_y, &a_z], 1)?;
339
340 let cb = (&a * a_coeff)? + (&b * b_coeff)? + (&c * c_coeff)? + &ca;
342
343 Ok(cb?)
344}
345
346pub fn cross_product(a: &Tensor, b: &Tensor) -> Result<Tensor> {
348 let last_dim = a.dims().len() - 1;
349
350 let a0 = a.narrow(last_dim, 0, 1)?;
352 let a1 = a.narrow(last_dim, 1, 1)?;
353 let a2 = a.narrow(last_dim, 2, 1)?;
354
355 let b0 = b.narrow(last_dim, 0, 1)?;
356 let b1 = b.narrow(last_dim, 1, 1)?;
357 let b2 = b.narrow(last_dim, 2, 1)?;
358
359 let c0 = ((&a1 * &b2)? - (&a2 * &b1)?)?;
361 let c1 = ((&a2 * &b0)? - (&a0 * &b2)?)?;
362 let c2 = ((&a0 * &b1)? - (&a1 * &b0)?)?;
363
364 Tensor::cat(&[&c0, &c1, &c2], last_dim)
366}
367
368pub fn gather_edges(edges: &Tensor, neighbor_idx: &Tensor) -> Result<Tensor> {
371 let (d1, d2, d3) = neighbor_idx.dims3()?;
372 let neighbors =
373 neighbor_idx
374 .unsqueeze(D::Minus1)?
375 .expand((d1, d2, d3, edges.dim(D::Minus1)?))?;
376 let edge_gather = edges.gather(&neighbors, 2)?;
377 Ok(edge_gather)
378}
379
380pub fn gather_nodes(nodes: &Tensor, neighbor_idx: &Tensor) -> Result<Tensor> {
385 let (batch_size, n_nodes, n_features) = nodes.dims3()?;
386 let (_, _, k_neighbors) = neighbor_idx.dims3()?;
387 let neighbors_flat = neighbor_idx.reshape((batch_size, n_nodes * k_neighbors))?;
389 let neighbors_flat = neighbors_flat
391 .unsqueeze(2)? .expand((batch_size, n_nodes * k_neighbors, n_features))?; let neighbors_flat = neighbors_flat.contiguous()?;
395 let neighbor_features = nodes.gather(&neighbors_flat, 1)?;
397 neighbor_features.reshape((batch_size, n_nodes, k_neighbors, n_features))
399}
400
401pub fn gather_nodes_t(nodes: &Tensor, neighbor_idx: &Tensor) -> Result<Tensor> {
402 let (d1, d2, d3) = nodes.dims3()?;
404 let idx_flat = neighbor_idx.unsqueeze(D::Minus1)?.expand((d1, d2, d3))?;
405 nodes.gather(&idx_flat, 1)
406}
407
408#[allow(dead_code)]
409fn get_seq_rec(s: &Tensor, s_pred: &Tensor, mask: &Tensor) -> Result<Tensor> {
410 let match_tensor = s.eq(s_pred)?;
417 let match_f32 = match_tensor.to_dtype(DType::F32)?;
418 let numerator = (match_f32 * mask)?.sum_keepdim(1)?;
419 let denominator = mask.sum_keepdim(1)?;
420 let average = numerator.broadcast_div(&denominator)?;
421 average.squeeze(1)
423}
424
425#[allow(dead_code)]
426fn get_score(s: &Tensor, log_probs: &Tensor, mask: &Tensor) -> Result<(Tensor, Tensor)> {
427 let s_one_hot = one_hot(s.clone(), 21, 1., 0.)?;
452 let loss_per_residue = s_one_hot.mul(&log_probs.neg()?)?.sum(D::Minus1)?;
453 let average_loss = loss_per_residue
454 .mul(&mask)?
455 .sum_keepdim(D::Minus1)?
456 .div(&(mask.sum_keepdim(D::Minus1)? + 1e-8f64)?)?
457 .squeeze(D::Minus1)?;
458
459 Ok((average_loss, loss_per_residue))
460}
461
462pub fn linspace(
463 start: f64,
464 stop: f64,
465 steps: usize,
466 device: &Device,
467 return_type: DType,
468) -> Result<Tensor> {
469 if steps == 0 {
470 Tensor::from_vec(Vec::<f64>::new(), steps, device)
471 } else if steps == 1 {
472 Tensor::from_vec(vec![start], steps, device)
473 } else {
474 let delta = (stop - start) / (steps - 1) as f64;
475 let vs = (0..steps)
476 .map(|step| start + step as f64 * delta)
477 .collect::<Vec<_>>();
478 Tensor::from_vec(vs, steps, device)?.to_dtype(return_type)
479 }
480}
481
482pub fn linspace_f32(start: f32, stop: f32, steps: usize, device: &Device) -> Result<Tensor> {
483 match steps {
484 0 => Tensor::from_vec(Vec::<f32>::new(), steps, device),
485 1 => Tensor::from_vec(vec![start], steps, device),
486 _ => {
487 let delta = (stop - start) / (steps - 1) as f32;
488 let vs = (0..steps)
489 .map(|step| start + step as f32 * delta)
490 .collect::<Vec<_>>();
491 Tensor::from_vec(vs, steps, device)
492 }
493 }
494}
495
496#[cfg(test)]
497mod tests {
498 use super::*;
499 use crate::StructureFeatures;
500 use anyhow::Result;
501 use ferritin_core::info::elements::Element;
502 use ferritin_core::load_structure;
503 use ferritin_test_data::TestFile;
504
505 #[test]
506 fn test_residue_codes() {
507 let ala = Residue::ALA;
508 assert_eq!(ala.code3(), "ALA");
509 assert_eq!(ala.code1(), 'A');
510 assert_eq!(ala.to_int(), 0);
511 }
512
513 #[test]
514 fn test_residue_from_int() {
515 assert!(matches!(Residue::from_int(0), Residue::ALA));
516 assert!(matches!(Residue::from_int(1), Residue::CYS));
517 assert!(matches!(Residue::from_int(999), Residue::UNK));
518 }
519
520 #[test]
521 fn test_residue_atoms() {
522 let trp = Residue::TRP;
523 let atoms = trp.atoms14();
524 assert_eq!(atoms[0], AAAtom::N);
525 assert_eq!(atoms[13], AAAtom::CH2);
526
527 let gly = Residue::GLY;
528 let atoms = gly.atoms14();
529 assert_eq!(atoms[4], AAAtom::Unknown);
530 }
531
532 #[test]
533 fn test_atom_backbone_tensor() -> Result<()> {
534 let device = Device::Cpu;
535 let (pdb_file, _temp) = TestFile::protein_01().create_temp()?;
536 let ac = load_structure(pdb_file)?;
537 let ac_backbone_tensor: Tensor = ac.to_numeric_backbone_atoms(&device).expect("REASON");
538 assert_eq!(ac_backbone_tensor.dims(), &[1, 154, 4, 3]);
540
541 let backbone_coords = [
547 ("N", (0, 0, 0, ..), vec![24.277, 8.374, -9.854]),
549 ("CA", (0, 0, 1, ..), vec![24.404, 9.859, -9.939]),
550 ("C", (0, 0, 2, ..), vec![25.814, 10.249, -10.359]),
551 ("O", (0, 0, 3, ..), vec![26.748, 9.469, -10.197]),
552 ("N", (0, 1, 0, ..), vec![25.964, 11.453, -10.903]),
554 ("CA", (0, 1, 1, ..), vec![27.263, 11.924, -11.359]),
555 ("C", (0, 1, 2, ..), vec![27.392, 13.428, -11.115]),
556 ("O", (0, 1, 3, ..), vec![26.443, 14.184, -11.327]),
557 ("N", (0, 153, 0, ..), vec![23.474, -3.227, 5.994]),
559 ("CA", (0, 153, 1, ..), vec![22.818, -2.798, 7.211]),
560 ("C", (0, 153, 2, ..), vec![22.695, -1.282, 7.219]),
561 ("O", (0, 153, 3, ..), vec![21.870, -0.745, 7.992]),
562 ];
563
564 for (atom_name, (b, i, j, k), expected) in backbone_coords {
565 let actual: Vec<f32> = ac_backbone_tensor.i((b, i, j, k))?.to_vec1()?;
567 assert_eq!(actual, expected, "Mismatch for atom {}", atom_name);
568 }
569 Ok(())
570 }
571
572 #[test]
573 fn test_all_atom37_tensor() -> Result<()> {
574 let device = Device::Cpu;
575 let (pdb_file, _temp) = TestFile::protein_01().create_temp()?;
576 let ac = load_structure(pdb_file)?;
577 let ac_backbone_tensor: Tensor = ac.to_numeric_atom37(&device).expect("REASON");
578 assert_eq!(ac_backbone_tensor.dims(), &[1, 154, 37, 3]);
579
580 let allatom_coords = [
602 ("N", (0, 0, 0, ..), vec![24.277, 8.374, -9.854]),
605 ("CA", (0, 0, 1, ..), vec![24.404, 9.859, -9.939]),
606 ("C", (0, 0, 2, ..), vec![25.814, 10.249, -10.359]),
607 ("CB", (0, 0, 3, ..), vec![24.070, 10.495, -8.596]),
608 ("O", (0, 0, 4, ..), vec![26.748, 9.469, -10.197]),
609 ("CG", (0, 0, 5, ..), vec![24.880, 9.939, -7.442]),
610 ("CG1", (0, 0, 6, ..), vec![0.0, 0.0, 0.0]),
611 ("CG2", (0, 0, 7, ..), vec![0.0, 0.0, 0.0]),
612 ("OG", (0, 0, 8, ..), vec![0.0, 0.0, 0.0]),
613 ("OG1", (0, 0, 9, ..), vec![0.0, 0.0, 0.0]),
614 ("SG", (0, 0, 10, ..), vec![0.0, 0.0, 0.0]),
615 ("CD", (0, 0, 11, ..), vec![0.0, 0.0, 0.0]),
616 ("CD1", (0, 0, 12, ..), vec![0.0, 0.0, 0.0]),
617 ("CD2", (0, 0, 13, ..), vec![0.0, 0.0, 0.0]),
618 ("ND1", (0, 0, 14, ..), vec![0.0, 0.0, 0.0]),
619 ("ND2", (0, 0, 15, ..), vec![0.0, 0.0, 0.0]),
620 ("OD1", (0, 0, 16, ..), vec![0.0, 0.0, 0.0]),
621 ("OD2", (0, 0, 17, ..), vec![0.0, 0.0, 0.0]),
622 ("SD", (0, 0, 18, ..), vec![24.262, 10.555, -5.873]),
623 ("CE", (0, 0, 19, ..), vec![24.822, 12.266, -5.967]),
624 ("CE1", (0, 0, 20, ..), vec![0.0, 0.0, 0.0]),
625 ("CE2", (0, 0, 21, ..), vec![0.0, 0.0, 0.0]),
626 ("CE3", (0, 0, 22, ..), vec![0.0, 0.0, 0.0]),
627 ("NE", (0, 0, 23, ..), vec![0.0, 0.0, 0.0]),
628 ("NE1", (0, 0, 24, ..), vec![0.0, 0.0, 0.0]),
629 ("NE2", (0, 0, 25, ..), vec![0.0, 0.0, 0.0]),
630 ("OE1", (0, 0, 26, ..), vec![0.0, 0.0, 0.0]),
631 ("OE2", (0, 0, 27, ..), vec![0.0, 0.0, 0.0]),
632 ("CH2", (0, 0, 28, ..), vec![0.0, 0.0, 0.0]),
633 ("NH1", (0, 0, 29, ..), vec![0.0, 0.0, 0.0]),
634 ("NH2", (0, 0, 30, ..), vec![0.0, 0.0, 0.0]),
635 ("OH", (0, 0, 31, ..), vec![0.0, 0.0, 0.0]),
636 ("CZ", (0, 0, 32, ..), vec![0.0, 0.0, 0.0]),
637 ("CZ2", (0, 0, 33, ..), vec![0.0, 0.0, 0.0]),
638 ("CZ3", (0, 0, 34, ..), vec![0.0, 0.0, 0.0]),
639 ("NZ", (0, 0, 35, ..), vec![0.0, 0.0, 0.0]),
640 ("OXT", (0, 0, 36, ..), vec![0.0, 0.0, 0.0]),
641 ];
642 for (atom_name, (b, i, j, k), expected) in allatom_coords {
643 let actual: Vec<f32> = ac_backbone_tensor.i((b, i, j, k))?.to_vec1()?;
644 assert_eq!(actual, expected, "Mismatch for atom {}", atom_name);
645 }
646 Ok(())
647 }
648
649 #[test]
650 fn test_ligand_tensor() -> Result<()> {
651 let device = Device::Cpu;
652 let (pdb_file, _temp) = TestFile::protein_01().create_temp()?;
653 let ac = load_structure(pdb_file)?;
654 let (ligand_coords, ligand_elements, _) =
655 ac.to_numeric_ligand_atoms(&device).expect("REASON");
656 assert_eq!(ligand_coords.dims(), &[1, 154, 54, 3]);
658 let allatom_coords = [
666 ("S", (0, 0, 0, ..), vec![30.746, 18.706, 28.896]),
667 ("O1", (0, 0, 1, ..), vec![30.697, 20.077, 28.620]),
668 ("O2", (0, 0, 2, ..), vec![31.104, 18.021, 27.725]),
669 ("O3", (0, 0, 3, ..), vec![29.468, 18.179, 29.331]),
670 ("O4", (0, 0, 4, ..), vec![31.722, 18.578, 29.881]),
671 ];
672 for (atom_name, (b, l, i, j), expected) in allatom_coords {
673 let actual: Vec<f32> = ligand_coords.i((b, l, i, ..))?.to_vec1()?;
674 assert_eq!(actual, expected, "Mismatch for atom {}", atom_name);
675 }
676
677 let elements: Vec<&str> = ligand_elements
679 .i((0, 0, ..))?
680 .to_vec1::<i64>()?
681 .into_iter()
682 .map(|elem| Element::new(elem as usize).unwrap().symbol())
683 .collect();
684
685 assert_eq!(elements[0], "S");
686 assert_eq!(elements[1], "O");
687 assert_eq!(elements[2], "O");
688 assert_eq!(elements[3], "O");
689
690 Ok(())
691 }
692
693 #[test]
694 fn test_backbone_tensor() {
695 let device = Device::Cpu;
696 let (pdb_file, _temp) = TestFile::protein_01().create_temp().unwrap();
697 let ac = load_structure(pdb_file).unwrap();
698 let xyz_37 = ac
699 .to_numeric_atom37(&device)
700 .expect("XYZ creation for all-atoms");
701 assert_eq!(xyz_37.dims(), [1, 154, 37, 3]);
702
703 let xyz_m = create_backbone_mask_37(&xyz_37).expect("masking procedure should work");
705 assert_eq!(xyz_m.dims(), &[1, 154, 37]);
706 }
707 #[test]
708 fn test_compute_nearest_neighbors() {
709 let device = Device::Cpu;
710 let test_dtype = DType::F32;
711
712 let coords = Tensor::new(
714 &[
715 [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [2.0, 0.0, 0.0]], [[0.0, 1.0, 0.0], [1.0, 1.0, 0.0], [2.0, 1.0, 0.0]], ],
718 &device,
719 )
720 .unwrap()
721 .to_dtype(test_dtype)
722 .unwrap();
723
724 let mask = Tensor::ones((2, 3), test_dtype, &device).unwrap();
726
727 let (distances, indices) = compute_nearest_neighbors(&coords, &mask, 2, 1e-6).unwrap();
729
730 assert_eq!(distances.dims(), &[2, 3, 2]); assert_eq!(indices.dims(), &[2, 3, 2]); let point_neighbors: Vec<u32> = indices.i((0, 1, ..)).unwrap().to_vec1().unwrap();
736 assert_eq!(point_neighbors, vec![0, 2]);
737
738 let point_distances: Vec<f32> = distances.i((0, 1, ..)).unwrap().to_vec1().unwrap();
740 assert!((point_distances[0] - 1.0).abs() < 1e-5);
741 assert!((point_distances[1] - 1.0).abs() < 1e-5);
742 }
743}