1use super::configs::{ModelTypes, ProteinMPNNConfig};
7use super::proteinfeatures::ProteinFeatures;
8use super::proteinfeaturesmodel::ProteinFeaturesModel;
9use crate::featurize::utilities::{cat_neighbors_nodes, gather_nodes, int_to_aa1};
10use crate::types::PseudoProbability;
11use candle_core::safetensors;
12use candle_core::{D, DType, Device, IndexOp, Module, Result, Tensor};
13use candle_nn::encoding::one_hot;
14use candle_nn::ops::{log_softmax, softmax};
15use candle_nn::{Dropout, Embedding, LayerNorm, Linear, VarBuilder, embedding, layer_norm, linear};
16use candle_transformers::generation::LogitsProcessor;
17use std::collections::HashMap;
18
19fn concat_node_tensors(h_v: &Tensor, h_e: &Tensor, e_idx: &Tensor) -> Result<Tensor> {
21 let h_ev = cat_neighbors_nodes(h_v, h_e, e_idx)?;
22 let h_v_expand = h_v.unsqueeze(D::Minus2)?;
23 let expand_shape = [
24 h_ev.dims()[0],
25 h_ev.dims()[1],
26 h_ev.dims()[2],
27 h_v_expand.dims()[3],
28 ];
29 let h_v_expand = h_v_expand.expand(&expand_shape)?.to_dtype(h_ev.dtype())?;
30 Tensor::cat(&[&h_v_expand, &h_ev], D::Minus1)?.contiguous()
31}
32fn apply_dropout_and_norm(
34 input: &Tensor,
35 delta: &Tensor,
36 dropout: &Dropout,
37 norm: &LayerNorm,
38 training: bool,
39) -> Result<Tensor> {
40 let delta_dropout = dropout.forward(delta, training)?;
41 norm.forward(&(input + delta_dropout)?)
42}
43
44pub fn multinomial_sample(probs: &Tensor, temperature: f64, seed: u64) -> Result<Tensor> {
45 let mut logits_processor = LogitsProcessor::new(
46 seed, Some(temperature), Some(0.95), );
51 let idx = logits_processor.sample(probs)?;
52 if idx >= 21 {
54 println!("WARNING: Invalid index {} selected", idx);
55 }
56 Tensor::new(&[idx], probs.device())
57}
58
59#[derive(Clone, Debug)]
61pub struct ScoreOutput {
62 pub(crate) s: Tensor,
64 pub(crate) log_probs: Tensor,
65 pub(crate) logits: Tensor,
66 pub(crate) decoding_order: Tensor,
67}
68impl ScoreOutput {
70 pub fn get_sequences(&self) -> Result<Vec<String>> {
71 let (b, l) = self.s.dims2()?;
72 let mut sequences = Vec::with_capacity(b);
73 for batch_idx in 0..b {
74 let batch = self.s.get(batch_idx)?;
75 let mut sequence = String::with_capacity(l);
76 for pos in 0..l {
77 let aa_idx = batch.get(pos)?.to_vec0::<u32>()?;
78 sequence.push(int_to_aa1(aa_idx));
79 }
80 sequences.push(sequence);
81 }
82 Ok(sequences)
83 }
84 pub fn get_decoding_order(&self) -> Result<Vec<u32>> {
85 let values = self.decoding_order.flatten_all()?.to_vec1::<u32>()?;
86 Ok(values)
87 }
88 pub fn get_log_probs(&self) -> &Tensor {
89 &self.log_probs
90 }
91 pub fn get_pseudo_probabilities(&self) -> Result<Vec<PseudoProbability>> {
92 let (batch_size, seq_len, _vocab_size) = self.logits.dims3()?;
93 let mut all_probabilities = Vec::with_capacity(batch_size * seq_len);
94
95 for batch_idx in 0..batch_size {
96 let batch_logits = self.logits.get(batch_idx)?;
97
98 for pos in 0..seq_len {
99 let pos_logits = batch_logits.get(pos)?;
100 let probs = softmax(&pos_logits, 0)?;
101 let probs = probs.to_vec1::<f32>()?;
102
103 let actual_pos = if let Ok(order) = self
105 .decoding_order
106 .get(batch_idx)?
107 .get(pos)?
108 .to_scalar::<u32>()
109 {
110 order as usize
111 } else {
112 pos
113 };
114
115 for aa_idx in 0..probs.len() {
116 if probs[aa_idx] > 0.01 {
117 all_probabilities.push(PseudoProbability {
119 position: actual_pos,
120 pseudo_prob: probs[aa_idx],
121 amino_acid: int_to_aa1(aa_idx as u32),
122 });
123 }
124 }
125 }
126 }
127 all_probabilities.sort_by(|a, b| {
129 a.position.cmp(&b.position).then(
130 b.pseudo_prob
131 .partial_cmp(&a.pseudo_prob)
132 .unwrap_or(std::cmp::Ordering::Equal),
133 )
134 });
135 Ok(all_probabilities)
136 }
137 pub fn save_as_safetensors(&self, filename: String) -> Result<()> {
138 let mut tensors = HashMap::new();
139 tensors.insert("S".to_string(), self.s.clone());
140 tensors.insert("log_probs".to_string(), self.log_probs.clone());
141 tensors.insert("logits".to_string(), self.logits.clone());
142 tensors.insert("decoding_order".to_string(), self.decoding_order.clone());
143 if let Some(parent) = std::path::Path::new(&filename).parent() {
145 std::fs::create_dir_all(parent)?;
146 }
147 let _ = safetensors::save(&tensors, &filename);
148 Ok(())
149 }
150}
151
152#[derive(Clone, Debug)]
153struct PositionWiseFeedForward {
154 w1: Linear,
155 w2: Linear,
156}
157
158impl PositionWiseFeedForward {
159 fn new(vb: VarBuilder, dim_input: usize, dim_feedforward: usize) -> Result<Self> {
160 let w1 = linear::linear(dim_input, dim_feedforward, vb.pp("W_in"))?;
161 let w2 = linear::linear(dim_feedforward, dim_input, vb.pp("W_out"))?;
162 Ok(Self { w1, w2 })
163 }
164}
165
166impl Module for PositionWiseFeedForward {
167 fn forward(&self, x: &Tensor) -> Result<Tensor> {
168 self.w1.forward(x)?.gelu().and_then(|x| self.w2.forward(&x))
169 }
170}
171
172#[derive(Clone, Debug)]
173pub struct EncLayer {
174 num_hidden: usize,
175 num_in: usize,
176 scale: f64,
177 dropout1: Dropout,
178 dropout2: Dropout,
179 dropout3: Dropout,
180 norm1: LayerNorm,
181 norm2: LayerNorm,
182 norm3: LayerNorm,
183 w1: Linear,
184 w2: Linear,
185 w3: Linear,
186 w11: Linear,
187 w12: Linear,
188 w13: Linear,
189 dense: PositionWiseFeedForward,
190}
191
192impl EncLayer {
193 pub fn load(vb: VarBuilder, config: &ProteinMPNNConfig, layer: i32) -> Result<Self> {
194 let vb = vb.pp(layer); let num_hidden = config.hidden_dim as usize;
196 let augment_eps = config.augment_eps as f64;
197 let num_in = (config.hidden_dim * 2) as usize;
198 let dropout_ratio = config.dropout_ratio;
199
200 let norm1 = layer_norm(num_hidden, augment_eps, vb.pp("norm1"))?;
202 let norm2 = layer_norm(num_hidden, augment_eps, vb.pp("norm2"))?;
203 let norm3 = layer_norm(num_hidden, augment_eps, vb.pp("norm3"))?;
204
205 let w1 = linear(num_hidden + num_in, num_hidden, vb.pp("W1"))?;
207 let w2 = linear(num_hidden, num_hidden, vb.pp("W2"))?;
208 let w3 = linear(num_hidden, num_hidden, vb.pp("W3"))?;
209 let w11 = linear(num_hidden + num_in, num_hidden, vb.pp("W11"))?;
210 let w12 = linear(num_hidden, num_hidden, vb.pp("W12"))?;
211 let w13 = linear(num_hidden, num_hidden, vb.pp("W13"))?;
212
213 let dropout1 = Dropout::new(dropout_ratio);
215 let dropout2 = Dropout::new(dropout_ratio);
216 let dropout3 = Dropout::new(dropout_ratio);
217
218 let dense = PositionWiseFeedForward::new(vb.pp("dense"), num_hidden, num_hidden * 4)?;
219
220 Ok(Self {
221 num_hidden,
222 num_in,
223 scale: config.scale_factor,
224 dropout1,
225 dropout2,
226 dropout3,
227 norm1,
228 norm2,
229 norm3,
230 w1,
231 w2,
232 w3,
233 w11,
234 w12,
235 w13,
236 dense,
237 })
238 }
239 fn forward(
240 &self,
241 h_v: &Tensor,
242 h_e: &Tensor,
243 e_idx: &Tensor,
244 mask_v: Option<&Tensor>,
245 mask_attend: Option<&Tensor>,
246 training: Option<bool>,
247 ) -> Result<(Tensor, Tensor)> {
248 let training = training.unwrap_or(false);
249 let h_v = h_v.to_dtype(DType::F32)?;
250 let h_ev = concat_node_tensors(&h_v, h_e, e_idx)?;
251 let h_message = self
252 .w1
253 .forward(&h_ev)?
254 .gelu()?
255 .apply(&self.w2)?
256 .gelu()?
257 .apply(&self.w3)?;
258
259 let h_message = mask_attend
260 .map(|mask| mask.unsqueeze(D::Minus1)?.broadcast_mul(&h_message))
261 .transpose()?
262 .unwrap_or(h_message);
263
264 let scale = if self.scale == 0.0 { 1.0 } else { self.scale };
266 let dh = (h_message.sum(D::Minus2)? / scale)?;
267
268 let h_v = apply_dropout_and_norm(&h_v, &dh, &self.dropout1, &self.norm1, training)?;
269 let dense_output = self.dense.forward(&h_v)?;
270 let h_v =
271 apply_dropout_and_norm(&h_v, &dense_output, &self.dropout2, &self.norm2, training)?;
272
273 let h_v = mask_v
275 .map(|mask| mask.unsqueeze(D::Minus1)?.broadcast_mul(&h_v))
276 .transpose()?
277 .unwrap_or(h_v);
278
279 let h_ev = concat_node_tensors(&h_v, h_e, e_idx)?;
280 let h_message = self
281 .w11
282 .forward(&h_ev)?
283 .gelu()?
284 .apply(&self.w12)?
285 .gelu()?
286 .apply(&self.w13)?;
287
288 let h_e = apply_dropout_and_norm(h_e, &h_message, &self.dropout3, &self.norm3, training)?;
289 Ok((h_v, h_e))
290 }
291}
292
293#[derive(Clone, Debug)]
294pub struct DecLayer {
295 num_hidden: usize,
296 num_in: usize,
297 scale: f64,
298 dropout1: Dropout,
299 dropout2: Dropout,
300 norm1: LayerNorm,
301 norm2: LayerNorm,
302 w1: Linear,
303 w2: Linear,
304 w3: Linear,
305 dense: PositionWiseFeedForward,
306}
307
308impl DecLayer {
309 pub fn load(vb: VarBuilder, config: &ProteinMPNNConfig, layer: i32) -> Result<Self> {
310 let vb = vb.pp(layer); let num_hidden = config.hidden_dim as usize;
312 let augment_eps = config.augment_eps as f64;
313 let num_in = (config.hidden_dim * 3) as usize;
314 let dropout_ratio = config.dropout_ratio;
315
316 let norm1 = layer_norm::layer_norm(num_hidden, augment_eps, vb.pp("norm1"))?;
317 let norm2 = layer_norm::layer_norm(num_hidden, augment_eps, vb.pp("norm2"))?;
318
319 let w1 = linear::linear(num_hidden + num_in, num_hidden, vb.pp("W1"))?;
320 let w2 = linear::linear(num_hidden, num_hidden, vb.pp("W2"))?;
321 let w3 = linear::linear(num_hidden, num_hidden, vb.pp("W3"))?;
322 let dropout1 = Dropout::new(dropout_ratio);
323 let dropout2 = Dropout::new(dropout_ratio);
324
325 let dense = PositionWiseFeedForward::new(vb.pp("dense"), num_hidden, num_hidden * 4)?;
326
327 Ok(Self {
328 num_hidden,
329 num_in,
330 scale: config.scale_factor,
331 dropout1,
332 dropout2,
333 norm1,
334 norm2,
335 w1,
336 w2,
337 w3,
338 dense,
339 })
340 }
341 pub fn forward(
342 &self,
343 h_v: &Tensor,
344 h_e: &Tensor,
345 mask_v: Option<&Tensor>,
346 mask_attend: Option<&Tensor>,
347 training: Option<bool>,
348 ) -> Result<Tensor> {
349 let training_bool = training.unwrap_or(false);
350
351 let expand_shape = [
353 h_e.dims()[0], h_e.dims()[1], h_e.dims()[2], h_v.dims()[2], ];
358
359 let h_v_expand = h_v.unsqueeze(D::Minus2)?.expand(&expand_shape)?;
360 let h_ev = Tensor::cat(&[&h_v_expand, h_e], D::Minus1)?.contiguous()?;
361
362 let h_message = self
363 .w1
364 .forward(&h_ev)?
365 .gelu()?
366 .apply(&self.w2)?
367 .gelu()?
368 .apply(&self.w3)?;
369
370 let h_message = self.dropout1.forward(&h_message, training_bool)?;
371
372 let h_message = mask_attend
373 .map(|mask| mask.unsqueeze(D::Minus1)?.broadcast_mul(&h_message))
374 .transpose()?
375 .unwrap_or(h_message);
376
377 let dh = (h_message.sum(D::Minus2)? / self.scale)?;
378 let h_v = self.norm1.forward(&(h_v + dh)?)?;
379 let dh = self.dense.forward(&h_v)?;
380 let dh_dropout = self.dropout2.forward(&dh, training_bool)?;
381 let h_v = self.norm2.forward(&(h_v + dh_dropout)?)?;
382
383 let h_v = mask_v
385 .map(|mask| mask.unsqueeze(D::Minus1)?.broadcast_mul(&h_v))
386 .transpose()?
387 .unwrap_or(h_v);
388
389 Ok(h_v)
390 }
391}
392
393pub struct ProteinMPNN {
396 pub(crate) config: ProteinMPNNConfig,
397 pub(crate) decoder_layers: Vec<DecLayer>,
398 pub(crate) device: Device,
399 pub(crate) encoder_layers: Vec<EncLayer>,
400 pub(crate) features: ProteinFeaturesModel,
401 pub(crate) w_e: Linear,
402 pub(crate) w_out: Linear,
403 pub(crate) w_s: Embedding,
404}
405
406impl ProteinMPNN {
407 pub fn load(vb: VarBuilder, config: &ProteinMPNNConfig) -> Result<Self> {
408 let hidden_dim = config.hidden_dim as usize;
409 let edge_features = config.edge_features as usize;
410 let num_letters = config.num_letters as usize;
411 let vocab_size = config.vocab as usize;
412
413 let encoder_layers = (0..config.num_encoder_layers)
415 .map(|i| EncLayer::load(vb.pp("encoder_layers"), config, i as i32))
416 .collect::<Result<Vec<_>>>()?;
417
418 let decoder_layers = (0..config.num_decoder_layers)
419 .map(|i| DecLayer::load(vb.pp("decoder_layers"), config, i as i32))
420 .collect::<Result<Vec<_>>>()?;
421
422 let w_e = linear::linear(edge_features, hidden_dim, vb.pp("W_e"))?;
424 let w_out = linear::linear(hidden_dim, num_letters, vb.pp("W_out"))?;
425 let w_s = embedding(vocab_size, hidden_dim, vb.pp("W_s"))?;
426 let features = ProteinFeaturesModel::load(vb.pp("features"), config.clone())?;
428
429 Ok(Self {
430 config: config.clone(),
431 decoder_layers,
432 device: vb.device().clone(),
433 encoder_layers,
434 features,
435 w_e,
436 w_out,
437 w_s,
438 })
439 }
440 pub fn encode(&self, features: &ProteinFeatures) -> Result<(Tensor, Tensor, Tensor)> {
441 let s_true = features.get_sequence();
442 let base_dtype = DType::F32;
443 let mask = match features.get_sequence_mask() {
444 Some(m) => m,
445 None => &Tensor::ones_like(s_true)?,
446 };
447 match self.config.model_type {
448 ModelTypes::ProteinMPNN => {
449 let (e, e_idx) = self.features.forward(features, &self.device)?;
450 let h_v = Tensor::zeros(
451 (e.dim(0)?, e.dim(1)?, e.dim(D::Minus1)?),
452 base_dtype,
453 &self.device,
454 )?;
455 let h_e = self.w_e.forward(&e)?;
456 let mask_attend = if let Some(seq_mask) = features.get_sequence_mask() {
457 let mask_expanded = seq_mask.unsqueeze(D::Minus1)?; let mask_gathered = gather_nodes(&mask_expanded, &e_idx)?.squeeze(D::Minus1)?;
459 let mask_unsqueezed = mask.unsqueeze(D::Minus1)?; mask_unsqueezed
461 .expand((
462 mask_gathered.dim(0)?, mask_gathered.dim(1)?, mask_gathered.dim(2)?, ))?
466 .mul(&mask_gathered)?
467 } else {
468 let (b, l) = mask.dims2()?;
469 Tensor::ones((b, l, e_idx.dim(2)?), DType::F32, &self.device)?
470 };
471 println!("Beginning the Encoding...");
472 let mask_f32 = mask.to_dtype(base_dtype)?;
474 let mask_attend_f32 = mask_attend.to_dtype(base_dtype)?;
475
476 let (h_v, h_e) =
478 self.encoder_layers
479 .iter()
480 .fold(Ok((h_v, h_e)), |acc, layer| {
481 let (h_v, h_e) = acc?;
482 layer.forward(
483 &h_v,
484 &h_e,
485 &e_idx,
486 Some(&mask_f32),
487 Some(&mask_attend_f32),
488 Some(false),
489 )
490 })?;
491
492 Ok((h_v, h_e, e_idx))
493 }
494 ModelTypes::LigandMPNN => {
495 todo!()
496 }
528 }
529 }
530 pub fn simple_decode(&self, features: &ProteinFeatures) -> Result<ScoreOutput> {
532 let b_decoder = 1;
534
535 let ProteinFeatures { s, x_mask, .. } = features;
537 let device = s.device();
538 let (_, l) = s.dims2()?;
539
540 let (h_v_enc, h_e_enc, e_idx_enc) = self.encode(features)?;
542
543 let s_true = s.clone();
545 let mask = x_mask.clone().unwrap();
546
547 let zeros = Tensor::zeros((b_decoder, l, h_v_enc.dim(D::Minus1)?), DType::F32, device)?;
549 let h_v = h_v_enc.clone();
550 let h_e = h_e_enc.clone();
551 let e_idx = e_idx_enc.clone();
552
553 let h_ex_encoder = cat_neighbors_nodes(&zeros, &h_e, &e_idx)?;
555 let h_exv_encoder = cat_neighbors_nodes(&h_v, &h_ex_encoder, &e_idx)?;
556
557 let h_v_final = self.decoder_layers.iter().fold(Ok(h_v), |acc, layer| {
559 layer.forward(&acc?, &h_exv_encoder, Some(&mask), None, None)
560 })?;
561
562 let logits = self.w_out.forward(&h_v_final)?;
564 let log_probs = log_softmax(&logits, D::Minus1)?;
565
566 let decoding_order = Tensor::arange(0, l as i64, device)?
568 .reshape((1, l))?
569 .broadcast_as((b_decoder, l))?
570 .to_dtype(DType::F32)?;
571
572 Ok(ScoreOutput {
574 s: s_true,
575 log_probs,
576 logits,
577 decoding_order,
578 })
579 }
580 pub fn sample(
581 &self,
582 features: &ProteinFeatures,
583 temperature: f64,
584 seed: u64,
585 ) -> Result<ScoreOutput> {
586 let sample_dtype = DType::F32;
587 let ProteinFeatures {
588 s,
589 x_mask,
590 ..
593 } = features;
594 let s_true = s.to_dtype(sample_dtype)?;
595 let device = s.device();
596 let (b, l) = s.dims2()?;
597 let chain_mask = x_mask.as_ref().unwrap().to_dtype(sample_dtype)?;
601 let (h_v, h_e, e_idx) = self.encode(features)?;
602 let rand_tensor = Tensor::randn(0f32, 0.25f32, (b, l), device)?.to_dtype(sample_dtype)?;
603 let decoding_order = (&chain_mask + 0.0001)?
604 .mul(&rand_tensor.abs()?)?
605 .arg_sort_last_dim(false)?;
606 let bias = Tensor::ones((b, l, 21), sample_dtype, device)?;
609 let symmetry_residues: Option<Vec<i32>> = None;
610 match symmetry_residues {
611 None => {
612 let e_idx = e_idx.repeat(&[b, 1, 1])?;
613 let permutation_matrix_reverse = one_hot(decoding_order.clone(), l, 1f32, 0f32)?
614 .to_dtype(sample_dtype)?
615 .contiguous()?;
616 let tril = Tensor::tril2(l, sample_dtype, device)?;
617 let tril = tril.unsqueeze(0)?;
618 let temp = tril
619 .matmul(&permutation_matrix_reverse.transpose(1, 2)?)?
620 .contiguous()?; let order_mask_backward = temp
622 .matmul(&permutation_matrix_reverse.transpose(1, 2)?)?
623 .contiguous()?; let mask_attend = order_mask_backward
625 .gather(&e_idx, 2)?
626 .unsqueeze(D::Minus1)?;
627 let mask_1d = x_mask.as_ref().unwrap().reshape((b, l, 1, 1))?;
628 let mask_1d = mask_1d
630 .broadcast_as(mask_attend.shape())?
631 .to_dtype(sample_dtype)?;
632 let mask_bw = mask_1d.mul(&mask_attend)?;
633 let mask_fw = mask_1d.mul(&(Tensor::ones_like(&mask_attend)? - mask_attend)?)?;
634 let s_true = s_true.repeat((b, 1))?;
637 let h_v = h_v.repeat((b, 1, 1))?;
638 let h_e = h_e.repeat((b, 1, 1, 1))?;
639 let mask = x_mask.as_ref().unwrap().repeat((b, 1))?.contiguous()?;
640 let chain_mask = &chain_mask.repeat((b, 1))?;
641 let bias = bias.repeat((b, 1, 1))?;
642 let mut all_probs = Tensor::zeros((b, l, 20), sample_dtype, device)?;
643 let mut all_log_probs = Tensor::zeros((b, l, 21), sample_dtype, device)?;
645 let mut h_s = Tensor::zeros_like(&h_v)?;
646 let mut s = Tensor::full(20u32, (b, l), device)?;
648 let mut h_v_stack = vec![h_v.clone()];
649
650 for _ in 0..self.decoder_layers.len() {
651 let zeros = Tensor::zeros_like(&h_v)?;
652 h_v_stack.push(zeros);
653 }
654 let h_ex_encoder = cat_neighbors_nodes(&Tensor::zeros_like(&h_s)?, &h_e, &e_idx)?;
655 let h_exv_encoder = cat_neighbors_nodes(&h_v, &h_ex_encoder, &e_idx)?;
656 let mask_fw = mask_fw
657 .broadcast_as(h_exv_encoder.shape())?
658 .to_dtype(h_exv_encoder.dtype())?;
659 let h_exv_encoder_fw = mask_fw.mul(&h_exv_encoder)?;
660 for t_ in 0..l {
661 let t = decoding_order.i((.., t_))?;
662 let t_gather = t.unsqueeze(1)?; let chain_mask_t = chain_mask.gather(&t_gather, 1)?.squeeze(1)?;
665 let mask_t = mask.gather(&t_gather, 1)?.squeeze(1)?.contiguous()?;
666 let bias_t = bias
667 .gather(&t_gather.unsqueeze(2)?.expand((b, 1, 21))?.contiguous()?, 1)?
668 .squeeze(1)?;
669 let e_idx_t = e_idx
671 .gather(
672 &t_gather
673 .unsqueeze(2)?
674 .expand((b, 1, e_idx.dim(2)?))?
675 .contiguous()?,
676 1,
677 )?
678 .contiguous()?;
679 let h_e_t = h_e.gather(
680 &t_gather
681 .unsqueeze(2)?
682 .unsqueeze(3)?
683 .expand((b, 1, h_e.dim(2)?, h_e.dim(3)?))?
684 .contiguous()?,
685 1,
686 )?;
687 let n = e_idx_t.dim(2)?; let c = h_s.dim(2)?; let h_e_t = h_e_t
690 .squeeze(1)? .unsqueeze(1)? .expand((b, l, n, c))? .contiguous()?;
694 let e_idx_t = e_idx_t
695 .expand((b, l, n))? .contiguous()?;
697 let h_es_t = cat_neighbors_nodes(&h_s, &h_e_t, &e_idx_t)?;
698 let h_exv_encoder_t = h_exv_encoder_fw.gather(
699 &t_gather
700 .unsqueeze(2)?
701 .unsqueeze(3)?
702 .expand((b, 1, h_exv_encoder_fw.dim(2)?, h_exv_encoder_fw.dim(3)?))?
703 .contiguous()?,
704 1,
705 )?;
706 let mask_bw_t = mask_bw.gather(
707 &t_gather
708 .unsqueeze(2)?
709 .unsqueeze(3)?
710 .expand((b, 1, mask_bw.dim(2)?, mask_bw.dim(3)?))?
711 .contiguous()?,
712 1,
713 )?;
714
715 for l in 0..self.decoder_layers.len() {
717 let h_v_stack_l = &h_v_stack[l];
718 let h_esv_decoder_t = cat_neighbors_nodes(h_v_stack_l, &h_es_t, &e_idx_t)?;
719 let h_v_t = h_v_stack_l.gather(
720 &t_gather
721 .unsqueeze(2)?
722 .expand((b, 1, h_v_stack_l.dim(2)?))?
723 .contiguous()?,
724 1,
725 )?;
726 let mask_bw_t = mask_bw_t.expand(h_esv_decoder_t.dims())?.contiguous()?;
727 let h_exv_encoder_t = h_exv_encoder_t
728 .expand(h_esv_decoder_t.dims())?
729 .contiguous()?
730 .to_dtype(sample_dtype)?;
731 let h_esv_t = mask_bw_t
732 .mul(&h_esv_decoder_t.to_dtype(sample_dtype)?)?
733 .add(&h_exv_encoder_t)?
734 .to_dtype(sample_dtype)?
735 .contiguous()?;
736 let h_v_t = h_v_t
737 .expand((
738 h_esv_t.dim(0)?, h_esv_t.dim(1)?, h_v_t.dim(2)?, ))?
742 .contiguous()?;
743 let decoder_output = self.decoder_layers[l].forward(
744 &h_v_t,
745 &h_esv_t,
746 Some(&mask_t),
747 None,
748 None,
749 )?;
750 let t_expanded = t_gather.reshape(&[b])?; let decoder_output = decoder_output
752 .narrow(1, 0, 1)?
753 .squeeze(1)? .unsqueeze(1)?; h_v_stack[l + 1] =
756 h_v_stack[l + 1].index_add(&t_expanded, &decoder_output, 1)?;
757 }
760 let h_v_t = h_v_stack
761 .last()
762 .unwrap()
763 .gather(
764 &t_gather
765 .unsqueeze(2)?
766 .expand((b, 1, h_v_stack.last().unwrap().dim(2)?))?
767 .contiguous()?,
768 1,
769 )?
770 .squeeze(1)?;
771 let logits = self.w_out.forward(&h_v_t)?;
773 let log_probs = log_softmax(&logits, D::Minus1)?;
774
775 let probs = {
777 let biased_logits = logits.add(&bias_t)?; let scaled_logits = (biased_logits / temperature)?; softmax(&scaled_logits, D::Minus1)? };
781
782 let probs_sample = probs
783 .narrow(1, 0, 20)?
784 .div(&probs.narrow(1, 0, 20)?.sum_keepdim(1)?.expand((b, 20))?)?;
785 let sum = probs_sample.sum(1)?;
787 let probs_sample_1d = probs_sample
788 .squeeze(0)? .clamp(1e-10, 1.0)?
790 .broadcast_div(&sum)?
791 .contiguous()?;
792
793 let s_t = multinomial_sample(&probs_sample_1d, temperature, seed)?;
794 let s_t = s_t.to_dtype(sample_dtype)?;
795 let s_true = s_true.to_dtype(sample_dtype)?;
796 let s_true_t = s_true.gather(&t_gather, 1)?.squeeze(1)?;
797 let s_t = s_t
798 .mul(&chain_mask_t)?
799 .add(&s_true_t.mul(&(&chain_mask_t.neg()? + 1.0)?)?)?
800 .to_dtype(DType::U32)?;
801
802 let s_t_idx = s_t.to_dtype(DType::U32)?;
803 let s_t_idx = s_t_idx.reshape(&[s_t_idx.dim(0)?])?;
804 let h_s_update = self.w_s.forward(&s_t_idx)?.unsqueeze(1)?;
805 let t_gather_expanded = t_gather.reshape(&[b])?;
806 let h_s_update = h_s_update.squeeze(0)?.unsqueeze(1)?;
807 h_s =
808 h_s.index_add(&t_gather_expanded, &Tensor::zeros_like(&h_s_update)?, 1)?;
809 h_s = h_s.index_add(&t_gather_expanded, &h_s_update, 1)?;
810
811 s = {
812 let dim = 1;
813 let start = t_gather.squeeze(0)?.squeeze(0)?.to_scalar::<u32>()? as usize;
814 let s_t_expanded = s_t.unsqueeze(1)?;
815 s.slice_scatter(&s_t_expanded, dim, start)?
816 };
817
818 let probs_update = chain_mask_t
819 .unsqueeze(1)?
820 .unsqueeze(2)?
821 .expand((b, 1, 20))?
822 .mul(&probs_sample.unsqueeze(1)?)?;
823 let t_expanded = t_gather.reshape(&[b])?;
824 let probs_update = probs_update
825 .squeeze(1)? .unsqueeze(1)?;
827 all_probs =
828 all_probs.index_add(&t_expanded, &Tensor::zeros_like(&probs_update)?, 1)?;
829 all_probs = all_probs.index_add(&t_expanded, &probs_update, 1)?;
830 let log_probs_update = chain_mask_t
831 .unsqueeze(1)?
832 .unsqueeze(2)?
833 .expand((b, 1, 21))?
834 .mul(&log_probs.unsqueeze(1)?)?
835 .squeeze(1)?
836 .unsqueeze(1)?;
837
838 all_log_probs = all_log_probs.index_add(
839 &t_expanded,
840 &Tensor::zeros_like(&log_probs_update)?,
841 1,
842 )?;
843 all_log_probs = all_log_probs.index_add(&t_expanded, &log_probs_update, 1)?;
844 }
845 Ok(ScoreOutput {
846 s,
847 log_probs: all_probs,
848 logits: all_log_probs,
849 decoding_order,
850 })
851 }
852 Some(symmetry_residues) => {
853 todo!()
854 }
855 }
856 }
857
858 pub fn score(&self, features: &ProteinFeatures, use_sequence: bool) -> Result<ScoreOutput> {
859 let ProteinFeatures { s, x_mask, .. } = &features;
860 let sample_dtype = DType::F32;
861 let s_true = &s.clone();
862 let device = s_true.device();
863 let (b, l) = s_true.dims2()?;
864 let mask = &x_mask.as_ref().clone();
865 let b_decoder: usize = b;
866
867 let chain_mask = Tensor::zeros_like(mask.unwrap())?.to_dtype(sample_dtype)?;
870 let chain_mask = mask.unwrap().mul(&chain_mask)?; let (h_v, h_e, e_idx) = self.encode(features)?;
874 let rand_tensor = Tensor::randn(0f32, 1f32, (b, l), device)?.to_dtype(sample_dtype)?;
875 let decoding_order = (chain_mask + 0.001)?
877 .mul(&rand_tensor.abs()?)?
878 .arg_sort_last_dim(false)?;
879
880 let symmetry_residues: Option<Vec<i32>> = None;
881
882 let (mask_fw, mask_bw, e_idx, decoding_order) = match symmetry_residues {
883 Some(symmetry_residues) => {
884 todo!();
885 }
886 None => {
887 let e_idx = e_idx.repeat(&[b_decoder, 1, 1])?;
888 let permutation_matrix_reverse = one_hot(decoding_order.clone(), l, 1f32, 0f32)?
889 .to_dtype(sample_dtype)?
890 .contiguous()?;
891
892 let tril = Tensor::tril2(l, sample_dtype, device)?.unsqueeze(0)?;
893 let temp = tril
894 .matmul(&permutation_matrix_reverse.transpose(1, 2)?)?
895 .contiguous()?; let order_mask_backward = temp
897 .matmul(&permutation_matrix_reverse.transpose(1, 2)?)?
898 .contiguous()?; let mask_attend = order_mask_backward
900 .gather(&e_idx, 2)?
901 .unsqueeze(D::Minus1)?;
902
903 let mask_1d = mask
905 .unwrap()
906 .reshape((b, l, 1, 1))?
907 .broadcast_as(mask_attend.shape())?
908 .to_dtype(sample_dtype)?;
909
910 let mask_bw = mask_1d.mul(&mask_attend)?;
911 let mask_fw = mask_1d.mul(&(mask_attend - 1.0)?.neg()?)?;
912 (mask_fw, mask_bw, e_idx, decoding_order)
913 }
914 };
915
916 let s_true = s_true.repeat(&[b_decoder, 1])?;
917 let h_v = h_v.repeat(&[b_decoder, 1, 1])?;
918 let h_e = h_e.repeat(&[b_decoder, 1, 1, 1])?;
919 let mask = mask.as_ref().unwrap().repeat(&[b_decoder, 1])?;
920
921 let h_s = self.w_s.forward(&s_true)?; let h_es = cat_neighbors_nodes(&h_s, &h_e, &e_idx)?;
923
924 let h_ex_encoder = cat_neighbors_nodes(&Tensor::zeros_like(&h_s)?, &h_e, &e_idx)?;
926 let h_exv_encoder = cat_neighbors_nodes(&h_v, &h_ex_encoder, &e_idx)?;
927 let h_exv_encoder_fw = mask_fw
928 .broadcast_as(h_exv_encoder.shape())?
929 .to_dtype(h_exv_encoder.dtype())?
930 .mul(&h_exv_encoder)?;
931
932 let h_v = if !use_sequence {
934 self.decoder_layers.iter().fold(Ok(h_v), |acc, layer| {
936 layer.forward(&acc?, &h_exv_encoder_fw, Some(&mask), None, None)
937 })?
938 } else {
939 self.decoder_layers.iter().fold(Ok(h_v), |acc, layer| {
941 let current_h_v = acc?;
942 let h_esv = cat_neighbors_nodes(¤t_h_v, &h_es, &e_idx)?
943 .mul(&mask_bw)?
944 .add(&h_exv_encoder_fw)?;
945 layer.forward(¤t_h_v, &h_esv, Some(&mask), None, None)
946 })?
947 };
948
949 let logits = self.w_out.forward(&h_v)?;
950 let log_probs = log_softmax(&logits, D::Minus1)?;
951
952 Ok(ScoreOutput {
953 s: s_true,
954 log_probs,
955 logits,
956 decoding_order,
957 })
958 }
959}