Skip to main content

ferritin_onnx_models/models/ligandmpnn/
mod.rs

1//! Module for running Ligand- and Protein-MPNN Models
2//!
3//! This module provides functionality for running LigandMPNN and ProteinMPNN models
4//! to predict amino acid sequences given protein structure coordinates and ligand information.
5//!
6//! The models are loaded from the Hugging Face model hub and executed using ONNX Runtime.
7//!
8//!
9use crate::{ndarray_to_tensor_f32, tensor_to_ndarray_f32, tensor_to_ndarray_i64};
10use anyhow::{Result, anyhow};
11use candle_core::Tensor;
12use candle_nn::ops;
13use ferritin_core::AtomCollection;
14use ferritin_plms::device;
15use ferritin_plms::featurize::StructureFeatures;
16use ferritin_plms::featurize::utilities::int_to_aa1;
17use ferritin_plms::types::PseudoProbability;
18use hf_hub::api::sync::Api;
19use ndarray::ArrayBase;
20use ort::{
21    execution_providers::CUDAExecutionProvider,
22    session::{
23        Session,
24        builder::{GraphOptimizationLevel, SessionBuilder},
25    },
26    value::Tensor as OrtTensor,
27};
28use std::path::PathBuf;
29
30type NdArrayF32 = ArrayBase<ndarray::OwnedRepr<f32>, ndarray::Dim<ndarray::IxDynImpl>>;
31type NdArrayI64 = ArrayBase<ndarray::OwnedRepr<i64>, ndarray::Dim<ndarray::IxDynImpl>>;
32
33pub enum ModelType {
34    Protein,
35    Ligand,
36}
37
38impl ModelType {
39    pub fn get_paths(&self) -> (&'static str, &'static str, &'static str) {
40        match self {
41            ModelType::Protein => (
42                "zcpbx/proteinmpnn-v48-030-onnx",
43                "protmpnn_encoder.onnx",
44                "protmpnn_decoder_step.onnx",
45            ),
46            ModelType::Ligand => (
47                "zcpbx/ligandmpnn-v32-030-25-onnx",
48                "ligand_encoder.onnx",
49                "ligand_decoder.onnx",
50            ),
51        }
52    }
53}
54
55pub struct LigandMPNN {
56    session: SessionBuilder,
57    encoder_path: PathBuf,
58    decoder_path: PathBuf,
59}
60
61impl LigandMPNN {
62    pub fn new() -> Result<Self> {
63        let session = Self::create_session()?;
64        let (encoder_path, decoder_path) = Self::load_model_paths(ModelType::Ligand)?;
65        Ok(Self {
66            session,
67            encoder_path,
68            decoder_path,
69        })
70    }
71
72    fn create_session() -> Result<SessionBuilder> {
73        ort::init()
74            .with_name("LigandMPNN")
75            .with_execution_providers([CUDAExecutionProvider::default().build()])
76            .commit();
77        Ok(Session::builder()?
78            .with_optimization_level(GraphOptimizationLevel::Level1)
79            .map_err(|e| anyhow!("ORT session config error: {}", e))?
80            .with_intra_threads(1)
81            .map_err(|e| anyhow!("ORT session config error: {}", e))?)
82    }
83
84    fn load_model_paths(model_type: ModelType) -> Result<(PathBuf, PathBuf)> {
85        let api = Api::new()?;
86        let (repo_id, encoder_name, decoder_name) = model_type.get_paths();
87        Ok((
88            api.model(repo_id.to_string()).get(&encoder_name)?,
89            api.model(repo_id.to_string()).get(&decoder_name)?,
90        ))
91    }
92
93    pub fn run_model(&self, ac: AtomCollection, position: i64, temperature: f32) -> Result<Tensor> {
94        let (h_v, h_e, e_idx) = self.run_encoder(&ac)?;
95        self.run_decoder(h_v, h_e, e_idx, temperature, position)
96    }
97
98    pub fn run_encoder(&self, ac: &AtomCollection) -> Result<(NdArrayF32, NdArrayF32, NdArrayI64)> {
99        let device = device(false)?;
100        let mut encoder_model = self.session.clone().commit_from_file(&self.encoder_path)?;
101        let x_bb = ac.to_numeric_backbone_atoms(&device)?;
102        let (lig_coords, lig_elements, lig_mask) = ac.to_numeric_ligand_atoms(&device)?;
103
104        let coords_nd = tensor_to_ndarray_f32(x_bb)?;
105        let lig_coords_nd = tensor_to_ndarray_f32(lig_coords)?;
106        let lig_types_nd = tensor_to_ndarray_i64(lig_elements)?;
107        let lig_mask_nd = tensor_to_ndarray_f32(lig_mask)?;
108
109        let encoder_inputs = ort::inputs![
110            "coords" => OrtTensor::from_array(coords_nd)?,
111            "ligand_coords" => OrtTensor::from_array(lig_coords_nd)?,
112            "ligand_types" => OrtTensor::from_array(lig_types_nd)?,
113            "ligand_mask" => OrtTensor::from_array(lig_mask_nd)?
114        ];
115
116        let encoder_outputs = encoder_model.run(encoder_inputs)?;
117        Ok((
118            encoder_outputs["h_V"]
119                .try_extract_array::<f32>()?
120                .to_owned(),
121            encoder_outputs["h_E"]
122                .try_extract_array::<f32>()?
123                .to_owned(),
124            encoder_outputs["E_idx"]
125                .try_extract_array::<i64>()?
126                .to_owned(),
127        ))
128    }
129
130    pub fn run_decoder(
131        &self,
132        h_v: NdArrayF32,
133        h_e: NdArrayF32,
134        e_idx: NdArrayI64,
135        temperature: f32,
136        position: i64,
137    ) -> Result<Tensor> {
138        let mut decoder_model = self.session.clone().commit_from_file(&self.decoder_path)?;
139
140        let position_tensor =
141            OrtTensor::from_array(ndarray::Array::from_shape_vec([1], vec![position])?)?;
142        let temp_tensor =
143            OrtTensor::from_array(ndarray::Array::from_shape_vec([1], vec![temperature])?)?;
144
145        let decoder_inputs = ort::inputs![
146            "h_v" => OrtTensor::from_array(h_v)?,
147            "h_e" => OrtTensor::from_array(h_e)?,
148            "e_idx" => OrtTensor::from_array(e_idx)?,
149            "position" => position_tensor,
150            "temperature" => temp_tensor
151        ];
152
153        let decoder_outputs = decoder_model.run(decoder_inputs)?;
154        let logits = decoder_outputs["logits"]
155            .try_extract_array::<f32>()? // Changed from try_extract_tensor
156            .to_owned();
157        ndarray_to_tensor_f32(logits)
158    }
159
160    pub fn get_single_location(
161        &self,
162        ac: AtomCollection,
163        temp: f32,
164        position: i64,
165    ) -> Result<Vec<PseudoProbability>> {
166        let logits = self.run_model(ac, position, temp)?;
167        let logits = ops::softmax(&logits, 1)?;
168        let logits = logits.get(0)?.to_vec1()?;
169        let mut amino_acid_probs = Vec::new();
170        for i in 0..21 {
171            amino_acid_probs.push(PseudoProbability {
172                amino_acid: int_to_aa1(i),
173                pseudo_prob: logits[i as usize],
174                position: position as usize,
175            });
176        }
177        Ok(amino_acid_probs)
178    }
179
180    pub fn get_all_locations(&self, temp: f32) -> Result<Vec<PseudoProbability>> {
181        todo!()
182    }
183}
184
185#[cfg(test)]
186mod tests {
187    use super::*;
188    use ferritin_core::load_structure;
189    use ferritin_test_data::TestFile;
190
191    fn setup_test_data() -> AtomCollection {
192        let (protfile, _handle) = TestFile::protein_01().create_temp().unwrap();
193        load_structure(protfile).unwrap()
194    }
195
196    #[test]
197    fn test_model_initialization() {
198        let model = LigandMPNN::new().unwrap();
199        assert!(model.encoder_path.exists());
200        assert!(model.decoder_path.exists());
201    }
202
203    #[test]
204    fn test_encoder_output_dimensions() -> Result<()> {
205        let model = LigandMPNN::new()?;
206        let ac = setup_test_data();
207        println!("Data is setup");
208
209        let (h_v, h_e, e_idx) = model.run_encoder(&ac)?;
210        println!("h_v shape: {:?}", h_v.shape());
211        println!("h_e shape: {:?}", h_e.shape());
212        println!("e_idx shape: {:?}", e_idx.shape());
213
214        assert_eq!(h_v.shape(), &[1, 154, 128]); // getting: 4 ([1, 154, 4, 3])
215        assert_eq!(h_e.shape(), &[1, 154, 16, 128]);
216        assert_eq!(e_idx.shape(), &[1, 154, 16]);
217        Ok(())
218    }
219
220    // #[test]
221    // fn test_full_pipeline() -> Result<()> {
222    //     let model = LigandMPNN::new().unwrap();
223    //     let ac = setup_test_data();
224    //     let logits = model.run_model(ac, 10, 0.1).unwrap();
225    //     assert_eq!(logits.dims2().unwrap(), (1, 21));
226    //     Ok(())
227    // }
228}