ferritin_plms/esm/models/esmc.rs
1use crate::esm::layers::regression_head::RegressionHead;
2use crate::esm::layers::transformer_stack::TransformerStack;
3// use crate::esm::pretrained::load_local_model;
4// use crate::esm::sdk::api::ESMProtein;
5// use crate::esm::sdk::api::ESMProteinTensor;
6// use crate::esm::sdk::api::ForwardTrackData;
7// use crate::esm::sdk::api::LogitsConfig;
8// use crate::esm::sdk::api::LogitsOutput;
9use crate::esm::tokenization::TokenizerCollection;
10use crate::esm::tokenization::sequence_tokenizer::EsmSequenceTokenizer;
11use candle_core::{Result, Tensor};
12use candle_nn::{self as nn, VarBuilder};
13// use crate::esm::utils::decoding::decode_sequence;
14// use crate::esm::utils::encoding::tokenize_sequence;
15// use crate::esm::utils::sampling::BatchedESMProteinTensor;
16
17#[derive(Debug)]
18struct ESMCOutput {
19 sequence_logits: Tensor,
20 embeddings: Option<Tensor>,
21}
22
23#[derive(Clone, Copy)]
24pub enum ESMTokenizer {
25 Esm3OpenSmall,
26}
27impl ESMTokenizer {
28 pub fn get_model_tokenizers(&self) -> TokenizerCollection {
29 match self {
30 ESMTokenizer::Esm3OpenSmall => {
31 let esm_tokenizer = EsmSequenceTokenizer::default();
32 TokenizerCollection {
33 sequence: esm_tokenizer,
34 }
35 }
36 }
37 }
38}
39
40#[derive(Clone, Copy)]
41pub enum FfnType {
42 SWIGLU,
43 GLU,
44}
45
46#[derive(Clone)]
47pub struct ESMCConfig {
48 pub d_model: usize,
49 pub n_heads: usize,
50 pub n_layers: usize,
51 pub v_head_transformer: Option<usize>,
52 pub ffn_type: FfnType,
53 pub tokenizer: ESMTokenizer,
54 // oringal above.
55 pub use_plain_attn: bool,
56 pub n_layers_geom: usize,
57 pub scale_residue: bool,
58 pub residue_scaling_factor: f64,
59 pub mask_and_zero_frameless: bool,
60 pub bias: bool,
61 pub qk_layernorm: bool,
62 pub expansion_ratio: f64,
63 // reg
64 pub regression_head_output_dim: usize,
65 pub regression_head_hidden_dim: usize,
66 pub embedding_dim: usize,
67}
68
69impl ESMCConfig {
70 pub fn esmc_300m() -> Self {
71 //
72 // residue_scaling_factor= if scale_residue {
73 // (n_layers as f64 / 36.0).sqrt()
74 // } else {
75 // 1.0
76 // },
77
78 Self {
79 d_model: 960,
80 n_heads: 15,
81 n_layers: 30,
82 v_head_transformer: None,
83 ffn_type: FfnType::SWIGLU,
84 tokenizer: ESMTokenizer::Esm3OpenSmall,
85 use_plain_attn: true,
86 n_layers_geom: 1,
87 scale_residue: true,
88 residue_scaling_factor: (30f64 / 36.).sqrt(),
89 mask_and_zero_frameless: false,
90 bias: false,
91 qk_layernorm: true,
92 expansion_ratio: 8.0 / 3.0,
93 regression_head_output_dim: 64,
94 regression_head_hidden_dim: 960, // d_model
95 embedding_dim: 64,
96 }
97 }
98}
99
100pub struct ESMC {
101 embed: candle_nn::Embedding,
102 transformer: TransformerStack,
103 sequence_head: RegressionHead,
104 tokenizer: EsmSequenceTokenizer,
105}
106
107impl ESMC {
108 // pub fn new(
109 // d_model: usize,
110 // n_heads: usize,
111 // n_layers: usize,
112 // tokenizer: EsmSequenceTokenizer,
113 // ) -> Self {
114 // Self {
115 // embed: nn::embedding(64, d_model, Default::default())?,
116 // transformer: TransformerStack::new(d_model, n_heads, None, n_layers, 0)?,
117 // sequence_head: RegressionHead::new(d_model, 64)?,
118 // tokenizer,
119 // }
120 // }
121
122 pub fn load(vb: VarBuilder, config: ESMCConfig) -> Result<Self> {
123 let ESMCConfig {
124 d_model,
125 tokenizer,
126 embedding_dim,
127 ..
128 } = config;
129
130 let tokenizer_collection = tokenizer.get_model_tokenizers();
131
132 Ok(Self {
133 embed: nn::embedding(embedding_dim, d_model, vb.pp("embed"))?,
134 transformer: TransformerStack::load(vb.pp("transformer"), &config)?,
135 sequence_head: RegressionHead::load(vb.pp("sequence_head"), &config)?,
136 tokenizer: tokenizer_collection.sequence,
137 })
138 }
139
140 // pub fn from_pretrained(model_name: impl Into<String>, device: Option<Device>) -> Result<Self> {
141 // let device = device.unwrap_or(Device::cuda_if_available()?);
142 // let model = load_local_model(&model_name.into(), &device)?;
143 // if device.is_cuda() {
144 // model.to_dtype(DType::BF16)?;
145 // }
146 // Ok(model)
147 // }
148
149 // pub fn forward(
150 // &self,
151 // sequence_tokens: Option<&Tensor>,
152 // sequence_id: Option<&Tensor>,
153 // ) -> Result<ESMCOutput> {
154 // let sequence_id = sequence_id
155 // .unwrap_or({ &(sequence_tokens.unwrap().eq(self.tokenizer.pad_token_id)?)? });
156
157 // let x = self.embed.forward(sequence_tokens.unwrap())?;
158 // let (x, _) = self.transformer.forward(&x, Some(sequence_id))?;
159 // let sequence_logits = self.sequence_head.forward(&x)?;
160
161 // Ok(ESMCOutput {
162 // sequence_logits,
163 // embeddings: Some(x),
164 // })
165 // }
166
167 // pub fn encode(&self, input: &ESMProtein) -> Result<ESMProteinTensor> {
168 // let sequence_tokens = if let Some(seq) = &input.sequence {
169 // Some(tokenize_sequence(seq, &self.tokenizer, true)?)
170 // } else {
171 // None
172 // };
173
174 // Ok(ESMProteinTensor::new(sequence_tokens)?.to_device(&self.device())?)
175 // }
176
177 // pub fn decode(&self, input: &ESMProteinTensor) -> Result<ESMProtein> {
178 // let sequence = input.sequence.as_ref().ok_or("Missing sequence")?;
179 // let sequence = decode_sequence(&sequence.slice(1..-1)?, &self.tokenizer)?;
180 // Ok(ESMProtein::new(Some(sequence)))
181 // }
182
183 // pub fn logits(&self, input: &ESMProteinTensor, config: &LogitsConfig) -> Result<LogitsOutput> {
184 // let input = if !input.is_batched() {
185 // BatchedESMProteinTensor::from_protein_tensor(input)?
186 // } else {
187 // input.clone()
188 // };
189
190 // candle_core::no_grad(|| {
191 // let output = self.forward(Some(&input.sequence), None)?;
192
193 // Ok(LogitsOutput {
194 // logits: ForwardTrackData {
195 // sequence: if config.sequence {
196 // Some(output.sequence_logits)
197 // } else {
198 // None
199 // },
200 // },
201 // embeddings: if config.return_embeddings {
202 // output.embeddings
203 // } else {
204 // None
205 // },
206 // })
207 // })
208 // }
209}