ferritin_plms/esm/layers/transformer_stack.rs
1use crate::esm::layers::blocks::UnifiedTransformerBlock;
2use crate::esm::models::esmc::ESMCConfig;
3// use crate::esm::utils::structure::affine3d::Affine3D;
4use candle_core::Result;
5use candle_nn::{self as nn, LayerNorm, LayerNormConfig};
6
7pub struct TransformerStack {
8 /*
9 A stack of transformer blocks used in the ESM-3 model. Each block is a UnifiedTransformerBlock,
10 which can either be geometric attention or standard multi-head attention.
11
12 Args:
13 d_model (i64): The dimensionality of the input and output feature vectors.
14 n_heads (i64): The number of attention heads.
15 v_heads (Option<i64>): The number of voting heads.
16 n_layers (i64): The number of transformer blocks in the stack.
17 n_layers_geom (i64, optional): The number of transformer blocks that use geometric attention.
18 scale_residue (bool, optional): Whether to scale the residue connections in each transformer block.
19 mask_and_zero_frameless (bool, optional): Whether to mask and zero frameless positions in the input.
20 Only applies in the geometric attention blocks, which is conditioned on the structure
21 */
22 blocks: Vec<UnifiedTransformerBlock>,
23 norm: LayerNorm,
24}
25
26impl TransformerStack {
27 pub fn load(vb: nn::VarBuilder, config: &ESMCConfig) -> Result<Self> {
28 let ESMCConfig {
29 d_model, n_layers, ..
30 } = config;
31
32 let mut blocks = Vec::with_capacity(*n_layers);
33 for i in 0..*n_layers {
34 blocks.push(UnifiedTransformerBlock::load(
35 vb.pp(format!("blocks.{}", i)),
36 config,
37 i,
38 )?);
39 }
40
41 // let ln_conf = LayerNormConfig::from(1e-5);
42 let ln_conf = LayerNormConfig {
43 eps: 1e-5,
44 remove_mean: true,
45 affine: false,
46 };
47 let norm = nn::layer_norm(*d_model, ln_conf, vb.pp("norm"))?;
48
49 Ok(Self { blocks, norm })
50 }
51 // pub fn new(
52 // d_model: i64,
53 // n_heads: i64,
54 // v_heads: Option<i64>,
55 // n_layers: i64,
56 // n_layers_geom: i64,
57 // scale_residue: bool,
58 // mask_and_zero_frameless: bool,
59 // bias: bool,
60 // qk_layernorm: bool,
61 // ffn_type: &str,
62 // expansion_ratio: f64,
63 // ) -> Result<Self> {
64 // let mut blocks = Vec::with_capacity(n_layers as usize);
65 // for i in 0..n_layers {
66 // blocks.push(UnifiedTransformerBlock::new(
67 // d_model,
68 // n_heads,
69 // v_heads,
70 // i < n_layers_geom,
71 // if scale_residue {
72 // (n_layers as f64 / 36.0).sqrt()
73 // } else {
74 // 1.0
75 // },
76 // expansion_ratio,
77 // mask_and_zero_frameless,
78 // bias,
79 // qk_layernorm,
80 // ffn_type,
81 // )?);
82 // }
83
84 // let norm = nn::LayerNorm::new(d_model, 1e-5, false)?;
85
86 // Ok(Self { blocks, norm })
87 // }
88
89 // pub fn forward(
90 // &self,
91 // x: &Tensor,
92 // sequence_id: Option<&Tensor>,
93 // affine: Option<&Affine3D>,
94 // affine_mask: Option<&Tensor>,
95 // chain_id: Option<&Tensor>,
96 // ) -> Result<(Tensor, Tensor)> {
97 // let mut x = x.clone();
98
99 // let chain_id = if chain_id.is_none() {
100 // let batch_dims = x.shape().split_last().unwrap().1;
101 // Tensor::ones(batch_dims, (x.device(), DType::I64))?
102 // } else {
103 // chain_id.unwrap().clone()
104 // };
105
106 // for block in self.blocks.iter() {
107 // x = block.forward(&x, sequence_id, affine, affine_mask, &chain_id)?;
108 // }
109
110 // let normalized = self.norm.forward(&x)?;
111 // Ok((normalized, x))
112 // }
113}