ferritin_onnx_models/models/ligandmpnn/
mod.rs

1//! Module for running Ligand- and Protein-MPNN Models
2//!
3//! This module provides functionality for running LigandMPNN and ProteinMPNN models
4//! to predict amino acid sequences given protein structure coordinates and ligand information.
5//!
6//! The models are loaded from the Hugging Face model hub and executed using ONNX Runtime.
7//!
8//!
9use crate::{ndarray_to_tensor_f32, tensor_to_ndarray_f32, tensor_to_ndarray_i64};
10use anyhow::Result;
11use candle_core::{Device, Tensor};
12use candle_nn::ops;
13use ferritin_core::AtomCollection;
14use ferritin_plms::featurize::StructureFeatures;
15use ferritin_plms::featurize::utilities::int_to_aa1;
16use ferritin_plms::types::PseudoProbability;
17use hf_hub::api::sync::Api;
18use ndarray::ArrayBase;
19use ort::{
20    execution_providers::CUDAExecutionProvider,
21    session::{
22        Session,
23        builder::{GraphOptimizationLevel, SessionBuilder},
24    },
25};
26use std::path::PathBuf;
27
28type NdArrayF32 = ArrayBase<ndarray::OwnedRepr<f32>, ndarray::Dim<ndarray::IxDynImpl>>;
29type NdArrayI64 = ArrayBase<ndarray::OwnedRepr<i64>, ndarray::Dim<ndarray::IxDynImpl>>;
30
31pub enum ModelType {
32    Protein,
33    Ligand,
34}
35
36impl ModelType {
37    pub fn get_paths(&self) -> (&'static str, &'static str, &'static str) {
38        match self {
39            ModelType::Protein => (
40                "zcpbx/proteinmpnn-v48-030-onnx",
41                "protmpnn_encoder.onnx",
42                "protmpnn_decoder_step.onnx",
43            ),
44            ModelType::Ligand => (
45                "zcpbx/ligandmpnn-v32-030-25-onnx",
46                "ligand_encoder.onnx",
47                "ligand_decoder.onnx",
48            ),
49        }
50    }
51}
52
53pub struct LigandMPNN {
54    session: SessionBuilder,
55    encoder_path: PathBuf,
56    decoder_path: PathBuf,
57}
58
59impl LigandMPNN {
60    pub fn new() -> Result<Self> {
61        let session = Self::create_session()?;
62        let (encoder_path, decoder_path) = Self::load_model_paths(ModelType::Ligand)?;
63        Ok(Self {
64            session,
65            encoder_path,
66            decoder_path,
67        })
68    }
69    fn create_session() -> Result<SessionBuilder> {
70        ort::init()
71            .with_name("LigandMPNN")
72            .with_execution_providers([CUDAExecutionProvider::default().build()])
73            .commit()?;
74        Ok(Session::builder()?
75            .with_optimization_level(GraphOptimizationLevel::Level1)?
76            .with_intra_threads(1)?)
77    }
78    fn load_model_paths(model_type: ModelType) -> Result<(PathBuf, PathBuf)> {
79        let api = Api::new()?;
80        let (repo_id, encoder_name, decoder_name) = model_type.get_paths();
81        Ok((
82            api.model(repo_id.to_string()).get(&encoder_name)?,
83            api.model(repo_id.to_string()).get(&decoder_name)?,
84        ))
85    }
86    pub fn run_model(&self, ac: AtomCollection, position: i64, temperature: f32) -> Result<Tensor> {
87        let (h_v, h_e, e_idx) = self.run_encoder(&ac)?;
88        self.run_decoder(h_v, h_e, e_idx, temperature, position)
89    }
90    pub fn run_encoder(&self, ac: &AtomCollection) -> Result<(NdArrayF32, NdArrayF32, NdArrayI64)> {
91        let device = Device::Cpu;
92        let encoder_model = self.session.clone().commit_from_file(&self.encoder_path)?;
93        let x_bb = ac.to_numeric_backbone_atoms(&device)?;
94        let (lig_coords, lig_elements, lig_mask) = ac.to_numeric_ligand_atoms(&device)?;
95        let coords_nd = tensor_to_ndarray_f32(x_bb)?;
96        let lig_coords_nd = tensor_to_ndarray_f32(lig_coords)?;
97        let lig_types_nd = tensor_to_ndarray_i64(lig_elements)?;
98        let lig_mask_nd = tensor_to_ndarray_f32(lig_mask)?;
99        let encoder_inputs = ort::inputs![
100            "coords" => coords_nd,
101            "ligand_coords" => lig_coords_nd,
102            "ligand_types" => lig_types_nd,
103            "ligand_mask" => lig_mask_nd
104        ]?;
105        let encoder_outputs = encoder_model.run(encoder_inputs)?;
106        Ok((
107            encoder_outputs["h_V"]
108                .try_extract_tensor::<f32>()?
109                .to_owned(),
110            encoder_outputs["h_E"]
111                .try_extract_tensor::<f32>()?
112                .to_owned(),
113            encoder_outputs["E_idx"]
114                .try_extract_tensor::<i64>()?
115                .to_owned(),
116        ))
117    }
118    pub fn run_decoder(
119        &self,
120        h_v: NdArrayF32,
121        h_e: NdArrayF32,
122        e_idx: NdArrayI64,
123        temperature: f32,
124        position: i64,
125    ) -> Result<Tensor> {
126        let decoder_model = self.session.clone().commit_from_file(&self.decoder_path)?;
127        let position_tensor =
128            ort::value::Tensor::from_array(ndarray::Array::from_shape_vec([1], vec![position])?)?;
129        let temp_tensor = ort::value::Tensor::from_array(ndarray::Array::from_shape_vec(
130            [1],
131            vec![temperature],
132        )?)?;
133        let decoder_inputs = ort::inputs![
134            "h_v" => h_v,
135            "h_e" => h_e,
136            "e_idx" => e_idx,
137            "position" => position_tensor,
138            "temperature" => temp_tensor,
139        ]?;
140
141        let decoder_outputs = decoder_model.run(decoder_inputs)?;
142        let logits = decoder_outputs["logits"]
143            .try_extract_tensor::<f32>()?
144            .to_owned();
145        ndarray_to_tensor_f32(logits)
146    }
147    pub fn get_single_location(
148        &self,
149        ac: AtomCollection,
150        temp: f32,
151        position: i64,
152    ) -> Result<Vec<PseudoProbability>> {
153        let logits = self.run_model(ac, position, temp)?;
154        let logits = ops::softmax(&logits, 1)?;
155        let logits = logits.get(0)?.to_vec1()?;
156        let mut amino_acid_probs = Vec::new();
157        for i in 0..21 {
158            amino_acid_probs.push(PseudoProbability {
159                amino_acid: int_to_aa1(i),
160                pseudo_prob: logits[i as usize],
161                position: position as usize,
162            });
163        }
164        Ok(amino_acid_probs)
165    }
166    pub fn get_all_locations(&self, temp: f32) -> Result<Vec<PseudoProbability>> {
167        todo!()
168    }
169}
170
171#[cfg(test)]
172mod tests {
173    use super::*;
174    use ferritin_core::load_structure;
175    use ferritin_test_data::TestFile;
176    fn setup_test_data() -> AtomCollection {
177        let (protfile, _handle) = TestFile::protein_01().create_temp().unwrap();
178        load_structure(protfile).unwrap()
179    }
180
181    #[test]
182    fn test_model_initialization() {
183        let model = LigandMPNN::new().unwrap();
184        assert!(model.encoder_path.exists());
185        assert!(model.decoder_path.exists());
186    }
187
188    #[test]
189    fn test_encoder_output_dimensions() -> Result<()> {
190        let model = LigandMPNN::new()?;
191        let ac = setup_test_data();
192
193        let (h_v, h_e, e_idx) = model.run_encoder(&ac)?;
194        println!("h_v shape: {:?}", h_v.shape());
195        println!("h_e shape: {:?}", h_e.shape());
196        println!("e_idx shape: {:?}", e_idx.shape());
197
198        assert_eq!(h_v.shape(), &[1, 154, 128]); // getting: 4 ([1, 154, 4, 3])
199        assert_eq!(h_e.shape(), &[1, 154, 16, 128]);
200        assert_eq!(e_idx.shape(), &[1, 154, 16]);
201        Ok(())
202    }
203
204    // #[test]
205    // fn test_full_pipeline() -> Result<()> {
206    //     let model = LigandMPNN::new().unwrap();
207    //     let ac = setup_test_data();
208    //     let logits = model.run_model(ac, 10, 0.1).unwrap();
209    //     assert_eq!(logits.dims2().unwrap(), (1, 21));
210    //     Ok(())
211    // }
212}