Skip to main content

ferritin_plms/ligandmpnn/
configs.rs

1//! PMPNN Core Config and Builder API
2//!
3//! This module provides configuration structs and builders for the PMPNN protein design system.
4//!
5//! # Core Configuration Types
6//!
7//! - `ModelTypes` - Enum of supported model architectures
8//! - `ProteinMPNNConfig` - Core model parameters
9//! - `AABiasConfig` - Amino acid biasing controls
10//! - `LigandMPNNConfig` - LigandMPNN specific settings
11//! - `MembraneMPNNConfig` - MembraneMPNN specific settings
12//! - `MultiPDBConfig` - Multi-PDB mode configuration
13//! - `ResidueControl` - Residue-level design controls
14//! - `RunConfig` - Runtime execution parameters// Core Configs for handling CLI ARGs and Model Params
15
16use super::model::ProteinMPNN;
17use super::proteinfeatures::ProteinFeatures;
18use crate::StructureFeatures;
19use anyhow::Error;
20use candle_core::pickle::PthTensors;
21use candle_core::{DType, Device, Tensor};
22use candle_nn::VarBuilder;
23use clap::ValueEnum;
24use ferritin_core::load_structure;
25use ferritin_test_data::TestFile;
26
27/// Responsible for taking CLI args and returning the Features and Model
28///
29#[allow(dead_code)]
30pub struct MPNNExecConfig {
31    pub(crate) protein_inputs: String, // Todo: make this optionally plural
32    pub(crate) run_config: RunConfig,
33    pub(crate) aabias_config: Option<AABiasConfig>,
34    pub(crate) ligand_mpnn_config: Option<LigandMPNNConfig>,
35    pub(crate) membrane_mpnn_config: Option<MembraneMPNNConfig>,
36    pub(crate) multi_pdb_config: Option<MultiPDBConfig>,
37    pub(crate) residue_control_config: Option<ResidueControl>,
38    pub(crate) device: Device,
39}
40
41impl MPNNExecConfig {
42    pub fn new(
43        device: Device,
44        pdb_path: String,
45        run_config: RunConfig,
46        residue_config: Option<ResidueControl>,
47        aa_bias: Option<AABiasConfig>,
48        lig_mpnn_specific: Option<LigandMPNNConfig>,
49        membrane_mpnn_specific: Option<MembraneMPNNConfig>,
50        multi_pdb_specific: Option<MultiPDBConfig>,
51    ) -> Result<Self, Error> {
52        Ok(MPNNExecConfig {
53            protein_inputs: pdb_path,
54            run_config,
55            aabias_config: aa_bias,
56            ligand_mpnn_config: lig_mpnn_specific,
57            membrane_mpnn_config: membrane_mpnn_specific,
58            residue_control_config: residue_config,
59            multi_pdb_config: multi_pdb_specific,
60            device,
61        })
62    }
63    // Todo: refactor this to use loader.
64    pub fn load_model(&self, model_type: ModelTypes) -> Result<ProteinMPNN, Error> {
65        let default_dtype = DType::F32;
66        match model_type {
67            ModelTypes::ProteinMPNN => {
68                // this is a hidden dep....
69                // todo: use hf_hub
70                let (mpnn_file, _handle) = TestFile::ligmpnn_pmpnn_01().create_temp()?;
71                let pth = PthTensors::new(mpnn_file, Some("model_state_dict"))?;
72                let vb =
73                    VarBuilder::from_backend(Box::new(pth), default_dtype, self.device.clone());
74                let pconf = ProteinMPNNConfig::proteinmpnn();
75                Ok(ProteinMPNN::load(vb, &pconf).expect("Unable to load the PMPNN Model"))
76            }
77            _ => panic!("not implented!"),
78        }
79    }
80    pub fn generate_model(self) {
81        todo!()
82    }
83    pub fn generate_protein_features(&self) -> Result<ProteinFeatures, Error> {
84        let device = self.device.clone();
85        let base_dtype = DType::F32;
86
87        // init the Protein Features
88        let ac = load_structure(self.protein_inputs.clone())?;
89
90        let s = ac
91            .encode_amino_acids(&device)
92            .expect("A complete convertion to locations");
93        let x_37 = ac.to_numeric_atom37(&device)?;
94        let x_37_mask = Tensor::ones((x_37.dim(0)?, x_37.dim(1)?), base_dtype, &device)?;
95        let (y, y_t, y_m) = ac.to_numeric_ligand_atoms(&device)?;
96        let res_idx = ac.get_res_index();
97        let res_idx_len = res_idx.len();
98        let res_idx_tensor = Tensor::from_vec(res_idx, (1, res_idx_len), &device)?;
99
100        // chain residues
101        let chain_letters: Vec<String> = ac
102            .iter_residues_aminoacid()
103            .map(|res| res.chain_id().to_string())
104            .collect();
105
106        // unique Chains
107        let chain_list: Vec<String> = chain_letters
108            .clone()
109            .into_iter()
110            .collect::<std::collections::HashSet<_>>()
111            .into_iter()
112            .collect();
113
114        // assert_eq!(true, false);
115
116        // update residue info
117        // residue_config: Option<ResidueControl>,
118        // handle these:
119        // pub fixed_residues: Option<String>,
120        // pub redesigned_residues: Option<String>,
121        // pub symmetry_residues: Option<String>,
122        // pub symmetry_weights: Option<String>,
123        // pub chains_to_design: Option<String>,
124        // pub parse_these_chains_only: Option<String>,
125
126        // update AA bias
127        // handle these:
128        // aa_bias: Option<AABiasConfig>,
129        // pub bias_aa: Option<String>,
130        // pub bias_aa_per_residue: Option<String>,
131        // pub omit_aa: Option<String>,
132        // pub omit_aa_per_residue: Option<String>,
133
134        // update LigmpnnConfif
135        // lig_mpnn_specific: Option<LigandMPNNConfig>,
136        // handle these:
137        // pub checkpoint_ligand_mpnn: Option<String>,
138        // pub ligand_mpnn_use_atom_context: Option<i32>,
139        // pub ligand_mpnn_use_side_chain_context: Option<i32>,
140        // pub ligand_mpnn_cutoff_for_score: Option<String>,
141
142        // update Membrane MPNN Config
143        // membrane_mpnn_specific: Option<MembraneMPNNConfig>,
144        // handle these:
145        // pub global_transmembrane_label: Option<i32>,
146        // pub transmembrane_buried: Option<String>,
147        // pub transmembrane_interface: Option<String>,
148
149        // update multipdb
150        // multi_pdb_specific: Option<MultiPDBConfig>,
151        // pub pdb_path_multi: Option<String>,
152        // pub fixed_residues_multi: Option<String>,
153        // pub redesigned_residues_multi: Option<String>,
154        // pub omit_aa_per_residue_multi: Option<String>,
155        // pub bias_aa_per_residue_multi: Option<String>,
156
157        // println!("Returning Protein Features....");
158        // return ligand MPNN.
159        Ok(ProteinFeatures {
160            s,                       // protein amino acids sequences as 1D Tensor of u32
161            x: x_37,                 // protein co-oords by residue [1, 37, 4]
162            x_mask: Some(x_37_mask), // protein mask by residue
163            y,                       // ligand coords
164            y_t,                     // encoded ligand atom names
165            y_m: Some(y_m),          // ligand mask
166            r_idx: res_idx_tensor,   // protein residue indices shape=[length]
167            chain_labels: None,      //  # protein chain letters shape=[length]
168            chain_letters,           // chain_letters: shape=[length]
169            mask_c: None,            // mask_c:  shape=[length]
170            chain_list,
171        })
172    }
173}
174
175#[derive(Debug, Clone, ValueEnum, Copy)]
176pub enum ModelTypes {
177    #[value(name = "protein_mpnn")]
178    ProteinMPNN,
179    #[value(name = "ligand_mpnn")]
180    LigandMPNN,
181}
182
183#[derive(Debug)]
184/// Amino Acid Biasing
185pub struct AABiasConfig {
186    pub bias_aa: Option<String>,
187    pub bias_aa_per_residue: Option<String>,
188    pub omit_aa: Option<String>,
189    pub omit_aa_per_residue: Option<String>,
190}
191
192/// LigandMPNN Specific
193pub struct LigandMPNNConfig {
194    pub checkpoint_ligand_mpnn: Option<String>,
195    pub ligand_mpnn_use_atom_context: Option<i32>,
196    pub ligand_mpnn_use_side_chain_context: Option<i32>,
197    pub ligand_mpnn_cutoff_for_score: Option<String>,
198}
199
200/// Membrane MPNN Specific
201pub struct MembraneMPNNConfig {
202    pub global_transmembrane_label: Option<i32>,
203    pub transmembrane_buried: Option<String>,
204    pub transmembrane_interface: Option<String>,
205}
206
207/// Multi-PDB Related
208pub struct MultiPDBConfig {
209    pub pdb_path_multi: Option<String>,
210    pub fixed_residues_multi: Option<String>,
211    pub redesigned_residues_multi: Option<String>,
212    pub omit_aa_per_residue_multi: Option<String>,
213    pub bias_aa_per_residue_multi: Option<String>,
214}
215#[derive(Clone, Debug)]
216pub struct ProteinMPNNConfig {
217    pub atom_context_num: usize,
218    pub augment_eps: f32,
219    pub dropout_ratio: f32,
220    pub edge_features: i64,
221    pub hidden_dim: i64,
222    pub k_neighbors: i64,
223    pub ligand_mpnn_use_side_chain_context: bool,
224    pub model_type: ModelTypes,
225    pub node_features: i64,
226    pub num_decoder_layers: i64,
227    pub num_encoder_layers: i64,
228    pub num_letters: i64,
229    pub num_rbf: i64,
230    pub scale_factor: f64,
231    pub vocab: i64,
232}
233
234impl ProteinMPNNConfig {
235    pub fn proteinmpnn() -> Self {
236        Self {
237            atom_context_num: 0,
238            augment_eps: 0.0,
239            dropout_ratio: 0.1,
240            edge_features: 128,
241            hidden_dim: 128,
242            k_neighbors: 24,
243            ligand_mpnn_use_side_chain_context: false,
244            model_type: ModelTypes::ProteinMPNN,
245            node_features: 128,
246            num_decoder_layers: 3,
247            num_encoder_layers: 3,
248            num_letters: 21,
249            num_rbf: 16,
250            scale_factor: 1.0,
251            vocab: 21,
252        }
253    }
254    #[allow(dead_code)]
255    fn ligandmpnn() {
256        todo!()
257    }
258    #[allow(dead_code)]
259    fn membranempnn() {
260        todo!()
261    }
262}
263
264#[derive(Debug)]
265pub struct ResidueControl {
266    pub fixed_residues: Option<String>,
267    pub redesigned_residues: Option<String>,
268    pub symmetry_residues: Option<String>,
269    pub symmetry_weights: Option<String>,
270    pub chains_to_design: Option<String>,
271    pub parse_these_chains_only: Option<String>,
272}
273
274#[derive(Debug)]
275pub struct RunConfig {
276    pub model_type: Option<ModelTypes>,
277    pub seed: Option<i32>,
278    pub temperature: Option<f32>,
279    pub verbose: Option<i32>,
280    pub save_stats: Option<bool>,
281    pub batch_size: Option<i32>,
282    pub number_of_batches: Option<i32>,
283    pub file_ending: Option<String>,
284    pub zero_indexed: Option<i32>,
285    pub homo_oligomer: Option<i32>,
286    pub fasta_seq_separation: Option<String>,
287}