ferritin_onnx_models/models/esm2/
mod.rs

1//! ESM2 Tokenizer. Models converted to ONNX format from [ESM2](https://github.com/facebookresearch/esm)
2//! and uploaded to HuggingFace hub. The tokenizer is included in this crate and loaded from
3//! memory using `tokenizer.json`. This is fairly minimal - for the full set of ESM2 models
4//! please see the ESM2 repository and the HuggingFace hub.
5//!
6//! # Models:
7//! * T6_8M - small 6-layer protein language model
8//! * T12_35M - medium 12-layer protein language model
9//! * T30_150M - large 30-layer protein language model
10//!
11use super::super::utilities::ndarray_to_tensor_f32;
12use anyhow::{Result, anyhow};
13use candle_core::{D, Tensor};
14use candle_nn::ops;
15use ferritin_plms::types::PseudoProbability;
16use hf_hub::api::sync::Api;
17use ndarray::Array2;
18use ort::{
19    execution_providers::CUDAExecutionProvider,
20    session::{
21        Session,
22        builder::{GraphOptimizationLevel, SessionBuilder},
23    },
24    value::Tensor as OrtTensor,
25};
26use std::path::PathBuf;
27use tokenizers::Tokenizer;
28
29pub enum ESM2Models {
30    T6_8M,
31    T12_35M,
32    T30_150M,
33    // ESM2_T33_650M,
34}
35
36pub struct ESM2 {
37    pub session: SessionBuilder,
38    pub model_path: PathBuf,
39    pub tokenizer: Tokenizer,
40}
41
42impl ESM2 {
43    pub fn new(model: ESM2Models) -> Result<Self> {
44        let session = Self::create_session()?;
45        let model_path = Self::load_model_path(model)?;
46        let tokenizer = Self::load_tokenizer()?;
47        Ok(Self {
48            session,
49            model_path,
50            tokenizer,
51        })
52    }
53
54    pub fn load_model_path(model: ESM2Models) -> Result<PathBuf> {
55        let api = Api::new()?;
56        let repo_id = match model {
57            ESM2Models::T6_8M => "zcpbx/esm2-t6-8m-UR50D-onnx",
58            ESM2Models::T12_35M => "zcpbx/esm2-t12-35M-UR50D-onnx",
59            ESM2Models::T30_150M => "zcpbx/esm2-t30-150M-UR50D-onnx",
60            // ESM2Models::ESM2_T33_650M => "zcpbx/esm2-t33-650M-UR50D-onnx",
61        }
62        .to_string();
63        let model_path = api.model(repo_id).get("model.onnx")?;
64        Ok(model_path)
65    }
66
67    pub fn load_tokenizer() -> Result<Tokenizer> {
68        let tokenizer_bytes = include_bytes!("tokenizer.json");
69        Tokenizer::from_bytes(tokenizer_bytes)
70            .map_err(|e| anyhow!("Failed to load tokenizer: {}", e))
71    }
72
73    fn create_session() -> Result<SessionBuilder> {
74        ort::init()
75            .with_name("ESM2")
76            .with_execution_providers([CUDAExecutionProvider::default().build()])
77            .commit()?;
78        Ok(Session::builder()?
79            .with_optimization_level(GraphOptimizationLevel::Level1)?
80            .with_intra_threads(1)?)
81    }
82
83    pub fn run_model(&self, sequence: &str) -> Result<Tensor> {
84        let mut model = self.session.clone().commit_from_file(&self.model_path)?;
85        let tokens = self
86            .tokenizer
87            .encode(sequence, false)
88            .map_err(|e| anyhow!("Tokenization failed: {}", e))?;
89        let token_ids = tokens.get_ids();
90        let shape = (1, tokens.len());
91
92        // Todo: Are we masking this correctly?
93        let mask_array: Array2<i64> = Array2::from_shape_vec(shape, vec![1; tokens.len()])?;
94        let tokens_array: Array2<i64> = Array2::from_shape_vec(
95            shape,
96            token_ids.iter().map(|&x| x as i64).collect::<Vec<_>>(),
97        )?;
98
99        let inputs = ort::inputs![
100            "input_ids" => OrtTensor::from_array(tokens_array)?,
101            "attention_mask" => OrtTensor::from_array(mask_array)?
102        ];
103
104        let outputs = model.run(inputs)?;
105        let logits = outputs["logits"].try_extract_array::<f32>()?.to_owned();
106        Ok(ndarray_to_tensor_f32(logits)?)
107    }
108
109    // Softmax and simplify
110    pub fn extract_logits(&self, tensor: &Tensor) -> Result<Vec<PseudoProbability>> {
111        let tensor = ops::softmax(tensor, D::Minus1)?;
112        let data = tensor.to_vec3::<f32>()?;
113        println!("Data: {:?}", data);
114        let shape = tensor.dims();
115        let mut logit_positions = Vec::new();
116        for seq_pos in 0..shape[1] {
117            for vocab_idx in 0..shape[2] {
118                let score = data[0][seq_pos][vocab_idx];
119                let amino_acid_char = self
120                    .tokenizer
121                    .decode(&[vocab_idx as u32], false)
122                    .map_err(|e| anyhow!("Failed to decode: {}", e))?
123                    .chars()
124                    .next()
125                    .ok_or_else(|| anyhow!("Empty decoded string"))?;
126                logit_positions.push(PseudoProbability {
127                    position: seq_pos,
128                    amino_acid: amino_acid_char,
129                    pseudo_prob: score,
130                });
131            }
132        }
133        Ok(logit_positions)
134    }
135}
136
137#[cfg(test)]
138mod tests {
139    use super::*;
140
141    #[test]
142    fn test_tokenizer_load() -> Result<()> {
143        let tokenizer = ESM2::load_tokenizer()?;
144        let text = "MLKLRV";
145        let encoding = tokenizer
146            .encode(text, false)
147            .map_err(|e| anyhow!("Failed to encode: {}", e))?;
148        let tokens = encoding.get_tokens();
149        assert_eq!(tokens.len(), 6);
150        assert_eq!(tokens, &["M", "L", "K", "L", "R", "V"]);
151        Ok(())
152    }
153}