ferritin_plms/ligandmpnn/
configs.rs1use super::model::ProteinMPNN;
17use super::proteinfeatures::ProteinFeatures;
18use crate::StructureFeatures;
19use anyhow::Error;
20use candle_core::pickle::PthTensors;
21use candle_core::{DType, Device, Tensor};
22use candle_nn::VarBuilder;
23use clap::ValueEnum;
24use ferritin_core::load_structure;
25use ferritin_test_data::TestFile;
26
27pub struct MPNNExecConfig {
30 pub(crate) protein_inputs: String, pub(crate) run_config: RunConfig,
32 pub(crate) aabias_config: Option<AABiasConfig>,
33 pub(crate) ligand_mpnn_config: Option<LigandMPNNConfig>,
34 pub(crate) membrane_mpnn_config: Option<MembraneMPNNConfig>,
35 pub(crate) multi_pdb_config: Option<MultiPDBConfig>,
36 pub(crate) residue_control_config: Option<ResidueControl>,
37 pub(crate) device: Device,
38}
39
40impl MPNNExecConfig {
41 pub fn new(
42 device: Device,
43 pdb_path: String,
44 run_config: RunConfig,
45 residue_config: Option<ResidueControl>,
46 aa_bias: Option<AABiasConfig>,
47 lig_mpnn_specific: Option<LigandMPNNConfig>,
48 membrane_mpnn_specific: Option<MembraneMPNNConfig>,
49 multi_pdb_specific: Option<MultiPDBConfig>,
50 ) -> Result<Self, Error> {
51 Ok(MPNNExecConfig {
52 protein_inputs: pdb_path,
53 run_config,
54 aabias_config: aa_bias,
55 ligand_mpnn_config: lig_mpnn_specific,
56 membrane_mpnn_config: membrane_mpnn_specific,
57 residue_control_config: residue_config,
58 multi_pdb_config: multi_pdb_specific,
59 device,
60 })
61 }
62 pub fn load_model(&self, model_type: ModelTypes) -> Result<ProteinMPNN, Error> {
64 let default_dtype = DType::F32;
65 match model_type {
66 ModelTypes::ProteinMPNN => {
67 let (mpnn_file, _handle) = TestFile::ligmpnn_pmpnn_01().create_temp()?;
70 let pth = PthTensors::new(mpnn_file, Some("model_state_dict"))?;
71 let vb =
72 VarBuilder::from_backend(Box::new(pth), default_dtype, self.device.clone());
73 let pconf = ProteinMPNNConfig::proteinmpnn();
74 Ok(ProteinMPNN::load(vb, &pconf).expect("Unable to load the PMPNN Model"))
75 }
76 _ => panic!("not implented!"),
77 }
78 }
79 pub fn generate_model(self) {
80 todo!()
81 }
82 pub fn generate_protein_features(&self) -> Result<ProteinFeatures, Error> {
83 let device = self.device.clone();
84 let base_dtype = DType::F32;
85
86 let ac = load_structure(self.protein_inputs.clone())?;
88
89 let s = ac
90 .encode_amino_acids(&device)
91 .expect("A complete convertion to locations");
92 let x_37 = ac.to_numeric_atom37(&device)?;
93 let x_37_mask = Tensor::ones((x_37.dim(0)?, x_37.dim(1)?), base_dtype, &device)?;
94 let (y, y_t, y_m) = ac.to_numeric_ligand_atoms(&device)?;
95 let res_idx = ac.get_res_index();
96 let res_idx_len = res_idx.len();
97 let res_idx_tensor = Tensor::from_vec(res_idx, (1, res_idx_len), &device)?;
98
99 let chain_letters: Vec<String> = ac
101 .iter_residues_aminoacid()
102 .map(|res| res.chain_id().to_string())
103 .collect();
104
105 let chain_list: Vec<String> = chain_letters
107 .clone()
108 .into_iter()
109 .collect::<std::collections::HashSet<_>>()
110 .into_iter()
111 .collect();
112
113 Ok(ProteinFeatures {
159 s, x: x_37, x_mask: Some(x_37_mask), y, y_t, y_m: Some(y_m), r_idx: res_idx_tensor, chain_labels: None, chain_letters, mask_c: None, chain_list,
170 })
171 }
172}
173
174#[derive(Debug, Clone, ValueEnum, Copy)]
175pub enum ModelTypes {
176 #[value(name = "protein_mpnn")]
177 ProteinMPNN,
178 #[value(name = "ligand_mpnn")]
179 LigandMPNN,
180}
181
182#[derive(Debug)]
183pub struct AABiasConfig {
185 pub bias_aa: Option<String>,
186 pub bias_aa_per_residue: Option<String>,
187 pub omit_aa: Option<String>,
188 pub omit_aa_per_residue: Option<String>,
189}
190
191pub struct LigandMPNNConfig {
193 pub checkpoint_ligand_mpnn: Option<String>,
194 pub ligand_mpnn_use_atom_context: Option<i32>,
195 pub ligand_mpnn_use_side_chain_context: Option<i32>,
196 pub ligand_mpnn_cutoff_for_score: Option<String>,
197}
198
199pub struct MembraneMPNNConfig {
201 pub global_transmembrane_label: Option<i32>,
202 pub transmembrane_buried: Option<String>,
203 pub transmembrane_interface: Option<String>,
204}
205
206pub struct MultiPDBConfig {
208 pub pdb_path_multi: Option<String>,
209 pub fixed_residues_multi: Option<String>,
210 pub redesigned_residues_multi: Option<String>,
211 pub omit_aa_per_residue_multi: Option<String>,
212 pub bias_aa_per_residue_multi: Option<String>,
213}
214#[derive(Clone, Debug)]
215pub struct ProteinMPNNConfig {
216 pub atom_context_num: usize,
217 pub augment_eps: f32,
218 pub dropout_ratio: f32,
219 pub edge_features: i64,
220 pub hidden_dim: i64,
221 pub k_neighbors: i64,
222 pub ligand_mpnn_use_side_chain_context: bool,
223 pub model_type: ModelTypes,
224 pub node_features: i64,
225 pub num_decoder_layers: i64,
226 pub num_encoder_layers: i64,
227 pub num_letters: i64,
228 pub num_rbf: i64,
229 pub scale_factor: f64,
230 pub vocab: i64,
231}
232
233impl ProteinMPNNConfig {
234 pub fn proteinmpnn() -> Self {
235 Self {
236 atom_context_num: 0,
237 augment_eps: 0.0,
238 dropout_ratio: 0.1,
239 edge_features: 128,
240 hidden_dim: 128,
241 k_neighbors: 24,
242 ligand_mpnn_use_side_chain_context: false,
243 model_type: ModelTypes::ProteinMPNN,
244 node_features: 128,
245 num_decoder_layers: 3,
246 num_encoder_layers: 3,
247 num_letters: 21,
248 num_rbf: 16,
249 scale_factor: 1.0,
250 vocab: 21,
251 }
252 }
253 fn ligandmpnn() {
254 todo!()
255 }
256 fn membranempnn() {
257 todo!()
258 }
259}
260
261#[derive(Debug)]
262pub struct ResidueControl {
263 pub fixed_residues: Option<String>,
264 pub redesigned_residues: Option<String>,
265 pub symmetry_residues: Option<String>,
266 pub symmetry_weights: Option<String>,
267 pub chains_to_design: Option<String>,
268 pub parse_these_chains_only: Option<String>,
269}
270
271#[derive(Debug)]
272pub struct RunConfig {
273 pub model_type: Option<ModelTypes>,
274 pub seed: Option<i32>,
275 pub temperature: Option<f32>,
276 pub verbose: Option<i32>,
277 pub save_stats: Option<bool>,
278 pub batch_size: Option<i32>,
279 pub number_of_batches: Option<i32>,
280 pub file_ending: Option<String>,
281 pub zero_indexed: Option<i32>,
282 pub homo_oligomer: Option<i32>,
283 pub fasta_seq_separation: Option<String>,
284}