ferritin_plms/ligandmpnn/
configs.rs

1//! PMPNN Core Config and Builder API
2//!
3//! This module provides configuration structs and builders for the PMPNN protein design system.
4//!
5//! # Core Configuration Types
6//!
7//! - `ModelTypes` - Enum of supported model architectures
8//! - `ProteinMPNNConfig` - Core model parameters
9//! - `AABiasConfig` - Amino acid biasing controls
10//! - `LigandMPNNConfig` - LigandMPNN specific settings
11//! - `MembraneMPNNConfig` - MembraneMPNN specific settings
12//! - `MultiPDBConfig` - Multi-PDB mode configuration
13//! - `ResidueControl` - Residue-level design controls
14//! - `RunConfig` - Runtime execution parameters// Core Configs for handling CLI ARGs and Model Params
15
16use super::model::ProteinMPNN;
17use super::proteinfeatures::ProteinFeatures;
18use crate::StructureFeatures;
19use anyhow::Error;
20use candle_core::pickle::PthTensors;
21use candle_core::{DType, Device, Tensor};
22use candle_nn::VarBuilder;
23use clap::ValueEnum;
24use ferritin_core::load_structure;
25use ferritin_test_data::TestFile;
26
27/// Responsible for taking CLI args and returning the Features and Model
28///
29pub struct MPNNExecConfig {
30    pub(crate) protein_inputs: String, // Todo: make this optionally plural
31    pub(crate) run_config: RunConfig,
32    pub(crate) aabias_config: Option<AABiasConfig>,
33    pub(crate) ligand_mpnn_config: Option<LigandMPNNConfig>,
34    pub(crate) membrane_mpnn_config: Option<MembraneMPNNConfig>,
35    pub(crate) multi_pdb_config: Option<MultiPDBConfig>,
36    pub(crate) residue_control_config: Option<ResidueControl>,
37    pub(crate) device: Device,
38}
39
40impl MPNNExecConfig {
41    pub fn new(
42        device: Device,
43        pdb_path: String,
44        run_config: RunConfig,
45        residue_config: Option<ResidueControl>,
46        aa_bias: Option<AABiasConfig>,
47        lig_mpnn_specific: Option<LigandMPNNConfig>,
48        membrane_mpnn_specific: Option<MembraneMPNNConfig>,
49        multi_pdb_specific: Option<MultiPDBConfig>,
50    ) -> Result<Self, Error> {
51        Ok(MPNNExecConfig {
52            protein_inputs: pdb_path,
53            run_config,
54            aabias_config: aa_bias,
55            ligand_mpnn_config: lig_mpnn_specific,
56            membrane_mpnn_config: membrane_mpnn_specific,
57            residue_control_config: residue_config,
58            multi_pdb_config: multi_pdb_specific,
59            device,
60        })
61    }
62    // Todo: refactor this to use loader.
63    pub fn load_model(&self, model_type: ModelTypes) -> Result<ProteinMPNN, Error> {
64        let default_dtype = DType::F32;
65        match model_type {
66            ModelTypes::ProteinMPNN => {
67                // this is a hidden dep....
68                // todo: use hf_hub
69                let (mpnn_file, _handle) = TestFile::ligmpnn_pmpnn_01().create_temp()?;
70                let pth = PthTensors::new(mpnn_file, Some("model_state_dict"))?;
71                let vb =
72                    VarBuilder::from_backend(Box::new(pth), default_dtype, self.device.clone());
73                let pconf = ProteinMPNNConfig::proteinmpnn();
74                Ok(ProteinMPNN::load(vb, &pconf).expect("Unable to load the PMPNN Model"))
75            }
76            _ => panic!("not implented!"),
77        }
78    }
79    pub fn generate_model(self) {
80        todo!()
81    }
82    pub fn generate_protein_features(&self) -> Result<ProteinFeatures, Error> {
83        let device = self.device.clone();
84        let base_dtype = DType::F32;
85
86        // init the Protein Features
87        let ac = load_structure(self.protein_inputs.clone())?;
88
89        let s = ac
90            .encode_amino_acids(&device)
91            .expect("A complete convertion to locations");
92        let x_37 = ac.to_numeric_atom37(&device)?;
93        let x_37_mask = Tensor::ones((x_37.dim(0)?, x_37.dim(1)?), base_dtype, &device)?;
94        let (y, y_t, y_m) = ac.to_numeric_ligand_atoms(&device)?;
95        let res_idx = ac.get_res_index();
96        let res_idx_len = res_idx.len();
97        let res_idx_tensor = Tensor::from_vec(res_idx, (1, res_idx_len), &device)?;
98
99        // chain residues
100        let chain_letters: Vec<String> = ac
101            .iter_residues_aminoacid()
102            .map(|res| res.chain_id().to_string())
103            .collect();
104
105        // unique Chains
106        let chain_list: Vec<String> = chain_letters
107            .clone()
108            .into_iter()
109            .collect::<std::collections::HashSet<_>>()
110            .into_iter()
111            .collect();
112
113        // assert_eq!(true, false);
114
115        // update residue info
116        // residue_config: Option<ResidueControl>,
117        // handle these:
118        // pub fixed_residues: Option<String>,
119        // pub redesigned_residues: Option<String>,
120        // pub symmetry_residues: Option<String>,
121        // pub symmetry_weights: Option<String>,
122        // pub chains_to_design: Option<String>,
123        // pub parse_these_chains_only: Option<String>,
124
125        // update AA bias
126        // handle these:
127        // aa_bias: Option<AABiasConfig>,
128        // pub bias_aa: Option<String>,
129        // pub bias_aa_per_residue: Option<String>,
130        // pub omit_aa: Option<String>,
131        // pub omit_aa_per_residue: Option<String>,
132
133        // update LigmpnnConfif
134        // lig_mpnn_specific: Option<LigandMPNNConfig>,
135        // handle these:
136        // pub checkpoint_ligand_mpnn: Option<String>,
137        // pub ligand_mpnn_use_atom_context: Option<i32>,
138        // pub ligand_mpnn_use_side_chain_context: Option<i32>,
139        // pub ligand_mpnn_cutoff_for_score: Option<String>,
140
141        // update Membrane MPNN Config
142        // membrane_mpnn_specific: Option<MembraneMPNNConfig>,
143        // handle these:
144        // pub global_transmembrane_label: Option<i32>,
145        // pub transmembrane_buried: Option<String>,
146        // pub transmembrane_interface: Option<String>,
147
148        // update multipdb
149        // multi_pdb_specific: Option<MultiPDBConfig>,
150        // pub pdb_path_multi: Option<String>,
151        // pub fixed_residues_multi: Option<String>,
152        // pub redesigned_residues_multi: Option<String>,
153        // pub omit_aa_per_residue_multi: Option<String>,
154        // pub bias_aa_per_residue_multi: Option<String>,
155
156        // println!("Returning Protein Features....");
157        // return ligand MPNN.
158        Ok(ProteinFeatures {
159            s,                       // protein amino acids sequences as 1D Tensor of u32
160            x: x_37,                 // protein co-oords by residue [1, 37, 4]
161            x_mask: Some(x_37_mask), // protein mask by residue
162            y,                       // ligand coords
163            y_t,                     // encoded ligand atom names
164            y_m: Some(y_m),          // ligand mask
165            r_idx: res_idx_tensor,   // protein residue indices shape=[length]
166            chain_labels: None,      //  # protein chain letters shape=[length]
167            chain_letters,           // chain_letters: shape=[length]
168            mask_c: None,            // mask_c:  shape=[length]
169            chain_list,
170        })
171    }
172}
173
174#[derive(Debug, Clone, ValueEnum, Copy)]
175pub enum ModelTypes {
176    #[value(name = "protein_mpnn")]
177    ProteinMPNN,
178    #[value(name = "ligand_mpnn")]
179    LigandMPNN,
180}
181
182#[derive(Debug)]
183/// Amino Acid Biasing
184pub struct AABiasConfig {
185    pub bias_aa: Option<String>,
186    pub bias_aa_per_residue: Option<String>,
187    pub omit_aa: Option<String>,
188    pub omit_aa_per_residue: Option<String>,
189}
190
191/// LigandMPNN Specific
192pub struct LigandMPNNConfig {
193    pub checkpoint_ligand_mpnn: Option<String>,
194    pub ligand_mpnn_use_atom_context: Option<i32>,
195    pub ligand_mpnn_use_side_chain_context: Option<i32>,
196    pub ligand_mpnn_cutoff_for_score: Option<String>,
197}
198
199/// Membrane MPNN Specific
200pub struct MembraneMPNNConfig {
201    pub global_transmembrane_label: Option<i32>,
202    pub transmembrane_buried: Option<String>,
203    pub transmembrane_interface: Option<String>,
204}
205
206/// Multi-PDB Related
207pub struct MultiPDBConfig {
208    pub pdb_path_multi: Option<String>,
209    pub fixed_residues_multi: Option<String>,
210    pub redesigned_residues_multi: Option<String>,
211    pub omit_aa_per_residue_multi: Option<String>,
212    pub bias_aa_per_residue_multi: Option<String>,
213}
214#[derive(Clone, Debug)]
215pub struct ProteinMPNNConfig {
216    pub atom_context_num: usize,
217    pub augment_eps: f32,
218    pub dropout_ratio: f32,
219    pub edge_features: i64,
220    pub hidden_dim: i64,
221    pub k_neighbors: i64,
222    pub ligand_mpnn_use_side_chain_context: bool,
223    pub model_type: ModelTypes,
224    pub node_features: i64,
225    pub num_decoder_layers: i64,
226    pub num_encoder_layers: i64,
227    pub num_letters: i64,
228    pub num_rbf: i64,
229    pub scale_factor: f64,
230    pub vocab: i64,
231}
232
233impl ProteinMPNNConfig {
234    pub fn proteinmpnn() -> Self {
235        Self {
236            atom_context_num: 0,
237            augment_eps: 0.0,
238            dropout_ratio: 0.1,
239            edge_features: 128,
240            hidden_dim: 128,
241            k_neighbors: 24,
242            ligand_mpnn_use_side_chain_context: false,
243            model_type: ModelTypes::ProteinMPNN,
244            node_features: 128,
245            num_decoder_layers: 3,
246            num_encoder_layers: 3,
247            num_letters: 21,
248            num_rbf: 16,
249            scale_factor: 1.0,
250            vocab: 21,
251        }
252    }
253    fn ligandmpnn() {
254        todo!()
255    }
256    fn membranempnn() {
257        todo!()
258    }
259}
260
261#[derive(Debug)]
262pub struct ResidueControl {
263    pub fixed_residues: Option<String>,
264    pub redesigned_residues: Option<String>,
265    pub symmetry_residues: Option<String>,
266    pub symmetry_weights: Option<String>,
267    pub chains_to_design: Option<String>,
268    pub parse_these_chains_only: Option<String>,
269}
270
271#[derive(Debug)]
272pub struct RunConfig {
273    pub model_type: Option<ModelTypes>,
274    pub seed: Option<i32>,
275    pub temperature: Option<f32>,
276    pub verbose: Option<i32>,
277    pub save_stats: Option<bool>,
278    pub batch_size: Option<i32>,
279    pub number_of_batches: Option<i32>,
280    pub file_ending: Option<String>,
281    pub zero_indexed: Option<i32>,
282    pub homo_oligomer: Option<i32>,
283    pub fasta_seq_separation: Option<String>,
284}