ferritin_plms/esm2/
esm2_runner.rs

1//! ESM2 Runner
2//!
3//! Class for loading and running the ESM2 models
4use super::esm2::{ESM2, ESM2Config, ESM2Output};
5use anyhow::{Error as E, Result, anyhow};
6use candle_core::{DType, Device, Tensor};
7use candle_nn::VarBuilder;
8use hf_hub::{Repo, RepoType, api::sync::Api};
9use tokenizers::Tokenizer;
10
11const ESM2_DTYPE: DType = DType::F32;
12
13pub enum ESM2Models {
14    T6_8M,
15    T12_35M,
16    T30_150M,
17    T33_650M,
18    T36_3B,
19    T48_15B,
20}
21impl ESM2Models {
22    pub fn get_model_files(model: Self) -> (&'static str, &'static str, ESM2Config) {
23        match model {
24            Self::T6_8M => ("facebook/esm2_t6_8M_UR50D", "main", ESM2Config::t6_8m()),
25            Self::T12_35M => ("facebook/esm2_t12_35M_UR50D", "main", ESM2Config::t12_35m()),
26            Self::T30_150M => (
27                "facebook/esm2_t30_150M_UR50D",
28                "main",
29                ESM2Config::t30_150m(),
30            ),
31            Self::T33_650M => (
32                "facebook/esm2_t33_650M_UR50D",
33                "main",
34                ESM2Config::t33_650m(),
35            ),
36            Self::T36_3B => ("facebook/esm2_t36_3B_UR50D", "main", ESM2Config::t36_3b()),
37            Self::T48_15B => ("facebook/esm2_t48_15B_UR50D", "main", ESM2Config::t48_15b()),
38        }
39    }
40}
41
42pub struct ESM2Runner {
43    model: ESM2,
44    tokenizer: Tokenizer,
45}
46impl ESM2Runner {
47    // from hf-hub
48    pub fn load_model(modeltype: ESM2Models, device: Device) -> Result<ESM2Runner> {
49        let (model_id, revision, config) = ESM2Models::get_model_files(modeltype);
50        let repo = Repo::with_revision(model_id.to_string(), RepoType::Model, revision.to_string());
51        let api = Api::new()?;
52        let api = api.repo(repo);
53        let weights_filename = api.get("model.safetensors")?;
54        let vb = unsafe {
55            VarBuilder::from_mmaped_safetensors(&[weights_filename], ESM2_DTYPE, &device)?
56        };
57        let model = ESM2::load(vb, config)?;
58        let tokenizer = ESM2::load_tokenizer()?;
59        Ok(ESM2Runner { model, tokenizer })
60    }
61    pub fn run_forward(&self, prot_sequence: &str) -> Result<ESM2Output> {
62        let device = self.model.get_device();
63        let tokens = self
64            .tokenizer
65            .encode(prot_sequence.to_string(), false)
66            .map_err(E::msg)?
67            .get_ids()
68            .to_vec();
69        let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
70        let encoded = self.model.forward(&token_ids)?;
71        Ok(encoded)
72    }
73    pub fn decode_logits(&self, output: ESM2Output) -> Result<String> {
74        // Get the predicted token IDs by taking argmax along the vocabulary dimension
75        let predicted_token_ids = output.logits.argmax(2)?;
76        let predicted_token_ids = if predicted_token_ids.dims().len() > 1 {
77            predicted_token_ids.squeeze(0)?
78        } else {
79            predicted_token_ids
80        };
81        let token_ids: Vec<u32> = predicted_token_ids.to_vec1::<u32>()?;
82        let decoded_sequence = self
83            .tokenizer
84            .decode(&token_ids, true) // set skip_special_tokens to true
85            .map_err(|e| anyhow!("Failed to decode tokens: {}", e))?
86            .replace(" ", "");
87        Ok(decoded_sequence)
88    }
89}