ferritin_plms/esm/layers/
transformer_stack.rs

1use crate::esm::layers::blocks::UnifiedTransformerBlock;
2use crate::esm::models::esmc::ESMCConfig;
3// use crate::esm::utils::structure::affine3d::Affine3D;
4use candle_core::Result;
5use candle_nn::{self as nn, LayerNorm, LayerNormConfig};
6
7pub struct TransformerStack {
8    /*
9    A stack of transformer blocks used in the ESM-3 model. Each block is a UnifiedTransformerBlock,
10    which can either be geometric attention or standard multi-head attention.
11
12    Args:
13        d_model (i64): The dimensionality of the input and output feature vectors.
14        n_heads (i64): The number of attention heads.
15        v_heads (Option<i64>): The number of voting heads.
16        n_layers (i64): The number of transformer blocks in the stack.
17        n_layers_geom (i64, optional): The number of transformer blocks that use geometric attention.
18        scale_residue (bool, optional): Whether to scale the residue connections in each transformer block.
19        mask_and_zero_frameless (bool, optional): Whether to mask and zero frameless positions in the input.
20            Only applies in the geometric attention blocks, which is conditioned on the structure
21    */
22    blocks: Vec<UnifiedTransformerBlock>,
23    norm: LayerNorm,
24}
25
26impl TransformerStack {
27    pub fn load(vb: nn::VarBuilder, config: &ESMCConfig) -> Result<Self> {
28        let ESMCConfig {
29            d_model, n_layers, ..
30        } = config;
31
32        let mut blocks = Vec::with_capacity(*n_layers);
33        for i in 0..*n_layers {
34            blocks.push(UnifiedTransformerBlock::load(
35                vb.pp(format!("blocks.{}", i)),
36                config,
37                i,
38            )?);
39        }
40
41        // let ln_conf = LayerNormConfig::from(1e-5);
42        let ln_conf = LayerNormConfig {
43            eps: 1e-5,
44            remove_mean: true,
45            affine: false,
46        };
47        let norm = nn::layer_norm(*d_model, ln_conf, vb.pp("norm"))?;
48
49        Ok(Self { blocks, norm })
50    }
51    // pub fn new(
52    //     d_model: i64,
53    //     n_heads: i64,
54    //     v_heads: Option<i64>,
55    //     n_layers: i64,
56    //     n_layers_geom: i64,
57    //     scale_residue: bool,
58    //     mask_and_zero_frameless: bool,
59    //     bias: bool,
60    //     qk_layernorm: bool,
61    //     ffn_type: &str,
62    //     expansion_ratio: f64,
63    // ) -> Result<Self> {
64    //     let mut blocks = Vec::with_capacity(n_layers as usize);
65    //     for i in 0..n_layers {
66    //         blocks.push(UnifiedTransformerBlock::new(
67    //             d_model,
68    //             n_heads,
69    //             v_heads,
70    //             i < n_layers_geom,
71    //             if scale_residue {
72    //                 (n_layers as f64 / 36.0).sqrt()
73    //             } else {
74    //                 1.0
75    //             },
76    //             expansion_ratio,
77    //             mask_and_zero_frameless,
78    //             bias,
79    //             qk_layernorm,
80    //             ffn_type,
81    //         )?);
82    //     }
83
84    //     let norm = nn::LayerNorm::new(d_model, 1e-5, false)?;
85
86    //     Ok(Self { blocks, norm })
87    // }
88
89    // pub fn forward(
90    //     &self,
91    //     x: &Tensor,
92    //     sequence_id: Option<&Tensor>,
93    //     affine: Option<&Affine3D>,
94    //     affine_mask: Option<&Tensor>,
95    //     chain_id: Option<&Tensor>,
96    // ) -> Result<(Tensor, Tensor)> {
97    //     let mut x = x.clone();
98
99    //     let chain_id = if chain_id.is_none() {
100    //         let batch_dims = x.shape().split_last().unwrap().1;
101    //         Tensor::ones(batch_dims, (x.device(), DType::I64))?
102    //     } else {
103    //         chain_id.unwrap().clone()
104    //     };
105
106    //     for block in self.blocks.iter() {
107    //         x = block.forward(&x, sequence_id, affine, affine_mask, &chain_id)?;
108    //     }
109
110    //     let normalized = self.norm.forward(&x)?;
111    //     Ok((normalized, x))
112    // }
113}