Skip to main content

ferritin_plms/esmc/tokenization/
sequence_tokenizer.rs

1use crate::esmc::utils::constants::esm3::SEQUENCE_VOCAB;
2use anyhow::Result;
3use std::collections::HashMap;
4use std::sync::Arc;
5use tokenizers::models::bpe::{BPE, BpeBuilder};
6use tokenizers::processors::PostProcessorWrapper;
7use tokenizers::processors::template::{Template, TemplateProcessing};
8use tokenizers::{AddedToken, Tokenizer};
9
10pub trait EsmTokenizerBase {
11    fn encode(&self) -> Result<()>;
12    fn decode(&self) -> Result<()>;
13    fn mask_token(&self) -> &str;
14    fn mask_token_id(&self) -> u32;
15    fn bos_token(&self) -> &str;
16    fn bos_token_id(&self) -> u32;
17    fn eos_token(&self) -> &str;
18    fn eos_token_id(&self) -> u32;
19    fn pad_token(&self) -> &str;
20    fn pad_token_id(&self) -> u32;
21    fn chain_break_token(&self) -> &str;
22    fn chain_break_token_id(&self) -> u32;
23    fn all_token_ids(&self) -> Vec<u32>;
24    fn special_token_ids(&self) -> Vec<u32>;
25}
26
27pub struct EsmSequenceTokenizer {
28    tokenizer: Arc<Tokenizer>,
29    cb_token: String,
30}
31
32impl EsmSequenceTokenizer {
33    pub fn new(
34        unk_token: &str,
35        cls_token: &str,
36        pad_token: &str,
37        mask_token: &str,
38        eos_token: &str,
39        chain_break_token: &str,
40    ) -> Result<Self> {
41        let mut token_to_id = HashMap::new();
42        for (i, tok) in SEQUENCE_VOCAB.iter().enumerate() {
43            token_to_id.insert(tok.to_string(), i);
44        }
45        let bpe_builder = BpeBuilder::new();
46        let bpe: BPE = bpe_builder
47            .unk_token(unk_token.to_string())
48            .build()
49            .map_err(|e| anyhow::anyhow!("Failed to build BPE tokenizer: {}", e))?;
50
51        let mut tokenizer = Tokenizer::new(bpe);
52        let special_tokens = vec![
53            AddedToken::from(cls_token, true),
54            AddedToken::from(pad_token, true),
55            AddedToken::from(mask_token, true),
56            AddedToken::from(eos_token, true),
57            AddedToken::from(chain_break_token, true),
58        ];
59
60        let _ = tokenizer.add_special_tokens(special_tokens);
61
62        let post_processor = TemplateProcessing::builder()
63            .try_single(Template::try_from(format!("{} $A {}", cls_token, eos_token)).unwrap())?
64            .special_tokens(vec![
65                (cls_token, tokenizer.token_to_id(cls_token).unwrap()),
66                (eos_token, tokenizer.token_to_id(eos_token).unwrap()),
67            ])
68            .build()?;
69
70        tokenizer.with_post_processor(Some(PostProcessorWrapper::Template(post_processor)));
71
72        Ok(Self {
73            tokenizer: Arc::new(tokenizer),
74            cb_token: chain_break_token.to_string(),
75        })
76    }
77}
78impl Default for EsmSequenceTokenizer {
79    fn default() -> Self {
80        Self::new("<unk>", "<cls>", "<pad>", "<mask>", "<eos>", "|")
81            .expect("Failed to create default tokenizer")
82    }
83}
84
85impl EsmSequenceTokenizer {
86    /// Tokenize an amino-acid sequence string into token IDs.
87    ///
88    /// Looks up each character in `SEQUENCE_VOCAB`. Unknown characters map to
89    /// the `<unk>` token (index 3). When `add_special_tokens` is true, prepends
90    /// BOS (`<cls>` = 0) and appends EOS (`<eos>` = 2).
91    pub fn tokenize_sequence(&self, sequence: &str, add_special_tokens: bool) -> Vec<u32> {
92        use std::collections::HashMap;
93        let vocab: HashMap<&str, u32> = SEQUENCE_VOCAB
94            .iter()
95            .enumerate()
96            .map(|(i, s)| (*s, i as u32))
97            .collect();
98        let unk_id = *vocab.get("<unk>").unwrap_or(&3);
99
100        let mut tokens = Vec::with_capacity(sequence.len() + 2);
101        if add_special_tokens {
102            tokens.push(*vocab.get("<cls>").unwrap_or(&0));
103        }
104        for ch in sequence.chars() {
105            let s = ch.to_string();
106            let id = vocab.get(s.as_str()).copied().unwrap_or(unk_id);
107            tokens.push(id);
108        }
109        if add_special_tokens {
110            tokens.push(*vocab.get("<eos>").unwrap_or(&2));
111        }
112        tokens
113    }
114
115    /// Decode token IDs back to an amino-acid sequence string.
116    ///
117    /// Skips the standard special tokens (BOS=0, PAD=1, EOS=2, MASK=32) and
118    /// concatenates the remaining vocabulary entries.
119    pub fn decode_sequence(&self, token_ids: &[u32]) -> String {
120        const SPECIAL: [u32; 4] = [0, 1, 2, 32]; // cls, pad, eos, mask
121        let mut result = String::new();
122        for &id in token_ids {
123            if SPECIAL.contains(&id) {
124                continue;
125            }
126            if let Some(tok) = SEQUENCE_VOCAB.get(id as usize) {
127                result.push_str(tok);
128            }
129        }
130        result
131    }
132}
133
134impl EsmTokenizerBase for EsmSequenceTokenizer {
135    fn encode(&self) -> Result<()> {
136        todo!()
137    }
138
139    fn decode(&self) -> Result<()> {
140        todo!()
141    }
142
143    fn mask_token(&self) -> &str {
144        "mask"
145    }
146
147    fn mask_token_id(&self) -> u32 {
148        self.tokenizer.token_to_id("mask").unwrap_or(0)
149    }
150
151    fn bos_token(&self) -> &str {
152        unimplemented!()
153        // self.cls_token()
154    }
155
156    fn bos_token_id(&self) -> u32 {
157        unimplemented!()
158        // self.cls_token_id()
159    }
160
161    fn eos_token(&self) -> &str {
162        "eos"
163    }
164
165    fn eos_token_id(&self) -> u32 {
166        self.tokenizer.token_to_id("eos").unwrap_or(0)
167    }
168
169    fn pad_token(&self) -> &str {
170        "pad"
171    }
172
173    fn pad_token_id(&self) -> u32 {
174        self.tokenizer.token_to_id("pad").unwrap_or(0)
175    }
176
177    fn chain_break_token(&self) -> &str {
178        &self.cb_token
179    }
180
181    fn chain_break_token_id(&self) -> u32 {
182        self.tokenizer.token_to_id(&self.cb_token).unwrap_or(0)
183    }
184
185    fn all_token_ids(&self) -> Vec<u32> {
186        unimplemented!()
187        // (0..self.vocab_size()).collect()
188    }
189
190    fn special_token_ids(&self) -> Vec<u32> {
191        unimplemented!()
192        // self.tokenizer.get_special_token_ids()
193    }
194}