ferritin_onnx_models/models/esm2/
mod.rs1use super::super::utilities::ndarray_to_tensor_f32;
12use anyhow::{Result, anyhow};
13use candle_core::{D, Tensor};
14use candle_nn::ops;
15use ferritin_plms::types::PseudoProbability;
16use hf_hub::api::sync::Api;
17use ndarray::Array2;
18use ort::{
19 execution_providers::CUDAExecutionProvider,
20 session::{
21 Session,
22 builder::{GraphOptimizationLevel, SessionBuilder},
23 },
24 value::Tensor as OrtTensor,
25};
26use std::path::PathBuf;
27use tokenizers::Tokenizer;
28
29pub enum ESM2Models {
30 T6_8M,
31 T12_35M,
32 T30_150M,
33 }
35
36pub struct ESM2 {
37 pub session: SessionBuilder,
38 pub model_path: PathBuf,
39 pub tokenizer: Tokenizer,
40}
41
42impl ESM2 {
43 pub fn new(model: ESM2Models) -> Result<Self> {
44 let session = Self::create_session()?;
45 let model_path = Self::load_model_path(model)?;
46 let tokenizer = Self::load_tokenizer()?;
47 Ok(Self {
48 session,
49 model_path,
50 tokenizer,
51 })
52 }
53
54 pub fn load_model_path(model: ESM2Models) -> Result<PathBuf> {
55 let api = Api::new()?;
56 let repo_id = match model {
57 ESM2Models::T6_8M => "zcpbx/esm2-t6-8m-UR50D-onnx",
58 ESM2Models::T12_35M => "zcpbx/esm2-t12-35M-UR50D-onnx",
59 ESM2Models::T30_150M => "zcpbx/esm2-t30-150M-UR50D-onnx",
60 }
62 .to_string();
63 let model_path = api.model(repo_id).get("model.onnx")?;
64 Ok(model_path)
65 }
66
67 pub fn load_tokenizer() -> Result<Tokenizer> {
68 let tokenizer_bytes = include_bytes!("tokenizer.json");
69 Tokenizer::from_bytes(tokenizer_bytes)
70 .map_err(|e| anyhow!("Failed to load tokenizer: {}", e))
71 }
72
73 fn create_session() -> Result<SessionBuilder> {
74 ort::init()
75 .with_name("ESM2")
76 .with_execution_providers([CUDAExecutionProvider::default().build()])
77 .commit()?;
78 Ok(Session::builder()?
79 .with_optimization_level(GraphOptimizationLevel::Level1)?
80 .with_intra_threads(1)?)
81 }
82
83 pub fn run_model(&self, sequence: &str) -> Result<Tensor> {
84 let mut model = self.session.clone().commit_from_file(&self.model_path)?;
85 let tokens = self
86 .tokenizer
87 .encode(sequence, false)
88 .map_err(|e| anyhow!("Tokenization failed: {}", e))?;
89 let token_ids = tokens.get_ids();
90 let shape = (1, tokens.len());
91
92 let mask_array: Array2<i64> = Array2::from_shape_vec(shape, vec![1; tokens.len()])?;
94 let tokens_array: Array2<i64> = Array2::from_shape_vec(
95 shape,
96 token_ids.iter().map(|&x| x as i64).collect::<Vec<_>>(),
97 )?;
98
99 let inputs = ort::inputs![
100 "input_ids" => OrtTensor::from_array(tokens_array)?,
101 "attention_mask" => OrtTensor::from_array(mask_array)?
102 ];
103
104 let outputs = model.run(inputs)?;
105 let logits = outputs["logits"].try_extract_array::<f32>()?.to_owned();
106 Ok(ndarray_to_tensor_f32(logits)?)
107 }
108
109 pub fn extract_logits(&self, tensor: &Tensor) -> Result<Vec<PseudoProbability>> {
111 let tensor = ops::softmax(tensor, D::Minus1)?;
112 let data = tensor.to_vec3::<f32>()?;
113 println!("Data: {:?}", data);
114 let shape = tensor.dims();
115 let mut logit_positions = Vec::new();
116 for seq_pos in 0..shape[1] {
117 for vocab_idx in 0..shape[2] {
118 let score = data[0][seq_pos][vocab_idx];
119 let amino_acid_char = self
120 .tokenizer
121 .decode(&[vocab_idx as u32], false)
122 .map_err(|e| anyhow!("Failed to decode: {}", e))?
123 .chars()
124 .next()
125 .ok_or_else(|| anyhow!("Empty decoded string"))?;
126 logit_positions.push(PseudoProbability {
127 position: seq_pos,
128 amino_acid: amino_acid_char,
129 pseudo_prob: score,
130 });
131 }
132 }
133 Ok(logit_positions)
134 }
135}
136
137#[cfg(test)]
138mod tests {
139 use super::*;
140
141 #[test]
142 fn test_tokenizer_load() -> Result<()> {
143 let tokenizer = ESM2::load_tokenizer()?;
144 let text = "MLKLRV";
145 let encoding = tokenizer
146 .encode(text, false)
147 .map_err(|e| anyhow!("Failed to encode: {}", e))?;
148 let tokens = encoding.get_tokens();
149 assert_eq!(tokens.len(), 6);
150 assert_eq!(tokens, &["M", "L", "K", "L", "R", "V"]);
151 Ok(())
152 }
153}