Skip to main content

ferritin_plms/esmc/layers/
transformer_stack.rs

1use crate::esmc::layers::blocks::UnifiedTransformerBlock;
2use crate::esmc::models::esmc::ESMCConfig;
3// use crate::esmc::utils::structure::affine3d::Affine3D;
4use candle_core::{Module, Result, Tensor};
5use candle_nn::{self as nn, LayerNorm};
6
7pub struct TransformerStack {
8    /*
9    A stack of transformer blocks used in the ESM-3 model. Each block is a UnifiedTransformerBlock,
10    which can either be geometric attention or standard multi-head attention.
11
12    Args:
13        d_model (i64): The dimensionality of the input and output feature vectors.
14        n_heads (i64): The number of attention heads.
15        v_heads (Option<i64>): The number of voting heads.
16        n_layers (i64): The number of transformer blocks in the stack.
17        n_layers_geom (i64, optional): The number of transformer blocks that use geometric attention.
18        scale_residue (bool, optional): Whether to scale the residue connections in each transformer block.
19        mask_and_zero_frameless (bool, optional): Whether to mask and zero frameless positions in the input.
20            Only applies in the geometric attention blocks, which is conditioned on the structure
21    */
22    blocks: Vec<UnifiedTransformerBlock>,
23    norm: LayerNorm,
24}
25
26impl TransformerStack {
27    pub fn load(vb: nn::VarBuilder, config: &ESMCConfig) -> Result<Self> {
28        let ESMCConfig {
29            d_model, n_layers, ..
30        } = config;
31
32        let mut blocks = Vec::with_capacity(*n_layers);
33        for i in 0..*n_layers {
34            blocks.push(UnifiedTransformerBlock::load(
35                vb.pp(format!("blocks.{}", i)),
36                config,
37                i,
38            )?);
39        }
40
41        // transformer.norm has weight but no bias in the checkpoint.
42        let norm_weight = vb.pp("norm").get((*d_model,), "weight")?;
43        let norm = LayerNorm::new_no_bias(norm_weight, 1e-5);
44
45        Ok(Self { blocks, norm })
46    }
47    /// Run the full transformer stack.
48    ///
49    /// - `x`: `(B, L, d_model)`
50    /// - `sequence_id`: optional `(B, L)` boolean mask (True = real token, False = pad)
51    /// - `output_hidden_states`: if true, returns hidden state after every block as a Vec
52    ///
53    /// Returns `(final_hidden, hidden_states_per_layer_if_requested)`.
54    pub fn forward(
55        &self,
56        x: &Tensor,
57        sequence_id: Option<&Tensor>,
58        output_hidden_states: bool,
59    ) -> Result<(Tensor, Option<Vec<Tensor>>)> {
60        let mut x = x.clone();
61        let mut hidden_states: Option<Vec<Tensor>> = if output_hidden_states {
62            Some(Vec::with_capacity(self.blocks.len()))
63        } else {
64            None
65        };
66
67        for block in &self.blocks {
68            x = block.forward(&x, sequence_id)?;
69            if let Some(ref mut hs) = hidden_states {
70                hs.push(x.clone());
71            }
72        }
73
74        let x = self.norm.forward(&x)?;
75        Ok((x, hidden_states))
76    }
77
78    // pub fn new(
79    //     d_model: i64,
80    //     n_heads: i64,
81    //     v_heads: Option<i64>,
82    //     n_layers: i64,
83    //     n_layers_geom: i64,
84    //     scale_residue: bool,
85    //     mask_and_zero_frameless: bool,
86    //     bias: bool,
87    //     qk_layernorm: bool,
88    //     ffn_type: &str,
89    //     expansion_ratio: f64,
90    // ) -> Result<Self> {
91    //     let mut blocks = Vec::with_capacity(n_layers as usize);
92    //     for i in 0..n_layers {
93    //         blocks.push(UnifiedTransformerBlock::new(
94    //             d_model,
95    //             n_heads,
96    //             v_heads,
97    //             i < n_layers_geom,
98    //             if scale_residue {
99    //                 (n_layers as f64 / 36.0).sqrt()
100    //             } else {
101    //                 1.0
102    //             },
103    //             expansion_ratio,
104    //             mask_and_zero_frameless,
105    //             bias,
106    //             qk_layernorm,
107    //             ffn_type,
108    //         )?);
109    //     }
110
111    //     let norm = nn::LayerNorm::new(d_model, 1e-5, false)?;
112
113    //     Ok(Self { blocks, norm })
114    // }
115
116    // pub fn forward(
117    //     &self,
118    //     x: &Tensor,
119    //     sequence_id: Option<&Tensor>,
120    //     affine: Option<&Affine3D>,
121    //     affine_mask: Option<&Tensor>,
122    //     chain_id: Option<&Tensor>,
123    // ) -> Result<(Tensor, Tensor)> {
124    //     let mut x = x.clone();
125
126    //     let chain_id = if chain_id.is_none() {
127    //         let batch_dims = x.shape().split_last().unwrap().1;
128    //         Tensor::ones(batch_dims, (x.device(), DType::I64))?
129    //     } else {
130    //         chain_id.unwrap().clone()
131    //     };
132
133    //     for block in self.blocks.iter() {
134    //         x = block.forward(&x, sequence_id, affine, affine_mask, &chain_id)?;
135    //     }
136
137    //     let normalized = self.norm.forward(&x)?;
138    //     Ok((normalized, x))
139    // }
140}