ferritin_plms/esm/layers/
geom_attention.rs

1use crate::esm::models::esmc::ESMCConfig;
2use candle_core::{Result, Tensor};
3use candle_nn::{self as nn, LayerNorm, LayerNormConfig, Linear, VarBuilder};
4
5pub struct GeometricReasoningOriginalImpl {
6    c_s: usize,
7    v_heads: usize,
8    num_vector_messages: usize,
9    mask_and_zero_frameless: bool,
10    s_norm: LayerNorm,
11    proj: Linear,
12    out_proj: Linear,
13    distance_scale_per_head: Tensor,
14    rotation_scale_per_head: Tensor,
15}
16
17impl GeometricReasoningOriginalImpl {
18    // pub fn new(
19    //     c_s: i64,
20    //     v_heads: i64,
21    //     num_vector_messages: i64,
22    //     mask_and_zero_frameless: bool,
23    //     _divide_residual_by_depth: bool,
24    //     bias: bool,
25    //     device: &Device,
26    // ) -> Result<Self> {
27    //     let dim_proj = 4 * v_heads * 3 + v_heads * 3 * num_vector_messages;
28    //     let channels_out = v_heads * 3 * num_vector_messages;
29
30    //     Ok(Self {
31    //         c_s,
32    //         v_heads,
33    //         num_vector_messages,
34    //         mask_and_zero_frameless,
35    //         s_norm: LayerNorm::new(c_s, bias)?,
36    //         proj: Linear::new(c_s, dim_proj, bias)?,
37    //         out_proj: Linear::new(channels_out, c_s, bias)?,
38    //         distance_scale_per_head: Tensor::zeros((v_heads,), device)?,
39    //         rotation_scale_per_head: Tensor::zeros((v_heads,), device)?,
40    //     })
41    // }
42    pub fn load(vb: VarBuilder, config: &ESMCConfig) -> Result<Self> {
43        let ESMCConfig {
44            d_model,
45            v_head_transformer,
46            mask_and_zero_frameless,
47            ..
48        } = config;
49
50        let num_vector_messages = 1usize;
51
52        // todo: this is a hidden param. Needs to be fixed
53        let v_heads = v_head_transformer.unwrap_or(128);
54
55        let dim_proj = 4 * v_heads * 3 + v_heads * 3 * num_vector_messages;
56        let channels_out = v_heads * 3 * num_vector_messages;
57
58        let ln_conf = LayerNormConfig::from(1e-5);
59        let s_norm = nn::layer_norm(*d_model, ln_conf, vb.pp("layer_norm"))?;
60
61        let proj = nn::linear(*d_model, dim_proj, vb.pp("linear1"))?;
62        let out_proj = nn::linear(channels_out, *d_model, vb.pp("outproj"))?;
63        let distance_scale_per_head = Tensor::zeros((v_heads,), vb.dtype(), vb.device())?;
64        let rotation_scale_per_head = Tensor::zeros((v_heads,), vb.dtype(), vb.device())?;
65
66        Ok(Self {
67            c_s: *d_model,
68            v_heads,
69            num_vector_messages,
70            mask_and_zero_frameless: *mask_and_zero_frameless,
71            s_norm,
72            proj,
73            out_proj,
74            distance_scale_per_head,
75            rotation_scale_per_head,
76        })
77    }
78
79    // pub fn forward(
80    //     &self,
81    //     s: &Tensor,
82    //     affine: &Affine,
83    //     affine_mask: &Tensor,
84    //     sequence_id: Option<&Tensor>,
85    //     chain_id: &Tensor,
86    // ) -> Result<Tensor> {
87    //     let sequence_id = match sequence_id {
88    //         Some(sid) => sid.clone(),
89    //         None => Tensor::zeros_like(&s.slice(s.dims()? - 1, 0, 1)?)?,
90    //     };
91
92    //     let attn_bias = sequence_id.unsqueeze(-1)?.eq(&sequence_id.unsqueeze(-2)?)?;
93    //     let attn_bias = attn_bias.unsqueeze(1)?.to_dtype(s.dtype())?;
94    //     let attn_bias = attn_bias.masked_fill(
95    //         &affine_mask.broadcast_left(3)?.logical_not()?,
96    //         f32::NEG_INFINITY,
97    //     )?;
98
99    //     let chain_id_mask = chain_id.unsqueeze(1)?.ne(&chain_id.unsqueeze(2)?)?;
100    //     let attn_bias = attn_bias.masked_fill(&chain_id_mask.unsqueeze(1)?, f32::NEG_INFINITY)?;
101
102    //     let ns = self.s_norm.forward(s)?;
103    //     let proj_out = self.proj.forward(&ns)?;
104
105    //     let (vec_rot, vec_dist) = proj_out.split_at(
106    //         -1,
107    //         &[
108    //             self.v_heads * 2 * 3 + self.v_heads * 3 * self.num_vector_messages,
109    //             self.v_heads * 2 * 3,
110    //         ],
111    //     )?;
112
113    //     let vec_rot = rearrange(&vec_rot, "... (h c) -> ... h c", &[("c", 3)])?;
114    //     let rot_out = affine.rot.broadcast_right(1)?.apply(&vec_rot)?;
115    //     let (query_rot, key_rot, value) = rot_out.split_at(
116    //         -2,
117    //         &[
118    //             self.v_heads,
119    //             self.v_heads,
120    //             self.v_heads * self.num_vector_messages,
121    //         ],
122    //     )?;
123
124    //     let vec_dist = rearrange(&vec_dist, "... (h c) -> ... h c", &[("c", 3)])?;
125    //     let (query_dist, key_dist) = affine.broadcast_right(1)?.apply(&vec_dist)?.chunk(2, -2)?;
126
127    //     let query_dist = rearrange(&query_dist, "b s h d -> b h s 1 d", &[])?;
128    //     let key_dist = rearrange(&key_dist, "b s h d -> b h 1 s d", &[])?;
129    //     let query_rot = rearrange(&query_rot, "b s h d -> b h s d", &[])?;
130    //     let key_rot = rearrange(&key_rot, "b s h d -> b h d s", &[])?;
131    //     let value = rearrange(
132    //         &value,
133    //         "b s (h m) d -> b h s (m d)",
134    //         &[("m", self.num_vector_messages)],
135    //     )?;
136
137    //     let distance_term = query_dist.sub(&key_dist)?.norm_dim(-1, true)?.div(SQRT_3)?;
138    //     let rotation_term = query_rot.matmul(&key_rot)?.div(SQRT_3)?;
139
140    //     let distance_term_weight =
141    //         rearrange(&self.distance_scale_per_head.softplus()?, "h -> h 1 1", &[])?;
142    //     let rotation_term_weight =
143    //         rearrange(&self.rotation_scale_per_head.softplus()?, "h -> h 1 1", &[])?;
144
145    //     let mut attn_weight = rotation_term
146    //         .mul(&rotation_term_weight)?
147    //         .sub(&distance_term.mul(&distance_term_weight)?)?;
148
149    //     if let Some(bias) = attn_bias {
150    //         let s_q = attn_weight.size(2)?;
151    //         let s_k = attn_weight.size(3)?;
152    //         let _s_q = (bias.size(2)? - s_q).max(0);
153    //         let _s_k = (bias.size(3)? - s_k).max(0);
154    //         let bias = bias.slice(_s_q..bias.size(2)?, _s_k..bias.size(3)?)?;
155    //         attn_weight = attn_weight.add(&bias)?;
156    //     }
157
158    //     let attn_weight = attn_weight.softmax(-1)?;
159    //     let mut attn_out = attn_weight.matmul(&value)?;
160
161    //     attn_out = affine.rot.broadcast_right(1)?.invert()?.apply(&rearrange(
162    //         &attn_out,
163    //         "b h s (m d) -> b s (h m) d",
164    //         &[("m", self.num_vector_messages)],
165    //     )?)?;
166
167    //     let mut attn_out = rearrange(
168    //         &attn_out,
169    //         "b s (h m) d -> b s (h m d)",
170    //         &[("m", self.num_vector_messages)],
171    //     )?;
172
173    //     if self.mask_and_zero_frameless {
174    //         attn_out =
175    //             attn_out.masked_fill(&affine_mask.broadcast_right(1)?.logical_not()?, 0.0)?;
176    //     }
177
178    //     self.out_proj.forward(&attn_out)
179    // }
180}