ferritin_core/
atomcollection.rs

1//! AtomCollection
2//!
3//! An AtomCollection is primarily a group of atoms with some atomic properties like coordinates, element type
4//! and residue information. Additional data like bonds can be added post-instantiation.
5//! The data for residues within this collection can be iterated through. Other useful queries like inter-atomic
6//! distances are supported.
7use super::bonds::Bond;
8use super::info::constants::get_bonds_canonical20;
9use super::views::chain::ChainView;
10use super::views::residue::ResidueView;
11use crate::info::elements::Element;
12use itertools::{Itertools, izip};
13
14/// Atom Collection
15///
16/// The core data structure of ferritin-core.
17///
18/// it strives to be simple, high performance, and extensible using
19/// traits.
20#[derive(Clone)]
21pub struct AtomCollection {
22    size: usize,
23    coords: Vec<[f32; 3]>,
24    res_ids: Vec<i32>,
25    res_names: Vec<String>,
26    is_hetero: Vec<bool>,
27    elements: Vec<Element>,
28    atom_names: Vec<String>,
29    chain_ids: Vec<String>,
30    bonds: Option<Vec<Bond>>,
31    residue_start_indices: Option<Vec<i32>>,
32    chain_start_indices: Option<Vec<i32>>,
33}
34
35impl AtomCollection {
36    pub fn new(
37        size: usize,
38        coords: Vec<[f32; 3]>,
39        res_ids: Vec<i32>,
40        res_names: Vec<String>,
41        is_hetero: Vec<bool>,
42        elements: Vec<Element>,
43        atom_names: Vec<String>,
44        chain_ids: Vec<String>,
45        bonds: Option<Vec<Bond>>,
46    ) -> Self {
47        let mut ac = AtomCollection {
48            size,
49            coords,
50            res_ids,
51            res_names,
52            is_hetero,
53            elements,
54            atom_names,
55            chain_ids,
56            bonds,
57            residue_start_indices: None,
58            chain_start_indices: None,
59        };
60        ac.calculate_chain_indices();
61        ac
62    }
63    // Calculate and cache chain start indices
64    pub fn calculate_chain_indices(&mut self) {
65        if self.chain_start_indices.is_none() {
66            if self.residue_start_indices.is_none() {
67                let residue_starts = self.get_residue_starts();
68                self.residue_start_indices =
69                    Some(residue_starts.iter().map(|&idx| idx as i32).collect());
70            }
71
72            // Get chain starts as residue indices
73            let residue_starts = self.residue_start_indices.as_ref().unwrap();
74            let chain_starts: Vec<i32> = self
75                .get_chain_starts()
76                .iter()
77                .map(|&atom_idx| {
78                    // Find the residue index that contains this atom
79                    let residue_idx = residue_starts
80                        .iter()
81                        .enumerate()
82                        .filter(|&(_, &res_start)| res_start as usize <= atom_idx)
83                        .last()
84                        .map(|(i, _)| i as i32)
85                        .unwrap_or(0);
86                    residue_idx
87                })
88                .collect();
89
90            self.chain_start_indices = Some(chain_starts);
91        }
92    }
93    pub fn calculate_displacement(&self) {
94        // Measure the displacement vector, i.e. the vector difference, from
95        // one array of atom coordinates to another array of coordinates.
96        unimplemented!()
97    }
98    pub fn calculate_distance(&self, _atoms: AtomCollection) {
99        // def distance(atoms1, atoms2, box=None):
100        // """
101        // Measure the euclidian distance between atoms.
102
103        // Parameters
104        // ----------
105        // atoms1, atoms2 : ndarray or Atom or AtomArray or AtomArrayStack
106        //     The atoms to measure the distances between.
107        //     The dimensions may vary.
108        //     Alternatively, a ndarray containing the coordinates can be
109        //     provided.
110        //     Usual *NumPy* broadcasting rules apply.
111        // box : ndarray, shape=(3,3) or shape=(m,3,3), optional
112        //     If this parameter is set, periodic boundary conditions are
113        //     taken into account (minimum-image convention), based on
114        //     the box vectors given with this parameter.
115        //     The shape *(m,3,3)* is only allowed, when the input coordinates
116        //     comprise multiple models.
117
118        // Returns
119        // -------
120        // dist : float or ndarray
121        //     The atom distances.
122        //     The shape is equal to the shape of the input `atoms` with the
123        //     highest dimensionality minus the last axis.
124
125        // See also
126        // --------
127        // index_distance
128        // """
129        // diff = displacement(atoms1, atoms2, box)
130        // return np.sqrt(vector_dot(diff, diff))
131        unimplemented!()
132    }
133    pub fn connect_via_residue_names(&mut self) {
134        if self.bonds.is_some() {
135            println!("Bonds already in place. Not overwriting.");
136            return;
137        }
138        let aa_bond_info = get_bonds_canonical20();
139        let residue_starts = self.get_residue_starts();
140        let mut bonds = Vec::new();
141        for res_i in 0..residue_starts.len() - 1 {
142            let curr_start_i = residue_starts[res_i] as usize;
143            let next_start_i = residue_starts[res_i + 1] as usize;
144            if let Some(bond_dict_for_res) =
145                aa_bond_info.get(&self.res_names[curr_start_i].as_str())
146            {
147                for &(atom_name1, atom_name2, bond_type) in bond_dict_for_res {
148                    let atom_indices1: Vec<usize> = (curr_start_i..next_start_i)
149                        .filter(|&i| self.atom_names[i] == atom_name1)
150                        .collect();
151                    let atom_indices2: Vec<usize> = (curr_start_i..next_start_i)
152                        .filter(|&i| self.atom_names[i] == atom_name2)
153                        .collect();
154                    // Create all possible bond combinations
155                    for &i in &atom_indices1 {
156                        for &j in &atom_indices2 {
157                            bonds.push(Bond::new(i as i32, j as i32, bond_type));
158                        }
159                    }
160                }
161            }
162        }
163        self.bonds = Some(bonds);
164    }
165    pub fn connect_via_distance(&self) -> Vec<Bond> {
166        // note: was intending to follow Biotite's algo
167        unimplemented!()
168    }
169    pub fn get_size(&self) -> usize {
170        self.size
171    }
172    pub fn get_atom_name(&self, idx: usize) -> &String {
173        &self.atom_names[idx]
174    }
175    pub fn get_bonds(&self) -> Option<&Vec<Bond>> {
176        self.bonds.as_ref()
177    }
178    pub fn get_chain_id(&self, idx: usize) -> &String {
179        &self.chain_ids[idx]
180    }
181    pub fn get_coord(&self, idx: usize) -> &[f32; 3] {
182        &self.coords[idx]
183    }
184    pub fn get_coords(&self) -> &Vec<[f32; 3]> {
185        self.coords.as_ref()
186    }
187    pub fn get_element(&self, idx: usize) -> &Element {
188        &self.elements[idx]
189    }
190    pub fn get_elements(&self) -> &Vec<Element> {
191        self.elements.as_ref()
192    }
193    pub fn get_is_hetero(&self, idx: usize) -> bool {
194        self.is_hetero[idx]
195    }
196    pub fn get_resnames(&self) -> &Vec<String> {
197        self.res_names.as_ref()
198    }
199    pub fn get_res_id(&self, idx: usize) -> &i32 {
200        &self.res_ids[idx]
201    }
202    pub fn get_resids(&self) -> &Vec<i32> {
203        self.res_ids.as_ref()
204    }
205    pub fn get_res_name(&self, idx: usize) -> &String {
206        &self.res_names[idx]
207    }
208    /// A new residue starts, either when the chain ID, residue ID,
209    /// insertion code or residue name changes from one to the next atom.
210    fn get_residue_starts(&self) -> Vec<i64> {
211        let mut starts = vec![0];
212
213        starts.extend(
214            izip!(&self.res_ids, &self.res_names, &self.chain_ids)
215                .tuple_windows()
216                .enumerate()
217                .filter_map(
218                    |(i, ((res_id1, name1, chain1), (res_id2, name2, chain2)))| {
219                        if res_id1 != res_id2 || name1 != name2 || chain1 != chain2 {
220                            Some((i + 1) as i64)
221                        } else {
222                            None
223                        }
224                    },
225                ),
226        );
227        starts
228    }
229    pub fn get_residue_start_indices(&self) -> Option<&Vec<i32>> {
230        self.residue_start_indices.as_ref()
231    }
232    /// A new chain starts when the chain ID changes from one atom to the next.
233    fn get_chain_starts(&self) -> Vec<usize> {
234        let mut starts = vec![0];
235        starts.extend(
236            self.chain_ids
237                .iter()
238                .tuple_windows()
239                .enumerate()
240                .filter_map(
241                    |(i, (chain1, chain2))| {
242                        if chain1 != chain2 { Some(i + 1) } else { None }
243                    },
244                ),
245        );
246        starts
247    }
248
249    pub fn iter_coords_and_elements(&self) -> impl Iterator<Item = (&[f32; 3], &Element)> {
250        izip!(&self.coords, &self.elements)
251    }
252
253    pub fn iter_chains(&self) -> impl Iterator<Item = ChainView<'_>> {
254        // Make sure indices are calculated
255        let chain_starts = match &self.chain_start_indices {
256            Some(indices) => indices.clone(),
257            None => Vec::new(),
258        };
259
260        (0..chain_starts.len()).map(move |i| {
261            let start_residue_idx = chain_starts[i] as usize;
262            let end_residue_idx = if i + 1 < chain_starts.len() {
263                chain_starts[i + 1] as usize
264            } else {
265                // If it's the last chain, go to the end of the structure
266                match &self.residue_start_indices {
267                    Some(indices) => indices.len(),
268                    None => self.size,
269                }
270            };
271
272            ChainView {
273                data: self,
274                start_residue_idx,
275                end_residue_idx,
276            }
277        })
278    }
279    pub fn iter_residues(&self) -> impl Iterator<Item = ResidueView<'_>> {
280        let residue_starts = self.get_residue_starts();
281        let atom_starts: Vec<usize> = residue_starts.iter().map(|&idx| idx as usize).collect();
282        let atom_size = self.get_size();
283        // Create a copy of the last element if it exists
284        // Generate pairs for all residues
285        let last_atom_idx = atom_starts.last().copied();
286        (0..atom_starts.len().saturating_sub(1))
287            .map(move |i| ResidueView::new(self, atom_starts[i], atom_starts[i + 1]))
288            .chain(
289                last_atom_idx
290                    .map(|idx| ResidueView::new(self, idx, atom_size))
291                    .into_iter(),
292            )
293    }
294    /// Iterates over amino acid residues in the collection
295    ///
296    /// Returns a filtered iterator that only includes standard amino acid residues
297    pub fn iter_residues_aminoacid(&self) -> impl Iterator<Item = ResidueView<'_>> {
298        self.iter_residues()
299            .filter(|residue| residue.is_amino_acid())
300    }
301}