ferritin_plms/amplify/
amplify_runner.rs1use super::super::types::{ContactMap, PseudoProbability};
6use super::amplify::{AMPLIFY, ModelOutput};
7use super::config::AMPLIFYConfig;
8use anyhow::{Error as E, Result, anyhow};
9use candle_core::{D, DType, Device, Tensor};
10use candle_nn::VarBuilder;
11use candle_nn::ops;
12use hf_hub::{Repo, RepoType, api::sync::Api};
13use tokenizers::Tokenizer;
14
15const AMPLIFY_DTYPE: DType = DType::F32;
16
17pub enum AmplifyModels {
18 AMP120M,
19 AMP350M,
20}
21impl AmplifyModels {
22 pub fn get_model_files(model: Self) -> (&'static str, &'static str) {
23 match model {
24 AmplifyModels::AMP120M => ("chandar-lab/AMPLIFY_120M", "main"),
25 AmplifyModels::AMP350M => ("chandar-lab/AMPLIFY_350M", "main"),
26 }
27 }
28}
29
30pub struct AmplifyRunner {
31 model: AMPLIFY,
32 tokenizer: Tokenizer,
33}
34impl AmplifyRunner {
35 pub fn load_model(modeltype: AmplifyModels, device: Device) -> Result<AmplifyRunner> {
36 let (model_id, revision) = AmplifyModels::get_model_files(modeltype);
37 let repo = Repo::with_revision(model_id.to_string(), RepoType::Model, revision.to_string());
38 let (config_filename, tokenizer_filename, weights_filename) = {
39 let api = Api::new()?;
40 let api = api.repo(repo);
41 let config = api.get("config.json")?;
42 let tokenizer = api.get("tokenizer.json")?;
43 let weights = api.get("model.safetensors")?;
44 (config, tokenizer, weights)
45 };
46 let config_str = std::fs::read_to_string(config_filename)?;
47 let config_str = config_str
48 .replace("SwiGLU", "swiglu")
49 .replace("Swiglu", "swiglu");
50 let config: AMPLIFYConfig = serde_json::from_str(&config_str)?;
51 let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
52 let vb = unsafe {
53 VarBuilder::from_mmaped_safetensors(&[weights_filename], AMPLIFY_DTYPE, &device)?
54 };
55 let model = AMPLIFY::load(vb, &config)?;
56 Ok(AmplifyRunner { model, tokenizer })
57 }
58 pub fn run_forward(&self, prot_sequence: &str) -> Result<ModelOutput> {
59 let device = self.model.get_device();
60 let tokens = self
61 .tokenizer
62 .encode(prot_sequence.to_string(), false)
63 .map_err(E::msg)?
64 .get_ids()
65 .to_vec();
66 let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
67 let encoded = self.model.forward(&token_ids, None, false, true)?;
68 Ok(encoded)
69 }
70 pub fn get_best_prediction(
71 &self,
72 prot_sequence: &str,
73 ) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
74 let model_output = self.run_forward(prot_sequence)?;
75 let predictions = model_output.logits.argmax(D::Minus1)?;
76 let indices: Vec<u32> = predictions.to_vec2()?[0].to_vec();
77 let decoded = self.tokenizer.decode(indices.as_slice(), true)?;
78 let decoded = decoded.replace(" ", "");
79 Ok(decoded)
80 }
81 pub fn get_pseudo_probabilities(&self, prot_sequence: &str) -> Result<Vec<PseudoProbability>> {
82 let model_output = self.run_forward(prot_sequence)?;
83 let predictions = model_output.logits;
84 let outputs = self.extract_logits(&predictions)?;
85 Ok(outputs)
86 }
87 pub fn get_contact_map(&self, prot_sequence: &str) -> Result<Vec<ContactMap>> {
88 let model_output = self.run_forward(prot_sequence)?;
89 let contact_map_tensor = model_output.get_contact_map()?;
90 let averaged = contact_map_tensor.clone().unwrap().max_keepdim(D::Minus1)?;
91 let (position1, position2, val) = averaged.dims3()?;
92 let data = averaged.to_vec3::<f32>()?;
93
94 let mut contacts = Vec::new();
95 for i in 0..position1 {
96 for j in 0..position2 {
97 for k in 0..val {
98 contacts.push(ContactMap {
99 position_1: i,
100 amino_acid_1: self
101 .tokenizer
102 .decode(&[i as u32], true)
103 .ok()
104 .and_then(|s| s.chars().next())
105 .unwrap_or('?'),
106 position_2: j,
107 amino_acid_2: self
108 .tokenizer
109 .decode(&[i as u32], true)
110 .ok()
111 .and_then(|s| s.chars().next())
112 .unwrap_or('?'),
113 contact_estimate: data[i][j][k],
114 layer: 1,
115 });
116 }
117 }
118 }
119 Ok(contacts)
120 }
121 fn extract_logits(&self, tensor: &Tensor) -> Result<Vec<PseudoProbability>> {
123 let tensor = ops::softmax(tensor, D::Minus1)?;
124 let data = tensor.to_vec3::<f32>()?;
125 let (_, seq_len, vocab_size) = tensor.dims3()?;
126 let mut logit_positions = Vec::with_capacity(seq_len * vocab_size);
127 for seq_pos in 0..seq_len {
128 for vocab_idx in 0..vocab_size {
129 let score = data[0][seq_pos][vocab_idx];
130 let amino_acid_char = self
131 .tokenizer
132 .decode(&[vocab_idx as u32], false)
133 .map_err(|e| anyhow!("Failed to decode: {}", e))?
134 .chars()
135 .next()
136 .ok_or_else(|| anyhow!("Empty decoded string"))?;
137 logit_positions.push(PseudoProbability {
138 position: seq_pos,
139 amino_acid: amino_acid_char,
140 pseudo_prob: score,
141 });
142 }
143 }
144 Ok(logit_positions)
145 }
146}