ferritin_plms/featurize/
utilities.rs

1use candle_core::{D, DType, Device, IndexOp, Result, Tensor};
2use candle_nn::encoding::one_hot;
3use strum::{Display, EnumIter, EnumString};
4
5#[rustfmt::skip]
6// todo: better utility library
7pub fn aa3to1(aa: &str) -> char {
8    match aa {
9        "ALA" => 'A', "CYS" => 'C', "ASP" => 'D',
10        "GLU" => 'E', "PHE" => 'F', "GLY" => 'G',
11        "HIS" => 'H', "ILE" => 'I', "LYS" => 'K',
12        "LEU" => 'L', "MET" => 'M', "ASN" => 'N',
13        "PRO" => 'P', "GLN" => 'Q', "ARG" => 'R',
14        "SER" => 'S', "THR" => 'T', "VAL" => 'V',
15        "TRP" => 'W', "TYR" => 'Y', _     => 'X',
16    }
17}
18
19#[rustfmt::skip]
20// todo: better utility library
21pub fn aa1to_int(aa: char) -> u32 {
22    match aa {
23        'A' => 0, 'C' => 1, 'D' => 2,
24        'E' => 3, 'F' => 4, 'G' => 5,
25        'H' => 6, 'I' => 7, 'K' => 8,
26        'L' => 9, 'M' => 10, 'N' => 11,
27        'P' => 12, 'Q' => 13, 'R' => 14,
28        'S' => 15, 'T' => 16, 'V' => 17,
29        'W' => 18, 'Y' => 19, _   => 20,
30    }
31}
32
33#[rustfmt::skip]
34pub fn int_to_aa1(aa_int: u32) -> char {
35    match aa_int {
36        0 => 'A', 1 => 'C', 2 => 'D',
37        3 => 'E', 4 => 'F', 5 => 'G',
38        6 => 'H', 7 => 'I', 8 => 'K',
39        9 => 'L', 10 => 'M', 11 => 'N',
40        12 => 'P', 13 => 'Q', 14 => 'R',
41        15 => 'S', 16 => 'T', 17 => 'V',
42        18 => 'W', 19 => 'Y', 20 => 'X',
43        _ => 'X'
44
45    }
46}
47
48const ALPHABET: [char; 21] = [
49    'A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W',
50    'Y', 'X',
51];
52
53const ELEMENT_LIST: [&str; 118] = [
54    "H", "He", "Li", "Be", "B", "C", "N", "O", "F", "Ne", "Na", "Mg", "Al", "Si", "P", "S", "Cl",
55    "Ar", "K", "Ca", "Sc", "Ti", "V", "Cr", "Mn", "Fe", "Co", "Ni", "Cu", "Zn", "Ga", "Ge", "As",
56    "Se", "Br", "Kr", "Rb", "Sr", "Y", "Zr", "Nb", "Mo", "Tc", "Ru", "Rh", "Pd", "Ag", "Cd", "In",
57    "Sn", "Sb", "Te", "I", "Xe", "Cs", "Ba", "La", "Ce", "Pr", "Nd", "Pm", "Sm", "Eu", "Gd", "Tb",
58    "Dy", "Ho", "Er", "Tm", "Yb", "Lu", "Hf", "Ta", "W", "Re", "Os", "Ir", "Pt", "Au", "Hg", "Tl",
59    "Pb", "Bi", "Po", "At", "Rn", "Fr", "Ra", "Ac", "Th", "Pa", "U", "Np", "Pu", "Am", "Cm", "Bk",
60    "Cf", "Es", "Fm", "Md", "No", "Lr", "Rf", "Db", "Sg", "Bh", "Hs", "Mt", "Ds", "Rg", "Cn", "Nh",
61    "Fl", "Mc", "Lv", "Ts", "Og",
62];
63
64#[rustfmt::skip]
65#[derive(Debug, Clone, Copy, PartialEq, Display, EnumString, EnumIter)]
66pub enum AAAtom {
67    N = 0,    CA = 1,   C = 2,    CB = 3,   O = 4,
68    CG = 5,   CG1 = 6,  CG2 = 7,  OG = 8,   OG1 = 9,
69    SG = 10,  CD = 11,  CD1 = 12, CD2 = 13, ND1 = 14,
70    ND2 = 15, OD1 = 16, OD2 = 17, SD = 18,  CE = 19,
71    CE1 = 20, CE2 = 21, CE3 = 22, NE = 23,  NE1 = 24,
72    NE2 = 25, OE1 = 26, OE2 = 27, CH2 = 28, NH1 = 29,
73    NH2 = 30, OH = 31,  CZ = 32,  CZ2 = 33, CZ3 = 34,
74    NZ = 35,  OXT = 36,
75    Unknown = -1,
76}
77impl AAAtom {
78    // Get numeric value (might still be useful in some contexts)
79    pub fn to_index(&self) -> usize {
80        *self as usize
81    }
82}
83
84macro_rules! define_residues {
85    ($($name:ident: $code3:expr_2021, $code1:expr_2021, $idx:expr_2021, $features:expr_2021, $atoms14:expr_2021),* $(,)?) => {
86        #[derive(Debug, Copy, Clone)]
87        pub enum Residue {
88            $($name),*
89        }
90
91        impl Residue {
92            pub const fn code3(&self) -> &'static str {
93                match self {
94                    $(Self::$name => $code3),*
95                }
96            }
97            pub const fn code1(&self) -> char {
98                match self {
99                    $(Self::$name => $code1),*
100                }
101            }
102            pub const fn atoms14(&self) -> [AAAtom; 14] {
103                match self {
104                    $(Self::$name => $atoms14),*
105                }
106            }
107            pub fn from_int(value: i32) -> Self {
108                match value {
109                    $($idx => Self::$name,)*
110                    _ => Self::UNK
111                }
112            }
113            pub fn to_int(&self) -> i32 {
114                match self {
115                    $(Self::$name => $idx),*
116                }
117            }
118        }
119    }
120}
121
122define_residues! {
123    ALA: "ALA", 'A', 0,  [1.0, 0.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::CB, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown],
124    CYS: "CYS", 'C', 1,  [0.0, 0.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::CB, AAAtom::SG, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown],
125    ASP: "ASP", 'D', 2,  [0.0, 0.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::CB, AAAtom::CG, AAAtom::OD1, AAAtom::OD2, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown],
126    GLU: "GLU", 'E', 3,  [0.0, 0.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::CB, AAAtom::CG, AAAtom::CD, AAAtom::OE1, AAAtom::OE2, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown],
127    PHE: "PHE", 'F', 4,  [0.0, 0.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::CB, AAAtom::CG, AAAtom::CD1, AAAtom::CD2, AAAtom::CE1, AAAtom::CE2, AAAtom::CZ, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown],
128    GLY: "GLY", 'G', 5,  [0.0, 0.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown],
129    HIS: "HIS", 'H', 6,  [0.0, 0.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::CB, AAAtom::CG, AAAtom::ND1, AAAtom::CD2, AAAtom::CE1, AAAtom::NE2, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown],
130    ILE: "ILE", 'I', 7,  [0.0, 0.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::CB, AAAtom::CG1, AAAtom::CG2, AAAtom::CD1, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown],
131    LYS: "LYS", 'K', 8,  [0.0, 0.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::CB, AAAtom::CG, AAAtom::CD, AAAtom::CE, AAAtom::NZ, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown],
132    LEU: "LEU", 'L', 9,  [0.0, 0.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::CB, AAAtom::CG, AAAtom::CD1, AAAtom::CD2, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown],
133    MET: "MET", 'M', 10, [0.0, 0.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::CB, AAAtom::CG, AAAtom::SD, AAAtom::CE, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown],
134    ASN: "ASN", 'N', 11, [0.0, 0.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::CB, AAAtom::CG, AAAtom::OD1, AAAtom::ND2, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown],
135    PRO: "PRO", 'P', 12, [0.0, 0.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::CB, AAAtom::CG, AAAtom::CD, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown],
136    GLN: "GLN", 'Q', 13, [0.0, 0.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::CB, AAAtom::CG, AAAtom::CD, AAAtom::OE1, AAAtom::NE2, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown],
137    ARG: "ARG", 'R', 14, [0.0, 1.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::CB, AAAtom::CG, AAAtom::CD, AAAtom::NE, AAAtom::CZ, AAAtom::NH1, AAAtom::NH2, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown],
138    SER: "SER", 'S', 15, [0.0, 0.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::CB, AAAtom::OG, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown],
139    THR: "THR", 'T', 16, [0.0, 0.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::CB, AAAtom::OG1, AAAtom::CG2, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown],
140    VAL: "VAL", 'V', 17, [0.0, 0.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::CB, AAAtom::CG1, AAAtom::CG2, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown],
141    TRP: "TRP", 'W', 18, [0.0, 0.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::CB, AAAtom::CG, AAAtom::CD1, AAAtom::CD2, AAAtom::CE2, AAAtom::CE3, AAAtom::NE1, AAAtom::CZ2, AAAtom::CZ3, AAAtom::CH2],
142    TYR: "TYR", 'Y', 19, [0.0, 0.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::CB, AAAtom::CG, AAAtom::CD1, AAAtom::CD2, AAAtom::CE1, AAAtom::CE2, AAAtom::CZ, AAAtom::OH, AAAtom::Unknown, AAAtom::Unknown],
143    UNK: "UNK", 'X', 20, [0.0, 0.0], [AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown],
144}
145
146pub fn get_nearest_neighbours(
147    cb: &Tensor,
148    mask: &Tensor,
149    y: &Tensor,
150    y_t: &Tensor,
151    y_m: &Tensor,
152    number_of_ligand_atoms: i64,
153) -> Result<(Tensor, Tensor, Tensor, Tensor)> {
154    // First, remove batch dimension if present using squeeze(0)
155    let cb = cb.squeeze(0)?;
156    let mask = mask.squeeze(0)?;
157    let y_m = if y_m.dims().len() > 1 {
158        y_m.sum_keepdim(1)?.squeeze(1)? // or .any(1)? depending on your needs
159    } else {
160        y_m.clone()
161    };
162    let num_residues = cb.dim(0)?;
163    let mask_cby = mask.unsqueeze(1)?.matmul(&y_m.unsqueeze(0)?)?;
164    let cb_flat = cb.reshape((cb.dim(0)?, 1, 3))?; // [154, 1, 3]
165    let y_flat = y.reshape((1, y.dim(0)?, 3))?; // [1, 54, 3]
166    // Try broadcasting manually if needed
167    let cb_broadcast = cb_flat.broadcast_as((cb.dim(0)?, y.dim(0)?, 3))?; // [154, 54, 3]
168    let y_broadcast = y_flat.broadcast_as((cb.dim(0)?, y.dim(0)?, 3))?; // [154, 54, 3]
169    let diff = cb_broadcast.sub(&y_broadcast)?;
170    let l2_ab = diff.powf(2.0)?.sum(D::Minus1)?;
171    let complement_mask = (mask_cby.neg()? + 1.0)?;
172    let padding_value = Tensor::full(1000.0_f32, mask_cby.dims(), cb.device())?;
173    let masked_distances = l2_ab.mul(&mask_cby)?;
174    let padding_contribution = complement_mask.mul(&padding_value)?;
175    let l2_ab = masked_distances.add(&padding_contribution)?;
176
177    // Get nearest neighbors
178    let nn_idx = l2_ab
179        .arg_sort_last_dim(false)?
180        .narrow(1, 0, number_of_ligand_atoms as usize)?
181        .contiguous()?;
182    let l2_ab_nn = l2_ab.contiguous()?.gather(&nn_idx, 1)?;
183    let d_ab_closest = l2_ab_nn.i((.., 0))?.sqrt()?;
184    let y_new = y
185        .unsqueeze(0)?
186        .expand((num_residues, y.dim(0)?, 3))?
187        .contiguous()?
188        .gather(
189            &nn_idx
190                .unsqueeze(2)?
191                .expand((num_residues, number_of_ligand_atoms as usize, 3))?
192                .contiguous()?,
193            1,
194        )?;
195
196    let y_t_new = y_t
197        .unsqueeze(0)?
198        .expand((num_residues, y_t.dim(0)?))?
199        .contiguous()?
200        .gather(&nn_idx, 1)?;
201
202    let y_m_new = y_m
203        .unsqueeze(0)?
204        .expand((num_residues, y_m.dim(0)?))?
205        .contiguous()?
206        .gather(&nn_idx, 1)?;
207
208    Ok((y_new, y_t_new, y_m_new, d_ab_closest))
209}
210
211pub fn cat_neighbors_nodes(
212    h_nodes: &Tensor,
213    h_neighbors: &Tensor,
214    e_idx: &Tensor,
215) -> Result<Tensor> {
216    let h_nodes_gathered = gather_nodes(h_nodes, e_idx)?;
217    let h_neighbors = h_neighbors.expand((
218        h_neighbors.dim(0)?, // 1
219        h_nodes.dim(1)?,     // 93
220        h_neighbors.dim(2)?, // 24
221        h_neighbors.dim(3)?, // 128
222    ))?;
223
224    Tensor::cat(
225        &[h_neighbors, h_nodes_gathered.to_dtype(DType::F32)?],
226        D::Minus1,
227    )
228}
229
230/// Retrieve the nearest Neighbor of a set of coordinates.
231/// Usually used for CA carbon distance.
232pub fn compute_nearest_neighbors(
233    coords: &Tensor,
234    mask: &Tensor,
235    k: usize,
236    eps: f32,
237) -> Result<(Tensor, Tensor)> {
238    let (_batch_size, seq_len, _) = coords.dims3()?;
239    // broadcast_matmul handles broadcasting automatically
240    // [2, 3, 1] × [2, 1, 3] -> [2, 3, 3]
241
242    let mask_2d = mask
243        .unsqueeze(2)?
244        .broadcast_matmul(&mask.unsqueeze(1)?)?
245        .to_dtype(DType::F32)?;
246    // Compute pairwise distances with broadcasting
247
248    let distances = (coords
249        .unsqueeze(2)?
250        .broadcast_sub(&coords.unsqueeze(1)?)?
251        .powf(2.)?
252        .sum(D::Minus1)?
253        + eps as f64)? // also  doesn't have add
254        .sqrt()?
255        .to_dtype(DType::F32)?;
256
257    // Apply mask
258    // Get max values for adjustment
259    let masked_distances = (&distances * &mask_2d.to_dtype(DType::F32)?)?;
260    // println!("after masked_distances");
261    let d_max = masked_distances.max_keepdim(D::Minus1)?;
262    let mask_term = ((&mask_2d.to_dtype(DType::F32)? * -1.0)? + 1.0)?;
263    let d_adjust = (&masked_distances + mask_term.broadcast_mul(&d_max)?)?;
264    let d_adjust = d_adjust.to_dtype(DType::F32)?;
265    Ok(topk_last_dim(&d_adjust, k.min(seq_len))?)
266}
267
268// https://github.com/huggingface/candle/pull/2375/files#diff-e4d52a71060a80ac8c549f2daffcee77f9bf4de8252ad067c47b1c383c3ac828R957
269pub fn topk_last_dim(xs: &Tensor, topk: usize) -> Result<(Tensor, Tensor)> {
270    let sorted_indices = xs.arg_sort_last_dim(false)?.to_dtype(DType::U32)?;
271    let topk_indices = sorted_indices.narrow(D::Minus1, 0, topk)?.contiguous()?;
272    let gathered = xs.gather(&topk_indices, D::Minus1)?;
273    Ok((gathered, topk_indices))
274}
275
276/// Input coords. Output 1 <batch  x 1 > Tensor
277/// representing whether each residue has all 4 backbone atoms.
278/// note that the internal ordering is different between
279/// backbone only [N/CA/C/O] and all-atom [N/CA/C/CB/O]....
280pub fn create_backbone_mask_37(xyz_37: &Tensor) -> Result<Tensor> {
281    let (b, l, rescount, _) = xyz_37.dims4()?;
282    // Create a vector with 1s at positions 0,1,2,4 and 0s elsewhere
283    let mut values = vec![0f32; rescount];
284    for &idx in &[0, 1, 2, 4] {
285        if idx < rescount {
286            values[idx] = 1.0;
287        }
288    }
289
290    // Create the base mask for one sequence, explicitly specifying the data type
291    let base_mask = Tensor::new(values.as_slice(), xyz_37.device())?.to_dtype(DType::F32)?;
292
293    // Create the full mask by repeating it for each batch and length
294    let mask = base_mask.unsqueeze(0)?.unsqueeze(0)?; // Add batch and length dimensions
295    let mask = mask.broadcast_as((b, l, rescount))?;
296
297    Ok(mask)
298}
299/// Get Pseudo CB
300pub fn calculate_cb(xyz_37: &Tensor) -> Result<Tensor> {
301    // make sure we are dealing with
302    let (_, dim37, dim3) = xyz_37.dims3()?;
303    assert_eq!(dim37, 37);
304    assert_eq!(dim3, 3);
305
306    // Constants for CB calculation
307    let a_coeff = -0.58273431f64;
308    let b_coeff = 0.56802827f64;
309    let c_coeff = -0.54067466f64;
310
311    // Get N, CA, C coordinates
312    let n = xyz_37.i((.., 0, ..))?; // N  at index 0
313    let ca = xyz_37.i((.., 1, ..))?; // CA at index 1
314    let c = xyz_37.i((.., 2, ..))?; // C  at index 2
315
316    // Calculate vectors
317    let b = (&ca - &n)?; // CA - N
318    let c = (&c - &ca)?; // C - CA
319
320    // Manual cross product components
321    // a_x = b_y * c_z - b_z * c_y
322    // a_y = b_z * c_x - b_x * c_z
323    // a_z = b_x * c_y - b_y * c_x
324    let b_x = b.i((.., 0))?;
325    let b_y = b.i((.., 1))?;
326    let b_z = b.i((.., 2))?;
327    let c_x = c.i((.., 0))?;
328    let c_y = c.i((.., 1))?;
329    let c_z = c.i((.., 2))?;
330
331    let a_x = ((&b_y * &c_z)? - (&b_z * &c_y)?)?;
332    let a_y = ((&b_z * &c_x)? - (&b_x * &c_z)?)?;
333    let a_z = ((&b_x * &c_y)? - (&b_y * &c_x)?)?;
334
335    // Stack the cross product components back together
336    let a = Tensor::stack(&[&a_x, &a_y, &a_z], 1)?;
337
338    // Final CB calculation: -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + CA
339    let cb = (&a * a_coeff)? + (&b * b_coeff)? + (&c * c_coeff)? + &ca;
340
341    Ok(cb?)
342}
343
344/// Custom Cross-Product Fn.
345pub fn cross_product(a: &Tensor, b: &Tensor) -> Result<Tensor> {
346    let last_dim = a.dims().len() - 1;
347
348    // Extract components
349    let a0 = a.narrow(last_dim, 0, 1)?;
350    let a1 = a.narrow(last_dim, 1, 1)?;
351    let a2 = a.narrow(last_dim, 2, 1)?;
352
353    let b0 = b.narrow(last_dim, 0, 1)?;
354    let b1 = b.narrow(last_dim, 1, 1)?;
355    let b2 = b.narrow(last_dim, 2, 1)?;
356
357    // Compute cross produAAct components
358    let c0 = ((&a1 * &b2)? - (&a2 * &b1)?)?;
359    let c1 = ((&a2 * &b0)? - (&a0 * &b2)?)?;
360    let c2 = ((&a0 * &b1)? - (&a1 * &b0)?)?;
361
362    // Stack the results
363    Tensor::cat(&[&c0, &c1, &c2], last_dim)
364}
365
366/// Gather_edges
367/// Features [B,N,N,C] at Neighbor indices [B,N,K] => Neighbor features [B,N,K,C]
368pub fn gather_edges(edges: &Tensor, neighbor_idx: &Tensor) -> Result<Tensor> {
369    let (d1, d2, d3) = neighbor_idx.dims3()?;
370    let neighbors =
371        neighbor_idx
372            .unsqueeze(D::Minus1)?
373            .expand((d1, d2, d3, edges.dim(D::Minus1)?))?;
374    let edge_gather = edges.gather(&neighbors, 2)?;
375    Ok(edge_gather)
376}
377
378/// Gather Nodes
379///
380/// Features [B,N,C] at Neighbor indices [B,N,K] => [B,N,K,C]
381/// Flatten and expand indices per batch [B,N,K] => [B,NK] => [B,NK,C]
382pub fn gather_nodes(nodes: &Tensor, neighbor_idx: &Tensor) -> Result<Tensor> {
383    let (batch_size, n_nodes, n_features) = nodes.dims3()?;
384    let (_, _, k_neighbors) = neighbor_idx.dims3()?;
385    // Reshape neighbor_idx to [B, N*K]
386    let neighbors_flat = neighbor_idx.reshape((batch_size, n_nodes * k_neighbors))?;
387    // Add feature dimension and expand
388    let neighbors_flat = neighbors_flat
389        .unsqueeze(2)? // Add feature dimension [B, N*K, 1]
390        .expand((batch_size, n_nodes * k_neighbors, n_features))?; // Expand to [B, N*K, C]
391    // make contiguous for the gather.
392    let neighbors_flat = neighbors_flat.contiguous()?;
393    // Gather features
394    let neighbor_features = nodes.gather(&neighbors_flat, 1)?;
395    // Reshape back to [B, N, K, C]
396    neighbor_features.reshape((batch_size, n_nodes, k_neighbors, n_features))
397}
398
399pub fn gather_nodes_t(nodes: &Tensor, neighbor_idx: &Tensor) -> Result<Tensor> {
400    // Features [B,N,C] at Neighbor index [B,K] => Neighbor features[B,K,C]
401    let (d1, d2, d3) = nodes.dims3()?;
402    let idx_flat = neighbor_idx.unsqueeze(D::Minus1)?.expand((d1, d2, d3))?;
403    nodes.gather(&idx_flat, 1)
404}
405
406fn get_seq_rec(s: &Tensor, s_pred: &Tensor, mask: &Tensor) -> Result<Tensor> {
407    // S: true sequence shape=[batch, length]
408    // S_pred: predicted sequence shape=[batch, length]
409    // mask: mask to compute average over the region shape=[batch, length]
410    // Returns: averaged sequence recovery shape=[batch]
411    //
412    // Compute the match tensor
413    let match_tensor = s.eq(s_pred)?;
414    let match_f32 = match_tensor.to_dtype(DType::F32)?;
415    let numerator = (match_f32 * mask)?.sum_keepdim(1)?;
416    let denominator = mask.sum_keepdim(1)?;
417    let average = numerator.broadcast_div(&denominator)?;
418    // Remove the last dimension to get shape=[batch]
419    average.squeeze(1)
420}
421
422fn get_score(s: &Tensor, log_probs: &Tensor, mask: &Tensor) -> Result<(Tensor, Tensor)> {
423    //     S : true sequence shape=[batch, length]
424    //     log_probs : predicted sequence shape=[batch, length]
425    //     mask : mask to compute average over the region shape=[batch, length]
426    //     average_loss : averaged categorical cross entropy (CCE) [batch]
427    //     loss_per_resdue : per position CCE [batch, length]
428
429    //     """
430    //     S_one_hot = torch.nn.functional.one_hot(S, 21)
431    //     loss_per_residue = -(S_one_hot * log_probs).sum(-1)  # [B, L]
432    //     average_loss = torch.sum(loss_per_residue * mask, dim=-1) / (
433    //         torch.sum(mask, dim=-1) + 1e-8
434    //     )
435    //     return average_loss, loss_per_residue
436
437    // S: true sequence shape=[batch, length]
438    // log_probs: predicted sequence shape=[batch, length, 21]
439    // mask: mask to compute average over the region shape=[batch, length]
440    // Returns:
441    //   - average_loss: averaged categorical cross entropy (CCE) [batch]
442    //   - loss_per_residue: per position CCE [batch, length]
443
444    // Create one-hot encoding of S.
445    // see https://docs.rs/candle-nn/0.7.2/candle_nn/encoding/fn.one_hot.html
446    // this could be wrong...
447    let s_one_hot = one_hot(s.clone(), 21, 1., 0.)?;
448    let loss_per_residue = s_one_hot.mul(&log_probs.neg()?)?.sum(D::Minus1)?;
449    let average_loss = loss_per_residue
450        .mul(&mask)?
451        .sum_keepdim(D::Minus1)?
452        .div(&(mask.sum_keepdim(D::Minus1)? + 1e-8f64)?)?
453        .squeeze(D::Minus1)?;
454
455    Ok((average_loss, loss_per_residue))
456}
457
458pub fn linspace(
459    start: f64,
460    stop: f64,
461    steps: usize,
462    device: &Device,
463    return_type: DType,
464) -> Result<Tensor> {
465    if steps == 0 {
466        Tensor::from_vec(Vec::<f64>::new(), steps, device)
467    } else if steps == 1 {
468        Tensor::from_vec(vec![start], steps, device)
469    } else {
470        let delta = (stop - start) / (steps - 1) as f64;
471        let vs = (0..steps)
472            .map(|step| start + step as f64 * delta)
473            .collect::<Vec<_>>();
474        Tensor::from_vec(vs, steps, device)?.to_dtype(return_type)
475    }
476}
477
478pub fn linspace_f32(start: f32, stop: f32, steps: usize, device: &Device) -> Result<Tensor> {
479    match steps {
480        0 => Tensor::from_vec(Vec::<f32>::new(), steps, device),
481        1 => Tensor::from_vec(vec![start], steps, device),
482        _ => {
483            let delta = (stop - start) / (steps - 1) as f32;
484            let vs = (0..steps)
485                .map(|step| start + step as f32 * delta)
486                .collect::<Vec<_>>();
487            Tensor::from_vec(vs, steps, device)
488        }
489    }
490}
491
492#[cfg(test)]
493mod tests {
494    use super::*;
495    use crate::StructureFeatures;
496    use anyhow::Result;
497    use ferritin_core::info::elements::Element;
498    use ferritin_core::load_structure;
499    use ferritin_test_data::TestFile;
500
501    #[test]
502    fn test_residue_codes() {
503        let ala = Residue::ALA;
504        assert_eq!(ala.code3(), "ALA");
505        assert_eq!(ala.code1(), 'A');
506        assert_eq!(ala.to_int(), 0);
507    }
508
509    #[test]
510    fn test_residue_from_int() {
511        assert!(matches!(Residue::from_int(0), Residue::ALA));
512        assert!(matches!(Residue::from_int(1), Residue::CYS));
513        assert!(matches!(Residue::from_int(999), Residue::UNK));
514    }
515
516    #[test]
517    fn test_residue_atoms() {
518        let trp = Residue::TRP;
519        let atoms = trp.atoms14();
520        assert_eq!(atoms[0], AAAtom::N);
521        assert_eq!(atoms[13], AAAtom::CH2);
522
523        let gly = Residue::GLY;
524        let atoms = gly.atoms14();
525        assert_eq!(atoms[4], AAAtom::Unknown);
526    }
527
528    #[test]
529    fn test_atom_backbone_tensor() -> Result<()> {
530        let device = Device::Cpu;
531        let (pdb_file, _temp) = TestFile::protein_01().create_temp()?;
532        let ac = load_structure(pdb_file)?;
533        let ac_backbone_tensor: Tensor = ac.to_numeric_backbone_atoms(&device).expect("REASON");
534        // batch size of 1;154 residues; N/CA/C/O; positions
535        assert_eq!(ac_backbone_tensor.dims(), &[1, 154, 4, 3]);
536
537        // Check my residue coords in the Tensor
538        // ATOM   1    N  N   . MET A 1 1   ? 24.277 8.374   -9.854  1.00 38.41  ? 0   MET A N   1
539        // ATOM   2    C  CA  . MET A 1 1   ? 24.404 9.859   -9.939  1.00 37.90  ? 0   MET A CA  1
540        // ATOM   3    C  C   . MET A 1 1   ? 25.814 10.249  -10.359 1.00 36.65  ? 0   MET A C   1
541        // ATOM   4    O  O   . MET A 1 1   ? 26.748 9.469   -10.197 1.00 37.13  ? 0   MET A O   1
542        let backbone_coords = [
543            // Methionine - AA00
544            ("N", (0, 0, 0, ..), vec![24.277, 8.374, -9.854]),
545            ("CA", (0, 0, 1, ..), vec![24.404, 9.859, -9.939]),
546            ("C", (0, 0, 2, ..), vec![25.814, 10.249, -10.359]),
547            ("O", (0, 0, 3, ..), vec![26.748, 9.469, -10.197]),
548            // Valine - AA01
549            ("N", (0, 1, 0, ..), vec![25.964, 11.453, -10.903]),
550            ("CA", (0, 1, 1, ..), vec![27.263, 11.924, -11.359]),
551            ("C", (0, 1, 2, ..), vec![27.392, 13.428, -11.115]),
552            ("O", (0, 1, 3, ..), vec![26.443, 14.184, -11.327]),
553            // Glycing - AAlast
554            ("N", (0, 153, 0, ..), vec![23.474, -3.227, 5.994]),
555            ("CA", (0, 153, 1, ..), vec![22.818, -2.798, 7.211]),
556            ("C", (0, 153, 2, ..), vec![22.695, -1.282, 7.219]),
557            ("O", (0, 153, 3, ..), vec![21.870, -0.745, 7.992]),
558        ];
559
560        for (atom_name, (b, i, j, k), expected) in backbone_coords {
561            // assert_eq!(ac_backbone_tensor.dims(), &[1, 154, 4, 3])
562            let actual: Vec<f32> = ac_backbone_tensor.i((b, i, j, k))?.to_vec1()?;
563            assert_eq!(actual, expected, "Mismatch for atom {}", atom_name);
564        }
565        Ok(())
566    }
567
568    #[test]
569    fn test_all_atom37_tensor() -> Result<()> {
570        let device = Device::Cpu;
571        let (pdb_file, _temp) = TestFile::protein_01().create_temp()?;
572        let ac = load_structure(pdb_file)?;
573        let ac_backbone_tensor: Tensor = ac.to_numeric_atom37(&device).expect("REASON");
574        assert_eq!(ac_backbone_tensor.dims(), &[1, 154, 37, 3]);
575
576        // Check my residue coords in the Tensor
577        // ATOM   1    N  N   . MET A 1 1   ? 24.277 8.374   -9.854  1.00 38.41  ? 0   MET A N   1
578        // ATOM   2    C  CA  . MET A 1 1   ? 24.404 9.859   -9.939  1.00 37.90  ? 0   MET A CA  1
579        // ATOM   3    C  C   . MET A 1 1   ? 25.814 10.249  -10.359 1.00 36.65  ? 0   MET A C   1
580        // ATOM   4    O  O   . MET A 1 1   ? 26.748 9.469   -10.197 1.00 37.13  ? 0   MET A O   1
581        // ATOM   5    C  CB  . MET A 1 1   ? 24.070 10.495  -8.596  1.00 39.58  ? 0   MET A CB  1
582        // ATOM   6    C  CG  . MET A 1 1   ? 24.880 9.939   -7.442  1.00 41.49  ? 0   MET A CG  1
583        // ATOM   7    S  SD  . MET A 1 1   ? 24.262 10.555  -5.873  1.00 44.70  ? 0   MET A SD  1
584        // ATOM   8    C  CE  . MET A 1 1   ? 24.822 12.266  -5.967  1.00 41.59  ? 0   MET A CE  1
585        //
586        // pub enum AAAtom {
587        //     N = 0,    CA = 1,   C = 2,    CB = 3,   O = 4,
588        //     CG = 5,   CG1 = 6,  CG2 = 7,  OG = 8,   OG1 = 9,
589        //     SG = 10,  CD = 11,  CD1 = 12, CD2 = 13, ND1 = 14,
590        //     ND2 = 15, OD1 = 16, OD2 = 17, SD = 18,  CE = 19,
591        //     CE1 = 20, CE2 = 21, CE3 = 22, NE = 23,  NE1 = 24,
592        //     NE2 = 25, OE1 = 26, OE2 = 27, CH2 = 28, NH1 = 29,
593        //     NH2 = 30, OH = 31,  CZ = 32,  CZ2 = 33, CZ3 = 34,
594        //     NZ = 35,  OXT = 36,
595        //     Unknown = -1,
596        // }
597        let allatom_coords = [
598            // Methionine - AA00
599            // We iterate through these positions. Not all AA's have each
600            ("N", (0, 0, 0, ..), vec![24.277, 8.374, -9.854]),
601            ("CA", (0, 0, 1, ..), vec![24.404, 9.859, -9.939]),
602            ("C", (0, 0, 2, ..), vec![25.814, 10.249, -10.359]),
603            ("CB", (0, 0, 3, ..), vec![24.070, 10.495, -8.596]),
604            ("O", (0, 0, 4, ..), vec![26.748, 9.469, -10.197]),
605            ("CG", (0, 0, 5, ..), vec![24.880, 9.939, -7.442]),
606            ("CG1", (0, 0, 6, ..), vec![0.0, 0.0, 0.0]),
607            ("CG2", (0, 0, 7, ..), vec![0.0, 0.0, 0.0]),
608            ("OG", (0, 0, 8, ..), vec![0.0, 0.0, 0.0]),
609            ("OG1", (0, 0, 9, ..), vec![0.0, 0.0, 0.0]),
610            ("SG", (0, 0, 10, ..), vec![0.0, 0.0, 0.0]),
611            ("CD", (0, 0, 11, ..), vec![0.0, 0.0, 0.0]),
612            ("CD1", (0, 0, 12, ..), vec![0.0, 0.0, 0.0]),
613            ("CD2", (0, 0, 13, ..), vec![0.0, 0.0, 0.0]),
614            ("ND1", (0, 0, 14, ..), vec![0.0, 0.0, 0.0]),
615            ("ND2", (0, 0, 15, ..), vec![0.0, 0.0, 0.0]),
616            ("OD1", (0, 0, 16, ..), vec![0.0, 0.0, 0.0]),
617            ("OD2", (0, 0, 17, ..), vec![0.0, 0.0, 0.0]),
618            ("SD", (0, 0, 18, ..), vec![24.262, 10.555, -5.873]),
619            ("CE", (0, 0, 19, ..), vec![24.822, 12.266, -5.967]),
620            ("CE1", (0, 0, 20, ..), vec![0.0, 0.0, 0.0]),
621            ("CE2", (0, 0, 21, ..), vec![0.0, 0.0, 0.0]),
622            ("CE3", (0, 0, 22, ..), vec![0.0, 0.0, 0.0]),
623            ("NE", (0, 0, 23, ..), vec![0.0, 0.0, 0.0]),
624            ("NE1", (0, 0, 24, ..), vec![0.0, 0.0, 0.0]),
625            ("NE2", (0, 0, 25, ..), vec![0.0, 0.0, 0.0]),
626            ("OE1", (0, 0, 26, ..), vec![0.0, 0.0, 0.0]),
627            ("OE2", (0, 0, 27, ..), vec![0.0, 0.0, 0.0]),
628            ("CH2", (0, 0, 28, ..), vec![0.0, 0.0, 0.0]),
629            ("NH1", (0, 0, 29, ..), vec![0.0, 0.0, 0.0]),
630            ("NH2", (0, 0, 30, ..), vec![0.0, 0.0, 0.0]),
631            ("OH", (0, 0, 31, ..), vec![0.0, 0.0, 0.0]),
632            ("CZ", (0, 0, 32, ..), vec![0.0, 0.0, 0.0]),
633            ("CZ2", (0, 0, 33, ..), vec![0.0, 0.0, 0.0]),
634            ("CZ3", (0, 0, 34, ..), vec![0.0, 0.0, 0.0]),
635            ("NZ", (0, 0, 35, ..), vec![0.0, 0.0, 0.0]),
636            ("OXT", (0, 0, 36, ..), vec![0.0, 0.0, 0.0]),
637        ];
638        for (atom_name, (b, i, j, k), expected) in allatom_coords {
639            let actual: Vec<f32> = ac_backbone_tensor
640                .i((b, i, j, k))?
641                .to_vec1()?;
642            assert_eq!(actual, expected, "Mismatch for atom {}", atom_name);
643        }
644        Ok(())
645    }
646
647    #[test]
648    fn test_ligand_tensor() -> Result<()> {
649        let device = Device::Cpu;
650        let (pdb_file, _temp) = TestFile::protein_01().create_temp()?;
651        let ac = load_structure(pdb_file)?;
652        let (ligand_coords, ligand_elements, _) =
653            ac.to_numeric_ligand_atoms(&device).expect("REASON");
654        // 154 residues; 54 other atoms.
655        assert_eq!(ligand_coords.dims(), &[1, 154, 54, 3]);
656        // Check my residue coords in the Tensor
657        //
658        // HETATM 1222 S  S   . SO4 B 2 .   ? 30.746 18.706  28.896  1.00 47.98  ? 157 SO4 A S   1
659        // HETATM 1223 O  O1  . SO4 B 2 .   ? 30.697 20.077  28.620  1.00 48.06  ? 157 SO4 A O1  1
660        // HETATM 1224 O  O2  . SO4 B 2 .   ? 31.104 18.021  27.725  1.00 47.52  ? 157 SO4 A O2  1
661        // HETATM 1225 O  O3  . SO4 B 2 .   ? 29.468 18.179  29.331  1.00 47.79  ? 157 SO4 A O3  1
662        // HETATM 1226 O  O4  . SO4 B 2 .   ? 31.722 18.578  29.881  1.00 47.85  ? 157 SO4 A O4  1
663        let allatom_coords = [
664            ("S", (0, 0, 0, ..), vec![30.746, 18.706, 28.896]),
665            ("O1", (0, 0, 1, ..), vec![30.697, 20.077, 28.620]),
666            ("O2", (0, 0, 2, ..), vec![31.104, 18.021, 27.725]),
667            ("O3", (0, 0, 3, ..), vec![29.468, 18.179, 29.331]),
668            ("O4", (0, 0, 4, ..), vec![31.722, 18.578, 29.881]),
669        ];
670        for (atom_name, (b, l, i, j), expected) in allatom_coords {
671            let actual: Vec<f32> = ligand_coords.i((b, l, i, ..))?.to_vec1()?;
672            assert_eq!(actual, expected, "Mismatch for atom {}", atom_name);
673        }
674
675        // Now check the elements
676        let elements: Vec<&str> = ligand_elements
677            .i((0, 0, ..))?
678            .to_vec1::<i64>()?
679            .into_iter()
680            .map(|elem| Element::new(elem as usize).unwrap().symbol())
681            .collect();
682
683        assert_eq!(elements[0], "S");
684        assert_eq!(elements[1], "O");
685        assert_eq!(elements[2], "O");
686        assert_eq!(elements[3], "O");
687
688        Ok(())
689    }
690
691    #[test]
692    fn test_backbone_tensor() {
693        let device = Device::Cpu;
694        let (pdb_file, _temp) = TestFile::protein_01().create_temp().unwrap();
695        let ac = load_structure(pdb_file).unwrap();
696        let xyz_37 = ac
697            .to_numeric_atom37(&device)
698            .expect("XYZ creation for all-atoms");
699        assert_eq!(xyz_37.dims(), [1, 154, 37, 3]);
700
701        // # xyz_37_m = feature_dict["xyz_37_m"] #[B,L,37] - mask for all coords
702        let xyz_m = create_backbone_mask_37(&xyz_37).expect("masking procedure should work");
703        assert_eq!(xyz_m.dims(), &[1, 154, 37]);
704    }
705    #[test]
706    fn test_compute_nearest_neighbors() {
707        let device = Device::Cpu;
708        let test_dtype = DType::F32;
709
710        // Create a simple 2x3x3 tensor representing 2 sequences of 3 points in 3D space
711        let coords = Tensor::new(
712            &[
713                [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [2.0, 0.0, 0.0]], // First sequence
714                [[0.0, 1.0, 0.0], [1.0, 1.0, 0.0], [2.0, 1.0, 0.0]], // Second sequence
715            ],
716            &device,
717        )
718        .unwrap()
719        .to_dtype(test_dtype)
720        .unwrap();
721
722        // Create mask indicating all points are valid
723        let mask = Tensor::ones((2, 3), test_dtype, &device).unwrap();
724
725        // Get 2 nearest neighbors for each point
726        let (distances, indices) = compute_nearest_neighbors(&coords, &mask, 2, 1e-6).unwrap();
727
728        // Check shapes
729        assert_eq!(distances.dims(), &[2, 3, 2]); // [batch, seq_len, k]
730        assert_eq!(indices.dims(), &[2, 3, 2]); // [batch, seq_len, k]
731
732        // For first sequence, point [1,0,0] should have [0,0,0] and [2,0,0] as nearest neighbors
733        let point_neighbors: Vec<u32> = indices.i((0, 1, ..)).unwrap().to_vec1().unwrap();
734        assert_eq!(point_neighbors, vec![0, 2]);
735
736        // Check distances are correct
737        let point_distances: Vec<f32> = distances.i((0, 1, ..)).unwrap().to_vec1().unwrap();
738        assert!((point_distances[0] - 1.0).abs() < 1e-5);
739        assert!((point_distances[1] - 1.0).abs() < 1e-5);
740    }
741}