ferritin_plms/esm/tokenization/mod.rs
1pub mod sequence_tokenizer;
2use crate::esm::utils::constants::models::{ESM3_OPEN_SMALL, normalize_model_name};
3use anyhow::{Result, anyhow};
4use sequence_tokenizer::EsmSequenceTokenizer;
5
6pub struct TokenizerCollection {
7 pub sequence: EsmSequenceTokenizer,
8 // pub structure: structure_tokenizer::StructureTokenizer,
9 // pub secondary_structure: ss_tokenizer::SecondaryStructureTokenizer,
10 // pub sasa: sasa_tokenizer::SASADiscretizingTokenizer,
11 // pub function: function_tokenizer::InterProQuantizedTokenizer,
12 // pub residue_annotations: residue_tokenizer::ResidueAnnotationsTokenizer,
13}
14
15pub fn get_model_tokenizers(model: &str) -> Result<TokenizerCollection> {
16 if normalize_model_name(model) == ESM3_OPEN_SMALL {
17 Ok(TokenizerCollection {
18 sequence: EsmSequenceTokenizer::default(),
19 // structure: structure_tokenizer::StructureTokenizer::new()?,
20 // secondary_structure: ss_tokenizer::SecondaryStructureTokenizer::new("ss8")?,
21 // sasa: sasa_tokenizer::SASADiscretizingTokenizer::new()?,
22 // function: function_tokenizer::InterProQuantizedTokenizer::new()?,
23 // residue_annotations: residue_tokenizer::ResidueAnnotationsTokenizer::new()?,
24 })
25 } else {
26 Err(anyhow!("Unknown model: {}", model))
27 }
28}
29
30// pub fn get_invalid_tokenizer_ids(tokenizer: &impl EsmTokenizerBase) -> Vec<i64> {
31// if tokenizer.is_sequence_tokenizer() {
32// vec![
33// tokenizer.mask_token_id(),
34// tokenizer.pad_token_id(),
35// tokenizer.cls_token_id(),
36// tokenizer.eos_token_id(),
37// ]
38// } else {
39// vec![
40// tokenizer.mask_token_id(),
41// tokenizer.pad_token_id(),
42// tokenizer.bos_token_id(),
43// tokenizer.eos_token_id(),
44// ]
45// }
46// }