Skip to main content

ferritin_plms/esmc/layers/
geom_attention.rs

1use crate::esmc::models::esmc::ESMCConfig;
2use candle_core::{Result, Tensor};
3use candle_nn::{self as nn, LayerNorm, LayerNormConfig, Linear, VarBuilder};
4
5#[allow(dead_code)]
6pub struct GeometricReasoningOriginalImpl {
7    c_s: usize,
8    v_heads: usize,
9    num_vector_messages: usize,
10    mask_and_zero_frameless: bool,
11    s_norm: LayerNorm,
12    proj: Linear,
13    out_proj: Linear,
14    distance_scale_per_head: Tensor,
15    rotation_scale_per_head: Tensor,
16}
17
18impl GeometricReasoningOriginalImpl {
19    // pub fn new(
20    //     c_s: i64,
21    //     v_heads: i64,
22    //     num_vector_messages: i64,
23    //     mask_and_zero_frameless: bool,
24    //     _divide_residual_by_depth: bool,
25    //     bias: bool,
26    //     device: &Device,
27    // ) -> Result<Self> {
28    //     let dim_proj = 4 * v_heads * 3 + v_heads * 3 * num_vector_messages;
29    //     let channels_out = v_heads * 3 * num_vector_messages;
30
31    //     Ok(Self {
32    //         c_s,
33    //         v_heads,
34    //         num_vector_messages,
35    //         mask_and_zero_frameless,
36    //         s_norm: LayerNorm::new(c_s, bias)?,
37    //         proj: Linear::new(c_s, dim_proj, bias)?,
38    //         out_proj: Linear::new(channels_out, c_s, bias)?,
39    //         distance_scale_per_head: Tensor::zeros((v_heads,), device)?,
40    //         rotation_scale_per_head: Tensor::zeros((v_heads,), device)?,
41    //     })
42    // }
43    pub fn load(vb: VarBuilder, config: &ESMCConfig) -> Result<Self> {
44        let ESMCConfig {
45            d_model,
46            v_head_transformer,
47            mask_and_zero_frameless,
48            ..
49        } = config;
50
51        let num_vector_messages = 1usize;
52
53        // todo: this is a hidden param. Needs to be fixed
54        let v_heads = v_head_transformer.unwrap_or(128);
55
56        let dim_proj = 4 * v_heads * 3 + v_heads * 3 * num_vector_messages;
57        let channels_out = v_heads * 3 * num_vector_messages;
58
59        let ln_conf = LayerNormConfig::from(1e-5);
60        let s_norm = nn::layer_norm(*d_model, ln_conf, vb.pp("layer_norm"))?;
61
62        let proj = nn::linear(*d_model, dim_proj, vb.pp("linear1"))?;
63        let out_proj = nn::linear(channels_out, *d_model, vb.pp("outproj"))?;
64        let distance_scale_per_head = Tensor::zeros((v_heads,), vb.dtype(), vb.device())?;
65        let rotation_scale_per_head = Tensor::zeros((v_heads,), vb.dtype(), vb.device())?;
66
67        Ok(Self {
68            c_s: *d_model,
69            v_heads,
70            num_vector_messages,
71            mask_and_zero_frameless: *mask_and_zero_frameless,
72            s_norm,
73            proj,
74            out_proj,
75            distance_scale_per_head,
76            rotation_scale_per_head,
77        })
78    }
79
80    // pub fn forward(
81    //     &self,
82    //     s: &Tensor,
83    //     affine: &Affine,
84    //     affine_mask: &Tensor,
85    //     sequence_id: Option<&Tensor>,
86    //     chain_id: &Tensor,
87    // ) -> Result<Tensor> {
88    //     let sequence_id = match sequence_id {
89    //         Some(sid) => sid.clone(),
90    //         None => Tensor::zeros_like(&s.slice(s.dims()? - 1, 0, 1)?)?,
91    //     };
92
93    //     let attn_bias = sequence_id.unsqueeze(-1)?.eq(&sequence_id.unsqueeze(-2)?)?;
94    //     let attn_bias = attn_bias.unsqueeze(1)?.to_dtype(s.dtype())?;
95    //     let attn_bias = attn_bias.masked_fill(
96    //         &affine_mask.broadcast_left(3)?.logical_not()?,
97    //         f32::NEG_INFINITY,
98    //     )?;
99
100    //     let chain_id_mask = chain_id.unsqueeze(1)?.ne(&chain_id.unsqueeze(2)?)?;
101    //     let attn_bias = attn_bias.masked_fill(&chain_id_mask.unsqueeze(1)?, f32::NEG_INFINITY)?;
102
103    //     let ns = self.s_norm.forward(s)?;
104    //     let proj_out = self.proj.forward(&ns)?;
105
106    //     let (vec_rot, vec_dist) = proj_out.split_at(
107    //         -1,
108    //         &[
109    //             self.v_heads * 2 * 3 + self.v_heads * 3 * self.num_vector_messages,
110    //             self.v_heads * 2 * 3,
111    //         ],
112    //     )?;
113
114    //     let vec_rot = rearrange(&vec_rot, "... (h c) -> ... h c", &[("c", 3)])?;
115    //     let rot_out = affine.rot.broadcast_right(1)?.apply(&vec_rot)?;
116    //     let (query_rot, key_rot, value) = rot_out.split_at(
117    //         -2,
118    //         &[
119    //             self.v_heads,
120    //             self.v_heads,
121    //             self.v_heads * self.num_vector_messages,
122    //         ],
123    //     )?;
124
125    //     let vec_dist = rearrange(&vec_dist, "... (h c) -> ... h c", &[("c", 3)])?;
126    //     let (query_dist, key_dist) = affine.broadcast_right(1)?.apply(&vec_dist)?.chunk(2, -2)?;
127
128    //     let query_dist = rearrange(&query_dist, "b s h d -> b h s 1 d", &[])?;
129    //     let key_dist = rearrange(&key_dist, "b s h d -> b h 1 s d", &[])?;
130    //     let query_rot = rearrange(&query_rot, "b s h d -> b h s d", &[])?;
131    //     let key_rot = rearrange(&key_rot, "b s h d -> b h d s", &[])?;
132    //     let value = rearrange(
133    //         &value,
134    //         "b s (h m) d -> b h s (m d)",
135    //         &[("m", self.num_vector_messages)],
136    //     )?;
137
138    //     let distance_term = query_dist.sub(&key_dist)?.norm_dim(-1, true)?.div(SQRT_3)?;
139    //     let rotation_term = query_rot.matmul(&key_rot)?.div(SQRT_3)?;
140
141    //     let distance_term_weight =
142    //         rearrange(&self.distance_scale_per_head.softplus()?, "h -> h 1 1", &[])?;
143    //     let rotation_term_weight =
144    //         rearrange(&self.rotation_scale_per_head.softplus()?, "h -> h 1 1", &[])?;
145
146    //     let mut attn_weight = rotation_term
147    //         .mul(&rotation_term_weight)?
148    //         .sub(&distance_term.mul(&distance_term_weight)?)?;
149
150    //     if let Some(bias) = attn_bias {
151    //         let s_q = attn_weight.size(2)?;
152    //         let s_k = attn_weight.size(3)?;
153    //         let _s_q = (bias.size(2)? - s_q).max(0);
154    //         let _s_k = (bias.size(3)? - s_k).max(0);
155    //         let bias = bias.slice(_s_q..bias.size(2)?, _s_k..bias.size(3)?)?;
156    //         attn_weight = attn_weight.add(&bias)?;
157    //     }
158
159    //     let attn_weight = attn_weight.softmax(-1)?;
160    //     let mut attn_out = attn_weight.matmul(&value)?;
161
162    //     attn_out = affine.rot.broadcast_right(1)?.invert()?.apply(&rearrange(
163    //         &attn_out,
164    //         "b h s (m d) -> b s (h m) d",
165    //         &[("m", self.num_vector_messages)],
166    //     )?)?;
167
168    //     let mut attn_out = rearrange(
169    //         &attn_out,
170    //         "b s (h m) d -> b s (h m d)",
171    //         &[("m", self.num_vector_messages)],
172    //     )?;
173
174    //     if self.mask_and_zero_frameless {
175    //         attn_out =
176    //             attn_out.masked_fill(&affine_mask.broadcast_right(1)?.logical_not()?, 0.0)?;
177    //     }
178
179    //     self.out_proj.forward(&attn_out)
180    // }
181}