ferritin_plms/esm/tokenization/
sequence_tokenizer.rs

1use crate::esm::utils::constants::esm3::SEQUENCE_VOCAB;
2use anyhow::Result;
3use std::collections::HashMap;
4use std::sync::Arc;
5use tokenizers::models::bpe::{BPE, BpeBuilder};
6use tokenizers::processors::PostProcessorWrapper;
7use tokenizers::processors::template::{Template, TemplateProcessing};
8use tokenizers::{AddedToken, Tokenizer};
9
10pub trait EsmTokenizerBase {
11    fn encode(&self) -> Result<()>;
12    fn decode(&self) -> Result<()>;
13    fn mask_token(&self) -> &str;
14    fn mask_token_id(&self) -> u32;
15    fn bos_token(&self) -> &str;
16    fn bos_token_id(&self) -> u32;
17    fn eos_token(&self) -> &str;
18    fn eos_token_id(&self) -> u32;
19    fn pad_token(&self) -> &str;
20    fn pad_token_id(&self) -> u32;
21    fn chain_break_token(&self) -> &str;
22    fn chain_break_token_id(&self) -> u32;
23    fn all_token_ids(&self) -> Vec<u32>;
24    fn special_token_ids(&self) -> Vec<u32>;
25}
26
27pub struct EsmSequenceTokenizer {
28    tokenizer: Arc<Tokenizer>,
29    cb_token: String,
30}
31
32impl EsmSequenceTokenizer {
33    pub fn new(
34        unk_token: &str,
35        cls_token: &str,
36        pad_token: &str,
37        mask_token: &str,
38        eos_token: &str,
39        chain_break_token: &str,
40    ) -> Result<Self> {
41        let mut token_to_id = HashMap::new();
42        for (i, tok) in SEQUENCE_VOCAB.iter().enumerate() {
43            token_to_id.insert(tok.to_string(), i);
44        }
45        let bpe_builder = BpeBuilder::new();
46        let bpe: BPE = bpe_builder
47            .unk_token(unk_token.to_string())
48            .build()
49            .map_err(|e| anyhow::anyhow!("Failed to build BPE tokenizer: {}", e))?;
50
51        let mut tokenizer = Tokenizer::new(bpe);
52        let special_tokens = vec![
53            AddedToken::from(cls_token, true),
54            AddedToken::from(pad_token, true),
55            AddedToken::from(mask_token, true),
56            AddedToken::from(eos_token, true),
57            AddedToken::from(chain_break_token, true),
58        ];
59
60        tokenizer.add_special_tokens(&special_tokens);
61
62        let post_processor = TemplateProcessing::builder()
63            .try_single(Template::try_from(format!("{} $A {}", cls_token, eos_token)).unwrap())?
64            .special_tokens(vec![
65                (cls_token, tokenizer.token_to_id(cls_token).unwrap()),
66                (eos_token, tokenizer.token_to_id(eos_token).unwrap()),
67            ])
68            .build()?;
69
70        tokenizer.with_post_processor(Some(PostProcessorWrapper::Template(post_processor)));
71
72        Ok(Self {
73            tokenizer: Arc::new(tokenizer),
74            cb_token: chain_break_token.to_string(),
75        })
76    }
77}
78impl Default for EsmSequenceTokenizer {
79    fn default() -> Self {
80        Self::new("<unk>", "<cls>", "<pad>", "<mask>", "<eos>", "|")
81            .expect("Failed to create default tokenizer")
82    }
83}
84
85impl EsmTokenizerBase for EsmSequenceTokenizer {
86    fn encode(&self) -> Result<()> {
87        todo!()
88    }
89
90    fn decode(&self) -> Result<()> {
91        todo!()
92    }
93
94    fn mask_token(&self) -> &str {
95        "mask"
96    }
97
98    fn mask_token_id(&self) -> u32 {
99        self.tokenizer.token_to_id("mask").unwrap_or(0)
100    }
101
102    fn bos_token(&self) -> &str {
103        unimplemented!()
104        // self.cls_token()
105    }
106
107    fn bos_token_id(&self) -> u32 {
108        unimplemented!()
109        // self.cls_token_id()
110    }
111
112    fn eos_token(&self) -> &str {
113        "eos"
114    }
115
116    fn eos_token_id(&self) -> u32 {
117        self.tokenizer.token_to_id("eos").unwrap_or(0)
118    }
119
120    fn pad_token(&self) -> &str {
121        "pad"
122    }
123
124    fn pad_token_id(&self) -> u32 {
125        self.tokenizer.token_to_id("pad").unwrap_or(0)
126    }
127
128    fn chain_break_token(&self) -> &str {
129        &self.cb_token
130    }
131
132    fn chain_break_token_id(&self) -> u32 {
133        self.tokenizer.token_to_id(&self.cb_token).unwrap_or(0)
134    }
135
136    fn all_token_ids(&self) -> Vec<u32> {
137        unimplemented!()
138        // (0..self.vocab_size()).collect()
139    }
140
141    fn special_token_ids(&self) -> Vec<u32> {
142        unimplemented!()
143        // self.tokenizer.get_special_token_ids()
144    }
145}