ferritin_plms/esm/layers/
blocks.rs

1use super::attention::MultiHeadAttention;
2use super::geom_attention::GeometricReasoningOriginalImpl;
3use crate::esm::models::esmc::{ESMCConfig, FfnType};
4// use crate::esm::utils::structure::affine3d::Affine3D;
5use candle_core::{D, Module, Result, Tensor};
6use candle_nn::{self as nn, VarBuilder};
7
8pub struct SwiGLU {
9    layer_norm: nn::LayerNorm,
10    linear1: nn::Linear,
11    linear2: nn::Linear,
12}
13
14impl SwiGLU {
15    fn swiglu_correction_fn(expansion_ratio: f64, d_model: usize) -> usize {
16        // set hidden dimension to nearest multiple of 256 after expansion ratio
17        ((expansion_ratio * d_model as f64 + 255.0) / 256.0).floor() as usize * 256
18    }
19
20    pub fn load(vb: VarBuilder, config: &ESMCConfig) -> Result<Self> {
21        let ESMCConfig {
22            d_model,
23            expansion_ratio,
24            ..
25        } = config;
26
27        let hidden_dim = Self::swiglu_correction_fn(*expansion_ratio, *d_model);
28
29        Ok(Self {
30            layer_norm: nn::layer_norm(*d_model, 1e-5, vb.pp("0"))?,
31            linear1: nn::linear_no_bias(*d_model, hidden_dim * 2, vb.pp("1"))?,
32            linear2: nn::linear_no_bias(hidden_dim, *d_model, vb.pp("3"))?,
33        })
34    }
35}
36
37impl Module for SwiGLU {
38    fn forward(&self, x: &Tensor) -> Result<Tensor> {
39        let x = self.layer_norm.forward(x)?;
40        let x = self.linear1.forward(&x)?;
41        let chunks = x.chunk(2, D::Minus1)?;
42        let x1 = &chunks[0];
43        let x2 = &chunks[1];
44        let x = x1.silu()? * x2;
45        self.linear2.forward(&x?)
46    }
47}
48
49pub struct UnifiedTransformerBlock {
50    use_plain_attn: bool,
51    attn: Option<MultiHeadAttention>,
52    use_geom_attn: bool,
53    geom_attn: Option<GeometricReasoningOriginalImpl>,
54    ffn: SwiGLU,
55    scaling_factor: f64,
56}
57
58impl UnifiedTransformerBlock {
59    /// Creates a new UnifiedTransformerBlock.
60    ///
61    /// # Parameters
62    /// - d_model: The dimensionality of the input and output features
63    /// - n_heads: The number of attention heads
64    /// - use_geom_attn: Whether to use geometric attention
65    /// - use_plain_attn: Whether to use plain attention
66    /// - v_heads: Number of heads for geometric attention
67    // pub fn new(
68    //     d_model: i64,
69    //     n_heads: i64,
70    //     use_geom_attn: bool,
71    //     use_plain_attn: bool,
72    //     v_heads: Option<i64>,
73    //     bias: bool,
74    //     expansion_ratio: f64,
75    //     residue_scaling_factor: f64,
76    //     mask_and_zero_frameless: bool,
77    //     qk_layernorm: bool,
78    //     ffn_type: &str,
79    // ) -> Result<Self> {
80    //     let attn = if use_plain_attn {
81    //         Some(MultiHeadAttention::new(
82    //             d_model,
83    //             n_heads,
84    //             bias,
85    //             qk_layernorm,
86    //         )?)
87    //     } else {
88    //         None
89    //     };
90
91    //     let geom_attn = if use_geom_attn {
92    //         match v_heads {
93    //             Some(vh) => Some(GeometricReasoningOriginalImpl::new(
94    //                 d_model,
95    //                 vh,
96    //                 bias,
97    //                 mask_and_zero_frameless,
98    //             )?),
99    //             None => {
100    //                 return Err(candle_core::Error::Msg(
101    //                     "v_heads must be specified when use_geom_attn is True".into(),
102    //                 ))
103    //             }
104    //         }
105    //     } else {
106    //         None
107    //     };
108
109    //     let ffn = match ffn_type {
110    //         "swiglu" => swiglu_ln_ffn(d_model, expansion_ratio, bias)?,
111    //         "gelu" => gelu_ln_ffn(d_model, expansion_ratio, bias)?,
112    //         _ => {
113    //             return Err(candle_core::Error::Msg(format!(
114    //                 "Unknown ffn_type: {}",
115    //                 ffn_type
116    //             )))
117    //         }
118    //     };
119
120    //     Ok(Self {
121    //         use_plain_attn,
122    //         attn,
123    //         use_geom_attn,
124    //         geom_attn,
125    //         ffn,
126    //         scaling_factor: residue_scaling_factor,
127    //     })
128    // }
129    pub fn load(vb: VarBuilder, config: &ESMCConfig, layer: usize) -> Result<Self> {
130        let ESMCConfig {
131            ffn_type,
132            v_head_transformer,
133            use_plain_attn,
134            n_layers_geom,
135            residue_scaling_factor,
136            ..
137        } = config;
138
139        let attn = match use_plain_attn {
140            false => None,
141            true => Some(MultiHeadAttention::load(vb.pp("attn"), config)?),
142        };
143
144        // println!("LAYER; GEOM: {}, {}", layer, n_layers_geom);
145        let use_geom_attn: bool = layer < *n_layers_geom;
146        // println!("Geom ATTN {}", use_geom_attn);
147        // let geom_attn = match use_geom_attn {
148        //     false => None,
149        //     true => Some(GeometricReasoningOriginalImpl::load(
150        //         vb.pp("geometric"),
151        //         config,
152        //     )?),
153        // };
154
155        let geom_attn = None;
156
157        let ffn = match ffn_type {
158            FfnType::SWIGLU => SwiGLU::load(vb.pp("ffn"), config)?,
159            _ => unimplemented!(), // FfnType::GLU => unimplemented!(),
160        };
161
162        Ok(Self {
163            use_plain_attn: *use_plain_attn,
164            attn,
165            use_geom_attn,
166            geom_attn,
167            ffn,
168            scaling_factor: *residue_scaling_factor,
169        })
170    }
171}
172
173// impl Module for UnifiedTransformerBlock {
174//     fn forward(&self, x: &Tensor) -> Result<Tensor> {
175//         let mut x = x.clone();
176
177//         if self.use_plain_attn {
178//             if let Some(attn) = &self.attn {
179//                 let r1 = attn.forward(&x)?;
180//                 x = &x + &(&r1 / self.scaling_factor)?;
181//             }
182//         }
183
184//         if self.use_geom_attn {
185//             if let Some(geom_attn) = &self.geom_attn {
186//                 let r2 = geom_attn.forward(&x)?;
187//                 x = &x + &(&r2 / self.scaling_factor)?;
188//             }
189//         }
190
191//         let r3 = self.ffn.forward(&x)?;
192//         let r3 = &r3 / self.scaling_factor;
193//         Ok(&x + &r3)
194//     }
195// }