ferritin_onnx_models/models/esm2/
mod.rs1use super::super::utilities::ndarray_to_tensor_f32;
12use anyhow::{Result, anyhow};
13use candle_core::{D, Tensor};
14use candle_nn::ops;
15use ferritin_plms::types::PseudoProbability;
16use hf_hub::api::sync::Api;
17use ndarray::Array2;
18use ort::{
19 execution_providers::CUDAExecutionProvider,
20 session::{
21 Session,
22 builder::{GraphOptimizationLevel, SessionBuilder},
23 },
24 value::Tensor as OrtTensor,
25};
26use std::path::PathBuf;
27use tokenizers::Tokenizer;
28
29pub enum ESM2Models {
30 T6_8M,
31 T12_35M,
32 T30_150M,
33 }
35
36pub struct ESM2 {
37 pub session: SessionBuilder,
38 pub model_path: PathBuf,
39 pub tokenizer: Tokenizer,
40}
41
42impl ESM2 {
43 pub fn new(model: ESM2Models) -> Result<Self> {
44 let session = Self::create_session()?;
45 let model_path = Self::load_model_path(model)?;
46 let tokenizer = Self::load_tokenizer()?;
47 Ok(Self {
48 session,
49 model_path,
50 tokenizer,
51 })
52 }
53
54 pub fn load_model_path(model: ESM2Models) -> Result<PathBuf> {
55 let api = Api::new()?;
56 let repo_id = match model {
57 ESM2Models::T6_8M => "zcpbx/esm2-t6-8m-UR50D-onnx",
58 ESM2Models::T12_35M => "zcpbx/esm2-t12-35M-UR50D-onnx",
59 ESM2Models::T30_150M => "zcpbx/esm2-t30-150M-UR50D-onnx",
60 }
62 .to_string();
63 let model_path = api.model(repo_id).get("model.onnx")?;
64 Ok(model_path)
65 }
66
67 pub fn load_tokenizer() -> Result<Tokenizer> {
68 let tokenizer_bytes = include_bytes!("tokenizer.json");
69 Tokenizer::from_bytes(tokenizer_bytes)
70 .map_err(|e| anyhow!("Failed to load tokenizer: {}", e))
71 }
72
73 fn create_session() -> Result<SessionBuilder> {
74 ort::init()
75 .with_name("ESM2")
76 .with_execution_providers([CUDAExecutionProvider::default().build()])
77 .commit();
78 Ok(Session::builder()?
79 .with_optimization_level(GraphOptimizationLevel::Level1)
80 .map_err(|e| anyhow!("ORT session config error: {}", e))?
81 .with_intra_threads(1)
82 .map_err(|e| anyhow!("ORT session config error: {}", e))?)
83 }
84
85 pub fn run_model(&self, sequence: &str) -> Result<Tensor> {
86 let mut model = self.session.clone().commit_from_file(&self.model_path)?;
87 let tokens = self
88 .tokenizer
89 .encode(sequence, false)
90 .map_err(|e| anyhow!("Tokenization failed: {}", e))?;
91 let token_ids = tokens.get_ids();
92 let shape = (1, tokens.len());
93
94 let mask_array: Array2<i64> = Array2::from_shape_vec(shape, vec![1; tokens.len()])?;
96 let tokens_array: Array2<i64> = Array2::from_shape_vec(
97 shape,
98 token_ids.iter().map(|&x| x as i64).collect::<Vec<_>>(),
99 )?;
100
101 let inputs = ort::inputs![
102 "input_ids" => OrtTensor::from_array(tokens_array)?,
103 "attention_mask" => OrtTensor::from_array(mask_array)?
104 ];
105
106 let outputs = model.run(inputs)?;
107 let logits = outputs["logits"].try_extract_array::<f32>()?.to_owned();
108 Ok(ndarray_to_tensor_f32(logits)?)
109 }
110
111 pub fn extract_logits(&self, tensor: &Tensor) -> Result<Vec<PseudoProbability>> {
113 let tensor = ops::softmax(tensor, D::Minus1)?;
114 let data = tensor.to_vec3::<f32>()?;
115 println!("Data: {:?}", data);
116 let shape = tensor.dims();
117 let mut logit_positions = Vec::new();
118 for seq_pos in 0..shape[1] {
119 for vocab_idx in 0..shape[2] {
120 let score = data[0][seq_pos][vocab_idx];
121 let amino_acid_char = self
122 .tokenizer
123 .decode(&[vocab_idx as u32], false)
124 .map_err(|e| anyhow!("Failed to decode: {}", e))?
125 .chars()
126 .next()
127 .ok_or_else(|| anyhow!("Empty decoded string"))?;
128 logit_positions.push(PseudoProbability {
129 position: seq_pos,
130 amino_acid: amino_acid_char,
131 pseudo_prob: score,
132 });
133 }
134 }
135 Ok(logit_positions)
136 }
137}
138
139#[cfg(test)]
140mod tests {
141 use super::*;
142
143 #[test]
144 fn test_tokenizer_load() -> Result<()> {
145 let tokenizer = ESM2::load_tokenizer()?;
146 let text = "MLKLRV";
147 let encoding = tokenizer
148 .encode(text, false)
149 .map_err(|e| anyhow!("Failed to encode: {}", e))?;
150 let tokens = encoding.get_tokens();
151 assert_eq!(tokens.len(), 6);
152 assert_eq!(tokens, &["M", "L", "K", "L", "R", "V"]);
153 Ok(())
154 }
155}