ferritin_plms/esm/tokenization/
mod.rs

1pub mod sequence_tokenizer;
2use crate::esm::utils::constants::models::{ESM3_OPEN_SMALL, normalize_model_name};
3use anyhow::{Result, anyhow};
4use sequence_tokenizer::EsmSequenceTokenizer;
5
6pub struct TokenizerCollection {
7    pub sequence: EsmSequenceTokenizer,
8    // pub structure: structure_tokenizer::StructureTokenizer,
9    // pub secondary_structure: ss_tokenizer::SecondaryStructureTokenizer,
10    // pub sasa: sasa_tokenizer::SASADiscretizingTokenizer,
11    // pub function: function_tokenizer::InterProQuantizedTokenizer,
12    // pub residue_annotations: residue_tokenizer::ResidueAnnotationsTokenizer,
13}
14
15pub fn get_model_tokenizers(model: &str) -> Result<TokenizerCollection> {
16    if normalize_model_name(model) == ESM3_OPEN_SMALL {
17        Ok(TokenizerCollection {
18            sequence: EsmSequenceTokenizer::default(),
19            // structure: structure_tokenizer::StructureTokenizer::new()?,
20            // secondary_structure: ss_tokenizer::SecondaryStructureTokenizer::new("ss8")?,
21            // sasa: sasa_tokenizer::SASADiscretizingTokenizer::new()?,
22            // function: function_tokenizer::InterProQuantizedTokenizer::new()?,
23            // residue_annotations: residue_tokenizer::ResidueAnnotationsTokenizer::new()?,
24        })
25    } else {
26        Err(anyhow!("Unknown model: {}", model))
27    }
28}
29
30// pub fn get_invalid_tokenizer_ids(tokenizer: &impl EsmTokenizerBase) -> Vec<i64> {
31//     if tokenizer.is_sequence_tokenizer() {
32//         vec![
33//             tokenizer.mask_token_id(),
34//             tokenizer.pad_token_id(),
35//             tokenizer.cls_token_id(),
36//             tokenizer.eos_token_id(),
37//         ]
38//     } else {
39//         vec![
40//             tokenizer.mask_token_id(),
41//             tokenizer.pad_token_id(),
42//             tokenizer.bos_token_id(),
43//             tokenizer.eos_token_id(),
44//         ]
45//     }
46// }