ferritin_plms/esm/models/
esmc.rs

1use crate::esm::layers::regression_head::RegressionHead;
2use crate::esm::layers::transformer_stack::TransformerStack;
3// use crate::esm::pretrained::load_local_model;
4// use crate::esm::sdk::api::ESMProtein;
5// use crate::esm::sdk::api::ESMProteinTensor;
6// use crate::esm::sdk::api::ForwardTrackData;
7// use crate::esm::sdk::api::LogitsConfig;
8// use crate::esm::sdk::api::LogitsOutput;
9use crate::esm::tokenization::TokenizerCollection;
10use crate::esm::tokenization::sequence_tokenizer::EsmSequenceTokenizer;
11use candle_core::{Result, Tensor};
12use candle_nn::{self as nn, VarBuilder};
13// use crate::esm::utils::decoding::decode_sequence;
14// use crate::esm::utils::encoding::tokenize_sequence;
15// use crate::esm::utils::sampling::BatchedESMProteinTensor;
16
17#[derive(Debug)]
18struct ESMCOutput {
19    sequence_logits: Tensor,
20    embeddings: Option<Tensor>,
21}
22
23#[derive(Clone, Copy)]
24pub enum ESMTokenizer {
25    Esm3OpenSmall,
26}
27impl ESMTokenizer {
28    pub fn get_model_tokenizers(&self) -> TokenizerCollection {
29        match self {
30            ESMTokenizer::Esm3OpenSmall => {
31                let esm_tokenizer = EsmSequenceTokenizer::default();
32                TokenizerCollection {
33                    sequence: esm_tokenizer,
34                }
35            }
36        }
37    }
38}
39
40#[derive(Clone, Copy)]
41pub enum FfnType {
42    SWIGLU,
43    GLU,
44}
45
46#[derive(Clone)]
47pub struct ESMCConfig {
48    pub d_model: usize,
49    pub n_heads: usize,
50    pub n_layers: usize,
51    pub v_head_transformer: Option<usize>,
52    pub ffn_type: FfnType,
53    pub tokenizer: ESMTokenizer,
54    // oringal above.
55    pub use_plain_attn: bool,
56    pub n_layers_geom: usize,
57    pub scale_residue: bool,
58    pub residue_scaling_factor: f64,
59    pub mask_and_zero_frameless: bool,
60    pub bias: bool,
61    pub qk_layernorm: bool,
62    pub expansion_ratio: f64,
63    // reg
64    pub regression_head_output_dim: usize,
65    pub regression_head_hidden_dim: usize,
66    pub embedding_dim: usize,
67}
68
69impl ESMCConfig {
70    pub fn esmc_300m() -> Self {
71        //
72        //    residue_scaling_factor=  if scale_residue {
73        //         (n_layers as f64 / 36.0).sqrt()
74        //     } else {
75        //         1.0
76        //     },
77
78        Self {
79            d_model: 960,
80            n_heads: 15,
81            n_layers: 30,
82            v_head_transformer: None,
83            ffn_type: FfnType::SWIGLU,
84            tokenizer: ESMTokenizer::Esm3OpenSmall,
85            use_plain_attn: true,
86            n_layers_geom: 1,
87            scale_residue: true,
88            residue_scaling_factor: (30f64 / 36.).sqrt(),
89            mask_and_zero_frameless: false,
90            bias: false,
91            qk_layernorm: true,
92            expansion_ratio: 8.0 / 3.0,
93            regression_head_output_dim: 64,
94            regression_head_hidden_dim: 960, // d_model
95            embedding_dim: 64,
96        }
97    }
98}
99
100pub struct ESMC {
101    embed: candle_nn::Embedding,
102    transformer: TransformerStack,
103    sequence_head: RegressionHead,
104    tokenizer: EsmSequenceTokenizer,
105}
106
107impl ESMC {
108    // pub fn new(
109    //     d_model: usize,
110    //     n_heads: usize,
111    //     n_layers: usize,
112    //     tokenizer: EsmSequenceTokenizer,
113    // ) -> Self {
114    //     Self {
115    //         embed: nn::embedding(64, d_model, Default::default())?,
116    //         transformer: TransformerStack::new(d_model, n_heads, None, n_layers, 0)?,
117    //         sequence_head: RegressionHead::new(d_model, 64)?,
118    //         tokenizer,
119    //     }
120    // }
121
122    pub fn load(vb: VarBuilder, config: ESMCConfig) -> Result<Self> {
123        let ESMCConfig {
124            d_model,
125            tokenizer,
126            embedding_dim,
127            ..
128        } = config;
129
130        let tokenizer_collection = tokenizer.get_model_tokenizers();
131
132        Ok(Self {
133            embed: nn::embedding(embedding_dim, d_model, vb.pp("embed"))?,
134            transformer: TransformerStack::load(vb.pp("transformer"), &config)?,
135            sequence_head: RegressionHead::load(vb.pp("sequence_head"), &config)?,
136            tokenizer: tokenizer_collection.sequence,
137        })
138    }
139
140    // pub fn from_pretrained(model_name: impl Into<String>, device: Option<Device>) -> Result<Self> {
141    //     let device = device.unwrap_or(Device::cuda_if_available()?);
142    //     let model = load_local_model(&model_name.into(), &device)?;
143    //     if device.is_cuda() {
144    //         model.to_dtype(DType::BF16)?;
145    //     }
146    //     Ok(model)
147    // }
148
149    // pub fn forward(
150    //     &self,
151    //     sequence_tokens: Option<&Tensor>,
152    //     sequence_id: Option<&Tensor>,
153    // ) -> Result<ESMCOutput> {
154    //     let sequence_id = sequence_id
155    //         .unwrap_or({ &(sequence_tokens.unwrap().eq(self.tokenizer.pad_token_id)?)? });
156
157    //     let x = self.embed.forward(sequence_tokens.unwrap())?;
158    //     let (x, _) = self.transformer.forward(&x, Some(sequence_id))?;
159    //     let sequence_logits = self.sequence_head.forward(&x)?;
160
161    //     Ok(ESMCOutput {
162    //         sequence_logits,
163    //         embeddings: Some(x),
164    //     })
165    // }
166
167    // pub fn encode(&self, input: &ESMProtein) -> Result<ESMProteinTensor> {
168    //     let sequence_tokens = if let Some(seq) = &input.sequence {
169    //         Some(tokenize_sequence(seq, &self.tokenizer, true)?)
170    //     } else {
171    //         None
172    //     };
173
174    //     Ok(ESMProteinTensor::new(sequence_tokens)?.to_device(&self.device())?)
175    // }
176
177    // pub fn decode(&self, input: &ESMProteinTensor) -> Result<ESMProtein> {
178    //     let sequence = input.sequence.as_ref().ok_or("Missing sequence")?;
179    //     let sequence = decode_sequence(&sequence.slice(1..-1)?, &self.tokenizer)?;
180    //     Ok(ESMProtein::new(Some(sequence)))
181    // }
182
183    // pub fn logits(&self, input: &ESMProteinTensor, config: &LogitsConfig) -> Result<LogitsOutput> {
184    //     let input = if !input.is_batched() {
185    //         BatchedESMProteinTensor::from_protein_tensor(input)?
186    //     } else {
187    //         input.clone()
188    //     };
189
190    //     candle_core::no_grad(|| {
191    //         let output = self.forward(Some(&input.sequence), None)?;
192
193    //         Ok(LogitsOutput {
194    //             logits: ForwardTrackData {
195    //                 sequence: if config.sequence {
196    //                     Some(output.sequence_logits)
197    //                 } else {
198    //                     None
199    //                 },
200    //             },
201    //             embeddings: if config.return_embeddings {
202    //                 output.embeddings
203    //             } else {
204    //                 None
205    //             },
206    //         })
207    //     })
208    // }
209}