Skip to main content

ferritin_plms/ligandmpnn/
model.rs

1//! A message passing protein design neural network
2//! that samples sequences diffusing conditional probabilities.
3//!
4//! - See the [LigandMPNN Repo](https://github.com/dauparas/LigandMPNN)
5//!
6use super::configs::{ModelTypes, ProteinMPNNConfig};
7use super::proteinfeatures::ProteinFeatures;
8use super::proteinfeaturesmodel::ProteinFeaturesModel;
9use crate::featurize::utilities::{cat_neighbors_nodes, gather_nodes, int_to_aa1};
10use crate::types::PseudoProbability;
11use candle_core::safetensors;
12use candle_core::{D, DType, Device, IndexOp, Module, Result, Tensor};
13use candle_nn::encoding::one_hot;
14use candle_nn::ops::{log_softmax, softmax};
15use candle_nn::{Dropout, Embedding, LayerNorm, Linear, VarBuilder, embedding, layer_norm, linear};
16use candle_transformers::generation::LogitsProcessor;
17use std::collections::HashMap;
18
19// refactoring common fn
20fn concat_node_tensors(h_v: &Tensor, h_e: &Tensor, e_idx: &Tensor) -> Result<Tensor> {
21    let h_ev = cat_neighbors_nodes(h_v, h_e, e_idx)?;
22    let h_v_expand = h_v.unsqueeze(D::Minus2)?;
23    let expand_shape = [
24        h_ev.dims()[0],
25        h_ev.dims()[1],
26        h_ev.dims()[2],
27        h_v_expand.dims()[3],
28    ];
29    let h_v_expand = h_v_expand.expand(&expand_shape)?.to_dtype(h_ev.dtype())?;
30    Tensor::cat(&[&h_v_expand, &h_ev], D::Minus1)?.contiguous()
31}
32// refactoring common fn
33fn apply_dropout_and_norm(
34    input: &Tensor,
35    delta: &Tensor,
36    dropout: &Dropout,
37    norm: &LayerNorm,
38    training: bool,
39) -> Result<Tensor> {
40    let delta_dropout = dropout.forward(delta, training)?;
41    norm.forward(&(input + delta_dropout)?)
42}
43
44pub fn multinomial_sample(probs: &Tensor, temperature: f64, seed: u64) -> Result<Tensor> {
45    let mut logits_processor = LogitsProcessor::new(
46        seed,              // seed for reproducibility
47        Some(temperature), // temperature scaling
48        // None,              // top_p (nucleus sampling), we don't need this
49        Some(0.95), // top_p (nucleus sampling), we don't need this
50    );
51    let idx = logits_processor.sample(probs)?;
52    // println!("Selected index: {}", idx);
53    if idx >= 21 {
54        println!("WARNING: Invalid index {} selected", idx);
55    }
56    Tensor::new(&[idx], probs.device())
57}
58
59// Primary Return Object from the ProtMPNN Model
60#[derive(Clone, Debug)]
61pub struct ScoreOutput {
62    // Sequence
63    pub(crate) s: Tensor,
64    pub(crate) log_probs: Tensor,
65    pub(crate) logits: Tensor,
66    pub(crate) decoding_order: Tensor,
67}
68///  Score dims are [Batch, seqlength]
69impl ScoreOutput {
70    pub fn get_sequences(&self) -> Result<Vec<String>> {
71        let (b, l) = self.s.dims2()?;
72        let mut sequences = Vec::with_capacity(b);
73        for batch_idx in 0..b {
74            let batch = self.s.get(batch_idx)?;
75            let mut sequence = String::with_capacity(l);
76            for pos in 0..l {
77                let aa_idx = batch.get(pos)?.to_vec0::<u32>()?;
78                sequence.push(int_to_aa1(aa_idx));
79            }
80            sequences.push(sequence);
81        }
82        Ok(sequences)
83    }
84    pub fn get_decoding_order(&self) -> Result<Vec<u32>> {
85        let values = self.decoding_order.flatten_all()?.to_vec1::<u32>()?;
86        Ok(values)
87    }
88    pub fn get_log_probs(&self) -> &Tensor {
89        &self.log_probs
90    }
91    pub fn get_pseudo_probabilities(&self) -> Result<Vec<PseudoProbability>> {
92        let (batch_size, seq_len, _vocab_size) = self.logits.dims3()?;
93        let mut all_probabilities = Vec::with_capacity(batch_size * seq_len);
94
95        for batch_idx in 0..batch_size {
96            let batch_logits = self.logits.get(batch_idx)?;
97
98            for pos in 0..seq_len {
99                let pos_logits = batch_logits.get(pos)?;
100                let probs = softmax(&pos_logits, 0)?;
101                let probs = probs.to_vec1::<f32>()?;
102
103                // Get the decoding order to determine the actual position
104                let actual_pos = if let Ok(order) = self
105                    .decoding_order
106                    .get(batch_idx)?
107                    .get(pos)?
108                    .to_scalar::<u32>()
109                {
110                    order as usize
111                } else {
112                    pos
113                };
114
115                for aa_idx in 0..probs.len() {
116                    if probs[aa_idx] > 0.01 {
117                        // Only include probabilities above threshold
118                        all_probabilities.push(PseudoProbability {
119                            position: actual_pos,
120                            pseudo_prob: probs[aa_idx],
121                            amino_acid: int_to_aa1(aa_idx as u32),
122                        });
123                    }
124                }
125            }
126        }
127        // Sort by position and then by probability (descending)
128        all_probabilities.sort_by(|a, b| {
129            a.position.cmp(&b.position).then(
130                b.pseudo_prob
131                    .partial_cmp(&a.pseudo_prob)
132                    .unwrap_or(std::cmp::Ordering::Equal),
133            )
134        });
135        Ok(all_probabilities)
136    }
137    pub fn save_as_safetensors(&self, filename: String) -> Result<()> {
138        let mut tensors = HashMap::new();
139        tensors.insert("S".to_string(), self.s.clone());
140        tensors.insert("log_probs".to_string(), self.log_probs.clone());
141        tensors.insert("logits".to_string(), self.logits.clone());
142        tensors.insert("decoding_order".to_string(), self.decoding_order.clone());
143        // Create directory if it doesn't exist
144        if let Some(parent) = std::path::Path::new(&filename).parent() {
145            std::fs::create_dir_all(parent)?;
146        }
147        let _ = safetensors::save(&tensors, &filename);
148        Ok(())
149    }
150}
151
152#[derive(Clone, Debug)]
153struct PositionWiseFeedForward {
154    w1: Linear,
155    w2: Linear,
156}
157
158impl PositionWiseFeedForward {
159    fn new(vb: VarBuilder, dim_input: usize, dim_feedforward: usize) -> Result<Self> {
160        let w1 = linear::linear(dim_input, dim_feedforward, vb.pp("W_in"))?;
161        let w2 = linear::linear(dim_feedforward, dim_input, vb.pp("W_out"))?;
162        Ok(Self { w1, w2 })
163    }
164}
165
166impl Module for PositionWiseFeedForward {
167    fn forward(&self, x: &Tensor) -> Result<Tensor> {
168        self.w1.forward(x)?.gelu().and_then(|x| self.w2.forward(&x))
169    }
170}
171
172#[derive(Clone, Debug)]
173#[allow(dead_code)]
174pub struct EncLayer {
175    num_hidden: usize,
176    num_in: usize,
177    scale: f64,
178    dropout1: Dropout,
179    dropout2: Dropout,
180    dropout3: Dropout,
181    norm1: LayerNorm,
182    norm2: LayerNorm,
183    norm3: LayerNorm,
184    w1: Linear,
185    w2: Linear,
186    w3: Linear,
187    w11: Linear,
188    w12: Linear,
189    w13: Linear,
190    dense: PositionWiseFeedForward,
191}
192
193impl EncLayer {
194    pub fn load(vb: VarBuilder, config: &ProteinMPNNConfig, layer: i32) -> Result<Self> {
195        let vb = vb.pp(layer); // handle the layer number here.
196        let num_hidden = config.hidden_dim as usize;
197        let augment_eps = config.augment_eps as f64;
198        let num_in = (config.hidden_dim * 2) as usize;
199        let dropout_ratio = config.dropout_ratio;
200
201        // Create layer norms
202        let norm1 = layer_norm(num_hidden, augment_eps, vb.pp("norm1"))?;
203        let norm2 = layer_norm(num_hidden, augment_eps, vb.pp("norm2"))?;
204        let norm3 = layer_norm(num_hidden, augment_eps, vb.pp("norm3"))?;
205
206        // Create linear layers
207        let w1 = linear(num_hidden + num_in, num_hidden, vb.pp("W1"))?;
208        let w2 = linear(num_hidden, num_hidden, vb.pp("W2"))?;
209        let w3 = linear(num_hidden, num_hidden, vb.pp("W3"))?;
210        let w11 = linear(num_hidden + num_in, num_hidden, vb.pp("W11"))?;
211        let w12 = linear(num_hidden, num_hidden, vb.pp("W12"))?;
212        let w13 = linear(num_hidden, num_hidden, vb.pp("W13"))?;
213
214        // Create dropouts with same ratio
215        let dropout1 = Dropout::new(dropout_ratio);
216        let dropout2 = Dropout::new(dropout_ratio);
217        let dropout3 = Dropout::new(dropout_ratio);
218
219        let dense = PositionWiseFeedForward::new(vb.pp("dense"), num_hidden, num_hidden * 4)?;
220
221        Ok(Self {
222            num_hidden,
223            num_in,
224            scale: config.scale_factor,
225            dropout1,
226            dropout2,
227            dropout3,
228            norm1,
229            norm2,
230            norm3,
231            w1,
232            w2,
233            w3,
234            w11,
235            w12,
236            w13,
237            dense,
238        })
239    }
240    fn forward(
241        &self,
242        h_v: &Tensor,
243        h_e: &Tensor,
244        e_idx: &Tensor,
245        mask_v: Option<&Tensor>,
246        mask_attend: Option<&Tensor>,
247        training: Option<bool>,
248    ) -> Result<(Tensor, Tensor)> {
249        let training = training.unwrap_or(false);
250        let h_v = h_v.to_dtype(DType::F32)?;
251        let h_ev = concat_node_tensors(&h_v, h_e, e_idx)?;
252        let h_message = self
253            .w1
254            .forward(&h_ev)?
255            .gelu()?
256            .apply(&self.w2)?
257            .gelu()?
258            .apply(&self.w3)?;
259
260        let h_message = mask_attend
261            .map(|mask| mask.unsqueeze(D::Minus1)?.broadcast_mul(&h_message))
262            .transpose()?
263            .unwrap_or(h_message);
264
265        // Safe division with scale
266        let scale = if self.scale == 0.0 { 1.0 } else { self.scale };
267        let dh = (h_message.sum(D::Minus2)? / scale)?;
268
269        let h_v = apply_dropout_and_norm(&h_v, &dh, &self.dropout1, &self.norm1, training)?;
270        let dense_output = self.dense.forward(&h_v)?;
271        let h_v =
272            apply_dropout_and_norm(&h_v, &dense_output, &self.dropout2, &self.norm2, training)?;
273
274        // Apply mask if provided
275        let h_v = mask_v
276            .map(|mask| mask.unsqueeze(D::Minus1)?.broadcast_mul(&h_v))
277            .transpose()?
278            .unwrap_or(h_v);
279
280        let h_ev = concat_node_tensors(&h_v, h_e, e_idx)?;
281        let h_message = self
282            .w11
283            .forward(&h_ev)?
284            .gelu()?
285            .apply(&self.w12)?
286            .gelu()?
287            .apply(&self.w13)?;
288
289        let h_e = apply_dropout_and_norm(h_e, &h_message, &self.dropout3, &self.norm3, training)?;
290        Ok((h_v, h_e))
291    }
292}
293
294#[derive(Clone, Debug)]
295#[allow(dead_code)]
296pub struct DecLayer {
297    num_hidden: usize,
298    num_in: usize,
299    scale: f64,
300    dropout1: Dropout,
301    dropout2: Dropout,
302    norm1: LayerNorm,
303    norm2: LayerNorm,
304    w1: Linear,
305    w2: Linear,
306    w3: Linear,
307    dense: PositionWiseFeedForward,
308}
309
310impl DecLayer {
311    pub fn load(vb: VarBuilder, config: &ProteinMPNNConfig, layer: i32) -> Result<Self> {
312        let vb = vb.pp(layer); // handle the layer number here.
313        let num_hidden = config.hidden_dim as usize;
314        let augment_eps = config.augment_eps as f64;
315        let num_in = (config.hidden_dim * 3) as usize;
316        let dropout_ratio = config.dropout_ratio;
317
318        let norm1 = layer_norm::layer_norm(num_hidden, augment_eps, vb.pp("norm1"))?;
319        let norm2 = layer_norm::layer_norm(num_hidden, augment_eps, vb.pp("norm2"))?;
320
321        let w1 = linear::linear(num_hidden + num_in, num_hidden, vb.pp("W1"))?;
322        let w2 = linear::linear(num_hidden, num_hidden, vb.pp("W2"))?;
323        let w3 = linear::linear(num_hidden, num_hidden, vb.pp("W3"))?;
324        let dropout1 = Dropout::new(dropout_ratio);
325        let dropout2 = Dropout::new(dropout_ratio);
326
327        let dense = PositionWiseFeedForward::new(vb.pp("dense"), num_hidden, num_hidden * 4)?;
328
329        Ok(Self {
330            num_hidden,
331            num_in,
332            scale: config.scale_factor,
333            dropout1,
334            dropout2,
335            norm1,
336            norm2,
337            w1,
338            w2,
339            w3,
340            dense,
341        })
342    }
343    pub fn forward(
344        &self,
345        h_v: &Tensor,
346        h_e: &Tensor,
347        mask_v: Option<&Tensor>,
348        mask_attend: Option<&Tensor>,
349        training: Option<bool>,
350    ) -> Result<Tensor> {
351        let training_bool = training.unwrap_or(false);
352
353        // Expand node features to match edge dimensions
354        let expand_shape = [
355            h_e.dims()[0], // batch (1)
356            h_e.dims()[1], // sequence length (93)
357            h_e.dims()[2], // number of neighbors (24)
358            h_v.dims()[2], // keep original hidden dim (128)
359        ];
360
361        let h_v_expand = h_v.unsqueeze(D::Minus2)?.expand(&expand_shape)?;
362        let h_ev = Tensor::cat(&[&h_v_expand, h_e], D::Minus1)?.contiguous()?;
363
364        let h_message = self
365            .w1
366            .forward(&h_ev)?
367            .gelu()?
368            .apply(&self.w2)?
369            .gelu()?
370            .apply(&self.w3)?;
371
372        let h_message = self.dropout1.forward(&h_message, training_bool)?;
373
374        let h_message = mask_attend
375            .map(|mask| mask.unsqueeze(D::Minus1)?.broadcast_mul(&h_message))
376            .transpose()?
377            .unwrap_or(h_message);
378
379        let dh = (h_message.sum(D::Minus2)? / self.scale)?;
380        let h_v = self.norm1.forward(&(h_v + dh)?)?;
381        let dh = self.dense.forward(&h_v)?;
382        let dh_dropout = self.dropout2.forward(&dh, training_bool)?;
383        let h_v = self.norm2.forward(&(h_v + dh_dropout)?)?;
384
385        // Apply optional node mask
386        let h_v = mask_v
387            .map(|mask| mask.unsqueeze(D::Minus1)?.broadcast_mul(&h_v))
388            .transpose()?
389            .unwrap_or(h_v);
390
391        Ok(h_v)
392    }
393}
394
395/// ProteinMPNN Model
396/// - [link](https://github.com/dauparas/LigandMPNN/blob/main/model_utils.py#L10C7-L10C18)
397pub struct ProteinMPNN {
398    pub(crate) config: ProteinMPNNConfig,
399    pub(crate) decoder_layers: Vec<DecLayer>,
400    pub(crate) device: Device,
401    pub(crate) encoder_layers: Vec<EncLayer>,
402    pub(crate) features: ProteinFeaturesModel,
403    pub(crate) w_e: Linear,
404    pub(crate) w_out: Linear,
405    pub(crate) w_s: Embedding,
406}
407
408impl ProteinMPNN {
409    pub fn load(vb: VarBuilder, config: &ProteinMPNNConfig) -> Result<Self> {
410        let hidden_dim = config.hidden_dim as usize;
411        let edge_features = config.edge_features as usize;
412        let num_letters = config.num_letters as usize;
413        let vocab_size = config.vocab as usize;
414
415        // Create encoder and decoder layers using iterators
416        let encoder_layers = (0..config.num_encoder_layers)
417            .map(|i| EncLayer::load(vb.pp("encoder_layers"), config, i as i32))
418            .collect::<Result<Vec<_>>>()?;
419
420        let decoder_layers = (0..config.num_decoder_layers)
421            .map(|i| DecLayer::load(vb.pp("decoder_layers"), config, i as i32))
422            .collect::<Result<Vec<_>>>()?;
423
424        // Initialize weights
425        let w_e = linear::linear(edge_features, hidden_dim, vb.pp("W_e"))?;
426        let w_out = linear::linear(hidden_dim, num_letters, vb.pp("W_out"))?;
427        let w_s = embedding(vocab_size, hidden_dim, vb.pp("W_s"))?;
428        // Features
429        let features = ProteinFeaturesModel::load(vb.pp("features"), config.clone())?;
430
431        Ok(Self {
432            config: config.clone(),
433            decoder_layers,
434            device: vb.device().clone(),
435            encoder_layers,
436            features,
437            w_e,
438            w_out,
439            w_s,
440        })
441    }
442    pub fn encode(&self, features: &ProteinFeatures) -> Result<(Tensor, Tensor, Tensor)> {
443        let s_true = features.get_sequence();
444        let base_dtype = DType::F32;
445        let mask = match features.get_sequence_mask() {
446            Some(m) => m,
447            None => &Tensor::ones_like(s_true)?,
448        };
449        match self.config.model_type {
450            ModelTypes::ProteinMPNN => {
451                let (e, e_idx) = self.features.forward(features, &self.device)?;
452                let h_v = Tensor::zeros(
453                    (e.dim(0)?, e.dim(1)?, e.dim(D::Minus1)?),
454                    base_dtype,
455                    &self.device,
456                )?;
457                let h_e = self.w_e.forward(&e)?;
458                let mask_attend = if let Some(seq_mask) = features.get_sequence_mask() {
459                    let mask_expanded = seq_mask.unsqueeze(D::Minus1)?; // [B, L, 1]
460                    let mask_gathered = gather_nodes(&mask_expanded, &e_idx)?.squeeze(D::Minus1)?;
461                    let mask_unsqueezed = mask.unsqueeze(D::Minus1)?; // [B, L, 1]
462                    mask_unsqueezed
463                        .expand((
464                            mask_gathered.dim(0)?, // batch
465                            mask_gathered.dim(1)?, // sequence length
466                            mask_gathered.dim(2)?, // number of neighbors
467                        ))?
468                        .mul(&mask_gathered)?
469                } else {
470                    let (b, l) = mask.dims2()?;
471                    Tensor::ones((b, l, e_idx.dim(2)?), DType::F32, &self.device)?
472                };
473                println!("Beginning the Encoding...");
474                // todo: dtype handling not ideal
475                let mask_f32 = mask.to_dtype(base_dtype)?;
476                let mask_attend_f32 = mask_attend.to_dtype(base_dtype)?;
477
478                // Process through all encoder layers
479                let (h_v, h_e) =
480                    self.encoder_layers
481                        .iter()
482                        .fold(Ok((h_v, h_e)), |acc, layer| {
483                            let (h_v, h_e) = acc?;
484                            layer.forward(
485                                &h_v,
486                                &h_e,
487                                &e_idx,
488                                Some(&mask_f32),
489                                Some(&mask_attend_f32),
490                                Some(false),
491                            )
492                        })?;
493
494                Ok((h_v, h_e, e_idx))
495            }
496            ModelTypes::LigandMPNN => {
497                todo!()
498                //     let (v, e, e_idx, y_nodes, y_edges, y_m) = self.features.forward(feature_dict)?;
499                //     let mut h_v = Tensor::zeros((e.dim(0)?, e.dim(1)?, e.dim(-1)?), device)?;
500                //     let mut h_e = self.w_e.forward(&e)?;
501                //     let h_e_context = self.w_v.forward(&v)?;
502                //     let mask_attend = gather_nodes(&mask.unsqueeze(-1)?, &e_idx)?.squeeze(-1)?;
503                //     let mask_attend = mask.unsqueeze(-1)? * &mask_attend;
504                //
505                //     for layer in &self.encoder_layers {
506                //         let (new_h_v, new_h_e) =
507                //             layer.forward(&h_v, &h_e, &e_idx, &mask, &mask_attend)?;
508                //         h_v = new_h_v;
509                //         h_e = new_h_e;
510                //     }
511                //
512                //     let mut h_v_c = self.w_c.forward(&h_v)?;
513                //     let y_m_edges = &y_m.unsqueeze(-1)? * &y_m.unsqueeze(-2)?;
514                //     let mut y_nodes = self.w_nodes_y.forward(&y_nodes)?;
515                //     let y_edges = self.w_edges_y.forward(&y_edges)?;
516                //
517                //     for (y_layer, c_layer) in self
518                //         .y_context_encoder_layers
519                //         .iter()
520                //         .zip(&self.context_encoder_layers)
521                //     {
522                //         y_nodes = y_layer.forward(&y_nodes, &y_edges, &y_m, &y_m_edges)?;
523                //         let h_e_context_cat = Tensor::cat(&[&h_e_context, &y_nodes], -1)?;
524                //         h_v_c = c_layer.forward(&h_v_c, &h_e_context_cat, &mask, &y_m)?;
525                //     }
526                //     h_v_c = self.v_c.forward(&h_v_c)?;
527                //     h_v = &h_v + &self.v_c_norm.forward(&self.dropout.forward(&h_v_c)?)?;
528                //     Ok((h_v, h_e, e_idx))
529            }
530        }
531    }
532    // Removed unused decode methods
533    pub fn simple_decode(&self, features: &ProteinFeatures) -> Result<ScoreOutput> {
534        // Create a batch size of 1 for simple decoding
535        let b_decoder = 1;
536
537        // Extract relevant features
538        let ProteinFeatures { s, x_mask, .. } = features;
539        let device = s.device();
540        let (_, l) = s.dims2()?;
541
542        // Encode the structure once
543        let (h_v_enc, h_e_enc, e_idx_enc) = self.encode(features)?;
544
545        // Process all positions at once with a simplified approach
546        let s_true = s.clone();
547        let mask = x_mask.clone().unwrap();
548
549        // Create tensors for the decoder
550        let zeros = Tensor::zeros((b_decoder, l, h_v_enc.dim(D::Minus1)?), DType::F32, device)?;
551        let h_v = h_v_enc.clone();
552        let h_e = h_e_enc.clone();
553        let e_idx = e_idx_enc.clone();
554
555        // Build encoder embeddings for neighbors
556        let h_ex_encoder = cat_neighbors_nodes(&zeros, &h_e, &e_idx)?;
557        let h_exv_encoder = cat_neighbors_nodes(&h_v, &h_ex_encoder, &e_idx)?;
558
559        // Apply decoder layers using only structure information
560        let h_v_final = self.decoder_layers.iter().fold(Ok(h_v), |acc, layer| {
561            layer.forward(&acc?, &h_exv_encoder, Some(&mask), None, None)
562        })?;
563
564        // Calculate logits and log probabilities
565        let logits = self.w_out.forward(&h_v_final)?;
566        let log_probs = log_softmax(&logits, D::Minus1)?;
567
568        // For the decoding order, just use a placeholder
569        let decoding_order = Tensor::arange(0, l as i64, device)?
570            .reshape((1, l))?
571            .broadcast_as((b_decoder, l))?
572            .to_dtype(DType::F32)?;
573
574        // Return the output directly
575        Ok(ScoreOutput {
576            s: s_true,
577            log_probs,
578            logits,
579            decoding_order,
580        })
581    }
582    pub fn sample(
583        &self,
584        features: &ProteinFeatures,
585        temperature: f64,
586        seed: u64,
587    ) -> Result<ScoreOutput> {
588        let sample_dtype = DType::F32;
589        let ProteinFeatures {
590            s,
591            x_mask,
592            // symmetry_residues,
593            // symmetry_weights,
594            ..
595        } = features;
596        let s_true = s.to_dtype(sample_dtype)?;
597        let device = s.device();
598        let (b, l) = s.dims2()?;
599        // Todo: This is a hack. we should be passing in encoded chains.
600        // let chain_mask = Tensor::ones_like(&x_mask.as_ref().unwrap())?.to_dtype(sample_dtype)?;
601        // let chain_mask = x_mask.as_ref().unwrap().mul(&chain_mask)?;
602        let chain_mask = x_mask.as_ref().unwrap().to_dtype(sample_dtype)?;
603        let (h_v, h_e, e_idx) = self.encode(features)?;
604        let rand_tensor = Tensor::randn(0f32, 0.25f32, (b, l), device)?.to_dtype(sample_dtype)?;
605        let decoding_order = (&chain_mask + 0.0001)?
606            .mul(&rand_tensor.abs()?)?
607            .arg_sort_last_dim(false)?;
608        // TodoL add  bias
609        // # [B,L,21] - amino acid bias per position
610        let bias = Tensor::ones((b, l, 21), sample_dtype, device)?;
611        let symmetry_residues: Option<Vec<i32>> = None;
612        match symmetry_residues {
613            None => {
614                let e_idx = e_idx.repeat(&[b, 1, 1])?;
615                let permutation_matrix_reverse = one_hot(decoding_order.clone(), l, 1f32, 0f32)?
616                    .to_dtype(sample_dtype)?
617                    .contiguous()?;
618                let tril = Tensor::tril2(l, sample_dtype, device)?;
619                let tril = tril.unsqueeze(0)?;
620                let temp = tril
621                    .matmul(&permutation_matrix_reverse.transpose(1, 2)?)?
622                    .contiguous()?; //tensor of shape (b, i, q)
623                let order_mask_backward = temp
624                    .matmul(&permutation_matrix_reverse.transpose(1, 2)?)?
625                    .contiguous()?; // This will give us a tensor of shape (b, q, p)
626                let mask_attend = order_mask_backward
627                    .gather(&e_idx, 2)?
628                    .unsqueeze(D::Minus1)?;
629                let mask_1d = x_mask.as_ref().unwrap().reshape((b, l, 1, 1))?;
630                // Broadcast mask_1d to match mask_attend's shape
631                let mask_1d = mask_1d
632                    .broadcast_as(mask_attend.shape())?
633                    .to_dtype(sample_dtype)?;
634                let mask_bw = mask_1d.mul(&mask_attend)?;
635                let mask_fw = mask_1d.mul(&(Tensor::ones_like(&mask_attend)? - mask_attend)?)?;
636                // Note: `sample` begins to diverge from the `score` here.
637                // repeat for decoding
638                let s_true = s_true.repeat((b, 1))?;
639                let h_v = h_v.repeat((b, 1, 1))?;
640                let h_e = h_e.repeat((b, 1, 1, 1))?;
641                let mask = x_mask.as_ref().unwrap().repeat((b, 1))?.contiguous()?;
642                let chain_mask = &chain_mask.repeat((b, 1))?;
643                let bias = bias.repeat((b, 1, 1))?;
644                let mut all_probs = Tensor::zeros((b, l, 20), sample_dtype, device)?;
645                // why is this one 21 and the others are 20?
646                let mut all_log_probs = Tensor::zeros((b, l, 21), sample_dtype, device)?;
647                let mut h_s = Tensor::zeros_like(&h_v)?;
648                // note: we this value of 20 is `X`. We will need to replace the values below, not add them
649                let mut s = Tensor::full(20u32, (b, l), device)?;
650                let mut h_v_stack = vec![h_v.clone()];
651
652                for _ in 0..self.decoder_layers.len() {
653                    let zeros = Tensor::zeros_like(&h_v)?;
654                    h_v_stack.push(zeros);
655                }
656                let h_ex_encoder = cat_neighbors_nodes(&Tensor::zeros_like(&h_s)?, &h_e, &e_idx)?;
657                let h_exv_encoder = cat_neighbors_nodes(&h_v, &h_ex_encoder, &e_idx)?;
658                let mask_fw = mask_fw
659                    .broadcast_as(h_exv_encoder.shape())?
660                    .to_dtype(h_exv_encoder.dtype())?;
661                let h_exv_encoder_fw = mask_fw.mul(&h_exv_encoder)?;
662                for t_ in 0..l {
663                    let t = decoding_order.i((.., t_))?;
664                    let t_gather = t.unsqueeze(1)?; // Shape [B, 1]
665                    // Gather masks and bias
666                    let chain_mask_t = chain_mask.gather(&t_gather, 1)?.squeeze(1)?;
667                    let mask_t = mask.gather(&t_gather, 1)?.squeeze(1)?.contiguous()?;
668                    let bias_t = bias
669                        .gather(&t_gather.unsqueeze(2)?.expand((b, 1, 21))?.contiguous()?, 1)?
670                        .squeeze(1)?;
671                    // Gather edge and node indices/features
672                    let e_idx_t = e_idx
673                        .gather(
674                            &t_gather
675                                .unsqueeze(2)?
676                                .expand((b, 1, e_idx.dim(2)?))?
677                                .contiguous()?,
678                            1,
679                        )?
680                        .contiguous()?;
681                    let h_e_t = h_e.gather(
682                        &t_gather
683                            .unsqueeze(2)?
684                            .unsqueeze(3)?
685                            .expand((b, 1, h_e.dim(2)?, h_e.dim(3)?))?
686                            .contiguous()?,
687                        1,
688                    )?;
689                    let n = e_idx_t.dim(2)?; // number of neighbors
690                    let c = h_s.dim(2)?; // channels/features
691                    let h_e_t = h_e_t
692                        .squeeze(1)? // [B, N, C]
693                        .unsqueeze(1)? // [B, 1, N, C]
694                        .expand((b, l, n, c))? // [B, L, N, C]
695                        .contiguous()?;
696                    let e_idx_t = e_idx_t
697                        .expand((b, l, n))? // [B, L, N]
698                        .contiguous()?;
699                    let h_es_t = cat_neighbors_nodes(&h_s, &h_e_t, &e_idx_t)?;
700                    let h_exv_encoder_t = h_exv_encoder_fw.gather(
701                        &t_gather
702                            .unsqueeze(2)?
703                            .unsqueeze(3)?
704                            .expand((b, 1, h_exv_encoder_fw.dim(2)?, h_exv_encoder_fw.dim(3)?))?
705                            .contiguous()?,
706                        1,
707                    )?;
708                    let mask_bw_t = mask_bw.gather(
709                        &t_gather
710                            .unsqueeze(2)?
711                            .unsqueeze(3)?
712                            .expand((b, 1, mask_bw.dim(2)?, mask_bw.dim(3)?))?
713                            .contiguous()?,
714                        1,
715                    )?;
716
717                    // Decoder layers loop
718                    for l in 0..self.decoder_layers.len() {
719                        let h_v_stack_l = &h_v_stack[l];
720                        let h_esv_decoder_t = cat_neighbors_nodes(h_v_stack_l, &h_es_t, &e_idx_t)?;
721                        let h_v_t = h_v_stack_l.gather(
722                            &t_gather
723                                .unsqueeze(2)?
724                                .expand((b, 1, h_v_stack_l.dim(2)?))?
725                                .contiguous()?,
726                            1,
727                        )?;
728                        let mask_bw_t = mask_bw_t.expand(h_esv_decoder_t.dims())?.contiguous()?;
729                        let h_exv_encoder_t = h_exv_encoder_t
730                            .expand(h_esv_decoder_t.dims())?
731                            .contiguous()?
732                            .to_dtype(sample_dtype)?;
733                        let h_esv_t = mask_bw_t
734                            .mul(&h_esv_decoder_t.to_dtype(sample_dtype)?)?
735                            .add(&h_exv_encoder_t)?
736                            .to_dtype(sample_dtype)?
737                            .contiguous()?;
738                        let h_v_t = h_v_t
739                            .expand((
740                                h_esv_t.dim(0)?, // batch size
741                                h_esv_t.dim(1)?, // sequence length (93)
742                                h_v_t.dim(2)?,   // features (128)
743                            ))?
744                            .contiguous()?;
745                        let decoder_output = self.decoder_layers[l].forward(
746                            &h_v_t,
747                            &h_esv_t,
748                            Some(&mask_t),
749                            None,
750                            None,
751                        )?;
752                        let t_expanded = t_gather.reshape(&[b])?; // This will give us a 1D tensor of shape [b]
753                        let decoder_output = decoder_output
754                            .narrow(1, 0, 1)?
755                            .squeeze(1)? // Now [1, 128]
756                            .unsqueeze(1)?; // Now [1, 1, 128] - same rank as target
757                        h_v_stack[l + 1] =
758                            h_v_stack[l + 1].index_add(&t_expanded, &decoder_output, 1)?;
759                        // h_v_stack[l + 1] =
760                        //     h_v_stack[l + 1].index_add(&t_expanded, &decoder_output, 1)?;
761                    }
762                    let h_v_t = h_v_stack
763                        .last()
764                        .unwrap()
765                        .gather(
766                            &t_gather
767                                .unsqueeze(2)?
768                                .expand((b, 1, h_v_stack.last().unwrap().dim(2)?))?
769                                .contiguous()?,
770                            1,
771                        )?
772                        .squeeze(1)?;
773                    // Generate logits and probabilities
774                    let logits = self.w_out.forward(&h_v_t)?;
775                    let log_probs = log_softmax(&logits, D::Minus1)?;
776
777                    // explicit for OoO
778                    let probs = {
779                        let biased_logits = logits.add(&bias_t)?; // (logits + bias_t)
780                        let scaled_logits = (biased_logits / temperature)?; // (logits + bias_t) / temperature
781                        softmax(&scaled_logits, D::Minus1)? // softmax((logits + bias_t) / temperature)
782                    };
783
784                    let probs_sample = probs
785                        .narrow(1, 0, 20)?
786                        .div(&probs.narrow(1, 0, 20)?.sum_keepdim(1)?.expand((b, 20))?)?;
787                    // Sample new token
788                    let sum = probs_sample.sum(1)?;
789                    let probs_sample_1d = probs_sample
790                        .squeeze(0)? // Remove batch dimension -> [20]
791                        .clamp(1e-10, 1.0)?
792                        .broadcast_div(&sum)?
793                        .contiguous()?;
794
795                    let s_t = multinomial_sample(&probs_sample_1d, temperature, seed)?;
796                    let s_t = s_t.to_dtype(sample_dtype)?;
797                    let s_true = s_true.to_dtype(sample_dtype)?;
798                    let s_true_t = s_true.gather(&t_gather, 1)?.squeeze(1)?;
799                    let s_t = s_t
800                        .mul(&chain_mask_t)?
801                        .add(&s_true_t.mul(&(&chain_mask_t.neg()? + 1.0)?)?)?
802                        .to_dtype(DType::U32)?;
803
804                    let s_t_idx = s_t.to_dtype(DType::U32)?;
805                    let s_t_idx = s_t_idx.reshape(&[s_t_idx.dim(0)?])?;
806                    let h_s_update = self.w_s.forward(&s_t_idx)?.unsqueeze(1)?;
807                    let t_gather_expanded = t_gather.reshape(&[b])?;
808                    let h_s_update = h_s_update.squeeze(0)?.unsqueeze(1)?;
809                    h_s =
810                        h_s.index_add(&t_gather_expanded, &Tensor::zeros_like(&h_s_update)?, 1)?;
811                    h_s = h_s.index_add(&t_gather_expanded, &h_s_update, 1)?;
812
813                    s = {
814                        let dim = 1;
815                        let start = t_gather.squeeze(0)?.squeeze(0)?.to_scalar::<u32>()? as usize;
816                        let s_t_expanded = s_t.unsqueeze(1)?;
817                        s.slice_scatter(&s_t_expanded, dim, start)?
818                    };
819
820                    let probs_update = chain_mask_t
821                        .unsqueeze(1)?
822                        .unsqueeze(2)?
823                        .expand((b, 1, 20))?
824                        .mul(&probs_sample.unsqueeze(1)?)?;
825                    let t_expanded = t_gather.reshape(&[b])?;
826                    let probs_update = probs_update
827                        .squeeze(1)? // Remove extra dimension
828                        .unsqueeze(1)?;
829                    all_probs =
830                        all_probs.index_add(&t_expanded, &Tensor::zeros_like(&probs_update)?, 1)?;
831                    all_probs = all_probs.index_add(&t_expanded, &probs_update, 1)?;
832                    let log_probs_update = chain_mask_t
833                        .unsqueeze(1)?
834                        .unsqueeze(2)?
835                        .expand((b, 1, 21))?
836                        .mul(&log_probs.unsqueeze(1)?)?
837                        .squeeze(1)?
838                        .unsqueeze(1)?;
839
840                    all_log_probs = all_log_probs.index_add(
841                        &t_expanded,
842                        &Tensor::zeros_like(&log_probs_update)?,
843                        1,
844                    )?;
845                    all_log_probs = all_log_probs.index_add(&t_expanded, &log_probs_update, 1)?;
846                }
847                Ok(ScoreOutput {
848                    s,
849                    log_probs: all_probs,
850                    logits: all_log_probs,
851                    decoding_order,
852                })
853            }
854            Some(_symmetry_residues) => {
855                todo!()
856            }
857        }
858    }
859
860    pub fn score(&self, features: &ProteinFeatures, use_sequence: bool) -> Result<ScoreOutput> {
861        let ProteinFeatures { s, x_mask, .. } = &features;
862        let sample_dtype = DType::F32;
863        let s_true = &s.clone();
864        let device = s_true.device();
865        let (b, l) = s_true.dims2()?;
866        let mask = &x_mask.as_ref().clone();
867        let b_decoder: usize = b;
868
869        // Todo: This is a hack. we should be passing in encoded chains.
870        // Update chain_mask to include missing regions
871        let chain_mask = Tensor::zeros_like(mask.unwrap())?.to_dtype(sample_dtype)?;
872        let chain_mask = mask.unwrap().mul(&chain_mask)?; // does the order count here?
873
874        // encode ...
875        let (h_v, h_e, e_idx) = self.encode(features)?;
876        let rand_tensor = Tensor::randn(0f32, 1f32, (b, l), device)?.to_dtype(sample_dtype)?;
877        // Compute decoding order
878        let decoding_order = (chain_mask + 0.001)?
879            .mul(&rand_tensor.abs()?)?
880            .arg_sort_last_dim(false)?;
881
882        let symmetry_residues: Option<Vec<i32>> = None;
883
884        let (mask_fw, mask_bw, e_idx, decoding_order) = match symmetry_residues {
885            Some(_symmetry_residues) => {
886                todo!();
887            }
888            None => {
889                let e_idx = e_idx.repeat(&[b_decoder, 1, 1])?;
890                let permutation_matrix_reverse = one_hot(decoding_order.clone(), l, 1f32, 0f32)?
891                    .to_dtype(sample_dtype)?
892                    .contiguous()?;
893
894                let tril = Tensor::tril2(l, sample_dtype, device)?.unsqueeze(0)?;
895                let temp = tril
896                    .matmul(&permutation_matrix_reverse.transpose(1, 2)?)?
897                    .contiguous()?; // shape (b, i, q)
898                let order_mask_backward = temp
899                    .matmul(&permutation_matrix_reverse.transpose(1, 2)?)?
900                    .contiguous()?; // shape (b, q, p)
901                let mask_attend = order_mask_backward
902                    .gather(&e_idx, 2)?
903                    .unsqueeze(D::Minus1)?;
904
905                // Broadcast mask_1d to match mask_attend's shape
906                let mask_1d = mask
907                    .unwrap()
908                    .reshape((b, l, 1, 1))?
909                    .broadcast_as(mask_attend.shape())?
910                    .to_dtype(sample_dtype)?;
911
912                let mask_bw = mask_1d.mul(&mask_attend)?;
913                let mask_fw = mask_1d.mul(&(mask_attend - 1.0)?.neg()?)?;
914                (mask_fw, mask_bw, e_idx, decoding_order)
915            }
916        };
917
918        let s_true = s_true.repeat(&[b_decoder, 1])?;
919        let h_v = h_v.repeat(&[b_decoder, 1, 1])?;
920        let h_e = h_e.repeat(&[b_decoder, 1, 1, 1])?;
921        let mask = mask.as_ref().unwrap().repeat(&[b_decoder, 1])?;
922
923        let h_s = self.w_s.forward(&s_true)?; // embedding layer
924        let h_es = cat_neighbors_nodes(&h_s, &h_e, &e_idx)?;
925
926        // Build encoder embeddings
927        let h_ex_encoder = cat_neighbors_nodes(&Tensor::zeros_like(&h_s)?, &h_e, &e_idx)?;
928        let h_exv_encoder = cat_neighbors_nodes(&h_v, &h_ex_encoder, &e_idx)?;
929        let h_exv_encoder_fw = mask_fw
930            .broadcast_as(h_exv_encoder.shape())?
931            .to_dtype(h_exv_encoder.dtype())?
932            .mul(&h_exv_encoder)?;
933
934        // Apply decoder layers
935        let h_v = if !use_sequence {
936            // Simple forward pass through decoder layers
937            self.decoder_layers.iter().fold(Ok(h_v), |acc, layer| {
938                layer.forward(&acc?, &h_exv_encoder_fw, Some(&mask), None, None)
939            })?
940        } else {
941            // Forward pass with sequence-aware processing
942            self.decoder_layers.iter().fold(Ok(h_v), |acc, layer| {
943                let current_h_v = acc?;
944                let h_esv = cat_neighbors_nodes(&current_h_v, &h_es, &e_idx)?
945                    .mul(&mask_bw)?
946                    .add(&h_exv_encoder_fw)?;
947                layer.forward(&current_h_v, &h_esv, Some(&mask), None, None)
948            })?
949        };
950
951        let logits = self.w_out.forward(&h_v)?;
952        let log_probs = log_softmax(&logits, D::Minus1)?;
953
954        Ok(ScoreOutput {
955            s: s_true,
956            log_probs,
957            logits,
958            decoding_order,
959        })
960    }
961}