ferritin_plms/esm/tokenization/
sequence_tokenizer.rs1use crate::esm::utils::constants::esm3::SEQUENCE_VOCAB;
2use anyhow::Result;
3use std::collections::HashMap;
4use std::sync::Arc;
5use tokenizers::models::bpe::{BPE, BpeBuilder};
6use tokenizers::processors::PostProcessorWrapper;
7use tokenizers::processors::template::{Template, TemplateProcessing};
8use tokenizers::{AddedToken, Tokenizer};
9
10pub trait EsmTokenizerBase {
11 fn encode(&self) -> Result<()>;
12 fn decode(&self) -> Result<()>;
13 fn mask_token(&self) -> &str;
14 fn mask_token_id(&self) -> u32;
15 fn bos_token(&self) -> &str;
16 fn bos_token_id(&self) -> u32;
17 fn eos_token(&self) -> &str;
18 fn eos_token_id(&self) -> u32;
19 fn pad_token(&self) -> &str;
20 fn pad_token_id(&self) -> u32;
21 fn chain_break_token(&self) -> &str;
22 fn chain_break_token_id(&self) -> u32;
23 fn all_token_ids(&self) -> Vec<u32>;
24 fn special_token_ids(&self) -> Vec<u32>;
25}
26
27pub struct EsmSequenceTokenizer {
28 tokenizer: Arc<Tokenizer>,
29 cb_token: String,
30}
31
32impl EsmSequenceTokenizer {
33 pub fn new(
34 unk_token: &str,
35 cls_token: &str,
36 pad_token: &str,
37 mask_token: &str,
38 eos_token: &str,
39 chain_break_token: &str,
40 ) -> Result<Self> {
41 let mut token_to_id = HashMap::new();
42 for (i, tok) in SEQUENCE_VOCAB.iter().enumerate() {
43 token_to_id.insert(tok.to_string(), i);
44 }
45 let bpe_builder = BpeBuilder::new();
46 let bpe: BPE = bpe_builder
47 .unk_token(unk_token.to_string())
48 .build()
49 .map_err(|e| anyhow::anyhow!("Failed to build BPE tokenizer: {}", e))?;
50
51 let mut tokenizer = Tokenizer::new(bpe);
52 let special_tokens = vec![
53 AddedToken::from(cls_token, true),
54 AddedToken::from(pad_token, true),
55 AddedToken::from(mask_token, true),
56 AddedToken::from(eos_token, true),
57 AddedToken::from(chain_break_token, true),
58 ];
59
60 tokenizer.add_special_tokens(&special_tokens);
61
62 let post_processor = TemplateProcessing::builder()
63 .try_single(Template::try_from(format!("{} $A {}", cls_token, eos_token)).unwrap())?
64 .special_tokens(vec![
65 (cls_token, tokenizer.token_to_id(cls_token).unwrap()),
66 (eos_token, tokenizer.token_to_id(eos_token).unwrap()),
67 ])
68 .build()?;
69
70 tokenizer.with_post_processor(Some(PostProcessorWrapper::Template(post_processor)));
71
72 Ok(Self {
73 tokenizer: Arc::new(tokenizer),
74 cb_token: chain_break_token.to_string(),
75 })
76 }
77}
78impl Default for EsmSequenceTokenizer {
79 fn default() -> Self {
80 Self::new("<unk>", "<cls>", "<pad>", "<mask>", "<eos>", "|")
81 .expect("Failed to create default tokenizer")
82 }
83}
84
85impl EsmTokenizerBase for EsmSequenceTokenizer {
86 fn encode(&self) -> Result<()> {
87 todo!()
88 }
89
90 fn decode(&self) -> Result<()> {
91 todo!()
92 }
93
94 fn mask_token(&self) -> &str {
95 "mask"
96 }
97
98 fn mask_token_id(&self) -> u32 {
99 self.tokenizer.token_to_id("mask").unwrap_or(0)
100 }
101
102 fn bos_token(&self) -> &str {
103 unimplemented!()
104 }
106
107 fn bos_token_id(&self) -> u32 {
108 unimplemented!()
109 }
111
112 fn eos_token(&self) -> &str {
113 "eos"
114 }
115
116 fn eos_token_id(&self) -> u32 {
117 self.tokenizer.token_to_id("eos").unwrap_or(0)
118 }
119
120 fn pad_token(&self) -> &str {
121 "pad"
122 }
123
124 fn pad_token_id(&self) -> u32 {
125 self.tokenizer.token_to_id("pad").unwrap_or(0)
126 }
127
128 fn chain_break_token(&self) -> &str {
129 &self.cb_token
130 }
131
132 fn chain_break_token_id(&self) -> u32 {
133 self.tokenizer.token_to_id(&self.cb_token).unwrap_or(0)
134 }
135
136 fn all_token_ids(&self) -> Vec<u32> {
137 unimplemented!()
138 }
140
141 fn special_token_ids(&self) -> Vec<u32> {
142 unimplemented!()
143 }
145}