ferritin_onnx_models/models/ligandmpnn/
mod.rs1use crate::{ndarray_to_tensor_f32, tensor_to_ndarray_f32, tensor_to_ndarray_i64};
10use anyhow::Result;
11use candle_core::{Device, Tensor};
12use candle_nn::ops;
13use ferritin_core::AtomCollection;
14use ferritin_plms::featurize::StructureFeatures;
15use ferritin_plms::featurize::utilities::int_to_aa1;
16use ferritin_plms::types::PseudoProbability;
17use hf_hub::api::sync::Api;
18use ndarray::ArrayBase;
19use ort::{
20 execution_providers::CUDAExecutionProvider,
21 session::{
22 Session,
23 builder::{GraphOptimizationLevel, SessionBuilder},
24 },
25};
26use std::path::PathBuf;
27
28type NdArrayF32 = ArrayBase<ndarray::OwnedRepr<f32>, ndarray::Dim<ndarray::IxDynImpl>>;
29type NdArrayI64 = ArrayBase<ndarray::OwnedRepr<i64>, ndarray::Dim<ndarray::IxDynImpl>>;
30
31pub enum ModelType {
32 Protein,
33 Ligand,
34}
35
36impl ModelType {
37 pub fn get_paths(&self) -> (&'static str, &'static str, &'static str) {
38 match self {
39 ModelType::Protein => (
40 "zcpbx/proteinmpnn-v48-030-onnx",
41 "protmpnn_encoder.onnx",
42 "protmpnn_decoder_step.onnx",
43 ),
44 ModelType::Ligand => (
45 "zcpbx/ligandmpnn-v32-030-25-onnx",
46 "ligand_encoder.onnx",
47 "ligand_decoder.onnx",
48 ),
49 }
50 }
51}
52
53pub struct LigandMPNN {
54 session: SessionBuilder,
55 encoder_path: PathBuf,
56 decoder_path: PathBuf,
57}
58
59impl LigandMPNN {
60 pub fn new() -> Result<Self> {
61 let session = Self::create_session()?;
62 let (encoder_path, decoder_path) = Self::load_model_paths(ModelType::Ligand)?;
63 Ok(Self {
64 session,
65 encoder_path,
66 decoder_path,
67 })
68 }
69 fn create_session() -> Result<SessionBuilder> {
70 ort::init()
71 .with_name("LigandMPNN")
72 .with_execution_providers([CUDAExecutionProvider::default().build()])
73 .commit()?;
74 Ok(Session::builder()?
75 .with_optimization_level(GraphOptimizationLevel::Level1)?
76 .with_intra_threads(1)?)
77 }
78 fn load_model_paths(model_type: ModelType) -> Result<(PathBuf, PathBuf)> {
79 let api = Api::new()?;
80 let (repo_id, encoder_name, decoder_name) = model_type.get_paths();
81 Ok((
82 api.model(repo_id.to_string()).get(&encoder_name)?,
83 api.model(repo_id.to_string()).get(&decoder_name)?,
84 ))
85 }
86 pub fn run_model(&self, ac: AtomCollection, position: i64, temperature: f32) -> Result<Tensor> {
87 let (h_v, h_e, e_idx) = self.run_encoder(&ac)?;
88 self.run_decoder(h_v, h_e, e_idx, temperature, position)
89 }
90 pub fn run_encoder(&self, ac: &AtomCollection) -> Result<(NdArrayF32, NdArrayF32, NdArrayI64)> {
91 let device = Device::Cpu;
92 let encoder_model = self.session.clone().commit_from_file(&self.encoder_path)?;
93 let x_bb = ac.to_numeric_backbone_atoms(&device)?;
94 let (lig_coords, lig_elements, lig_mask) = ac.to_numeric_ligand_atoms(&device)?;
95 let coords_nd = tensor_to_ndarray_f32(x_bb)?;
96 let lig_coords_nd = tensor_to_ndarray_f32(lig_coords)?;
97 let lig_types_nd = tensor_to_ndarray_i64(lig_elements)?;
98 let lig_mask_nd = tensor_to_ndarray_f32(lig_mask)?;
99 let encoder_inputs = ort::inputs![
100 "coords" => coords_nd,
101 "ligand_coords" => lig_coords_nd,
102 "ligand_types" => lig_types_nd,
103 "ligand_mask" => lig_mask_nd
104 ]?;
105 let encoder_outputs = encoder_model.run(encoder_inputs)?;
106 Ok((
107 encoder_outputs["h_V"]
108 .try_extract_tensor::<f32>()?
109 .to_owned(),
110 encoder_outputs["h_E"]
111 .try_extract_tensor::<f32>()?
112 .to_owned(),
113 encoder_outputs["E_idx"]
114 .try_extract_tensor::<i64>()?
115 .to_owned(),
116 ))
117 }
118 pub fn run_decoder(
119 &self,
120 h_v: NdArrayF32,
121 h_e: NdArrayF32,
122 e_idx: NdArrayI64,
123 temperature: f32,
124 position: i64,
125 ) -> Result<Tensor> {
126 let decoder_model = self.session.clone().commit_from_file(&self.decoder_path)?;
127 let position_tensor =
128 ort::value::Tensor::from_array(ndarray::Array::from_shape_vec([1], vec![position])?)?;
129 let temp_tensor = ort::value::Tensor::from_array(ndarray::Array::from_shape_vec(
130 [1],
131 vec![temperature],
132 )?)?;
133 let decoder_inputs = ort::inputs![
134 "h_v" => h_v,
135 "h_e" => h_e,
136 "e_idx" => e_idx,
137 "position" => position_tensor,
138 "temperature" => temp_tensor,
139 ]?;
140
141 let decoder_outputs = decoder_model.run(decoder_inputs)?;
142 let logits = decoder_outputs["logits"]
143 .try_extract_tensor::<f32>()?
144 .to_owned();
145 ndarray_to_tensor_f32(logits)
146 }
147 pub fn get_single_location(
148 &self,
149 ac: AtomCollection,
150 temp: f32,
151 position: i64,
152 ) -> Result<Vec<PseudoProbability>> {
153 let logits = self.run_model(ac, position, temp)?;
154 let logits = ops::softmax(&logits, 1)?;
155 let logits = logits.get(0)?.to_vec1()?;
156 let mut amino_acid_probs = Vec::new();
157 for i in 0..21 {
158 amino_acid_probs.push(PseudoProbability {
159 amino_acid: int_to_aa1(i),
160 pseudo_prob: logits[i as usize],
161 position: position as usize,
162 });
163 }
164 Ok(amino_acid_probs)
165 }
166 pub fn get_all_locations(&self, temp: f32) -> Result<Vec<PseudoProbability>> {
167 todo!()
168 }
169}
170
171#[cfg(test)]
172mod tests {
173 use super::*;
174 use ferritin_core::load_structure;
175 use ferritin_test_data::TestFile;
176 fn setup_test_data() -> AtomCollection {
177 let (protfile, _handle) = TestFile::protein_01().create_temp().unwrap();
178 load_structure(protfile).unwrap()
179 }
180
181 #[test]
182 fn test_model_initialization() {
183 let model = LigandMPNN::new().unwrap();
184 assert!(model.encoder_path.exists());
185 assert!(model.decoder_path.exists());
186 }
187
188 #[test]
189 fn test_encoder_output_dimensions() -> Result<()> {
190 let model = LigandMPNN::new()?;
191 let ac = setup_test_data();
192
193 let (h_v, h_e, e_idx) = model.run_encoder(&ac)?;
194 println!("h_v shape: {:?}", h_v.shape());
195 println!("h_e shape: {:?}", h_e.shape());
196 println!("e_idx shape: {:?}", e_idx.shape());
197
198 assert_eq!(h_v.shape(), &[1, 154, 128]); assert_eq!(h_e.shape(), &[1, 154, 16, 128]);
200 assert_eq!(e_idx.shape(), &[1, 154, 16]);
201 Ok(())
202 }
203
204 }