ferritin_onnx_models/models/ligandmpnn/
mod.rs

1//! Module for running Ligand- and Protein-MPNN Models
2//!
3//! This module provides functionality for running LigandMPNN and ProteinMPNN models
4//! to predict amino acid sequences given protein structure coordinates and ligand information.
5//!
6//! The models are loaded from the Hugging Face model hub and executed using ONNX Runtime.
7//!
8//!
9use crate::{ndarray_to_tensor_f32, tensor_to_ndarray_f32, tensor_to_ndarray_i64};
10use anyhow::Result;
11use candle_core::Tensor;
12use candle_nn::ops;
13use ferritin_core::AtomCollection;
14use ferritin_plms::device;
15use ferritin_plms::featurize::StructureFeatures;
16use ferritin_plms::featurize::utilities::int_to_aa1;
17use ferritin_plms::types::PseudoProbability;
18use hf_hub::api::sync::Api;
19use ndarray::ArrayBase;
20use ort::{
21    execution_providers::CUDAExecutionProvider,
22    session::{
23        Session,
24        builder::{GraphOptimizationLevel, SessionBuilder},
25    },
26    value::Tensor as OrtTensor,
27};
28use std::path::PathBuf;
29
30type NdArrayF32 = ArrayBase<ndarray::OwnedRepr<f32>, ndarray::Dim<ndarray::IxDynImpl>>;
31type NdArrayI64 = ArrayBase<ndarray::OwnedRepr<i64>, ndarray::Dim<ndarray::IxDynImpl>>;
32
33pub enum ModelType {
34    Protein,
35    Ligand,
36}
37
38impl ModelType {
39    pub fn get_paths(&self) -> (&'static str, &'static str, &'static str) {
40        match self {
41            ModelType::Protein => (
42                "zcpbx/proteinmpnn-v48-030-onnx",
43                "protmpnn_encoder.onnx",
44                "protmpnn_decoder_step.onnx",
45            ),
46            ModelType::Ligand => (
47                "zcpbx/ligandmpnn-v32-030-25-onnx",
48                "ligand_encoder.onnx",
49                "ligand_decoder.onnx",
50            ),
51        }
52    }
53}
54
55pub struct LigandMPNN {
56    session: SessionBuilder,
57    encoder_path: PathBuf,
58    decoder_path: PathBuf,
59}
60
61impl LigandMPNN {
62    pub fn new() -> Result<Self> {
63        let session = Self::create_session()?;
64        let (encoder_path, decoder_path) = Self::load_model_paths(ModelType::Ligand)?;
65        Ok(Self {
66            session,
67            encoder_path,
68            decoder_path,
69        })
70    }
71
72    fn create_session() -> Result<SessionBuilder> {
73        ort::init()
74            .with_name("LigandMPNN")
75            .with_execution_providers([CUDAExecutionProvider::default().build()])
76            .commit()?;
77        Ok(Session::builder()?
78            .with_optimization_level(GraphOptimizationLevel::Level1)?
79            .with_intra_threads(1)?)
80    }
81
82    fn load_model_paths(model_type: ModelType) -> Result<(PathBuf, PathBuf)> {
83        let api = Api::new()?;
84        let (repo_id, encoder_name, decoder_name) = model_type.get_paths();
85        Ok((
86            api.model(repo_id.to_string()).get(&encoder_name)?,
87            api.model(repo_id.to_string()).get(&decoder_name)?,
88        ))
89    }
90
91    pub fn run_model(&self, ac: AtomCollection, position: i64, temperature: f32) -> Result<Tensor> {
92        let (h_v, h_e, e_idx) = self.run_encoder(&ac)?;
93        self.run_decoder(h_v, h_e, e_idx, temperature, position)
94    }
95
96    pub fn run_encoder(&self, ac: &AtomCollection) -> Result<(NdArrayF32, NdArrayF32, NdArrayI64)> {
97        let device = device()?;
98        let mut encoder_model = self.session.clone().commit_from_file(&self.encoder_path)?;
99        let x_bb = ac.to_numeric_backbone_atoms(&device)?;
100        let (lig_coords, lig_elements, lig_mask) = ac.to_numeric_ligand_atoms(&device)?;
101
102        let coords_nd = tensor_to_ndarray_f32(x_bb)?;
103        let lig_coords_nd = tensor_to_ndarray_f32(lig_coords)?;
104        let lig_types_nd = tensor_to_ndarray_i64(lig_elements)?;
105        let lig_mask_nd = tensor_to_ndarray_f32(lig_mask)?;
106
107        let encoder_inputs = ort::inputs![
108            "coords" => OrtTensor::from_array(coords_nd)?,
109            "ligand_coords" => OrtTensor::from_array(lig_coords_nd)?,
110            "ligand_types" => OrtTensor::from_array(lig_types_nd)?,
111            "ligand_mask" => OrtTensor::from_array(lig_mask_nd)?
112        ];
113
114        let encoder_outputs = encoder_model.run(encoder_inputs)?;
115        Ok((
116            encoder_outputs["h_V"]
117                .try_extract_array::<f32>()?
118                .to_owned(),
119            encoder_outputs["h_E"]
120                .try_extract_array::<f32>()?
121                .to_owned(),
122            encoder_outputs["E_idx"]
123                .try_extract_array::<i64>()?
124                .to_owned(),
125        ))
126    }
127
128    pub fn run_decoder(
129        &self,
130        h_v: NdArrayF32,
131        h_e: NdArrayF32,
132        e_idx: NdArrayI64,
133        temperature: f32,
134        position: i64,
135    ) -> Result<Tensor> {
136        let mut decoder_model = self.session.clone().commit_from_file(&self.decoder_path)?;
137
138        let position_tensor =
139            OrtTensor::from_array(ndarray::Array::from_shape_vec([1], vec![position])?)?;
140        let temp_tensor =
141            OrtTensor::from_array(ndarray::Array::from_shape_vec([1], vec![temperature])?)?;
142
143        let decoder_inputs = ort::inputs![
144            "h_v" => OrtTensor::from_array(h_v)?,
145            "h_e" => OrtTensor::from_array(h_e)?,
146            "e_idx" => OrtTensor::from_array(e_idx)?,
147            "position" => position_tensor,
148            "temperature" => temp_tensor
149        ];
150
151        let decoder_outputs = decoder_model.run(decoder_inputs)?;
152        let logits = decoder_outputs["logits"]
153            .try_extract_array::<f32>()? // Changed from try_extract_tensor
154            .to_owned();
155        ndarray_to_tensor_f32(logits)
156    }
157
158    pub fn get_single_location(
159        &self,
160        ac: AtomCollection,
161        temp: f32,
162        position: i64,
163    ) -> Result<Vec<PseudoProbability>> {
164        let logits = self.run_model(ac, position, temp)?;
165        let logits = ops::softmax(&logits, 1)?;
166        let logits = logits.get(0)?.to_vec1()?;
167        let mut amino_acid_probs = Vec::new();
168        for i in 0..21 {
169            amino_acid_probs.push(PseudoProbability {
170                amino_acid: int_to_aa1(i),
171                pseudo_prob: logits[i as usize],
172                position: position as usize,
173            });
174        }
175        Ok(amino_acid_probs)
176    }
177
178    pub fn get_all_locations(&self, temp: f32) -> Result<Vec<PseudoProbability>> {
179        todo!()
180    }
181}
182
183#[cfg(test)]
184mod tests {
185    use super::*;
186    use ferritin_core::load_structure;
187    use ferritin_test_data::TestFile;
188
189    fn setup_test_data() -> AtomCollection {
190        let (protfile, _handle) = TestFile::protein_01().create_temp().unwrap();
191        load_structure(protfile).unwrap()
192    }
193
194    #[test]
195    fn test_model_initialization() {
196        let model = LigandMPNN::new().unwrap();
197        assert!(model.encoder_path.exists());
198        assert!(model.decoder_path.exists());
199    }
200
201    #[test]
202    fn test_encoder_output_dimensions() -> Result<()> {
203        let model = LigandMPNN::new()?;
204        let ac = setup_test_data();
205        println!("Data is setup");
206
207        let (h_v, h_e, e_idx) = model.run_encoder(&ac)?;
208        println!("h_v shape: {:?}", h_v.shape());
209        println!("h_e shape: {:?}", h_e.shape());
210        println!("e_idx shape: {:?}", e_idx.shape());
211
212        assert_eq!(h_v.shape(), &[1, 154, 128]); // getting: 4 ([1, 154, 4, 3])
213        assert_eq!(h_e.shape(), &[1, 154, 16, 128]);
214        assert_eq!(e_idx.shape(), &[1, 154, 16]);
215        Ok(())
216    }
217
218    // #[test]
219    // fn test_full_pipeline() -> Result<()> {
220    //     let model = LigandMPNN::new().unwrap();
221    //     let ac = setup_test_data();
222    //     let logits = model.run_model(ac, 10, 0.1).unwrap();
223    //     assert_eq!(logits.dims2().unwrap(), (1, 21));
224    //     Ok(())
225    // }
226}