ferritin_plms/esm2/
esm2_runner.rs1use super::esm2::{ESM2, ESM2Config, ESM2Output};
5use anyhow::{Error as E, Result, anyhow};
6use candle_core::{DType, Device, Tensor};
7use candle_nn::VarBuilder;
8use hf_hub::{Repo, RepoType, api::sync::Api};
9use tokenizers::Tokenizer;
10
11const ESM2_DTYPE: DType = DType::F32;
12
13pub enum ESM2Models {
14 T6_8M,
15 T12_35M,
16 T30_150M,
17 T33_650M,
18 T36_3B,
19 T48_15B,
20}
21impl ESM2Models {
22 pub fn get_model_files(model: Self) -> (&'static str, &'static str, ESM2Config) {
23 match model {
24 Self::T6_8M => ("facebook/esm2_t6_8M_UR50D", "main", ESM2Config::t6_8m()),
25 Self::T12_35M => ("facebook/esm2_t12_35M_UR50D", "main", ESM2Config::t12_35m()),
26 Self::T30_150M => (
27 "facebook/esm2_t30_150M_UR50D",
28 "main",
29 ESM2Config::t30_150m(),
30 ),
31 Self::T33_650M => (
32 "facebook/esm2_t33_650M_UR50D",
33 "main",
34 ESM2Config::t33_650m(),
35 ),
36 Self::T36_3B => ("facebook/esm2_t36_3B_UR50D", "main", ESM2Config::t36_3b()),
37 Self::T48_15B => ("facebook/esm2_t48_15B_UR50D", "main", ESM2Config::t48_15b()),
38 }
39 }
40}
41
42pub struct ESM2Runner {
43 model: ESM2,
44 tokenizer: Tokenizer,
45}
46impl ESM2Runner {
47 pub fn load_model(modeltype: ESM2Models, device: Device) -> Result<ESM2Runner> {
49 let (model_id, revision, config) = ESM2Models::get_model_files(modeltype);
50 let repo = Repo::with_revision(model_id.to_string(), RepoType::Model, revision.to_string());
51 let api = Api::new()?;
52 let api = api.repo(repo);
53 let weights_filename = api.get("model.safetensors")?;
54 let vb = unsafe {
55 VarBuilder::from_mmaped_safetensors(&[weights_filename], ESM2_DTYPE, &device)?
56 };
57 let model = ESM2::load(vb, config)?;
58 let tokenizer = ESM2::load_tokenizer()?;
59 Ok(ESM2Runner { model, tokenizer })
60 }
61 pub fn run_forward(&self, prot_sequence: &str) -> Result<ESM2Output> {
62 let device = self.model.get_device();
63 let tokens = self
64 .tokenizer
65 .encode(prot_sequence.to_string(), false)
66 .map_err(E::msg)?
67 .get_ids()
68 .to_vec();
69 let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
70 let encoded = self.model.forward(&token_ids)?;
71 Ok(encoded)
72 }
73 pub fn decode_logits(&self, output: ESM2Output) -> Result<String> {
74 let predicted_token_ids = output.logits.argmax(2)?;
76 let predicted_token_ids = if predicted_token_ids.dims().len() > 1 {
77 predicted_token_ids.squeeze(0)?
78 } else {
79 predicted_token_ids
80 };
81 let token_ids: Vec<u32> = predicted_token_ids.to_vec1::<u32>()?;
82 let decoded_sequence = self
83 .tokenizer
84 .decode(&token_ids, true) .map_err(|e| anyhow!("Failed to decode tokens: {}", e))?
86 .replace(" ", "");
87 Ok(decoded_sequence)
88 }
89}