ferritin_plms/esmc/layers/geom_attention.rs
1use crate::esmc::models::esmc::ESMCConfig;
2use candle_core::{Result, Tensor};
3use candle_nn::{self as nn, LayerNorm, LayerNormConfig, Linear, VarBuilder};
4
5#[allow(dead_code)]
6pub struct GeometricReasoningOriginalImpl {
7 c_s: usize,
8 v_heads: usize,
9 num_vector_messages: usize,
10 mask_and_zero_frameless: bool,
11 s_norm: LayerNorm,
12 proj: Linear,
13 out_proj: Linear,
14 distance_scale_per_head: Tensor,
15 rotation_scale_per_head: Tensor,
16}
17
18impl GeometricReasoningOriginalImpl {
19 // pub fn new(
20 // c_s: i64,
21 // v_heads: i64,
22 // num_vector_messages: i64,
23 // mask_and_zero_frameless: bool,
24 // _divide_residual_by_depth: bool,
25 // bias: bool,
26 // device: &Device,
27 // ) -> Result<Self> {
28 // let dim_proj = 4 * v_heads * 3 + v_heads * 3 * num_vector_messages;
29 // let channels_out = v_heads * 3 * num_vector_messages;
30
31 // Ok(Self {
32 // c_s,
33 // v_heads,
34 // num_vector_messages,
35 // mask_and_zero_frameless,
36 // s_norm: LayerNorm::new(c_s, bias)?,
37 // proj: Linear::new(c_s, dim_proj, bias)?,
38 // out_proj: Linear::new(channels_out, c_s, bias)?,
39 // distance_scale_per_head: Tensor::zeros((v_heads,), device)?,
40 // rotation_scale_per_head: Tensor::zeros((v_heads,), device)?,
41 // })
42 // }
43 pub fn load(vb: VarBuilder, config: &ESMCConfig) -> Result<Self> {
44 let ESMCConfig {
45 d_model,
46 v_head_transformer,
47 mask_and_zero_frameless,
48 ..
49 } = config;
50
51 let num_vector_messages = 1usize;
52
53 // todo: this is a hidden param. Needs to be fixed
54 let v_heads = v_head_transformer.unwrap_or(128);
55
56 let dim_proj = 4 * v_heads * 3 + v_heads * 3 * num_vector_messages;
57 let channels_out = v_heads * 3 * num_vector_messages;
58
59 let ln_conf = LayerNormConfig::from(1e-5);
60 let s_norm = nn::layer_norm(*d_model, ln_conf, vb.pp("layer_norm"))?;
61
62 let proj = nn::linear(*d_model, dim_proj, vb.pp("linear1"))?;
63 let out_proj = nn::linear(channels_out, *d_model, vb.pp("outproj"))?;
64 let distance_scale_per_head = Tensor::zeros((v_heads,), vb.dtype(), vb.device())?;
65 let rotation_scale_per_head = Tensor::zeros((v_heads,), vb.dtype(), vb.device())?;
66
67 Ok(Self {
68 c_s: *d_model,
69 v_heads,
70 num_vector_messages,
71 mask_and_zero_frameless: *mask_and_zero_frameless,
72 s_norm,
73 proj,
74 out_proj,
75 distance_scale_per_head,
76 rotation_scale_per_head,
77 })
78 }
79
80 // pub fn forward(
81 // &self,
82 // s: &Tensor,
83 // affine: &Affine,
84 // affine_mask: &Tensor,
85 // sequence_id: Option<&Tensor>,
86 // chain_id: &Tensor,
87 // ) -> Result<Tensor> {
88 // let sequence_id = match sequence_id {
89 // Some(sid) => sid.clone(),
90 // None => Tensor::zeros_like(&s.slice(s.dims()? - 1, 0, 1)?)?,
91 // };
92
93 // let attn_bias = sequence_id.unsqueeze(-1)?.eq(&sequence_id.unsqueeze(-2)?)?;
94 // let attn_bias = attn_bias.unsqueeze(1)?.to_dtype(s.dtype())?;
95 // let attn_bias = attn_bias.masked_fill(
96 // &affine_mask.broadcast_left(3)?.logical_not()?,
97 // f32::NEG_INFINITY,
98 // )?;
99
100 // let chain_id_mask = chain_id.unsqueeze(1)?.ne(&chain_id.unsqueeze(2)?)?;
101 // let attn_bias = attn_bias.masked_fill(&chain_id_mask.unsqueeze(1)?, f32::NEG_INFINITY)?;
102
103 // let ns = self.s_norm.forward(s)?;
104 // let proj_out = self.proj.forward(&ns)?;
105
106 // let (vec_rot, vec_dist) = proj_out.split_at(
107 // -1,
108 // &[
109 // self.v_heads * 2 * 3 + self.v_heads * 3 * self.num_vector_messages,
110 // self.v_heads * 2 * 3,
111 // ],
112 // )?;
113
114 // let vec_rot = rearrange(&vec_rot, "... (h c) -> ... h c", &[("c", 3)])?;
115 // let rot_out = affine.rot.broadcast_right(1)?.apply(&vec_rot)?;
116 // let (query_rot, key_rot, value) = rot_out.split_at(
117 // -2,
118 // &[
119 // self.v_heads,
120 // self.v_heads,
121 // self.v_heads * self.num_vector_messages,
122 // ],
123 // )?;
124
125 // let vec_dist = rearrange(&vec_dist, "... (h c) -> ... h c", &[("c", 3)])?;
126 // let (query_dist, key_dist) = affine.broadcast_right(1)?.apply(&vec_dist)?.chunk(2, -2)?;
127
128 // let query_dist = rearrange(&query_dist, "b s h d -> b h s 1 d", &[])?;
129 // let key_dist = rearrange(&key_dist, "b s h d -> b h 1 s d", &[])?;
130 // let query_rot = rearrange(&query_rot, "b s h d -> b h s d", &[])?;
131 // let key_rot = rearrange(&key_rot, "b s h d -> b h d s", &[])?;
132 // let value = rearrange(
133 // &value,
134 // "b s (h m) d -> b h s (m d)",
135 // &[("m", self.num_vector_messages)],
136 // )?;
137
138 // let distance_term = query_dist.sub(&key_dist)?.norm_dim(-1, true)?.div(SQRT_3)?;
139 // let rotation_term = query_rot.matmul(&key_rot)?.div(SQRT_3)?;
140
141 // let distance_term_weight =
142 // rearrange(&self.distance_scale_per_head.softplus()?, "h -> h 1 1", &[])?;
143 // let rotation_term_weight =
144 // rearrange(&self.rotation_scale_per_head.softplus()?, "h -> h 1 1", &[])?;
145
146 // let mut attn_weight = rotation_term
147 // .mul(&rotation_term_weight)?
148 // .sub(&distance_term.mul(&distance_term_weight)?)?;
149
150 // if let Some(bias) = attn_bias {
151 // let s_q = attn_weight.size(2)?;
152 // let s_k = attn_weight.size(3)?;
153 // let _s_q = (bias.size(2)? - s_q).max(0);
154 // let _s_k = (bias.size(3)? - s_k).max(0);
155 // let bias = bias.slice(_s_q..bias.size(2)?, _s_k..bias.size(3)?)?;
156 // attn_weight = attn_weight.add(&bias)?;
157 // }
158
159 // let attn_weight = attn_weight.softmax(-1)?;
160 // let mut attn_out = attn_weight.matmul(&value)?;
161
162 // attn_out = affine.rot.broadcast_right(1)?.invert()?.apply(&rearrange(
163 // &attn_out,
164 // "b h s (m d) -> b s (h m) d",
165 // &[("m", self.num_vector_messages)],
166 // )?)?;
167
168 // let mut attn_out = rearrange(
169 // &attn_out,
170 // "b s (h m) d -> b s (h m d)",
171 // &[("m", self.num_vector_messages)],
172 // )?;
173
174 // if self.mask_and_zero_frameless {
175 // attn_out =
176 // attn_out.masked_fill(&affine_mask.broadcast_right(1)?.logical_not()?, 0.0)?;
177 // }
178
179 // self.out_proj.forward(&attn_out)
180 // }
181}