Skip to main content

ferritin_plms/esm2/
esm2_runner.rs

1//! ESM2 Runner
2//!
3//! Class for loading and running the ESM2 models
4use super::esm2::{ESM2, ESM2Config, ESM2Output};
5use anyhow::{Error as E, Result, anyhow};
6use candle_core::{DType, Device, Tensor};
7use candle_nn::VarBuilder;
8use hf_hub::{Repo, RepoType, api::sync::Api};
9use serde_json;
10use tokenizers::Tokenizer;
11
12const ESM2_DTYPE: DType = DType::F32;
13
14pub enum ESM2Models {
15    T6_8M,
16    T12_35M,
17    T30_150M,
18    T33_650M,
19    T36_3B,
20    T48_15B,
21}
22impl ESM2Models {
23    pub fn get_model_files(model: Self) -> (&'static str, &'static str, ESM2Config) {
24        match model {
25            Self::T6_8M => ("facebook/esm2_t6_8M_UR50D", "main", ESM2Config::t6_8m()),
26            Self::T12_35M => ("facebook/esm2_t12_35M_UR50D", "main", ESM2Config::t12_35m()),
27            Self::T30_150M => (
28                "facebook/esm2_t30_150M_UR50D",
29                "main",
30                ESM2Config::t30_150m(),
31            ),
32            Self::T33_650M => (
33                "facebook/esm2_t33_650M_UR50D",
34                "main",
35                ESM2Config::t33_650m(),
36            ),
37            Self::T36_3B => ("facebook/esm2_t36_3B_UR50D", "main", ESM2Config::t36_3b()),
38            Self::T48_15B => ("facebook/esm2_t48_15B_UR50D", "main", ESM2Config::t48_15b()),
39        }
40    }
41}
42
43pub struct ESM2Runner {
44    model: ESM2,
45    tokenizer: Tokenizer,
46}
47impl ESM2Runner {
48    /// Load model from HuggingFace hub, downloading config.json, tokenizer files, and weights.
49    pub fn load_model(modeltype: ESM2Models, device: Device) -> Result<ESM2Runner> {
50        let (model_id, revision, fallback_config) = ESM2Models::get_model_files(modeltype);
51        let repo = Repo::with_revision(model_id.to_string(), RepoType::Model, revision.to_string());
52        let api = Api::new()?;
53        let api = api.repo(repo);
54        // Try to load config from HF hub; fall back to hardcoded config if unavailable.
55        let config = match api.get("config.json") {
56            Ok(config_path) => {
57                let config_str = std::fs::read_to_string(config_path)?;
58                serde_json::from_str::<ESM2Config>(&config_str).unwrap_or(fallback_config)
59            }
60            Err(_) => fallback_config,
61        };
62        let weights_filename = api.get("model.safetensors")?;
63        let vb = unsafe {
64            VarBuilder::from_mmaped_safetensors(&[weights_filename], ESM2_DTYPE, &device)?
65        };
66        let model = ESM2::load(vb, config)?;
67        let tokenizer = ESM2::load_tokenizer()?;
68        Ok(ESM2Runner { model, tokenizer })
69    }
70    pub fn run_forward(&self, prot_sequence: &str) -> Result<ESM2Output> {
71        let device = self.model.get_device();
72        let tokens = self
73            .tokenizer
74            .encode(prot_sequence.to_string(), false)
75            .map_err(E::msg)?
76            .get_ids()
77            .to_vec();
78        let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
79        let encoded = self.model.forward(&token_ids, None)?;
80        Ok(encoded)
81    }
82    /// Predict residue-residue contact probabilities for a single protein sequence.
83    ///
84    /// Returns a `(seq_len, seq_len)` contact probability matrix (BOS/EOS stripped,
85    /// so dimensions equal the number of amino acids in `prot_sequence`).
86    pub fn predict_contacts(&self, prot_sequence: &str) -> Result<Tensor> {
87        let device = self.model.get_device();
88        let tokens = self
89            .tokenizer
90            .encode(prot_sequence.to_string(), false)
91            .map_err(E::msg)?
92            .get_ids()
93            .to_vec();
94        let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
95        // squeeze batch dim: (1, L, L) → (L, L)
96        self.model
97            .predict_contacts(&token_ids, None)
98            .map_err(E::msg)?
99            .squeeze(0)
100            .map_err(E::msg)
101    }
102
103    pub fn decode_logits(&self, output: ESM2Output) -> Result<String> {
104        // Get the predicted token IDs by taking argmax along the vocabulary dimension
105        let predicted_token_ids = output.logits.argmax(2)?;
106        let predicted_token_ids = if predicted_token_ids.dims().len() > 1 {
107            predicted_token_ids.squeeze(0)?
108        } else {
109            predicted_token_ids
110        };
111        let token_ids: Vec<u32> = predicted_token_ids.to_vec1::<u32>()?;
112        let decoded_sequence = self
113            .tokenizer
114            .decode(&token_ids, true) // set skip_special_tokens to true
115            .map_err(|e| anyhow!("Failed to decode tokens: {}", e))?
116            .replace(" ", "");
117        Ok(decoded_sequence)
118    }
119}