ferritin_plms/esmc/layers/transformer_stack.rs
1use crate::esmc::layers::blocks::UnifiedTransformerBlock;
2use crate::esmc::models::esmc::ESMCConfig;
3// use crate::esmc::utils::structure::affine3d::Affine3D;
4use candle_core::{Module, Result, Tensor};
5use candle_nn::{self as nn, LayerNorm};
6
7pub struct TransformerStack {
8 /*
9 A stack of transformer blocks used in the ESM-3 model. Each block is a UnifiedTransformerBlock,
10 which can either be geometric attention or standard multi-head attention.
11
12 Args:
13 d_model (i64): The dimensionality of the input and output feature vectors.
14 n_heads (i64): The number of attention heads.
15 v_heads (Option<i64>): The number of voting heads.
16 n_layers (i64): The number of transformer blocks in the stack.
17 n_layers_geom (i64, optional): The number of transformer blocks that use geometric attention.
18 scale_residue (bool, optional): Whether to scale the residue connections in each transformer block.
19 mask_and_zero_frameless (bool, optional): Whether to mask and zero frameless positions in the input.
20 Only applies in the geometric attention blocks, which is conditioned on the structure
21 */
22 blocks: Vec<UnifiedTransformerBlock>,
23 norm: LayerNorm,
24}
25
26impl TransformerStack {
27 pub fn load(vb: nn::VarBuilder, config: &ESMCConfig) -> Result<Self> {
28 let ESMCConfig {
29 d_model, n_layers, ..
30 } = config;
31
32 let mut blocks = Vec::with_capacity(*n_layers);
33 for i in 0..*n_layers {
34 blocks.push(UnifiedTransformerBlock::load(
35 vb.pp(format!("blocks.{}", i)),
36 config,
37 i,
38 )?);
39 }
40
41 // transformer.norm has weight but no bias in the checkpoint.
42 let norm_weight = vb.pp("norm").get((*d_model,), "weight")?;
43 let norm = LayerNorm::new_no_bias(norm_weight, 1e-5);
44
45 Ok(Self { blocks, norm })
46 }
47 /// Run the full transformer stack.
48 ///
49 /// - `x`: `(B, L, d_model)`
50 /// - `sequence_id`: optional `(B, L)` boolean mask (True = real token, False = pad)
51 /// - `output_hidden_states`: if true, returns hidden state after every block as a Vec
52 ///
53 /// Returns `(final_hidden, hidden_states_per_layer_if_requested)`.
54 pub fn forward(
55 &self,
56 x: &Tensor,
57 sequence_id: Option<&Tensor>,
58 output_hidden_states: bool,
59 ) -> Result<(Tensor, Option<Vec<Tensor>>)> {
60 let mut x = x.clone();
61 let mut hidden_states: Option<Vec<Tensor>> = if output_hidden_states {
62 Some(Vec::with_capacity(self.blocks.len()))
63 } else {
64 None
65 };
66
67 for block in &self.blocks {
68 x = block.forward(&x, sequence_id)?;
69 if let Some(ref mut hs) = hidden_states {
70 hs.push(x.clone());
71 }
72 }
73
74 let x = self.norm.forward(&x)?;
75 Ok((x, hidden_states))
76 }
77
78 // pub fn new(
79 // d_model: i64,
80 // n_heads: i64,
81 // v_heads: Option<i64>,
82 // n_layers: i64,
83 // n_layers_geom: i64,
84 // scale_residue: bool,
85 // mask_and_zero_frameless: bool,
86 // bias: bool,
87 // qk_layernorm: bool,
88 // ffn_type: &str,
89 // expansion_ratio: f64,
90 // ) -> Result<Self> {
91 // let mut blocks = Vec::with_capacity(n_layers as usize);
92 // for i in 0..n_layers {
93 // blocks.push(UnifiedTransformerBlock::new(
94 // d_model,
95 // n_heads,
96 // v_heads,
97 // i < n_layers_geom,
98 // if scale_residue {
99 // (n_layers as f64 / 36.0).sqrt()
100 // } else {
101 // 1.0
102 // },
103 // expansion_ratio,
104 // mask_and_zero_frameless,
105 // bias,
106 // qk_layernorm,
107 // ffn_type,
108 // )?);
109 // }
110
111 // let norm = nn::LayerNorm::new(d_model, 1e-5, false)?;
112
113 // Ok(Self { blocks, norm })
114 // }
115
116 // pub fn forward(
117 // &self,
118 // x: &Tensor,
119 // sequence_id: Option<&Tensor>,
120 // affine: Option<&Affine3D>,
121 // affine_mask: Option<&Tensor>,
122 // chain_id: Option<&Tensor>,
123 // ) -> Result<(Tensor, Tensor)> {
124 // let mut x = x.clone();
125
126 // let chain_id = if chain_id.is_none() {
127 // let batch_dims = x.shape().split_last().unwrap().1;
128 // Tensor::ones(batch_dims, (x.device(), DType::I64))?
129 // } else {
130 // chain_id.unwrap().clone()
131 // };
132
133 // for block in self.blocks.iter() {
134 // x = block.forward(&x, sequence_id, affine, affine_mask, &chain_id)?;
135 // }
136
137 // let normalized = self.norm.forward(&x)?;
138 // Ok((normalized, x))
139 // }
140}