ferritin_onnx_models/models/esm2/
mod.rs1use super::super::utilities::ndarray_to_tensor_f32;
12use anyhow::{Result, anyhow};
13use candle_core::{D, Tensor};
14use candle_nn::ops;
15use ferritin_plms::types::PseudoProbability;
16use hf_hub::api::sync::Api;
17use ndarray::Array2;
18use ort::{
19 execution_providers::CUDAExecutionProvider,
20 session::{
21 Session,
22 builder::{GraphOptimizationLevel, SessionBuilder},
23 },
24};
25use std::path::PathBuf;
26use tokenizers::Tokenizer;
27
28pub enum ESM2Models {
29 T6_8M,
30 T12_35M,
31 T30_150M,
32 }
34
35pub struct ESM2 {
36 pub session: SessionBuilder,
37 pub model_path: PathBuf,
38 pub tokenizer: Tokenizer,
39}
40
41impl ESM2 {
42 pub fn new(model: ESM2Models) -> Result<Self> {
43 let session = Self::create_session()?;
44 let model_path = Self::load_model_path(model)?;
45 let tokenizer = Self::load_tokenizer()?;
46 Ok(Self {
47 session,
48 model_path,
49 tokenizer,
50 })
51 }
52 pub fn load_model_path(model: ESM2Models) -> Result<PathBuf> {
53 let api = Api::new()?;
54 let repo_id = match model {
55 ESM2Models::T6_8M => "zcpbx/esm2-t6-8m-UR50D-onnx",
56 ESM2Models::T12_35M => "zcpbx/esm2-t12-35M-UR50D-onnx",
57 ESM2Models::T30_150M => "zcpbx/esm2-t30-150M-UR50D-onnx",
58 }
60 .to_string();
61 let model_path = api.model(repo_id).get("model.onnx")?;
62 Ok(model_path)
63 }
64 pub fn load_tokenizer() -> Result<Tokenizer> {
65 let tokenizer_bytes = include_bytes!("tokenizer.json");
66 Tokenizer::from_bytes(tokenizer_bytes)
67 .map_err(|e| anyhow!("Failed to load tokenizer: {}", e))
68 }
69 fn create_session() -> Result<SessionBuilder> {
70 ort::init()
71 .with_name("ESM2")
72 .with_execution_providers([CUDAExecutionProvider::default().build()])
73 .commit()?;
74 Ok(Session::builder()?
75 .with_optimization_level(GraphOptimizationLevel::Level1)?
76 .with_intra_threads(1)?)
77 }
78 pub fn run_model(&self, sequence: &str) -> Result<Tensor> {
79 let model = self.session.clone().commit_from_file(&self.model_path)?;
80 let tokens = self
81 .tokenizer
82 .encode(sequence, false)
83 .map_err(|e| anyhow!("Tokenization failed: {}", e))?;
84 let token_ids = tokens.get_ids();
85 let shape = (1, tokens.len());
86 let mask_array: Array2<i64> = Array2::from_shape_vec(shape, vec![1; tokens.len()])?;
88 let tokens_array: Array2<i64> = Array2::from_shape_vec(
89 shape,
90 token_ids.iter().map(|&x| x as i64).collect::<Vec<_>>(),
91 )?;
92 let outputs =
93 model.run(ort::inputs!["input_ids" => tokens_array,"attention_mask" => mask_array]?)?;
94 let logits = outputs["logits"].try_extract_tensor::<f32>()?.to_owned();
95 Ok(ndarray_to_tensor_f32(logits)?)
96 }
97 pub fn extract_logits(&self, tensor: &Tensor) -> Result<Vec<PseudoProbability>> {
99 let tensor = ops::softmax(tensor, D::Minus1)?;
100 let data = tensor.to_vec3::<f32>()?;
101 println!("Data: {:?}", data);
102 let shape = tensor.dims();
103 let mut logit_positions = Vec::new();
104 for seq_pos in 0..shape[1] {
105 for vocab_idx in 0..shape[2] {
106 let score = data[0][seq_pos][vocab_idx];
107 let amino_acid_char = self
108 .tokenizer
109 .decode(&[vocab_idx as u32], false)
110 .map_err(|e| anyhow!("Failed to decode: {}", e))?
111 .chars()
112 .next()
113 .ok_or_else(|| anyhow!("Empty decoded string"))?;
114 logit_positions.push(PseudoProbability {
115 position: seq_pos,
116 amino_acid: amino_acid_char,
117 pseudo_prob: score,
118 });
119 }
120 }
121 Ok(logit_positions)
122 }
123}
124
125#[cfg(test)]
126mod tests {
127 use super::*;
128
129 fn test_tokenizer_load() -> Result<()> {
130 let tokenizer = ESM2::load_tokenizer()?;
131 let text = "MLKLRV";
132 let encoding = tokenizer
133 .encode(text, false)
134 .map_err(|e| anyhow!("Failed to encode: {}", e))?;
135 let tokens = encoding.get_tokens();
136 assert_eq!(tokens.len(), 6);
137 assert_eq!(tokens, &["M", "L", "K", "L", "R", "V"]);
138 Ok(())
139 }
140}