ferritin_plms/esmc/tokenization/
sequence_tokenizer.rs1use crate::esmc::utils::constants::esm3::SEQUENCE_VOCAB;
2use anyhow::Result;
3use std::collections::HashMap;
4use std::sync::Arc;
5use tokenizers::models::bpe::{BPE, BpeBuilder};
6use tokenizers::processors::PostProcessorWrapper;
7use tokenizers::processors::template::{Template, TemplateProcessing};
8use tokenizers::{AddedToken, Tokenizer};
9
10pub trait EsmTokenizerBase {
11 fn encode(&self) -> Result<()>;
12 fn decode(&self) -> Result<()>;
13 fn mask_token(&self) -> &str;
14 fn mask_token_id(&self) -> u32;
15 fn bos_token(&self) -> &str;
16 fn bos_token_id(&self) -> u32;
17 fn eos_token(&self) -> &str;
18 fn eos_token_id(&self) -> u32;
19 fn pad_token(&self) -> &str;
20 fn pad_token_id(&self) -> u32;
21 fn chain_break_token(&self) -> &str;
22 fn chain_break_token_id(&self) -> u32;
23 fn all_token_ids(&self) -> Vec<u32>;
24 fn special_token_ids(&self) -> Vec<u32>;
25}
26
27pub struct EsmSequenceTokenizer {
28 tokenizer: Arc<Tokenizer>,
29 cb_token: String,
30}
31
32impl EsmSequenceTokenizer {
33 pub fn new(
34 unk_token: &str,
35 cls_token: &str,
36 pad_token: &str,
37 mask_token: &str,
38 eos_token: &str,
39 chain_break_token: &str,
40 ) -> Result<Self> {
41 let mut token_to_id = HashMap::new();
42 for (i, tok) in SEQUENCE_VOCAB.iter().enumerate() {
43 token_to_id.insert(tok.to_string(), i);
44 }
45 let bpe_builder = BpeBuilder::new();
46 let bpe: BPE = bpe_builder
47 .unk_token(unk_token.to_string())
48 .build()
49 .map_err(|e| anyhow::anyhow!("Failed to build BPE tokenizer: {}", e))?;
50
51 let mut tokenizer = Tokenizer::new(bpe);
52 let special_tokens = vec![
53 AddedToken::from(cls_token, true),
54 AddedToken::from(pad_token, true),
55 AddedToken::from(mask_token, true),
56 AddedToken::from(eos_token, true),
57 AddedToken::from(chain_break_token, true),
58 ];
59
60 let _ = tokenizer.add_special_tokens(special_tokens);
61
62 let post_processor = TemplateProcessing::builder()
63 .try_single(Template::try_from(format!("{} $A {}", cls_token, eos_token)).unwrap())?
64 .special_tokens(vec![
65 (cls_token, tokenizer.token_to_id(cls_token).unwrap()),
66 (eos_token, tokenizer.token_to_id(eos_token).unwrap()),
67 ])
68 .build()?;
69
70 tokenizer.with_post_processor(Some(PostProcessorWrapper::Template(post_processor)));
71
72 Ok(Self {
73 tokenizer: Arc::new(tokenizer),
74 cb_token: chain_break_token.to_string(),
75 })
76 }
77}
78impl Default for EsmSequenceTokenizer {
79 fn default() -> Self {
80 Self::new("<unk>", "<cls>", "<pad>", "<mask>", "<eos>", "|")
81 .expect("Failed to create default tokenizer")
82 }
83}
84
85impl EsmSequenceTokenizer {
86 pub fn tokenize_sequence(&self, sequence: &str, add_special_tokens: bool) -> Vec<u32> {
92 use std::collections::HashMap;
93 let vocab: HashMap<&str, u32> = SEQUENCE_VOCAB
94 .iter()
95 .enumerate()
96 .map(|(i, s)| (*s, i as u32))
97 .collect();
98 let unk_id = *vocab.get("<unk>").unwrap_or(&3);
99
100 let mut tokens = Vec::with_capacity(sequence.len() + 2);
101 if add_special_tokens {
102 tokens.push(*vocab.get("<cls>").unwrap_or(&0));
103 }
104 for ch in sequence.chars() {
105 let s = ch.to_string();
106 let id = vocab.get(s.as_str()).copied().unwrap_or(unk_id);
107 tokens.push(id);
108 }
109 if add_special_tokens {
110 tokens.push(*vocab.get("<eos>").unwrap_or(&2));
111 }
112 tokens
113 }
114
115 pub fn decode_sequence(&self, token_ids: &[u32]) -> String {
120 const SPECIAL: [u32; 4] = [0, 1, 2, 32]; let mut result = String::new();
122 for &id in token_ids {
123 if SPECIAL.contains(&id) {
124 continue;
125 }
126 if let Some(tok) = SEQUENCE_VOCAB.get(id as usize) {
127 result.push_str(tok);
128 }
129 }
130 result
131 }
132}
133
134impl EsmTokenizerBase for EsmSequenceTokenizer {
135 fn encode(&self) -> Result<()> {
136 todo!()
137 }
138
139 fn decode(&self) -> Result<()> {
140 todo!()
141 }
142
143 fn mask_token(&self) -> &str {
144 "mask"
145 }
146
147 fn mask_token_id(&self) -> u32 {
148 self.tokenizer.token_to_id("mask").unwrap_or(0)
149 }
150
151 fn bos_token(&self) -> &str {
152 unimplemented!()
153 }
155
156 fn bos_token_id(&self) -> u32 {
157 unimplemented!()
158 }
160
161 fn eos_token(&self) -> &str {
162 "eos"
163 }
164
165 fn eos_token_id(&self) -> u32 {
166 self.tokenizer.token_to_id("eos").unwrap_or(0)
167 }
168
169 fn pad_token(&self) -> &str {
170 "pad"
171 }
172
173 fn pad_token_id(&self) -> u32 {
174 self.tokenizer.token_to_id("pad").unwrap_or(0)
175 }
176
177 fn chain_break_token(&self) -> &str {
178 &self.cb_token
179 }
180
181 fn chain_break_token_id(&self) -> u32 {
182 self.tokenizer.token_to_id(&self.cb_token).unwrap_or(0)
183 }
184
185 fn all_token_ids(&self) -> Vec<u32> {
186 unimplemented!()
187 }
189
190 fn special_token_ids(&self) -> Vec<u32> {
191 unimplemented!()
192 }
194}