ferritin_plms/esm2/
esm2_runner.rs1use super::esm2::{ESM2, ESM2Config, ESM2Output};
5use anyhow::{Error as E, Result, anyhow};
6use candle_core::{DType, Device, Tensor};
7use candle_nn::VarBuilder;
8use hf_hub::{Repo, RepoType, api::sync::Api};
9use serde_json;
10use tokenizers::Tokenizer;
11
12const ESM2_DTYPE: DType = DType::F32;
13
14pub enum ESM2Models {
15 T6_8M,
16 T12_35M,
17 T30_150M,
18 T33_650M,
19 T36_3B,
20 T48_15B,
21}
22impl ESM2Models {
23 pub fn get_model_files(model: Self) -> (&'static str, &'static str, ESM2Config) {
24 match model {
25 Self::T6_8M => ("facebook/esm2_t6_8M_UR50D", "main", ESM2Config::t6_8m()),
26 Self::T12_35M => ("facebook/esm2_t12_35M_UR50D", "main", ESM2Config::t12_35m()),
27 Self::T30_150M => (
28 "facebook/esm2_t30_150M_UR50D",
29 "main",
30 ESM2Config::t30_150m(),
31 ),
32 Self::T33_650M => (
33 "facebook/esm2_t33_650M_UR50D",
34 "main",
35 ESM2Config::t33_650m(),
36 ),
37 Self::T36_3B => ("facebook/esm2_t36_3B_UR50D", "main", ESM2Config::t36_3b()),
38 Self::T48_15B => ("facebook/esm2_t48_15B_UR50D", "main", ESM2Config::t48_15b()),
39 }
40 }
41}
42
43pub struct ESM2Runner {
44 model: ESM2,
45 tokenizer: Tokenizer,
46}
47impl ESM2Runner {
48 pub fn load_model(modeltype: ESM2Models, device: Device) -> Result<ESM2Runner> {
50 let (model_id, revision, fallback_config) = ESM2Models::get_model_files(modeltype);
51 let repo = Repo::with_revision(model_id.to_string(), RepoType::Model, revision.to_string());
52 let api = Api::new()?;
53 let api = api.repo(repo);
54 let config = match api.get("config.json") {
56 Ok(config_path) => {
57 let config_str = std::fs::read_to_string(config_path)?;
58 serde_json::from_str::<ESM2Config>(&config_str).unwrap_or(fallback_config)
59 }
60 Err(_) => fallback_config,
61 };
62 let weights_filename = api.get("model.safetensors")?;
63 let vb = unsafe {
64 VarBuilder::from_mmaped_safetensors(&[weights_filename], ESM2_DTYPE, &device)?
65 };
66 let model = ESM2::load(vb, config)?;
67 let tokenizer = ESM2::load_tokenizer()?;
68 Ok(ESM2Runner { model, tokenizer })
69 }
70 pub fn run_forward(&self, prot_sequence: &str) -> Result<ESM2Output> {
71 let device = self.model.get_device();
72 let tokens = self
73 .tokenizer
74 .encode(prot_sequence.to_string(), false)
75 .map_err(E::msg)?
76 .get_ids()
77 .to_vec();
78 let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
79 let encoded = self.model.forward(&token_ids, None)?;
80 Ok(encoded)
81 }
82 pub fn predict_contacts(&self, prot_sequence: &str) -> Result<Tensor> {
87 let device = self.model.get_device();
88 let tokens = self
89 .tokenizer
90 .encode(prot_sequence.to_string(), false)
91 .map_err(E::msg)?
92 .get_ids()
93 .to_vec();
94 let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
95 self.model
97 .predict_contacts(&token_ids, None)
98 .map_err(E::msg)?
99 .squeeze(0)
100 .map_err(E::msg)
101 }
102
103 pub fn decode_logits(&self, output: ESM2Output) -> Result<String> {
104 let predicted_token_ids = output.logits.argmax(2)?;
106 let predicted_token_ids = if predicted_token_ids.dims().len() > 1 {
107 predicted_token_ids.squeeze(0)?
108 } else {
109 predicted_token_ids
110 };
111 let token_ids: Vec<u32> = predicted_token_ids.to_vec1::<u32>()?;
112 let decoded_sequence = self
113 .tokenizer
114 .decode(&token_ids, true) .map_err(|e| anyhow!("Failed to decode tokens: {}", e))?
116 .replace(" ", "");
117 Ok(decoded_sequence)
118 }
119}