ferritin_onnx_models/models/ligandmpnn/
mod.rs1use crate::{ndarray_to_tensor_f32, tensor_to_ndarray_f32, tensor_to_ndarray_i64};
10use anyhow::Result;
11use candle_core::Tensor;
12use candle_nn::ops;
13use ferritin_core::AtomCollection;
14use ferritin_plms::device;
15use ferritin_plms::featurize::StructureFeatures;
16use ferritin_plms::featurize::utilities::int_to_aa1;
17use ferritin_plms::types::PseudoProbability;
18use hf_hub::api::sync::Api;
19use ndarray::ArrayBase;
20use ort::{
21 execution_providers::CUDAExecutionProvider,
22 session::{
23 Session,
24 builder::{GraphOptimizationLevel, SessionBuilder},
25 },
26 value::Tensor as OrtTensor,
27};
28use std::path::PathBuf;
29
30type NdArrayF32 = ArrayBase<ndarray::OwnedRepr<f32>, ndarray::Dim<ndarray::IxDynImpl>>;
31type NdArrayI64 = ArrayBase<ndarray::OwnedRepr<i64>, ndarray::Dim<ndarray::IxDynImpl>>;
32
33pub enum ModelType {
34 Protein,
35 Ligand,
36}
37
38impl ModelType {
39 pub fn get_paths(&self) -> (&'static str, &'static str, &'static str) {
40 match self {
41 ModelType::Protein => (
42 "zcpbx/proteinmpnn-v48-030-onnx",
43 "protmpnn_encoder.onnx",
44 "protmpnn_decoder_step.onnx",
45 ),
46 ModelType::Ligand => (
47 "zcpbx/ligandmpnn-v32-030-25-onnx",
48 "ligand_encoder.onnx",
49 "ligand_decoder.onnx",
50 ),
51 }
52 }
53}
54
55pub struct LigandMPNN {
56 session: SessionBuilder,
57 encoder_path: PathBuf,
58 decoder_path: PathBuf,
59}
60
61impl LigandMPNN {
62 pub fn new() -> Result<Self> {
63 let session = Self::create_session()?;
64 let (encoder_path, decoder_path) = Self::load_model_paths(ModelType::Ligand)?;
65 Ok(Self {
66 session,
67 encoder_path,
68 decoder_path,
69 })
70 }
71
72 fn create_session() -> Result<SessionBuilder> {
73 ort::init()
74 .with_name("LigandMPNN")
75 .with_execution_providers([CUDAExecutionProvider::default().build()])
76 .commit()?;
77 Ok(Session::builder()?
78 .with_optimization_level(GraphOptimizationLevel::Level1)?
79 .with_intra_threads(1)?)
80 }
81
82 fn load_model_paths(model_type: ModelType) -> Result<(PathBuf, PathBuf)> {
83 let api = Api::new()?;
84 let (repo_id, encoder_name, decoder_name) = model_type.get_paths();
85 Ok((
86 api.model(repo_id.to_string()).get(&encoder_name)?,
87 api.model(repo_id.to_string()).get(&decoder_name)?,
88 ))
89 }
90
91 pub fn run_model(&self, ac: AtomCollection, position: i64, temperature: f32) -> Result<Tensor> {
92 let (h_v, h_e, e_idx) = self.run_encoder(&ac)?;
93 self.run_decoder(h_v, h_e, e_idx, temperature, position)
94 }
95
96 pub fn run_encoder(&self, ac: &AtomCollection) -> Result<(NdArrayF32, NdArrayF32, NdArrayI64)> {
97 let device = device()?;
98 let mut encoder_model = self.session.clone().commit_from_file(&self.encoder_path)?;
99 let x_bb = ac.to_numeric_backbone_atoms(&device)?;
100 let (lig_coords, lig_elements, lig_mask) = ac.to_numeric_ligand_atoms(&device)?;
101
102 let coords_nd = tensor_to_ndarray_f32(x_bb)?;
103 let lig_coords_nd = tensor_to_ndarray_f32(lig_coords)?;
104 let lig_types_nd = tensor_to_ndarray_i64(lig_elements)?;
105 let lig_mask_nd = tensor_to_ndarray_f32(lig_mask)?;
106
107 let encoder_inputs = ort::inputs![
108 "coords" => OrtTensor::from_array(coords_nd)?,
109 "ligand_coords" => OrtTensor::from_array(lig_coords_nd)?,
110 "ligand_types" => OrtTensor::from_array(lig_types_nd)?,
111 "ligand_mask" => OrtTensor::from_array(lig_mask_nd)?
112 ];
113
114 let encoder_outputs = encoder_model.run(encoder_inputs)?;
115 Ok((
116 encoder_outputs["h_V"]
117 .try_extract_array::<f32>()?
118 .to_owned(),
119 encoder_outputs["h_E"]
120 .try_extract_array::<f32>()?
121 .to_owned(),
122 encoder_outputs["E_idx"]
123 .try_extract_array::<i64>()?
124 .to_owned(),
125 ))
126 }
127
128 pub fn run_decoder(
129 &self,
130 h_v: NdArrayF32,
131 h_e: NdArrayF32,
132 e_idx: NdArrayI64,
133 temperature: f32,
134 position: i64,
135 ) -> Result<Tensor> {
136 let mut decoder_model = self.session.clone().commit_from_file(&self.decoder_path)?;
137
138 let position_tensor =
139 OrtTensor::from_array(ndarray::Array::from_shape_vec([1], vec![position])?)?;
140 let temp_tensor =
141 OrtTensor::from_array(ndarray::Array::from_shape_vec([1], vec![temperature])?)?;
142
143 let decoder_inputs = ort::inputs![
144 "h_v" => OrtTensor::from_array(h_v)?,
145 "h_e" => OrtTensor::from_array(h_e)?,
146 "e_idx" => OrtTensor::from_array(e_idx)?,
147 "position" => position_tensor,
148 "temperature" => temp_tensor
149 ];
150
151 let decoder_outputs = decoder_model.run(decoder_inputs)?;
152 let logits = decoder_outputs["logits"]
153 .try_extract_array::<f32>()? .to_owned();
155 ndarray_to_tensor_f32(logits)
156 }
157
158 pub fn get_single_location(
159 &self,
160 ac: AtomCollection,
161 temp: f32,
162 position: i64,
163 ) -> Result<Vec<PseudoProbability>> {
164 let logits = self.run_model(ac, position, temp)?;
165 let logits = ops::softmax(&logits, 1)?;
166 let logits = logits.get(0)?.to_vec1()?;
167 let mut amino_acid_probs = Vec::new();
168 for i in 0..21 {
169 amino_acid_probs.push(PseudoProbability {
170 amino_acid: int_to_aa1(i),
171 pseudo_prob: logits[i as usize],
172 position: position as usize,
173 });
174 }
175 Ok(amino_acid_probs)
176 }
177
178 pub fn get_all_locations(&self, temp: f32) -> Result<Vec<PseudoProbability>> {
179 todo!()
180 }
181}
182
183#[cfg(test)]
184mod tests {
185 use super::*;
186 use ferritin_core::load_structure;
187 use ferritin_test_data::TestFile;
188
189 fn setup_test_data() -> AtomCollection {
190 let (protfile, _handle) = TestFile::protein_01().create_temp().unwrap();
191 load_structure(protfile).unwrap()
192 }
193
194 #[test]
195 fn test_model_initialization() {
196 let model = LigandMPNN::new().unwrap();
197 assert!(model.encoder_path.exists());
198 assert!(model.decoder_path.exists());
199 }
200
201 #[test]
202 fn test_encoder_output_dimensions() -> Result<()> {
203 let model = LigandMPNN::new()?;
204 let ac = setup_test_data();
205 println!("Data is setup");
206
207 let (h_v, h_e, e_idx) = model.run_encoder(&ac)?;
208 println!("h_v shape: {:?}", h_v.shape());
209 println!("h_e shape: {:?}", h_e.shape());
210 println!("e_idx shape: {:?}", e_idx.shape());
211
212 assert_eq!(h_v.shape(), &[1, 154, 128]); assert_eq!(h_e.shape(), &[1, 154, 16, 128]);
214 assert_eq!(e_idx.shape(), &[1, 154, 16]);
215 Ok(())
216 }
217
218 }