ferritin_onnx_models/models/esm2/
mod.rs

1//! ESM2 Tokenizer. Models converted to ONNX format from [ESM2](https://github.com/facebookresearch/esm)
2//! and uploaded to HuggingFace hub. The tokenizer is included in this crate and loaded from
3//! memory using `tokenizer.json`. This is fairly minimal - for the full set of ESM2 models
4//! please see the ESM2 repository and the HuggingFace hub.
5//!
6//! # Models:
7//! * T6_8M - small 6-layer protein language model
8//! * T12_35M - medium 12-layer protein language model
9//! * T30_150M - large 30-layer protein language model
10//!
11use super::super::utilities::ndarray_to_tensor_f32;
12use anyhow::{Result, anyhow};
13use candle_core::{D, Tensor};
14use candle_nn::ops;
15use ferritin_plms::types::PseudoProbability;
16use hf_hub::api::sync::Api;
17use ndarray::Array2;
18use ort::{
19    execution_providers::CUDAExecutionProvider,
20    session::{
21        Session,
22        builder::{GraphOptimizationLevel, SessionBuilder},
23    },
24};
25use std::path::PathBuf;
26use tokenizers::Tokenizer;
27
28pub enum ESM2Models {
29    T6_8M,
30    T12_35M,
31    T30_150M,
32    // ESM2_T33_650M,
33}
34
35pub struct ESM2 {
36    pub session: SessionBuilder,
37    pub model_path: PathBuf,
38    pub tokenizer: Tokenizer,
39}
40
41impl ESM2 {
42    pub fn new(model: ESM2Models) -> Result<Self> {
43        let session = Self::create_session()?;
44        let model_path = Self::load_model_path(model)?;
45        let tokenizer = Self::load_tokenizer()?;
46        Ok(Self {
47            session,
48            model_path,
49            tokenizer,
50        })
51    }
52    pub fn load_model_path(model: ESM2Models) -> Result<PathBuf> {
53        let api = Api::new()?;
54        let repo_id = match model {
55            ESM2Models::T6_8M => "zcpbx/esm2-t6-8m-UR50D-onnx",
56            ESM2Models::T12_35M => "zcpbx/esm2-t12-35M-UR50D-onnx",
57            ESM2Models::T30_150M => "zcpbx/esm2-t30-150M-UR50D-onnx",
58            // ESM2Models::ESM2_T33_650M => "zcpbx/esm2-t33-650M-UR50D-onnx",
59        }
60        .to_string();
61        let model_path = api.model(repo_id).get("model.onnx")?;
62        Ok(model_path)
63    }
64    pub fn load_tokenizer() -> Result<Tokenizer> {
65        let tokenizer_bytes = include_bytes!("tokenizer.json");
66        Tokenizer::from_bytes(tokenizer_bytes)
67            .map_err(|e| anyhow!("Failed to load tokenizer: {}", e))
68    }
69    fn create_session() -> Result<SessionBuilder> {
70        ort::init()
71            .with_name("ESM2")
72            .with_execution_providers([CUDAExecutionProvider::default().build()])
73            .commit()?;
74        Ok(Session::builder()?
75            .with_optimization_level(GraphOptimizationLevel::Level1)?
76            .with_intra_threads(1)?)
77    }
78    pub fn run_model(&self, sequence: &str) -> Result<Tensor> {
79        let model = self.session.clone().commit_from_file(&self.model_path)?;
80        let tokens = self
81            .tokenizer
82            .encode(sequence, false)
83            .map_err(|e| anyhow!("Tokenization failed: {}", e))?;
84        let token_ids = tokens.get_ids();
85        let shape = (1, tokens.len());
86        // Todo: Are we masking this correctly?
87        let mask_array: Array2<i64> = Array2::from_shape_vec(shape, vec![1; tokens.len()])?;
88        let tokens_array: Array2<i64> = Array2::from_shape_vec(
89            shape,
90            token_ids.iter().map(|&x| x as i64).collect::<Vec<_>>(),
91        )?;
92        let outputs =
93            model.run(ort::inputs!["input_ids" => tokens_array,"attention_mask" => mask_array]?)?;
94        let logits = outputs["logits"].try_extract_tensor::<f32>()?.to_owned();
95        Ok(ndarray_to_tensor_f32(logits)?)
96    }
97    // Softmax and simplify
98    pub fn extract_logits(&self, tensor: &Tensor) -> Result<Vec<PseudoProbability>> {
99        let tensor = ops::softmax(tensor, D::Minus1)?;
100        let data = tensor.to_vec3::<f32>()?;
101        println!("Data: {:?}", data);
102        let shape = tensor.dims();
103        let mut logit_positions = Vec::new();
104        for seq_pos in 0..shape[1] {
105            for vocab_idx in 0..shape[2] {
106                let score = data[0][seq_pos][vocab_idx];
107                let amino_acid_char = self
108                    .tokenizer
109                    .decode(&[vocab_idx as u32], false)
110                    .map_err(|e| anyhow!("Failed to decode: {}", e))?
111                    .chars()
112                    .next()
113                    .ok_or_else(|| anyhow!("Empty decoded string"))?;
114                logit_positions.push(PseudoProbability {
115                    position: seq_pos,
116                    amino_acid: amino_acid_char,
117                    pseudo_prob: score,
118                });
119            }
120        }
121        Ok(logit_positions)
122    }
123}
124
125#[cfg(test)]
126mod tests {
127    use super::*;
128
129    fn test_tokenizer_load() -> Result<()> {
130        let tokenizer = ESM2::load_tokenizer()?;
131        let text = "MLKLRV";
132        let encoding = tokenizer
133            .encode(text, false)
134            .map_err(|e| anyhow!("Failed to encode: {}", e))?;
135        let tokens = encoding.get_tokens();
136        assert_eq!(tokens.len(), 6);
137        assert_eq!(tokens, &["M", "L", "K", "L", "R", "V"]);
138        Ok(())
139    }
140}