ferritin_plms/amplify/
amplify_runner.rs

1//! Amplify RUnner
2//!
3//! Class for loading and running the AMPLIFY models
4
5use super::super::types::{ContactMap, PseudoProbability};
6use super::amplify::{AMPLIFY, ModelOutput};
7use super::config::AMPLIFYConfig;
8use anyhow::{Error as E, Result, anyhow};
9use candle_core::{D, DType, Device, Tensor};
10use candle_nn::VarBuilder;
11use candle_nn::ops;
12use hf_hub::{Repo, RepoType, api::sync::Api};
13use tokenizers::Tokenizer;
14
15const AMPLIFY_DTYPE: DType = DType::F32;
16
17pub enum AmplifyModels {
18    AMP120M,
19    AMP350M,
20}
21impl AmplifyModels {
22    pub fn get_model_files(model: Self) -> (&'static str, &'static str) {
23        match model {
24            AmplifyModels::AMP120M => ("chandar-lab/AMPLIFY_120M", "main"),
25            AmplifyModels::AMP350M => ("chandar-lab/AMPLIFY_350M", "main"),
26        }
27    }
28}
29
30pub struct AmplifyRunner {
31    model: AMPLIFY,
32    tokenizer: Tokenizer,
33}
34impl AmplifyRunner {
35    pub fn load_model(modeltype: AmplifyModels, device: Device) -> Result<AmplifyRunner> {
36        let (model_id, revision) = AmplifyModels::get_model_files(modeltype);
37        let repo = Repo::with_revision(model_id.to_string(), RepoType::Model, revision.to_string());
38        let (config_filename, tokenizer_filename, weights_filename) = {
39            let api = Api::new()?;
40            let api = api.repo(repo);
41            let config = api.get("config.json")?;
42            let tokenizer = api.get("tokenizer.json")?;
43            let weights = api.get("model.safetensors")?;
44            (config, tokenizer, weights)
45        };
46        let config_str = std::fs::read_to_string(config_filename)?;
47        let config_str = config_str
48            .replace("SwiGLU", "swiglu")
49            .replace("Swiglu", "swiglu");
50        let config: AMPLIFYConfig = serde_json::from_str(&config_str)?;
51        let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
52        let vb = unsafe {
53            VarBuilder::from_mmaped_safetensors(&[weights_filename], AMPLIFY_DTYPE, &device)?
54        };
55        let model = AMPLIFY::load(vb, &config)?;
56        Ok(AmplifyRunner { model, tokenizer })
57    }
58    pub fn run_forward(&self, prot_sequence: &str) -> Result<ModelOutput> {
59        let device = self.model.get_device();
60        let tokens = self
61            .tokenizer
62            .encode(prot_sequence.to_string(), false)
63            .map_err(E::msg)?
64            .get_ids()
65            .to_vec();
66        let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
67        let encoded = self.model.forward(&token_ids, None, false, true)?;
68        Ok(encoded)
69    }
70    pub fn get_best_prediction(
71        &self,
72        prot_sequence: &str,
73    ) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
74        let model_output = self.run_forward(prot_sequence)?;
75        let predictions = model_output.logits.argmax(D::Minus1)?;
76        let indices: Vec<u32> = predictions.to_vec2()?[0].to_vec();
77        let decoded = self.tokenizer.decode(indices.as_slice(), true)?;
78        let decoded = decoded.replace(" ", "");
79        Ok(decoded)
80    }
81    pub fn get_pseudo_probabilities(&self, prot_sequence: &str) -> Result<Vec<PseudoProbability>> {
82        let model_output = self.run_forward(prot_sequence)?;
83        let predictions = model_output.logits;
84        let outputs = self.extract_logits(&predictions)?;
85        Ok(outputs)
86    }
87    pub fn get_contact_map(&self, prot_sequence: &str) -> Result<Vec<ContactMap>> {
88        let model_output = self.run_forward(prot_sequence)?;
89        let contact_map_tensor = model_output.get_contact_map()?;
90        let averaged = contact_map_tensor.clone().unwrap().max_keepdim(D::Minus1)?;
91        let (position1, position2, val) = averaged.dims3()?;
92        let data = averaged.to_vec3::<f32>()?;
93
94        let mut contacts = Vec::new();
95        for i in 0..position1 {
96            for j in 0..position2 {
97                for k in 0..val {
98                    contacts.push(ContactMap {
99                        position_1: i,
100                        amino_acid_1: self
101                            .tokenizer
102                            .decode(&[i as u32], true)
103                            .ok()
104                            .and_then(|s| s.chars().next())
105                            .unwrap_or('?'),
106                        position_2: j,
107                        amino_acid_2: self
108                            .tokenizer
109                            .decode(&[i as u32], true)
110                            .ok()
111                            .and_then(|s| s.chars().next())
112                            .unwrap_or('?'),
113                        contact_estimate: data[i][j][k],
114                        layer: 1,
115                    });
116                }
117            }
118        }
119        Ok(contacts)
120    }
121    // Softmax and simplify
122    fn extract_logits(&self, tensor: &Tensor) -> Result<Vec<PseudoProbability>> {
123        let tensor = ops::softmax(tensor, D::Minus1)?;
124        let data = tensor.to_vec3::<f32>()?;
125        let (_, seq_len, vocab_size) = tensor.dims3()?;
126        let mut logit_positions = Vec::with_capacity(seq_len * vocab_size);
127        for seq_pos in 0..seq_len {
128            for vocab_idx in 0..vocab_size {
129                let score = data[0][seq_pos][vocab_idx];
130                let amino_acid_char = self
131                    .tokenizer
132                    .decode(&[vocab_idx as u32], false)
133                    .map_err(|e| anyhow!("Failed to decode: {}", e))?
134                    .chars()
135                    .next()
136                    .ok_or_else(|| anyhow!("Empty decoded string"))?;
137                logit_positions.push(PseudoProbability {
138                    position: seq_pos,
139                    amino_acid: amino_acid_char,
140                    pseudo_prob: score,
141                });
142            }
143        }
144        Ok(logit_positions)
145    }
146}