ferritin_plms/ligandmpnn/
configs.rs1use super::model::ProteinMPNN;
17use super::proteinfeatures::ProteinFeatures;
18use crate::StructureFeatures;
19use anyhow::Error;
20use candle_core::pickle::PthTensors;
21use candle_core::{DType, Device, Tensor};
22use candle_nn::VarBuilder;
23use clap::ValueEnum;
24use ferritin_core::load_structure;
25use ferritin_test_data::TestFile;
26
27#[allow(dead_code)]
30pub struct MPNNExecConfig {
31 pub(crate) protein_inputs: String, pub(crate) run_config: RunConfig,
33 pub(crate) aabias_config: Option<AABiasConfig>,
34 pub(crate) ligand_mpnn_config: Option<LigandMPNNConfig>,
35 pub(crate) membrane_mpnn_config: Option<MembraneMPNNConfig>,
36 pub(crate) multi_pdb_config: Option<MultiPDBConfig>,
37 pub(crate) residue_control_config: Option<ResidueControl>,
38 pub(crate) device: Device,
39}
40
41impl MPNNExecConfig {
42 pub fn new(
43 device: Device,
44 pdb_path: String,
45 run_config: RunConfig,
46 residue_config: Option<ResidueControl>,
47 aa_bias: Option<AABiasConfig>,
48 lig_mpnn_specific: Option<LigandMPNNConfig>,
49 membrane_mpnn_specific: Option<MembraneMPNNConfig>,
50 multi_pdb_specific: Option<MultiPDBConfig>,
51 ) -> Result<Self, Error> {
52 Ok(MPNNExecConfig {
53 protein_inputs: pdb_path,
54 run_config,
55 aabias_config: aa_bias,
56 ligand_mpnn_config: lig_mpnn_specific,
57 membrane_mpnn_config: membrane_mpnn_specific,
58 residue_control_config: residue_config,
59 multi_pdb_config: multi_pdb_specific,
60 device,
61 })
62 }
63 pub fn load_model(&self, model_type: ModelTypes) -> Result<ProteinMPNN, Error> {
65 let default_dtype = DType::F32;
66 match model_type {
67 ModelTypes::ProteinMPNN => {
68 let (mpnn_file, _handle) = TestFile::ligmpnn_pmpnn_01().create_temp()?;
71 let pth = PthTensors::new(mpnn_file, Some("model_state_dict"))?;
72 let vb =
73 VarBuilder::from_backend(Box::new(pth), default_dtype, self.device.clone());
74 let pconf = ProteinMPNNConfig::proteinmpnn();
75 Ok(ProteinMPNN::load(vb, &pconf).expect("Unable to load the PMPNN Model"))
76 }
77 _ => panic!("not implented!"),
78 }
79 }
80 pub fn generate_model(self) {
81 todo!()
82 }
83 pub fn generate_protein_features(&self) -> Result<ProteinFeatures, Error> {
84 let device = self.device.clone();
85 let base_dtype = DType::F32;
86
87 let ac = load_structure(self.protein_inputs.clone())?;
89
90 let s = ac
91 .encode_amino_acids(&device)
92 .expect("A complete convertion to locations");
93 let x_37 = ac.to_numeric_atom37(&device)?;
94 let x_37_mask = Tensor::ones((x_37.dim(0)?, x_37.dim(1)?), base_dtype, &device)?;
95 let (y, y_t, y_m) = ac.to_numeric_ligand_atoms(&device)?;
96 let res_idx = ac.get_res_index();
97 let res_idx_len = res_idx.len();
98 let res_idx_tensor = Tensor::from_vec(res_idx, (1, res_idx_len), &device)?;
99
100 let chain_letters: Vec<String> = ac
102 .iter_residues_aminoacid()
103 .map(|res| res.chain_id().to_string())
104 .collect();
105
106 let chain_list: Vec<String> = chain_letters
108 .clone()
109 .into_iter()
110 .collect::<std::collections::HashSet<_>>()
111 .into_iter()
112 .collect();
113
114 Ok(ProteinFeatures {
160 s, x: x_37, x_mask: Some(x_37_mask), y, y_t, y_m: Some(y_m), r_idx: res_idx_tensor, chain_labels: None, chain_letters, mask_c: None, chain_list,
171 })
172 }
173}
174
175#[derive(Debug, Clone, ValueEnum, Copy)]
176pub enum ModelTypes {
177 #[value(name = "protein_mpnn")]
178 ProteinMPNN,
179 #[value(name = "ligand_mpnn")]
180 LigandMPNN,
181}
182
183#[derive(Debug)]
184pub struct AABiasConfig {
186 pub bias_aa: Option<String>,
187 pub bias_aa_per_residue: Option<String>,
188 pub omit_aa: Option<String>,
189 pub omit_aa_per_residue: Option<String>,
190}
191
192pub struct LigandMPNNConfig {
194 pub checkpoint_ligand_mpnn: Option<String>,
195 pub ligand_mpnn_use_atom_context: Option<i32>,
196 pub ligand_mpnn_use_side_chain_context: Option<i32>,
197 pub ligand_mpnn_cutoff_for_score: Option<String>,
198}
199
200pub struct MembraneMPNNConfig {
202 pub global_transmembrane_label: Option<i32>,
203 pub transmembrane_buried: Option<String>,
204 pub transmembrane_interface: Option<String>,
205}
206
207pub struct MultiPDBConfig {
209 pub pdb_path_multi: Option<String>,
210 pub fixed_residues_multi: Option<String>,
211 pub redesigned_residues_multi: Option<String>,
212 pub omit_aa_per_residue_multi: Option<String>,
213 pub bias_aa_per_residue_multi: Option<String>,
214}
215#[derive(Clone, Debug)]
216pub struct ProteinMPNNConfig {
217 pub atom_context_num: usize,
218 pub augment_eps: f32,
219 pub dropout_ratio: f32,
220 pub edge_features: i64,
221 pub hidden_dim: i64,
222 pub k_neighbors: i64,
223 pub ligand_mpnn_use_side_chain_context: bool,
224 pub model_type: ModelTypes,
225 pub node_features: i64,
226 pub num_decoder_layers: i64,
227 pub num_encoder_layers: i64,
228 pub num_letters: i64,
229 pub num_rbf: i64,
230 pub scale_factor: f64,
231 pub vocab: i64,
232}
233
234impl ProteinMPNNConfig {
235 pub fn proteinmpnn() -> Self {
236 Self {
237 atom_context_num: 0,
238 augment_eps: 0.0,
239 dropout_ratio: 0.1,
240 edge_features: 128,
241 hidden_dim: 128,
242 k_neighbors: 24,
243 ligand_mpnn_use_side_chain_context: false,
244 model_type: ModelTypes::ProteinMPNN,
245 node_features: 128,
246 num_decoder_layers: 3,
247 num_encoder_layers: 3,
248 num_letters: 21,
249 num_rbf: 16,
250 scale_factor: 1.0,
251 vocab: 21,
252 }
253 }
254 #[allow(dead_code)]
255 fn ligandmpnn() {
256 todo!()
257 }
258 #[allow(dead_code)]
259 fn membranempnn() {
260 todo!()
261 }
262}
263
264#[derive(Debug)]
265pub struct ResidueControl {
266 pub fixed_residues: Option<String>,
267 pub redesigned_residues: Option<String>,
268 pub symmetry_residues: Option<String>,
269 pub symmetry_weights: Option<String>,
270 pub chains_to_design: Option<String>,
271 pub parse_these_chains_only: Option<String>,
272}
273
274#[derive(Debug)]
275pub struct RunConfig {
276 pub model_type: Option<ModelTypes>,
277 pub seed: Option<i32>,
278 pub temperature: Option<f32>,
279 pub verbose: Option<i32>,
280 pub save_stats: Option<bool>,
281 pub batch_size: Option<i32>,
282 pub number_of_batches: Option<i32>,
283 pub file_ending: Option<String>,
284 pub zero_indexed: Option<i32>,
285 pub homo_oligomer: Option<i32>,
286 pub fasta_seq_separation: Option<String>,
287}