1use super::configs::{ModelTypes, ProteinMPNNConfig};
7use super::proteinfeatures::ProteinFeatures;
8use super::proteinfeaturesmodel::ProteinFeaturesModel;
9use crate::featurize::utilities::{cat_neighbors_nodes, gather_nodes, int_to_aa1};
10use crate::types::PseudoProbability;
11use candle_core::safetensors;
12use candle_core::{D, DType, Device, IndexOp, Module, Result, Tensor};
13use candle_nn::encoding::one_hot;
14use candle_nn::ops::{log_softmax, softmax};
15use candle_nn::{Dropout, Embedding, LayerNorm, Linear, VarBuilder, embedding, layer_norm, linear};
16use candle_transformers::generation::LogitsProcessor;
17use std::collections::HashMap;
18
19fn concat_node_tensors(h_v: &Tensor, h_e: &Tensor, e_idx: &Tensor) -> Result<Tensor> {
21 let h_ev = cat_neighbors_nodes(h_v, h_e, e_idx)?;
22 let h_v_expand = h_v.unsqueeze(D::Minus2)?;
23 let expand_shape = [
24 h_ev.dims()[0],
25 h_ev.dims()[1],
26 h_ev.dims()[2],
27 h_v_expand.dims()[3],
28 ];
29 let h_v_expand = h_v_expand.expand(&expand_shape)?.to_dtype(h_ev.dtype())?;
30 Tensor::cat(&[&h_v_expand, &h_ev], D::Minus1)?.contiguous()
31}
32fn apply_dropout_and_norm(
34 input: &Tensor,
35 delta: &Tensor,
36 dropout: &Dropout,
37 norm: &LayerNorm,
38 training: bool,
39) -> Result<Tensor> {
40 let delta_dropout = dropout.forward(delta, training)?;
41 norm.forward(&(input + delta_dropout)?)
42}
43
44pub fn multinomial_sample(probs: &Tensor, temperature: f64, seed: u64) -> Result<Tensor> {
45 let mut logits_processor = LogitsProcessor::new(
46 seed, Some(temperature), Some(0.95), );
51 let idx = logits_processor.sample(probs)?;
52 if idx >= 21 {
54 println!("WARNING: Invalid index {} selected", idx);
55 }
56 Tensor::new(&[idx], probs.device())
57}
58
59#[derive(Clone, Debug)]
61pub struct ScoreOutput {
62 pub(crate) s: Tensor,
64 pub(crate) log_probs: Tensor,
65 pub(crate) logits: Tensor,
66 pub(crate) decoding_order: Tensor,
67}
68impl ScoreOutput {
70 pub fn get_sequences(&self) -> Result<Vec<String>> {
71 let (b, l) = self.s.dims2()?;
72 let mut sequences = Vec::with_capacity(b);
73 for batch_idx in 0..b {
74 let batch = self.s.get(batch_idx)?;
75 let mut sequence = String::with_capacity(l);
76 for pos in 0..l {
77 let aa_idx = batch.get(pos)?.to_vec0::<u32>()?;
78 sequence.push(int_to_aa1(aa_idx));
79 }
80 sequences.push(sequence);
81 }
82 Ok(sequences)
83 }
84 pub fn get_decoding_order(&self) -> Result<Vec<u32>> {
85 let values = self.decoding_order.flatten_all()?.to_vec1::<u32>()?;
86 Ok(values)
87 }
88 pub fn get_log_probs(&self) -> &Tensor {
89 &self.log_probs
90 }
91 pub fn get_pseudo_probabilities(&self) -> Result<Vec<PseudoProbability>> {
92 let (batch_size, seq_len, _vocab_size) = self.logits.dims3()?;
93 let mut all_probabilities = Vec::with_capacity(batch_size * seq_len);
94
95 for batch_idx in 0..batch_size {
96 let batch_logits = self.logits.get(batch_idx)?;
97
98 for pos in 0..seq_len {
99 let pos_logits = batch_logits.get(pos)?;
100 let probs = softmax(&pos_logits, 0)?;
101 let probs = probs.to_vec1::<f32>()?;
102
103 let actual_pos = if let Ok(order) = self
105 .decoding_order
106 .get(batch_idx)?
107 .get(pos)?
108 .to_scalar::<u32>()
109 {
110 order as usize
111 } else {
112 pos
113 };
114
115 for aa_idx in 0..probs.len() {
116 if probs[aa_idx] > 0.01 {
117 all_probabilities.push(PseudoProbability {
119 position: actual_pos,
120 pseudo_prob: probs[aa_idx],
121 amino_acid: int_to_aa1(aa_idx as u32),
122 });
123 }
124 }
125 }
126 }
127 all_probabilities.sort_by(|a, b| {
129 a.position.cmp(&b.position).then(
130 b.pseudo_prob
131 .partial_cmp(&a.pseudo_prob)
132 .unwrap_or(std::cmp::Ordering::Equal),
133 )
134 });
135 Ok(all_probabilities)
136 }
137 pub fn save_as_safetensors(&self, filename: String) -> Result<()> {
138 let mut tensors = HashMap::new();
139 tensors.insert("S".to_string(), self.s.clone());
140 tensors.insert("log_probs".to_string(), self.log_probs.clone());
141 tensors.insert("logits".to_string(), self.logits.clone());
142 tensors.insert("decoding_order".to_string(), self.decoding_order.clone());
143 if let Some(parent) = std::path::Path::new(&filename).parent() {
145 std::fs::create_dir_all(parent)?;
146 }
147 let _ = safetensors::save(&tensors, &filename);
148 Ok(())
149 }
150}
151
152#[derive(Clone, Debug)]
153struct PositionWiseFeedForward {
154 w1: Linear,
155 w2: Linear,
156}
157
158impl PositionWiseFeedForward {
159 fn new(vb: VarBuilder, dim_input: usize, dim_feedforward: usize) -> Result<Self> {
160 let w1 = linear::linear(dim_input, dim_feedforward, vb.pp("W_in"))?;
161 let w2 = linear::linear(dim_feedforward, dim_input, vb.pp("W_out"))?;
162 Ok(Self { w1, w2 })
163 }
164}
165
166impl Module for PositionWiseFeedForward {
167 fn forward(&self, x: &Tensor) -> Result<Tensor> {
168 self.w1.forward(x)?.gelu().and_then(|x| self.w2.forward(&x))
169 }
170}
171
172#[derive(Clone, Debug)]
173#[allow(dead_code)]
174pub struct EncLayer {
175 num_hidden: usize,
176 num_in: usize,
177 scale: f64,
178 dropout1: Dropout,
179 dropout2: Dropout,
180 dropout3: Dropout,
181 norm1: LayerNorm,
182 norm2: LayerNorm,
183 norm3: LayerNorm,
184 w1: Linear,
185 w2: Linear,
186 w3: Linear,
187 w11: Linear,
188 w12: Linear,
189 w13: Linear,
190 dense: PositionWiseFeedForward,
191}
192
193impl EncLayer {
194 pub fn load(vb: VarBuilder, config: &ProteinMPNNConfig, layer: i32) -> Result<Self> {
195 let vb = vb.pp(layer); let num_hidden = config.hidden_dim as usize;
197 let augment_eps = config.augment_eps as f64;
198 let num_in = (config.hidden_dim * 2) as usize;
199 let dropout_ratio = config.dropout_ratio;
200
201 let norm1 = layer_norm(num_hidden, augment_eps, vb.pp("norm1"))?;
203 let norm2 = layer_norm(num_hidden, augment_eps, vb.pp("norm2"))?;
204 let norm3 = layer_norm(num_hidden, augment_eps, vb.pp("norm3"))?;
205
206 let w1 = linear(num_hidden + num_in, num_hidden, vb.pp("W1"))?;
208 let w2 = linear(num_hidden, num_hidden, vb.pp("W2"))?;
209 let w3 = linear(num_hidden, num_hidden, vb.pp("W3"))?;
210 let w11 = linear(num_hidden + num_in, num_hidden, vb.pp("W11"))?;
211 let w12 = linear(num_hidden, num_hidden, vb.pp("W12"))?;
212 let w13 = linear(num_hidden, num_hidden, vb.pp("W13"))?;
213
214 let dropout1 = Dropout::new(dropout_ratio);
216 let dropout2 = Dropout::new(dropout_ratio);
217 let dropout3 = Dropout::new(dropout_ratio);
218
219 let dense = PositionWiseFeedForward::new(vb.pp("dense"), num_hidden, num_hidden * 4)?;
220
221 Ok(Self {
222 num_hidden,
223 num_in,
224 scale: config.scale_factor,
225 dropout1,
226 dropout2,
227 dropout3,
228 norm1,
229 norm2,
230 norm3,
231 w1,
232 w2,
233 w3,
234 w11,
235 w12,
236 w13,
237 dense,
238 })
239 }
240 fn forward(
241 &self,
242 h_v: &Tensor,
243 h_e: &Tensor,
244 e_idx: &Tensor,
245 mask_v: Option<&Tensor>,
246 mask_attend: Option<&Tensor>,
247 training: Option<bool>,
248 ) -> Result<(Tensor, Tensor)> {
249 let training = training.unwrap_or(false);
250 let h_v = h_v.to_dtype(DType::F32)?;
251 let h_ev = concat_node_tensors(&h_v, h_e, e_idx)?;
252 let h_message = self
253 .w1
254 .forward(&h_ev)?
255 .gelu()?
256 .apply(&self.w2)?
257 .gelu()?
258 .apply(&self.w3)?;
259
260 let h_message = mask_attend
261 .map(|mask| mask.unsqueeze(D::Minus1)?.broadcast_mul(&h_message))
262 .transpose()?
263 .unwrap_or(h_message);
264
265 let scale = if self.scale == 0.0 { 1.0 } else { self.scale };
267 let dh = (h_message.sum(D::Minus2)? / scale)?;
268
269 let h_v = apply_dropout_and_norm(&h_v, &dh, &self.dropout1, &self.norm1, training)?;
270 let dense_output = self.dense.forward(&h_v)?;
271 let h_v =
272 apply_dropout_and_norm(&h_v, &dense_output, &self.dropout2, &self.norm2, training)?;
273
274 let h_v = mask_v
276 .map(|mask| mask.unsqueeze(D::Minus1)?.broadcast_mul(&h_v))
277 .transpose()?
278 .unwrap_or(h_v);
279
280 let h_ev = concat_node_tensors(&h_v, h_e, e_idx)?;
281 let h_message = self
282 .w11
283 .forward(&h_ev)?
284 .gelu()?
285 .apply(&self.w12)?
286 .gelu()?
287 .apply(&self.w13)?;
288
289 let h_e = apply_dropout_and_norm(h_e, &h_message, &self.dropout3, &self.norm3, training)?;
290 Ok((h_v, h_e))
291 }
292}
293
294#[derive(Clone, Debug)]
295#[allow(dead_code)]
296pub struct DecLayer {
297 num_hidden: usize,
298 num_in: usize,
299 scale: f64,
300 dropout1: Dropout,
301 dropout2: Dropout,
302 norm1: LayerNorm,
303 norm2: LayerNorm,
304 w1: Linear,
305 w2: Linear,
306 w3: Linear,
307 dense: PositionWiseFeedForward,
308}
309
310impl DecLayer {
311 pub fn load(vb: VarBuilder, config: &ProteinMPNNConfig, layer: i32) -> Result<Self> {
312 let vb = vb.pp(layer); let num_hidden = config.hidden_dim as usize;
314 let augment_eps = config.augment_eps as f64;
315 let num_in = (config.hidden_dim * 3) as usize;
316 let dropout_ratio = config.dropout_ratio;
317
318 let norm1 = layer_norm::layer_norm(num_hidden, augment_eps, vb.pp("norm1"))?;
319 let norm2 = layer_norm::layer_norm(num_hidden, augment_eps, vb.pp("norm2"))?;
320
321 let w1 = linear::linear(num_hidden + num_in, num_hidden, vb.pp("W1"))?;
322 let w2 = linear::linear(num_hidden, num_hidden, vb.pp("W2"))?;
323 let w3 = linear::linear(num_hidden, num_hidden, vb.pp("W3"))?;
324 let dropout1 = Dropout::new(dropout_ratio);
325 let dropout2 = Dropout::new(dropout_ratio);
326
327 let dense = PositionWiseFeedForward::new(vb.pp("dense"), num_hidden, num_hidden * 4)?;
328
329 Ok(Self {
330 num_hidden,
331 num_in,
332 scale: config.scale_factor,
333 dropout1,
334 dropout2,
335 norm1,
336 norm2,
337 w1,
338 w2,
339 w3,
340 dense,
341 })
342 }
343 pub fn forward(
344 &self,
345 h_v: &Tensor,
346 h_e: &Tensor,
347 mask_v: Option<&Tensor>,
348 mask_attend: Option<&Tensor>,
349 training: Option<bool>,
350 ) -> Result<Tensor> {
351 let training_bool = training.unwrap_or(false);
352
353 let expand_shape = [
355 h_e.dims()[0], h_e.dims()[1], h_e.dims()[2], h_v.dims()[2], ];
360
361 let h_v_expand = h_v.unsqueeze(D::Minus2)?.expand(&expand_shape)?;
362 let h_ev = Tensor::cat(&[&h_v_expand, h_e], D::Minus1)?.contiguous()?;
363
364 let h_message = self
365 .w1
366 .forward(&h_ev)?
367 .gelu()?
368 .apply(&self.w2)?
369 .gelu()?
370 .apply(&self.w3)?;
371
372 let h_message = self.dropout1.forward(&h_message, training_bool)?;
373
374 let h_message = mask_attend
375 .map(|mask| mask.unsqueeze(D::Minus1)?.broadcast_mul(&h_message))
376 .transpose()?
377 .unwrap_or(h_message);
378
379 let dh = (h_message.sum(D::Minus2)? / self.scale)?;
380 let h_v = self.norm1.forward(&(h_v + dh)?)?;
381 let dh = self.dense.forward(&h_v)?;
382 let dh_dropout = self.dropout2.forward(&dh, training_bool)?;
383 let h_v = self.norm2.forward(&(h_v + dh_dropout)?)?;
384
385 let h_v = mask_v
387 .map(|mask| mask.unsqueeze(D::Minus1)?.broadcast_mul(&h_v))
388 .transpose()?
389 .unwrap_or(h_v);
390
391 Ok(h_v)
392 }
393}
394
395pub struct ProteinMPNN {
398 pub(crate) config: ProteinMPNNConfig,
399 pub(crate) decoder_layers: Vec<DecLayer>,
400 pub(crate) device: Device,
401 pub(crate) encoder_layers: Vec<EncLayer>,
402 pub(crate) features: ProteinFeaturesModel,
403 pub(crate) w_e: Linear,
404 pub(crate) w_out: Linear,
405 pub(crate) w_s: Embedding,
406}
407
408impl ProteinMPNN {
409 pub fn load(vb: VarBuilder, config: &ProteinMPNNConfig) -> Result<Self> {
410 let hidden_dim = config.hidden_dim as usize;
411 let edge_features = config.edge_features as usize;
412 let num_letters = config.num_letters as usize;
413 let vocab_size = config.vocab as usize;
414
415 let encoder_layers = (0..config.num_encoder_layers)
417 .map(|i| EncLayer::load(vb.pp("encoder_layers"), config, i as i32))
418 .collect::<Result<Vec<_>>>()?;
419
420 let decoder_layers = (0..config.num_decoder_layers)
421 .map(|i| DecLayer::load(vb.pp("decoder_layers"), config, i as i32))
422 .collect::<Result<Vec<_>>>()?;
423
424 let w_e = linear::linear(edge_features, hidden_dim, vb.pp("W_e"))?;
426 let w_out = linear::linear(hidden_dim, num_letters, vb.pp("W_out"))?;
427 let w_s = embedding(vocab_size, hidden_dim, vb.pp("W_s"))?;
428 let features = ProteinFeaturesModel::load(vb.pp("features"), config.clone())?;
430
431 Ok(Self {
432 config: config.clone(),
433 decoder_layers,
434 device: vb.device().clone(),
435 encoder_layers,
436 features,
437 w_e,
438 w_out,
439 w_s,
440 })
441 }
442 pub fn encode(&self, features: &ProteinFeatures) -> Result<(Tensor, Tensor, Tensor)> {
443 let s_true = features.get_sequence();
444 let base_dtype = DType::F32;
445 let mask = match features.get_sequence_mask() {
446 Some(m) => m,
447 None => &Tensor::ones_like(s_true)?,
448 };
449 match self.config.model_type {
450 ModelTypes::ProteinMPNN => {
451 let (e, e_idx) = self.features.forward(features, &self.device)?;
452 let h_v = Tensor::zeros(
453 (e.dim(0)?, e.dim(1)?, e.dim(D::Minus1)?),
454 base_dtype,
455 &self.device,
456 )?;
457 let h_e = self.w_e.forward(&e)?;
458 let mask_attend = if let Some(seq_mask) = features.get_sequence_mask() {
459 let mask_expanded = seq_mask.unsqueeze(D::Minus1)?; let mask_gathered = gather_nodes(&mask_expanded, &e_idx)?.squeeze(D::Minus1)?;
461 let mask_unsqueezed = mask.unsqueeze(D::Minus1)?; mask_unsqueezed
463 .expand((
464 mask_gathered.dim(0)?, mask_gathered.dim(1)?, mask_gathered.dim(2)?, ))?
468 .mul(&mask_gathered)?
469 } else {
470 let (b, l) = mask.dims2()?;
471 Tensor::ones((b, l, e_idx.dim(2)?), DType::F32, &self.device)?
472 };
473 println!("Beginning the Encoding...");
474 let mask_f32 = mask.to_dtype(base_dtype)?;
476 let mask_attend_f32 = mask_attend.to_dtype(base_dtype)?;
477
478 let (h_v, h_e) =
480 self.encoder_layers
481 .iter()
482 .fold(Ok((h_v, h_e)), |acc, layer| {
483 let (h_v, h_e) = acc?;
484 layer.forward(
485 &h_v,
486 &h_e,
487 &e_idx,
488 Some(&mask_f32),
489 Some(&mask_attend_f32),
490 Some(false),
491 )
492 })?;
493
494 Ok((h_v, h_e, e_idx))
495 }
496 ModelTypes::LigandMPNN => {
497 todo!()
498 }
530 }
531 }
532 pub fn simple_decode(&self, features: &ProteinFeatures) -> Result<ScoreOutput> {
534 let b_decoder = 1;
536
537 let ProteinFeatures { s, x_mask, .. } = features;
539 let device = s.device();
540 let (_, l) = s.dims2()?;
541
542 let (h_v_enc, h_e_enc, e_idx_enc) = self.encode(features)?;
544
545 let s_true = s.clone();
547 let mask = x_mask.clone().unwrap();
548
549 let zeros = Tensor::zeros((b_decoder, l, h_v_enc.dim(D::Minus1)?), DType::F32, device)?;
551 let h_v = h_v_enc.clone();
552 let h_e = h_e_enc.clone();
553 let e_idx = e_idx_enc.clone();
554
555 let h_ex_encoder = cat_neighbors_nodes(&zeros, &h_e, &e_idx)?;
557 let h_exv_encoder = cat_neighbors_nodes(&h_v, &h_ex_encoder, &e_idx)?;
558
559 let h_v_final = self.decoder_layers.iter().fold(Ok(h_v), |acc, layer| {
561 layer.forward(&acc?, &h_exv_encoder, Some(&mask), None, None)
562 })?;
563
564 let logits = self.w_out.forward(&h_v_final)?;
566 let log_probs = log_softmax(&logits, D::Minus1)?;
567
568 let decoding_order = Tensor::arange(0, l as i64, device)?
570 .reshape((1, l))?
571 .broadcast_as((b_decoder, l))?
572 .to_dtype(DType::F32)?;
573
574 Ok(ScoreOutput {
576 s: s_true,
577 log_probs,
578 logits,
579 decoding_order,
580 })
581 }
582 pub fn sample(
583 &self,
584 features: &ProteinFeatures,
585 temperature: f64,
586 seed: u64,
587 ) -> Result<ScoreOutput> {
588 let sample_dtype = DType::F32;
589 let ProteinFeatures {
590 s,
591 x_mask,
592 ..
595 } = features;
596 let s_true = s.to_dtype(sample_dtype)?;
597 let device = s.device();
598 let (b, l) = s.dims2()?;
599 let chain_mask = x_mask.as_ref().unwrap().to_dtype(sample_dtype)?;
603 let (h_v, h_e, e_idx) = self.encode(features)?;
604 let rand_tensor = Tensor::randn(0f32, 0.25f32, (b, l), device)?.to_dtype(sample_dtype)?;
605 let decoding_order = (&chain_mask + 0.0001)?
606 .mul(&rand_tensor.abs()?)?
607 .arg_sort_last_dim(false)?;
608 let bias = Tensor::ones((b, l, 21), sample_dtype, device)?;
611 let symmetry_residues: Option<Vec<i32>> = None;
612 match symmetry_residues {
613 None => {
614 let e_idx = e_idx.repeat(&[b, 1, 1])?;
615 let permutation_matrix_reverse = one_hot(decoding_order.clone(), l, 1f32, 0f32)?
616 .to_dtype(sample_dtype)?
617 .contiguous()?;
618 let tril = Tensor::tril2(l, sample_dtype, device)?;
619 let tril = tril.unsqueeze(0)?;
620 let temp = tril
621 .matmul(&permutation_matrix_reverse.transpose(1, 2)?)?
622 .contiguous()?; let order_mask_backward = temp
624 .matmul(&permutation_matrix_reverse.transpose(1, 2)?)?
625 .contiguous()?; let mask_attend = order_mask_backward
627 .gather(&e_idx, 2)?
628 .unsqueeze(D::Minus1)?;
629 let mask_1d = x_mask.as_ref().unwrap().reshape((b, l, 1, 1))?;
630 let mask_1d = mask_1d
632 .broadcast_as(mask_attend.shape())?
633 .to_dtype(sample_dtype)?;
634 let mask_bw = mask_1d.mul(&mask_attend)?;
635 let mask_fw = mask_1d.mul(&(Tensor::ones_like(&mask_attend)? - mask_attend)?)?;
636 let s_true = s_true.repeat((b, 1))?;
639 let h_v = h_v.repeat((b, 1, 1))?;
640 let h_e = h_e.repeat((b, 1, 1, 1))?;
641 let mask = x_mask.as_ref().unwrap().repeat((b, 1))?.contiguous()?;
642 let chain_mask = &chain_mask.repeat((b, 1))?;
643 let bias = bias.repeat((b, 1, 1))?;
644 let mut all_probs = Tensor::zeros((b, l, 20), sample_dtype, device)?;
645 let mut all_log_probs = Tensor::zeros((b, l, 21), sample_dtype, device)?;
647 let mut h_s = Tensor::zeros_like(&h_v)?;
648 let mut s = Tensor::full(20u32, (b, l), device)?;
650 let mut h_v_stack = vec![h_v.clone()];
651
652 for _ in 0..self.decoder_layers.len() {
653 let zeros = Tensor::zeros_like(&h_v)?;
654 h_v_stack.push(zeros);
655 }
656 let h_ex_encoder = cat_neighbors_nodes(&Tensor::zeros_like(&h_s)?, &h_e, &e_idx)?;
657 let h_exv_encoder = cat_neighbors_nodes(&h_v, &h_ex_encoder, &e_idx)?;
658 let mask_fw = mask_fw
659 .broadcast_as(h_exv_encoder.shape())?
660 .to_dtype(h_exv_encoder.dtype())?;
661 let h_exv_encoder_fw = mask_fw.mul(&h_exv_encoder)?;
662 for t_ in 0..l {
663 let t = decoding_order.i((.., t_))?;
664 let t_gather = t.unsqueeze(1)?; let chain_mask_t = chain_mask.gather(&t_gather, 1)?.squeeze(1)?;
667 let mask_t = mask.gather(&t_gather, 1)?.squeeze(1)?.contiguous()?;
668 let bias_t = bias
669 .gather(&t_gather.unsqueeze(2)?.expand((b, 1, 21))?.contiguous()?, 1)?
670 .squeeze(1)?;
671 let e_idx_t = e_idx
673 .gather(
674 &t_gather
675 .unsqueeze(2)?
676 .expand((b, 1, e_idx.dim(2)?))?
677 .contiguous()?,
678 1,
679 )?
680 .contiguous()?;
681 let h_e_t = h_e.gather(
682 &t_gather
683 .unsqueeze(2)?
684 .unsqueeze(3)?
685 .expand((b, 1, h_e.dim(2)?, h_e.dim(3)?))?
686 .contiguous()?,
687 1,
688 )?;
689 let n = e_idx_t.dim(2)?; let c = h_s.dim(2)?; let h_e_t = h_e_t
692 .squeeze(1)? .unsqueeze(1)? .expand((b, l, n, c))? .contiguous()?;
696 let e_idx_t = e_idx_t
697 .expand((b, l, n))? .contiguous()?;
699 let h_es_t = cat_neighbors_nodes(&h_s, &h_e_t, &e_idx_t)?;
700 let h_exv_encoder_t = h_exv_encoder_fw.gather(
701 &t_gather
702 .unsqueeze(2)?
703 .unsqueeze(3)?
704 .expand((b, 1, h_exv_encoder_fw.dim(2)?, h_exv_encoder_fw.dim(3)?))?
705 .contiguous()?,
706 1,
707 )?;
708 let mask_bw_t = mask_bw.gather(
709 &t_gather
710 .unsqueeze(2)?
711 .unsqueeze(3)?
712 .expand((b, 1, mask_bw.dim(2)?, mask_bw.dim(3)?))?
713 .contiguous()?,
714 1,
715 )?;
716
717 for l in 0..self.decoder_layers.len() {
719 let h_v_stack_l = &h_v_stack[l];
720 let h_esv_decoder_t = cat_neighbors_nodes(h_v_stack_l, &h_es_t, &e_idx_t)?;
721 let h_v_t = h_v_stack_l.gather(
722 &t_gather
723 .unsqueeze(2)?
724 .expand((b, 1, h_v_stack_l.dim(2)?))?
725 .contiguous()?,
726 1,
727 )?;
728 let mask_bw_t = mask_bw_t.expand(h_esv_decoder_t.dims())?.contiguous()?;
729 let h_exv_encoder_t = h_exv_encoder_t
730 .expand(h_esv_decoder_t.dims())?
731 .contiguous()?
732 .to_dtype(sample_dtype)?;
733 let h_esv_t = mask_bw_t
734 .mul(&h_esv_decoder_t.to_dtype(sample_dtype)?)?
735 .add(&h_exv_encoder_t)?
736 .to_dtype(sample_dtype)?
737 .contiguous()?;
738 let h_v_t = h_v_t
739 .expand((
740 h_esv_t.dim(0)?, h_esv_t.dim(1)?, h_v_t.dim(2)?, ))?
744 .contiguous()?;
745 let decoder_output = self.decoder_layers[l].forward(
746 &h_v_t,
747 &h_esv_t,
748 Some(&mask_t),
749 None,
750 None,
751 )?;
752 let t_expanded = t_gather.reshape(&[b])?; let decoder_output = decoder_output
754 .narrow(1, 0, 1)?
755 .squeeze(1)? .unsqueeze(1)?; h_v_stack[l + 1] =
758 h_v_stack[l + 1].index_add(&t_expanded, &decoder_output, 1)?;
759 }
762 let h_v_t = h_v_stack
763 .last()
764 .unwrap()
765 .gather(
766 &t_gather
767 .unsqueeze(2)?
768 .expand((b, 1, h_v_stack.last().unwrap().dim(2)?))?
769 .contiguous()?,
770 1,
771 )?
772 .squeeze(1)?;
773 let logits = self.w_out.forward(&h_v_t)?;
775 let log_probs = log_softmax(&logits, D::Minus1)?;
776
777 let probs = {
779 let biased_logits = logits.add(&bias_t)?; let scaled_logits = (biased_logits / temperature)?; softmax(&scaled_logits, D::Minus1)? };
783
784 let probs_sample = probs
785 .narrow(1, 0, 20)?
786 .div(&probs.narrow(1, 0, 20)?.sum_keepdim(1)?.expand((b, 20))?)?;
787 let sum = probs_sample.sum(1)?;
789 let probs_sample_1d = probs_sample
790 .squeeze(0)? .clamp(1e-10, 1.0)?
792 .broadcast_div(&sum)?
793 .contiguous()?;
794
795 let s_t = multinomial_sample(&probs_sample_1d, temperature, seed)?;
796 let s_t = s_t.to_dtype(sample_dtype)?;
797 let s_true = s_true.to_dtype(sample_dtype)?;
798 let s_true_t = s_true.gather(&t_gather, 1)?.squeeze(1)?;
799 let s_t = s_t
800 .mul(&chain_mask_t)?
801 .add(&s_true_t.mul(&(&chain_mask_t.neg()? + 1.0)?)?)?
802 .to_dtype(DType::U32)?;
803
804 let s_t_idx = s_t.to_dtype(DType::U32)?;
805 let s_t_idx = s_t_idx.reshape(&[s_t_idx.dim(0)?])?;
806 let h_s_update = self.w_s.forward(&s_t_idx)?.unsqueeze(1)?;
807 let t_gather_expanded = t_gather.reshape(&[b])?;
808 let h_s_update = h_s_update.squeeze(0)?.unsqueeze(1)?;
809 h_s =
810 h_s.index_add(&t_gather_expanded, &Tensor::zeros_like(&h_s_update)?, 1)?;
811 h_s = h_s.index_add(&t_gather_expanded, &h_s_update, 1)?;
812
813 s = {
814 let dim = 1;
815 let start = t_gather.squeeze(0)?.squeeze(0)?.to_scalar::<u32>()? as usize;
816 let s_t_expanded = s_t.unsqueeze(1)?;
817 s.slice_scatter(&s_t_expanded, dim, start)?
818 };
819
820 let probs_update = chain_mask_t
821 .unsqueeze(1)?
822 .unsqueeze(2)?
823 .expand((b, 1, 20))?
824 .mul(&probs_sample.unsqueeze(1)?)?;
825 let t_expanded = t_gather.reshape(&[b])?;
826 let probs_update = probs_update
827 .squeeze(1)? .unsqueeze(1)?;
829 all_probs =
830 all_probs.index_add(&t_expanded, &Tensor::zeros_like(&probs_update)?, 1)?;
831 all_probs = all_probs.index_add(&t_expanded, &probs_update, 1)?;
832 let log_probs_update = chain_mask_t
833 .unsqueeze(1)?
834 .unsqueeze(2)?
835 .expand((b, 1, 21))?
836 .mul(&log_probs.unsqueeze(1)?)?
837 .squeeze(1)?
838 .unsqueeze(1)?;
839
840 all_log_probs = all_log_probs.index_add(
841 &t_expanded,
842 &Tensor::zeros_like(&log_probs_update)?,
843 1,
844 )?;
845 all_log_probs = all_log_probs.index_add(&t_expanded, &log_probs_update, 1)?;
846 }
847 Ok(ScoreOutput {
848 s,
849 log_probs: all_probs,
850 logits: all_log_probs,
851 decoding_order,
852 })
853 }
854 Some(_symmetry_residues) => {
855 todo!()
856 }
857 }
858 }
859
860 pub fn score(&self, features: &ProteinFeatures, use_sequence: bool) -> Result<ScoreOutput> {
861 let ProteinFeatures { s, x_mask, .. } = &features;
862 let sample_dtype = DType::F32;
863 let s_true = &s.clone();
864 let device = s_true.device();
865 let (b, l) = s_true.dims2()?;
866 let mask = &x_mask.as_ref().clone();
867 let b_decoder: usize = b;
868
869 let chain_mask = Tensor::zeros_like(mask.unwrap())?.to_dtype(sample_dtype)?;
872 let chain_mask = mask.unwrap().mul(&chain_mask)?; let (h_v, h_e, e_idx) = self.encode(features)?;
876 let rand_tensor = Tensor::randn(0f32, 1f32, (b, l), device)?.to_dtype(sample_dtype)?;
877 let decoding_order = (chain_mask + 0.001)?
879 .mul(&rand_tensor.abs()?)?
880 .arg_sort_last_dim(false)?;
881
882 let symmetry_residues: Option<Vec<i32>> = None;
883
884 let (mask_fw, mask_bw, e_idx, decoding_order) = match symmetry_residues {
885 Some(_symmetry_residues) => {
886 todo!();
887 }
888 None => {
889 let e_idx = e_idx.repeat(&[b_decoder, 1, 1])?;
890 let permutation_matrix_reverse = one_hot(decoding_order.clone(), l, 1f32, 0f32)?
891 .to_dtype(sample_dtype)?
892 .contiguous()?;
893
894 let tril = Tensor::tril2(l, sample_dtype, device)?.unsqueeze(0)?;
895 let temp = tril
896 .matmul(&permutation_matrix_reverse.transpose(1, 2)?)?
897 .contiguous()?; let order_mask_backward = temp
899 .matmul(&permutation_matrix_reverse.transpose(1, 2)?)?
900 .contiguous()?; let mask_attend = order_mask_backward
902 .gather(&e_idx, 2)?
903 .unsqueeze(D::Minus1)?;
904
905 let mask_1d = mask
907 .unwrap()
908 .reshape((b, l, 1, 1))?
909 .broadcast_as(mask_attend.shape())?
910 .to_dtype(sample_dtype)?;
911
912 let mask_bw = mask_1d.mul(&mask_attend)?;
913 let mask_fw = mask_1d.mul(&(mask_attend - 1.0)?.neg()?)?;
914 (mask_fw, mask_bw, e_idx, decoding_order)
915 }
916 };
917
918 let s_true = s_true.repeat(&[b_decoder, 1])?;
919 let h_v = h_v.repeat(&[b_decoder, 1, 1])?;
920 let h_e = h_e.repeat(&[b_decoder, 1, 1, 1])?;
921 let mask = mask.as_ref().unwrap().repeat(&[b_decoder, 1])?;
922
923 let h_s = self.w_s.forward(&s_true)?; let h_es = cat_neighbors_nodes(&h_s, &h_e, &e_idx)?;
925
926 let h_ex_encoder = cat_neighbors_nodes(&Tensor::zeros_like(&h_s)?, &h_e, &e_idx)?;
928 let h_exv_encoder = cat_neighbors_nodes(&h_v, &h_ex_encoder, &e_idx)?;
929 let h_exv_encoder_fw = mask_fw
930 .broadcast_as(h_exv_encoder.shape())?
931 .to_dtype(h_exv_encoder.dtype())?
932 .mul(&h_exv_encoder)?;
933
934 let h_v = if !use_sequence {
936 self.decoder_layers.iter().fold(Ok(h_v), |acc, layer| {
938 layer.forward(&acc?, &h_exv_encoder_fw, Some(&mask), None, None)
939 })?
940 } else {
941 self.decoder_layers.iter().fold(Ok(h_v), |acc, layer| {
943 let current_h_v = acc?;
944 let h_esv = cat_neighbors_nodes(¤t_h_v, &h_es, &e_idx)?
945 .mul(&mask_bw)?
946 .add(&h_exv_encoder_fw)?;
947 layer.forward(¤t_h_v, &h_esv, Some(&mask), None, None)
948 })?
949 };
950
951 let logits = self.w_out.forward(&h_v)?;
952 let log_probs = log_softmax(&logits, D::Minus1)?;
953
954 Ok(ScoreOutput {
955 s: s_true,
956 log_probs,
957 logits,
958 decoding_order,
959 })
960 }
961}