ferritin_plms/esm/layers/blocks.rs
1use super::attention::MultiHeadAttention;
2use super::geom_attention::GeometricReasoningOriginalImpl;
3use crate::esm::models::esmc::{ESMCConfig, FfnType};
4// use crate::esm::utils::structure::affine3d::Affine3D;
5use candle_core::{D, Module, Result, Tensor};
6use candle_nn::{self as nn, VarBuilder};
7
8pub struct SwiGLU {
9 layer_norm: nn::LayerNorm,
10 linear1: nn::Linear,
11 linear2: nn::Linear,
12}
13
14impl SwiGLU {
15 fn swiglu_correction_fn(expansion_ratio: f64, d_model: usize) -> usize {
16 // set hidden dimension to nearest multiple of 256 after expansion ratio
17 ((expansion_ratio * d_model as f64 + 255.0) / 256.0).floor() as usize * 256
18 }
19
20 pub fn load(vb: VarBuilder, config: &ESMCConfig) -> Result<Self> {
21 let ESMCConfig {
22 d_model,
23 expansion_ratio,
24 ..
25 } = config;
26
27 let hidden_dim = Self::swiglu_correction_fn(*expansion_ratio, *d_model);
28
29 Ok(Self {
30 layer_norm: nn::layer_norm(*d_model, 1e-5, vb.pp("0"))?,
31 linear1: nn::linear_no_bias(*d_model, hidden_dim * 2, vb.pp("1"))?,
32 linear2: nn::linear_no_bias(hidden_dim, *d_model, vb.pp("3"))?,
33 })
34 }
35}
36
37impl Module for SwiGLU {
38 fn forward(&self, x: &Tensor) -> Result<Tensor> {
39 let x = self.layer_norm.forward(x)?;
40 let x = self.linear1.forward(&x)?;
41 let chunks = x.chunk(2, D::Minus1)?;
42 let x1 = &chunks[0];
43 let x2 = &chunks[1];
44 let x = x1.silu()? * x2;
45 self.linear2.forward(&x?)
46 }
47}
48
49pub struct UnifiedTransformerBlock {
50 use_plain_attn: bool,
51 attn: Option<MultiHeadAttention>,
52 use_geom_attn: bool,
53 geom_attn: Option<GeometricReasoningOriginalImpl>,
54 ffn: SwiGLU,
55 scaling_factor: f64,
56}
57
58impl UnifiedTransformerBlock {
59 /// Creates a new UnifiedTransformerBlock.
60 ///
61 /// # Parameters
62 /// - d_model: The dimensionality of the input and output features
63 /// - n_heads: The number of attention heads
64 /// - use_geom_attn: Whether to use geometric attention
65 /// - use_plain_attn: Whether to use plain attention
66 /// - v_heads: Number of heads for geometric attention
67 // pub fn new(
68 // d_model: i64,
69 // n_heads: i64,
70 // use_geom_attn: bool,
71 // use_plain_attn: bool,
72 // v_heads: Option<i64>,
73 // bias: bool,
74 // expansion_ratio: f64,
75 // residue_scaling_factor: f64,
76 // mask_and_zero_frameless: bool,
77 // qk_layernorm: bool,
78 // ffn_type: &str,
79 // ) -> Result<Self> {
80 // let attn = if use_plain_attn {
81 // Some(MultiHeadAttention::new(
82 // d_model,
83 // n_heads,
84 // bias,
85 // qk_layernorm,
86 // )?)
87 // } else {
88 // None
89 // };
90
91 // let geom_attn = if use_geom_attn {
92 // match v_heads {
93 // Some(vh) => Some(GeometricReasoningOriginalImpl::new(
94 // d_model,
95 // vh,
96 // bias,
97 // mask_and_zero_frameless,
98 // )?),
99 // None => {
100 // return Err(candle_core::Error::Msg(
101 // "v_heads must be specified when use_geom_attn is True".into(),
102 // ))
103 // }
104 // }
105 // } else {
106 // None
107 // };
108
109 // let ffn = match ffn_type {
110 // "swiglu" => swiglu_ln_ffn(d_model, expansion_ratio, bias)?,
111 // "gelu" => gelu_ln_ffn(d_model, expansion_ratio, bias)?,
112 // _ => {
113 // return Err(candle_core::Error::Msg(format!(
114 // "Unknown ffn_type: {}",
115 // ffn_type
116 // )))
117 // }
118 // };
119
120 // Ok(Self {
121 // use_plain_attn,
122 // attn,
123 // use_geom_attn,
124 // geom_attn,
125 // ffn,
126 // scaling_factor: residue_scaling_factor,
127 // })
128 // }
129 pub fn load(vb: VarBuilder, config: &ESMCConfig, layer: usize) -> Result<Self> {
130 let ESMCConfig {
131 ffn_type,
132 v_head_transformer,
133 use_plain_attn,
134 n_layers_geom,
135 residue_scaling_factor,
136 ..
137 } = config;
138
139 let attn = match use_plain_attn {
140 false => None,
141 true => Some(MultiHeadAttention::load(vb.pp("attn"), config)?),
142 };
143
144 // println!("LAYER; GEOM: {}, {}", layer, n_layers_geom);
145 let use_geom_attn: bool = layer < *n_layers_geom;
146 // println!("Geom ATTN {}", use_geom_attn);
147 // let geom_attn = match use_geom_attn {
148 // false => None,
149 // true => Some(GeometricReasoningOriginalImpl::load(
150 // vb.pp("geometric"),
151 // config,
152 // )?),
153 // };
154
155 let geom_attn = None;
156
157 let ffn = match ffn_type {
158 FfnType::SWIGLU => SwiGLU::load(vb.pp("ffn"), config)?,
159 _ => unimplemented!(), // FfnType::GLU => unimplemented!(),
160 };
161
162 Ok(Self {
163 use_plain_attn: *use_plain_attn,
164 attn,
165 use_geom_attn,
166 geom_attn,
167 ffn,
168 scaling_factor: *residue_scaling_factor,
169 })
170 }
171}
172
173// impl Module for UnifiedTransformerBlock {
174// fn forward(&self, x: &Tensor) -> Result<Tensor> {
175// let mut x = x.clone();
176
177// if self.use_plain_attn {
178// if let Some(attn) = &self.attn {
179// let r1 = attn.forward(&x)?;
180// x = &x + &(&r1 / self.scaling_factor)?;
181// }
182// }
183
184// if self.use_geom_attn {
185// if let Some(geom_attn) = &self.geom_attn {
186// let r2 = geom_attn.forward(&x)?;
187// x = &x + &(&r2 / self.scaling_factor)?;
188// }
189// }
190
191// let r3 = self.ffn.forward(&x)?;
192// let r3 = &r3 / self.scaling_factor;
193// Ok(&x + &r3)
194// }
195// }