ferritin_plms/ligandmpnn/
model.rs

1//! A message passing protein design neural network
2//! that samples sequences diffusing conditional probabilities.
3//!
4//! - See the [LigandMPNN Repo](https://github.com/dauparas/LigandMPNN)
5//!
6use super::configs::{ModelTypes, ProteinMPNNConfig};
7use super::proteinfeatures::ProteinFeatures;
8use super::proteinfeaturesmodel::ProteinFeaturesModel;
9use crate::featurize::utilities::{cat_neighbors_nodes, gather_nodes, int_to_aa1};
10use crate::types::PseudoProbability;
11use candle_core::safetensors;
12use candle_core::{D, DType, Device, IndexOp, Module, Result, Tensor};
13use candle_nn::encoding::one_hot;
14use candle_nn::ops::{log_softmax, softmax};
15use candle_nn::{Dropout, Embedding, LayerNorm, Linear, VarBuilder, embedding, layer_norm, linear};
16use candle_transformers::generation::LogitsProcessor;
17use std::collections::HashMap;
18
19// refactoring common fn
20fn concat_node_tensors(h_v: &Tensor, h_e: &Tensor, e_idx: &Tensor) -> Result<Tensor> {
21    let h_ev = cat_neighbors_nodes(h_v, h_e, e_idx)?;
22    let h_v_expand = h_v.unsqueeze(D::Minus2)?;
23    let expand_shape = [
24        h_ev.dims()[0],
25        h_ev.dims()[1],
26        h_ev.dims()[2],
27        h_v_expand.dims()[3],
28    ];
29    let h_v_expand = h_v_expand.expand(&expand_shape)?.to_dtype(h_ev.dtype())?;
30    Tensor::cat(&[&h_v_expand, &h_ev], D::Minus1)?.contiguous()
31}
32// refactoring common fn
33fn apply_dropout_and_norm(
34    input: &Tensor,
35    delta: &Tensor,
36    dropout: &Dropout,
37    norm: &LayerNorm,
38    training: bool,
39) -> Result<Tensor> {
40    let delta_dropout = dropout.forward(delta, training)?;
41    norm.forward(&(input + delta_dropout)?)
42}
43
44pub fn multinomial_sample(probs: &Tensor, temperature: f64, seed: u64) -> Result<Tensor> {
45    let mut logits_processor = LogitsProcessor::new(
46        seed,              // seed for reproducibility
47        Some(temperature), // temperature scaling
48        // None,              // top_p (nucleus sampling), we don't need this
49        Some(0.95), // top_p (nucleus sampling), we don't need this
50    );
51    let idx = logits_processor.sample(probs)?;
52    // println!("Selected index: {}", idx);
53    if idx >= 21 {
54        println!("WARNING: Invalid index {} selected", idx);
55    }
56    Tensor::new(&[idx], probs.device())
57}
58
59// Primary Return Object from the ProtMPNN Model
60#[derive(Clone, Debug)]
61pub struct ScoreOutput {
62    // Sequence
63    pub(crate) s: Tensor,
64    pub(crate) log_probs: Tensor,
65    pub(crate) logits: Tensor,
66    pub(crate) decoding_order: Tensor,
67}
68///  Score dims are [Batch, seqlength]
69impl ScoreOutput {
70    pub fn get_sequences(&self) -> Result<Vec<String>> {
71        let (b, l) = self.s.dims2()?;
72        let mut sequences = Vec::with_capacity(b);
73        for batch_idx in 0..b {
74            let batch = self.s.get(batch_idx)?;
75            let mut sequence = String::with_capacity(l);
76            for pos in 0..l {
77                let aa_idx = batch.get(pos)?.to_vec0::<u32>()?;
78                sequence.push(int_to_aa1(aa_idx));
79            }
80            sequences.push(sequence);
81        }
82        Ok(sequences)
83    }
84    pub fn get_decoding_order(&self) -> Result<Vec<u32>> {
85        let values = self.decoding_order.flatten_all()?.to_vec1::<u32>()?;
86        Ok(values)
87    }
88    pub fn get_log_probs(&self) -> &Tensor {
89        &self.log_probs
90    }
91    pub fn get_pseudo_probabilities(&self) -> Result<Vec<PseudoProbability>> {
92        let (batch_size, seq_len, _vocab_size) = self.logits.dims3()?;
93        let mut all_probabilities = Vec::with_capacity(batch_size * seq_len);
94
95        for batch_idx in 0..batch_size {
96            let batch_logits = self.logits.get(batch_idx)?;
97
98            for pos in 0..seq_len {
99                let pos_logits = batch_logits.get(pos)?;
100                let probs = softmax(&pos_logits, 0)?;
101                let probs = probs.to_vec1::<f32>()?;
102
103                // Get the decoding order to determine the actual position
104                let actual_pos = if let Ok(order) = self
105                    .decoding_order
106                    .get(batch_idx)?
107                    .get(pos)?
108                    .to_scalar::<u32>()
109                {
110                    order as usize
111                } else {
112                    pos
113                };
114
115                for aa_idx in 0..probs.len() {
116                    if probs[aa_idx] > 0.01 {
117                        // Only include probabilities above threshold
118                        all_probabilities.push(PseudoProbability {
119                            position: actual_pos,
120                            pseudo_prob: probs[aa_idx],
121                            amino_acid: int_to_aa1(aa_idx as u32),
122                        });
123                    }
124                }
125            }
126        }
127        // Sort by position and then by probability (descending)
128        all_probabilities.sort_by(|a, b| {
129            a.position.cmp(&b.position).then(
130                b.pseudo_prob
131                    .partial_cmp(&a.pseudo_prob)
132                    .unwrap_or(std::cmp::Ordering::Equal),
133            )
134        });
135        Ok(all_probabilities)
136    }
137    pub fn save_as_safetensors(&self, filename: String) -> Result<()> {
138        let mut tensors = HashMap::new();
139        tensors.insert("S".to_string(), self.s.clone());
140        tensors.insert("log_probs".to_string(), self.log_probs.clone());
141        tensors.insert("logits".to_string(), self.logits.clone());
142        tensors.insert("decoding_order".to_string(), self.decoding_order.clone());
143        // Create directory if it doesn't exist
144        if let Some(parent) = std::path::Path::new(&filename).parent() {
145            std::fs::create_dir_all(parent)?;
146        }
147        let _ = safetensors::save(&tensors, &filename);
148        Ok(())
149    }
150}
151
152#[derive(Clone, Debug)]
153struct PositionWiseFeedForward {
154    w1: Linear,
155    w2: Linear,
156}
157
158impl PositionWiseFeedForward {
159    fn new(vb: VarBuilder, dim_input: usize, dim_feedforward: usize) -> Result<Self> {
160        let w1 = linear::linear(dim_input, dim_feedforward, vb.pp("W_in"))?;
161        let w2 = linear::linear(dim_feedforward, dim_input, vb.pp("W_out"))?;
162        Ok(Self { w1, w2 })
163    }
164}
165
166impl Module for PositionWiseFeedForward {
167    fn forward(&self, x: &Tensor) -> Result<Tensor> {
168        self.w1.forward(x)?.gelu().and_then(|x| self.w2.forward(&x))
169    }
170}
171
172#[derive(Clone, Debug)]
173pub struct EncLayer {
174    num_hidden: usize,
175    num_in: usize,
176    scale: f64,
177    dropout1: Dropout,
178    dropout2: Dropout,
179    dropout3: Dropout,
180    norm1: LayerNorm,
181    norm2: LayerNorm,
182    norm3: LayerNorm,
183    w1: Linear,
184    w2: Linear,
185    w3: Linear,
186    w11: Linear,
187    w12: Linear,
188    w13: Linear,
189    dense: PositionWiseFeedForward,
190}
191
192impl EncLayer {
193    pub fn load(vb: VarBuilder, config: &ProteinMPNNConfig, layer: i32) -> Result<Self> {
194        let vb = vb.pp(layer); // handle the layer number here.
195        let num_hidden = config.hidden_dim as usize;
196        let augment_eps = config.augment_eps as f64;
197        let num_in = (config.hidden_dim * 2) as usize;
198        let dropout_ratio = config.dropout_ratio;
199
200        // Create layer norms
201        let norm1 = layer_norm(num_hidden, augment_eps, vb.pp("norm1"))?;
202        let norm2 = layer_norm(num_hidden, augment_eps, vb.pp("norm2"))?;
203        let norm3 = layer_norm(num_hidden, augment_eps, vb.pp("norm3"))?;
204
205        // Create linear layers
206        let w1 = linear(num_hidden + num_in, num_hidden, vb.pp("W1"))?;
207        let w2 = linear(num_hidden, num_hidden, vb.pp("W2"))?;
208        let w3 = linear(num_hidden, num_hidden, vb.pp("W3"))?;
209        let w11 = linear(num_hidden + num_in, num_hidden, vb.pp("W11"))?;
210        let w12 = linear(num_hidden, num_hidden, vb.pp("W12"))?;
211        let w13 = linear(num_hidden, num_hidden, vb.pp("W13"))?;
212
213        // Create dropouts with same ratio
214        let dropout1 = Dropout::new(dropout_ratio);
215        let dropout2 = Dropout::new(dropout_ratio);
216        let dropout3 = Dropout::new(dropout_ratio);
217
218        let dense = PositionWiseFeedForward::new(vb.pp("dense"), num_hidden, num_hidden * 4)?;
219
220        Ok(Self {
221            num_hidden,
222            num_in,
223            scale: config.scale_factor,
224            dropout1,
225            dropout2,
226            dropout3,
227            norm1,
228            norm2,
229            norm3,
230            w1,
231            w2,
232            w3,
233            w11,
234            w12,
235            w13,
236            dense,
237        })
238    }
239    fn forward(
240        &self,
241        h_v: &Tensor,
242        h_e: &Tensor,
243        e_idx: &Tensor,
244        mask_v: Option<&Tensor>,
245        mask_attend: Option<&Tensor>,
246        training: Option<bool>,
247    ) -> Result<(Tensor, Tensor)> {
248        let training = training.unwrap_or(false);
249        let h_v = h_v.to_dtype(DType::F32)?;
250        let h_ev = concat_node_tensors(&h_v, h_e, e_idx)?;
251        let h_message = self
252            .w1
253            .forward(&h_ev)?
254            .gelu()?
255            .apply(&self.w2)?
256            .gelu()?
257            .apply(&self.w3)?;
258
259        let h_message = mask_attend
260            .map(|mask| mask.unsqueeze(D::Minus1)?.broadcast_mul(&h_message))
261            .transpose()?
262            .unwrap_or(h_message);
263
264        // Safe division with scale
265        let scale = if self.scale == 0.0 { 1.0 } else { self.scale };
266        let dh = (h_message.sum(D::Minus2)? / scale)?;
267
268        let h_v = apply_dropout_and_norm(&h_v, &dh, &self.dropout1, &self.norm1, training)?;
269        let dense_output = self.dense.forward(&h_v)?;
270        let h_v =
271            apply_dropout_and_norm(&h_v, &dense_output, &self.dropout2, &self.norm2, training)?;
272
273        // Apply mask if provided
274        let h_v = mask_v
275            .map(|mask| mask.unsqueeze(D::Minus1)?.broadcast_mul(&h_v))
276            .transpose()?
277            .unwrap_or(h_v);
278
279        let h_ev = concat_node_tensors(&h_v, h_e, e_idx)?;
280        let h_message = self
281            .w11
282            .forward(&h_ev)?
283            .gelu()?
284            .apply(&self.w12)?
285            .gelu()?
286            .apply(&self.w13)?;
287
288        let h_e = apply_dropout_and_norm(h_e, &h_message, &self.dropout3, &self.norm3, training)?;
289        Ok((h_v, h_e))
290    }
291}
292
293#[derive(Clone, Debug)]
294pub struct DecLayer {
295    num_hidden: usize,
296    num_in: usize,
297    scale: f64,
298    dropout1: Dropout,
299    dropout2: Dropout,
300    norm1: LayerNorm,
301    norm2: LayerNorm,
302    w1: Linear,
303    w2: Linear,
304    w3: Linear,
305    dense: PositionWiseFeedForward,
306}
307
308impl DecLayer {
309    pub fn load(vb: VarBuilder, config: &ProteinMPNNConfig, layer: i32) -> Result<Self> {
310        let vb = vb.pp(layer); // handle the layer number here.
311        let num_hidden = config.hidden_dim as usize;
312        let augment_eps = config.augment_eps as f64;
313        let num_in = (config.hidden_dim * 3) as usize;
314        let dropout_ratio = config.dropout_ratio;
315
316        let norm1 = layer_norm::layer_norm(num_hidden, augment_eps, vb.pp("norm1"))?;
317        let norm2 = layer_norm::layer_norm(num_hidden, augment_eps, vb.pp("norm2"))?;
318
319        let w1 = linear::linear(num_hidden + num_in, num_hidden, vb.pp("W1"))?;
320        let w2 = linear::linear(num_hidden, num_hidden, vb.pp("W2"))?;
321        let w3 = linear::linear(num_hidden, num_hidden, vb.pp("W3"))?;
322        let dropout1 = Dropout::new(dropout_ratio);
323        let dropout2 = Dropout::new(dropout_ratio);
324
325        let dense = PositionWiseFeedForward::new(vb.pp("dense"), num_hidden, num_hidden * 4)?;
326
327        Ok(Self {
328            num_hidden,
329            num_in,
330            scale: config.scale_factor,
331            dropout1,
332            dropout2,
333            norm1,
334            norm2,
335            w1,
336            w2,
337            w3,
338            dense,
339        })
340    }
341    pub fn forward(
342        &self,
343        h_v: &Tensor,
344        h_e: &Tensor,
345        mask_v: Option<&Tensor>,
346        mask_attend: Option<&Tensor>,
347        training: Option<bool>,
348    ) -> Result<Tensor> {
349        let training_bool = training.unwrap_or(false);
350
351        // Expand node features to match edge dimensions
352        let expand_shape = [
353            h_e.dims()[0], // batch (1)
354            h_e.dims()[1], // sequence length (93)
355            h_e.dims()[2], // number of neighbors (24)
356            h_v.dims()[2], // keep original hidden dim (128)
357        ];
358
359        let h_v_expand = h_v.unsqueeze(D::Minus2)?.expand(&expand_shape)?;
360        let h_ev = Tensor::cat(&[&h_v_expand, h_e], D::Minus1)?.contiguous()?;
361
362        let h_message = self
363            .w1
364            .forward(&h_ev)?
365            .gelu()?
366            .apply(&self.w2)?
367            .gelu()?
368            .apply(&self.w3)?;
369
370        let h_message = self.dropout1.forward(&h_message, training_bool)?;
371
372        let h_message = mask_attend
373            .map(|mask| mask.unsqueeze(D::Minus1)?.broadcast_mul(&h_message))
374            .transpose()?
375            .unwrap_or(h_message);
376
377        let dh = (h_message.sum(D::Minus2)? / self.scale)?;
378        let h_v = self.norm1.forward(&(h_v + dh)?)?;
379        let dh = self.dense.forward(&h_v)?;
380        let dh_dropout = self.dropout2.forward(&dh, training_bool)?;
381        let h_v = self.norm2.forward(&(h_v + dh_dropout)?)?;
382
383        // Apply optional node mask
384        let h_v = mask_v
385            .map(|mask| mask.unsqueeze(D::Minus1)?.broadcast_mul(&h_v))
386            .transpose()?
387            .unwrap_or(h_v);
388
389        Ok(h_v)
390    }
391}
392
393/// ProteinMPNN Model
394/// - [link](https://github.com/dauparas/LigandMPNN/blob/main/model_utils.py#L10C7-L10C18)
395pub struct ProteinMPNN {
396    pub(crate) config: ProteinMPNNConfig,
397    pub(crate) decoder_layers: Vec<DecLayer>,
398    pub(crate) device: Device,
399    pub(crate) encoder_layers: Vec<EncLayer>,
400    pub(crate) features: ProteinFeaturesModel,
401    pub(crate) w_e: Linear,
402    pub(crate) w_out: Linear,
403    pub(crate) w_s: Embedding,
404}
405
406impl ProteinMPNN {
407    pub fn load(vb: VarBuilder, config: &ProteinMPNNConfig) -> Result<Self> {
408        let hidden_dim = config.hidden_dim as usize;
409        let edge_features = config.edge_features as usize;
410        let num_letters = config.num_letters as usize;
411        let vocab_size = config.vocab as usize;
412
413        // Create encoder and decoder layers using iterators
414        let encoder_layers = (0..config.num_encoder_layers)
415            .map(|i| EncLayer::load(vb.pp("encoder_layers"), config, i as i32))
416            .collect::<Result<Vec<_>>>()?;
417
418        let decoder_layers = (0..config.num_decoder_layers)
419            .map(|i| DecLayer::load(vb.pp("decoder_layers"), config, i as i32))
420            .collect::<Result<Vec<_>>>()?;
421
422        // Initialize weights
423        let w_e = linear::linear(edge_features, hidden_dim, vb.pp("W_e"))?;
424        let w_out = linear::linear(hidden_dim, num_letters, vb.pp("W_out"))?;
425        let w_s = embedding(vocab_size, hidden_dim, vb.pp("W_s"))?;
426        // Features
427        let features = ProteinFeaturesModel::load(vb.pp("features"), config.clone())?;
428
429        Ok(Self {
430            config: config.clone(),
431            decoder_layers,
432            device: vb.device().clone(),
433            encoder_layers,
434            features,
435            w_e,
436            w_out,
437            w_s,
438        })
439    }
440    pub fn encode(&self, features: &ProteinFeatures) -> Result<(Tensor, Tensor, Tensor)> {
441        let s_true = features.get_sequence();
442        let base_dtype = DType::F32;
443        let mask = match features.get_sequence_mask() {
444            Some(m) => m,
445            None => &Tensor::ones_like(s_true)?,
446        };
447        match self.config.model_type {
448            ModelTypes::ProteinMPNN => {
449                let (e, e_idx) = self.features.forward(features, &self.device)?;
450                let h_v = Tensor::zeros(
451                    (e.dim(0)?, e.dim(1)?, e.dim(D::Minus1)?),
452                    base_dtype,
453                    &self.device,
454                )?;
455                let h_e = self.w_e.forward(&e)?;
456                let mask_attend = if let Some(seq_mask) = features.get_sequence_mask() {
457                    let mask_expanded = seq_mask.unsqueeze(D::Minus1)?; // [B, L, 1]
458                    let mask_gathered = gather_nodes(&mask_expanded, &e_idx)?.squeeze(D::Minus1)?;
459                    let mask_unsqueezed = mask.unsqueeze(D::Minus1)?; // [B, L, 1]
460                    mask_unsqueezed
461                        .expand((
462                            mask_gathered.dim(0)?, // batch
463                            mask_gathered.dim(1)?, // sequence length
464                            mask_gathered.dim(2)?, // number of neighbors
465                        ))?
466                        .mul(&mask_gathered)?
467                } else {
468                    let (b, l) = mask.dims2()?;
469                    Tensor::ones((b, l, e_idx.dim(2)?), DType::F32, &self.device)?
470                };
471                println!("Beginning the Encoding...");
472                // todo: dtype handling not ideal
473                let mask_f32 = mask.to_dtype(base_dtype)?;
474                let mask_attend_f32 = mask_attend.to_dtype(base_dtype)?;
475
476                // Process through all encoder layers
477                let (h_v, h_e) =
478                    self.encoder_layers
479                        .iter()
480                        .fold(Ok((h_v, h_e)), |acc, layer| {
481                            let (h_v, h_e) = acc?;
482                            layer.forward(
483                                &h_v,
484                                &h_e,
485                                &e_idx,
486                                Some(&mask_f32),
487                                Some(&mask_attend_f32),
488                                Some(false),
489                            )
490                        })?;
491
492                Ok((h_v, h_e, e_idx))
493            }
494            ModelTypes::LigandMPNN => {
495                todo!()
496                //     let (v, e, e_idx, y_nodes, y_edges, y_m) = self.features.forward(feature_dict)?;
497                //     let mut h_v = Tensor::zeros((e.dim(0)?, e.dim(1)?, e.dim(-1)?), device)?;
498                //     let mut h_e = self.w_e.forward(&e)?;
499                //     let h_e_context = self.w_v.forward(&v)?;
500                //     let mask_attend = gather_nodes(&mask.unsqueeze(-1)?, &e_idx)?.squeeze(-1)?;
501                //     let mask_attend = mask.unsqueeze(-1)? * &mask_attend;
502                //
503                //     for layer in &self.encoder_layers {
504                //         let (new_h_v, new_h_e) =
505                //             layer.forward(&h_v, &h_e, &e_idx, &mask, &mask_attend)?;
506                //         h_v = new_h_v;
507                //         h_e = new_h_e;
508                //     }
509                //
510                //     let mut h_v_c = self.w_c.forward(&h_v)?;
511                //     let y_m_edges = &y_m.unsqueeze(-1)? * &y_m.unsqueeze(-2)?;
512                //     let mut y_nodes = self.w_nodes_y.forward(&y_nodes)?;
513                //     let y_edges = self.w_edges_y.forward(&y_edges)?;
514                //
515                //     for (y_layer, c_layer) in self
516                //         .y_context_encoder_layers
517                //         .iter()
518                //         .zip(&self.context_encoder_layers)
519                //     {
520                //         y_nodes = y_layer.forward(&y_nodes, &y_edges, &y_m, &y_m_edges)?;
521                //         let h_e_context_cat = Tensor::cat(&[&h_e_context, &y_nodes], -1)?;
522                //         h_v_c = c_layer.forward(&h_v_c, &h_e_context_cat, &mask, &y_m)?;
523                //     }
524                //     h_v_c = self.v_c.forward(&h_v_c)?;
525                //     h_v = &h_v + &self.v_c_norm.forward(&self.dropout.forward(&h_v_c)?)?;
526                //     Ok((h_v, h_e, e_idx))
527            }
528        }
529    }
530    // Removed unused decode methods
531    pub fn simple_decode(&self, features: &ProteinFeatures) -> Result<ScoreOutput> {
532        // Create a batch size of 1 for simple decoding
533        let b_decoder = 1;
534
535        // Extract relevant features
536        let ProteinFeatures { s, x_mask, .. } = features;
537        let device = s.device();
538        let (_, l) = s.dims2()?;
539
540        // Encode the structure once
541        let (h_v_enc, h_e_enc, e_idx_enc) = self.encode(features)?;
542
543        // Process all positions at once with a simplified approach
544        let s_true = s.clone();
545        let mask = x_mask.clone().unwrap();
546
547        // Create tensors for the decoder
548        let zeros = Tensor::zeros((b_decoder, l, h_v_enc.dim(D::Minus1)?), DType::F32, device)?;
549        let h_v = h_v_enc.clone();
550        let h_e = h_e_enc.clone();
551        let e_idx = e_idx_enc.clone();
552
553        // Build encoder embeddings for neighbors
554        let h_ex_encoder = cat_neighbors_nodes(&zeros, &h_e, &e_idx)?;
555        let h_exv_encoder = cat_neighbors_nodes(&h_v, &h_ex_encoder, &e_idx)?;
556
557        // Apply decoder layers using only structure information
558        let h_v_final = self.decoder_layers.iter().fold(Ok(h_v), |acc, layer| {
559            layer.forward(&acc?, &h_exv_encoder, Some(&mask), None, None)
560        })?;
561
562        // Calculate logits and log probabilities
563        let logits = self.w_out.forward(&h_v_final)?;
564        let log_probs = log_softmax(&logits, D::Minus1)?;
565
566        // For the decoding order, just use a placeholder
567        let decoding_order = Tensor::arange(0, l as i64, device)?
568            .reshape((1, l))?
569            .broadcast_as((b_decoder, l))?
570            .to_dtype(DType::F32)?;
571
572        // Return the output directly
573        Ok(ScoreOutput {
574            s: s_true,
575            log_probs,
576            logits,
577            decoding_order,
578        })
579    }
580    pub fn sample(
581        &self,
582        features: &ProteinFeatures,
583        temperature: f64,
584        seed: u64,
585    ) -> Result<ScoreOutput> {
586        let sample_dtype = DType::F32;
587        let ProteinFeatures {
588            s,
589            x_mask,
590            // symmetry_residues,
591            // symmetry_weights,
592            ..
593        } = features;
594        let s_true = s.to_dtype(sample_dtype)?;
595        let device = s.device();
596        let (b, l) = s.dims2()?;
597        // Todo: This is a hack. we should be passing in encoded chains.
598        // let chain_mask = Tensor::ones_like(&x_mask.as_ref().unwrap())?.to_dtype(sample_dtype)?;
599        // let chain_mask = x_mask.as_ref().unwrap().mul(&chain_mask)?;
600        let chain_mask = x_mask.as_ref().unwrap().to_dtype(sample_dtype)?;
601        let (h_v, h_e, e_idx) = self.encode(features)?;
602        let rand_tensor = Tensor::randn(0f32, 0.25f32, (b, l), device)?.to_dtype(sample_dtype)?;
603        let decoding_order = (&chain_mask + 0.0001)?
604            .mul(&rand_tensor.abs()?)?
605            .arg_sort_last_dim(false)?;
606        // TodoL add  bias
607        // # [B,L,21] - amino acid bias per position
608        let bias = Tensor::ones((b, l, 21), sample_dtype, device)?;
609        let symmetry_residues: Option<Vec<i32>> = None;
610        match symmetry_residues {
611            None => {
612                let e_idx = e_idx.repeat(&[b, 1, 1])?;
613                let permutation_matrix_reverse = one_hot(decoding_order.clone(), l, 1f32, 0f32)?
614                    .to_dtype(sample_dtype)?
615                    .contiguous()?;
616                let tril = Tensor::tril2(l, sample_dtype, device)?;
617                let tril = tril.unsqueeze(0)?;
618                let temp = tril
619                    .matmul(&permutation_matrix_reverse.transpose(1, 2)?)?
620                    .contiguous()?; //tensor of shape (b, i, q)
621                let order_mask_backward = temp
622                    .matmul(&permutation_matrix_reverse.transpose(1, 2)?)?
623                    .contiguous()?; // This will give us a tensor of shape (b, q, p)
624                let mask_attend = order_mask_backward
625                    .gather(&e_idx, 2)?
626                    .unsqueeze(D::Minus1)?;
627                let mask_1d = x_mask.as_ref().unwrap().reshape((b, l, 1, 1))?;
628                // Broadcast mask_1d to match mask_attend's shape
629                let mask_1d = mask_1d
630                    .broadcast_as(mask_attend.shape())?
631                    .to_dtype(sample_dtype)?;
632                let mask_bw = mask_1d.mul(&mask_attend)?;
633                let mask_fw = mask_1d.mul(&(Tensor::ones_like(&mask_attend)? - mask_attend)?)?;
634                // Note: `sample` begins to diverge from the `score` here.
635                // repeat for decoding
636                let s_true = s_true.repeat((b, 1))?;
637                let h_v = h_v.repeat((b, 1, 1))?;
638                let h_e = h_e.repeat((b, 1, 1, 1))?;
639                let mask = x_mask.as_ref().unwrap().repeat((b, 1))?.contiguous()?;
640                let chain_mask = &chain_mask.repeat((b, 1))?;
641                let bias = bias.repeat((b, 1, 1))?;
642                let mut all_probs = Tensor::zeros((b, l, 20), sample_dtype, device)?;
643                // why is this one 21 and the others are 20?
644                let mut all_log_probs = Tensor::zeros((b, l, 21), sample_dtype, device)?;
645                let mut h_s = Tensor::zeros_like(&h_v)?;
646                // note: we this value of 20 is `X`. We will need to replace the values below, not add them
647                let mut s = Tensor::full(20u32, (b, l), device)?;
648                let mut h_v_stack = vec![h_v.clone()];
649
650                for _ in 0..self.decoder_layers.len() {
651                    let zeros = Tensor::zeros_like(&h_v)?;
652                    h_v_stack.push(zeros);
653                }
654                let h_ex_encoder = cat_neighbors_nodes(&Tensor::zeros_like(&h_s)?, &h_e, &e_idx)?;
655                let h_exv_encoder = cat_neighbors_nodes(&h_v, &h_ex_encoder, &e_idx)?;
656                let mask_fw = mask_fw
657                    .broadcast_as(h_exv_encoder.shape())?
658                    .to_dtype(h_exv_encoder.dtype())?;
659                let h_exv_encoder_fw = mask_fw.mul(&h_exv_encoder)?;
660                for t_ in 0..l {
661                    let t = decoding_order.i((.., t_))?;
662                    let t_gather = t.unsqueeze(1)?; // Shape [B, 1]
663                    // Gather masks and bias
664                    let chain_mask_t = chain_mask.gather(&t_gather, 1)?.squeeze(1)?;
665                    let mask_t = mask.gather(&t_gather, 1)?.squeeze(1)?.contiguous()?;
666                    let bias_t = bias
667                        .gather(&t_gather.unsqueeze(2)?.expand((b, 1, 21))?.contiguous()?, 1)?
668                        .squeeze(1)?;
669                    // Gather edge and node indices/features
670                    let e_idx_t = e_idx
671                        .gather(
672                            &t_gather
673                                .unsqueeze(2)?
674                                .expand((b, 1, e_idx.dim(2)?))?
675                                .contiguous()?,
676                            1,
677                        )?
678                        .contiguous()?;
679                    let h_e_t = h_e.gather(
680                        &t_gather
681                            .unsqueeze(2)?
682                            .unsqueeze(3)?
683                            .expand((b, 1, h_e.dim(2)?, h_e.dim(3)?))?
684                            .contiguous()?,
685                        1,
686                    )?;
687                    let n = e_idx_t.dim(2)?; // number of neighbors
688                    let c = h_s.dim(2)?; // channels/features
689                    let h_e_t = h_e_t
690                        .squeeze(1)? // [B, N, C]
691                        .unsqueeze(1)? // [B, 1, N, C]
692                        .expand((b, l, n, c))? // [B, L, N, C]
693                        .contiguous()?;
694                    let e_idx_t = e_idx_t
695                        .expand((b, l, n))? // [B, L, N]
696                        .contiguous()?;
697                    let h_es_t = cat_neighbors_nodes(&h_s, &h_e_t, &e_idx_t)?;
698                    let h_exv_encoder_t = h_exv_encoder_fw.gather(
699                        &t_gather
700                            .unsqueeze(2)?
701                            .unsqueeze(3)?
702                            .expand((b, 1, h_exv_encoder_fw.dim(2)?, h_exv_encoder_fw.dim(3)?))?
703                            .contiguous()?,
704                        1,
705                    )?;
706                    let mask_bw_t = mask_bw.gather(
707                        &t_gather
708                            .unsqueeze(2)?
709                            .unsqueeze(3)?
710                            .expand((b, 1, mask_bw.dim(2)?, mask_bw.dim(3)?))?
711                            .contiguous()?,
712                        1,
713                    )?;
714
715                    // Decoder layers loop
716                    for l in 0..self.decoder_layers.len() {
717                        let h_v_stack_l = &h_v_stack[l];
718                        let h_esv_decoder_t = cat_neighbors_nodes(h_v_stack_l, &h_es_t, &e_idx_t)?;
719                        let h_v_t = h_v_stack_l.gather(
720                            &t_gather
721                                .unsqueeze(2)?
722                                .expand((b, 1, h_v_stack_l.dim(2)?))?
723                                .contiguous()?,
724                            1,
725                        )?;
726                        let mask_bw_t = mask_bw_t.expand(h_esv_decoder_t.dims())?.contiguous()?;
727                        let h_exv_encoder_t = h_exv_encoder_t
728                            .expand(h_esv_decoder_t.dims())?
729                            .contiguous()?
730                            .to_dtype(sample_dtype)?;
731                        let h_esv_t = mask_bw_t
732                            .mul(&h_esv_decoder_t.to_dtype(sample_dtype)?)?
733                            .add(&h_exv_encoder_t)?
734                            .to_dtype(sample_dtype)?
735                            .contiguous()?;
736                        let h_v_t = h_v_t
737                            .expand((
738                                h_esv_t.dim(0)?, // batch size
739                                h_esv_t.dim(1)?, // sequence length (93)
740                                h_v_t.dim(2)?,   // features (128)
741                            ))?
742                            .contiguous()?;
743                        let decoder_output = self.decoder_layers[l].forward(
744                            &h_v_t,
745                            &h_esv_t,
746                            Some(&mask_t),
747                            None,
748                            None,
749                        )?;
750                        let t_expanded = t_gather.reshape(&[b])?; // This will give us a 1D tensor of shape [b]
751                        let decoder_output = decoder_output
752                            .narrow(1, 0, 1)?
753                            .squeeze(1)? // Now [1, 128]
754                            .unsqueeze(1)?; // Now [1, 1, 128] - same rank as target
755                        h_v_stack[l + 1] =
756                            h_v_stack[l + 1].index_add(&t_expanded, &decoder_output, 1)?;
757                        // h_v_stack[l + 1] =
758                        //     h_v_stack[l + 1].index_add(&t_expanded, &decoder_output, 1)?;
759                    }
760                    let h_v_t = h_v_stack
761                        .last()
762                        .unwrap()
763                        .gather(
764                            &t_gather
765                                .unsqueeze(2)?
766                                .expand((b, 1, h_v_stack.last().unwrap().dim(2)?))?
767                                .contiguous()?,
768                            1,
769                        )?
770                        .squeeze(1)?;
771                    // Generate logits and probabilities
772                    let logits = self.w_out.forward(&h_v_t)?;
773                    let log_probs = log_softmax(&logits, D::Minus1)?;
774
775                    // explicit for OoO
776                    let probs = {
777                        let biased_logits = logits.add(&bias_t)?; // (logits + bias_t)
778                        let scaled_logits = (biased_logits / temperature)?; // (logits + bias_t) / temperature
779                        softmax(&scaled_logits, D::Minus1)? // softmax((logits + bias_t) / temperature)
780                    };
781
782                    let probs_sample = probs
783                        .narrow(1, 0, 20)?
784                        .div(&probs.narrow(1, 0, 20)?.sum_keepdim(1)?.expand((b, 20))?)?;
785                    // Sample new token
786                    let sum = probs_sample.sum(1)?;
787                    let probs_sample_1d = probs_sample
788                        .squeeze(0)? // Remove batch dimension -> [20]
789                        .clamp(1e-10, 1.0)?
790                        .broadcast_div(&sum)?
791                        .contiguous()?;
792
793                    let s_t = multinomial_sample(&probs_sample_1d, temperature, seed)?;
794                    let s_t = s_t.to_dtype(sample_dtype)?;
795                    let s_true = s_true.to_dtype(sample_dtype)?;
796                    let s_true_t = s_true.gather(&t_gather, 1)?.squeeze(1)?;
797                    let s_t = s_t
798                        .mul(&chain_mask_t)?
799                        .add(&s_true_t.mul(&(&chain_mask_t.neg()? + 1.0)?)?)?
800                        .to_dtype(DType::U32)?;
801
802                    let s_t_idx = s_t.to_dtype(DType::U32)?;
803                    let s_t_idx = s_t_idx.reshape(&[s_t_idx.dim(0)?])?;
804                    let h_s_update = self.w_s.forward(&s_t_idx)?.unsqueeze(1)?;
805                    let t_gather_expanded = t_gather.reshape(&[b])?;
806                    let h_s_update = h_s_update.squeeze(0)?.unsqueeze(1)?;
807                    h_s =
808                        h_s.index_add(&t_gather_expanded, &Tensor::zeros_like(&h_s_update)?, 1)?;
809                    h_s = h_s.index_add(&t_gather_expanded, &h_s_update, 1)?;
810
811                    s = {
812                        let dim = 1;
813                        let start = t_gather.squeeze(0)?.squeeze(0)?.to_scalar::<u32>()? as usize;
814                        let s_t_expanded = s_t.unsqueeze(1)?;
815                        s.slice_scatter(&s_t_expanded, dim, start)?
816                    };
817
818                    let probs_update = chain_mask_t
819                        .unsqueeze(1)?
820                        .unsqueeze(2)?
821                        .expand((b, 1, 20))?
822                        .mul(&probs_sample.unsqueeze(1)?)?;
823                    let t_expanded = t_gather.reshape(&[b])?;
824                    let probs_update = probs_update
825                        .squeeze(1)? // Remove extra dimension
826                        .unsqueeze(1)?;
827                    all_probs =
828                        all_probs.index_add(&t_expanded, &Tensor::zeros_like(&probs_update)?, 1)?;
829                    all_probs = all_probs.index_add(&t_expanded, &probs_update, 1)?;
830                    let log_probs_update = chain_mask_t
831                        .unsqueeze(1)?
832                        .unsqueeze(2)?
833                        .expand((b, 1, 21))?
834                        .mul(&log_probs.unsqueeze(1)?)?
835                        .squeeze(1)?
836                        .unsqueeze(1)?;
837
838                    all_log_probs = all_log_probs.index_add(
839                        &t_expanded,
840                        &Tensor::zeros_like(&log_probs_update)?,
841                        1,
842                    )?;
843                    all_log_probs = all_log_probs.index_add(&t_expanded, &log_probs_update, 1)?;
844                }
845                Ok(ScoreOutput {
846                    s,
847                    log_probs: all_probs,
848                    logits: all_log_probs,
849                    decoding_order,
850                })
851            }
852            Some(symmetry_residues) => {
853                todo!()
854            }
855        }
856    }
857
858    pub fn score(&self, features: &ProteinFeatures, use_sequence: bool) -> Result<ScoreOutput> {
859        let ProteinFeatures { s, x_mask, .. } = &features;
860        let sample_dtype = DType::F32;
861        let s_true = &s.clone();
862        let device = s_true.device();
863        let (b, l) = s_true.dims2()?;
864        let mask = &x_mask.as_ref().clone();
865        let b_decoder: usize = b;
866
867        // Todo: This is a hack. we should be passing in encoded chains.
868        // Update chain_mask to include missing regions
869        let chain_mask = Tensor::zeros_like(mask.unwrap())?.to_dtype(sample_dtype)?;
870        let chain_mask = mask.unwrap().mul(&chain_mask)?; // does the order count here?
871
872        // encode ...
873        let (h_v, h_e, e_idx) = self.encode(features)?;
874        let rand_tensor = Tensor::randn(0f32, 1f32, (b, l), device)?.to_dtype(sample_dtype)?;
875        // Compute decoding order
876        let decoding_order = (chain_mask + 0.001)?
877            .mul(&rand_tensor.abs()?)?
878            .arg_sort_last_dim(false)?;
879
880        let symmetry_residues: Option<Vec<i32>> = None;
881
882        let (mask_fw, mask_bw, e_idx, decoding_order) = match symmetry_residues {
883            Some(symmetry_residues) => {
884                todo!();
885            }
886            None => {
887                let e_idx = e_idx.repeat(&[b_decoder, 1, 1])?;
888                let permutation_matrix_reverse = one_hot(decoding_order.clone(), l, 1f32, 0f32)?
889                    .to_dtype(sample_dtype)?
890                    .contiguous()?;
891
892                let tril = Tensor::tril2(l, sample_dtype, device)?.unsqueeze(0)?;
893                let temp = tril
894                    .matmul(&permutation_matrix_reverse.transpose(1, 2)?)?
895                    .contiguous()?; // shape (b, i, q)
896                let order_mask_backward = temp
897                    .matmul(&permutation_matrix_reverse.transpose(1, 2)?)?
898                    .contiguous()?; // shape (b, q, p)
899                let mask_attend = order_mask_backward
900                    .gather(&e_idx, 2)?
901                    .unsqueeze(D::Minus1)?;
902
903                // Broadcast mask_1d to match mask_attend's shape
904                let mask_1d = mask
905                    .unwrap()
906                    .reshape((b, l, 1, 1))?
907                    .broadcast_as(mask_attend.shape())?
908                    .to_dtype(sample_dtype)?;
909
910                let mask_bw = mask_1d.mul(&mask_attend)?;
911                let mask_fw = mask_1d.mul(&(mask_attend - 1.0)?.neg()?)?;
912                (mask_fw, mask_bw, e_idx, decoding_order)
913            }
914        };
915
916        let s_true = s_true.repeat(&[b_decoder, 1])?;
917        let h_v = h_v.repeat(&[b_decoder, 1, 1])?;
918        let h_e = h_e.repeat(&[b_decoder, 1, 1, 1])?;
919        let mask = mask.as_ref().unwrap().repeat(&[b_decoder, 1])?;
920
921        let h_s = self.w_s.forward(&s_true)?; // embedding layer
922        let h_es = cat_neighbors_nodes(&h_s, &h_e, &e_idx)?;
923
924        // Build encoder embeddings
925        let h_ex_encoder = cat_neighbors_nodes(&Tensor::zeros_like(&h_s)?, &h_e, &e_idx)?;
926        let h_exv_encoder = cat_neighbors_nodes(&h_v, &h_ex_encoder, &e_idx)?;
927        let h_exv_encoder_fw = mask_fw
928            .broadcast_as(h_exv_encoder.shape())?
929            .to_dtype(h_exv_encoder.dtype())?
930            .mul(&h_exv_encoder)?;
931
932        // Apply decoder layers
933        let h_v = if !use_sequence {
934            // Simple forward pass through decoder layers
935            self.decoder_layers.iter().fold(Ok(h_v), |acc, layer| {
936                layer.forward(&acc?, &h_exv_encoder_fw, Some(&mask), None, None)
937            })?
938        } else {
939            // Forward pass with sequence-aware processing
940            self.decoder_layers.iter().fold(Ok(h_v), |acc, layer| {
941                let current_h_v = acc?;
942                let h_esv = cat_neighbors_nodes(&current_h_v, &h_es, &e_idx)?
943                    .mul(&mask_bw)?
944                    .add(&h_exv_encoder_fw)?;
945                layer.forward(&current_h_v, &h_esv, Some(&mask), None, None)
946            })?
947        };
948
949        let logits = self.w_out.forward(&h_v)?;
950        let log_probs = log_softmax(&logits, D::Minus1)?;
951
952        Ok(ScoreOutput {
953            s: s_true,
954            log_probs,
955            logits,
956            decoding_order,
957        })
958    }
959}