ferritin_onnx_models/models/ligandmpnn/
mod.rs1use crate::{ndarray_to_tensor_f32, tensor_to_ndarray_f32, tensor_to_ndarray_i64};
10use anyhow::{Result, anyhow};
11use candle_core::Tensor;
12use candle_nn::ops;
13use ferritin_core::AtomCollection;
14use ferritin_plms::device;
15use ferritin_plms::featurize::StructureFeatures;
16use ferritin_plms::featurize::utilities::int_to_aa1;
17use ferritin_plms::types::PseudoProbability;
18use hf_hub::api::sync::Api;
19use ndarray::ArrayBase;
20use ort::{
21 execution_providers::CUDAExecutionProvider,
22 session::{
23 Session,
24 builder::{GraphOptimizationLevel, SessionBuilder},
25 },
26 value::Tensor as OrtTensor,
27};
28use std::path::PathBuf;
29
30type NdArrayF32 = ArrayBase<ndarray::OwnedRepr<f32>, ndarray::Dim<ndarray::IxDynImpl>>;
31type NdArrayI64 = ArrayBase<ndarray::OwnedRepr<i64>, ndarray::Dim<ndarray::IxDynImpl>>;
32
33pub enum ModelType {
34 Protein,
35 Ligand,
36}
37
38impl ModelType {
39 pub fn get_paths(&self) -> (&'static str, &'static str, &'static str) {
40 match self {
41 ModelType::Protein => (
42 "zcpbx/proteinmpnn-v48-030-onnx",
43 "protmpnn_encoder.onnx",
44 "protmpnn_decoder_step.onnx",
45 ),
46 ModelType::Ligand => (
47 "zcpbx/ligandmpnn-v32-030-25-onnx",
48 "ligand_encoder.onnx",
49 "ligand_decoder.onnx",
50 ),
51 }
52 }
53}
54
55pub struct LigandMPNN {
56 session: SessionBuilder,
57 encoder_path: PathBuf,
58 decoder_path: PathBuf,
59}
60
61impl LigandMPNN {
62 pub fn new() -> Result<Self> {
63 let session = Self::create_session()?;
64 let (encoder_path, decoder_path) = Self::load_model_paths(ModelType::Ligand)?;
65 Ok(Self {
66 session,
67 encoder_path,
68 decoder_path,
69 })
70 }
71
72 fn create_session() -> Result<SessionBuilder> {
73 ort::init()
74 .with_name("LigandMPNN")
75 .with_execution_providers([CUDAExecutionProvider::default().build()])
76 .commit();
77 Ok(Session::builder()?
78 .with_optimization_level(GraphOptimizationLevel::Level1)
79 .map_err(|e| anyhow!("ORT session config error: {}", e))?
80 .with_intra_threads(1)
81 .map_err(|e| anyhow!("ORT session config error: {}", e))?)
82 }
83
84 fn load_model_paths(model_type: ModelType) -> Result<(PathBuf, PathBuf)> {
85 let api = Api::new()?;
86 let (repo_id, encoder_name, decoder_name) = model_type.get_paths();
87 Ok((
88 api.model(repo_id.to_string()).get(&encoder_name)?,
89 api.model(repo_id.to_string()).get(&decoder_name)?,
90 ))
91 }
92
93 pub fn run_model(&self, ac: AtomCollection, position: i64, temperature: f32) -> Result<Tensor> {
94 let (h_v, h_e, e_idx) = self.run_encoder(&ac)?;
95 self.run_decoder(h_v, h_e, e_idx, temperature, position)
96 }
97
98 pub fn run_encoder(&self, ac: &AtomCollection) -> Result<(NdArrayF32, NdArrayF32, NdArrayI64)> {
99 let device = device(false)?;
100 let mut encoder_model = self.session.clone().commit_from_file(&self.encoder_path)?;
101 let x_bb = ac.to_numeric_backbone_atoms(&device)?;
102 let (lig_coords, lig_elements, lig_mask) = ac.to_numeric_ligand_atoms(&device)?;
103
104 let coords_nd = tensor_to_ndarray_f32(x_bb)?;
105 let lig_coords_nd = tensor_to_ndarray_f32(lig_coords)?;
106 let lig_types_nd = tensor_to_ndarray_i64(lig_elements)?;
107 let lig_mask_nd = tensor_to_ndarray_f32(lig_mask)?;
108
109 let encoder_inputs = ort::inputs![
110 "coords" => OrtTensor::from_array(coords_nd)?,
111 "ligand_coords" => OrtTensor::from_array(lig_coords_nd)?,
112 "ligand_types" => OrtTensor::from_array(lig_types_nd)?,
113 "ligand_mask" => OrtTensor::from_array(lig_mask_nd)?
114 ];
115
116 let encoder_outputs = encoder_model.run(encoder_inputs)?;
117 Ok((
118 encoder_outputs["h_V"]
119 .try_extract_array::<f32>()?
120 .to_owned(),
121 encoder_outputs["h_E"]
122 .try_extract_array::<f32>()?
123 .to_owned(),
124 encoder_outputs["E_idx"]
125 .try_extract_array::<i64>()?
126 .to_owned(),
127 ))
128 }
129
130 pub fn run_decoder(
131 &self,
132 h_v: NdArrayF32,
133 h_e: NdArrayF32,
134 e_idx: NdArrayI64,
135 temperature: f32,
136 position: i64,
137 ) -> Result<Tensor> {
138 let mut decoder_model = self.session.clone().commit_from_file(&self.decoder_path)?;
139
140 let position_tensor =
141 OrtTensor::from_array(ndarray::Array::from_shape_vec([1], vec![position])?)?;
142 let temp_tensor =
143 OrtTensor::from_array(ndarray::Array::from_shape_vec([1], vec![temperature])?)?;
144
145 let decoder_inputs = ort::inputs![
146 "h_v" => OrtTensor::from_array(h_v)?,
147 "h_e" => OrtTensor::from_array(h_e)?,
148 "e_idx" => OrtTensor::from_array(e_idx)?,
149 "position" => position_tensor,
150 "temperature" => temp_tensor
151 ];
152
153 let decoder_outputs = decoder_model.run(decoder_inputs)?;
154 let logits = decoder_outputs["logits"]
155 .try_extract_array::<f32>()? .to_owned();
157 ndarray_to_tensor_f32(logits)
158 }
159
160 pub fn get_single_location(
161 &self,
162 ac: AtomCollection,
163 temp: f32,
164 position: i64,
165 ) -> Result<Vec<PseudoProbability>> {
166 let logits = self.run_model(ac, position, temp)?;
167 let logits = ops::softmax(&logits, 1)?;
168 let logits = logits.get(0)?.to_vec1()?;
169 let mut amino_acid_probs = Vec::new();
170 for i in 0..21 {
171 amino_acid_probs.push(PseudoProbability {
172 amino_acid: int_to_aa1(i),
173 pseudo_prob: logits[i as usize],
174 position: position as usize,
175 });
176 }
177 Ok(amino_acid_probs)
178 }
179
180 pub fn get_all_locations(&self, temp: f32) -> Result<Vec<PseudoProbability>> {
181 todo!()
182 }
183}
184
185#[cfg(test)]
186mod tests {
187 use super::*;
188 use ferritin_core::load_structure;
189 use ferritin_test_data::TestFile;
190
191 fn setup_test_data() -> AtomCollection {
192 let (protfile, _handle) = TestFile::protein_01().create_temp().unwrap();
193 load_structure(protfile).unwrap()
194 }
195
196 #[test]
197 fn test_model_initialization() {
198 let model = LigandMPNN::new().unwrap();
199 assert!(model.encoder_path.exists());
200 assert!(model.decoder_path.exists());
201 }
202
203 #[test]
204 fn test_encoder_output_dimensions() -> Result<()> {
205 let model = LigandMPNN::new()?;
206 let ac = setup_test_data();
207 println!("Data is setup");
208
209 let (h_v, h_e, e_idx) = model.run_encoder(&ac)?;
210 println!("h_v shape: {:?}", h_v.shape());
211 println!("h_e shape: {:?}", h_e.shape());
212 println!("e_idx shape: {:?}", e_idx.shape());
213
214 assert_eq!(h_v.shape(), &[1, 154, 128]); assert_eq!(h_e.shape(), &[1, 154, 16, 128]);
216 assert_eq!(e_idx.shape(), &[1, 154, 16]);
217 Ok(())
218 }
219
220 }