ferritin_plms/esm/layers/geom_attention.rs
1use crate::esm::models::esmc::ESMCConfig;
2use candle_core::{Result, Tensor};
3use candle_nn::{self as nn, LayerNorm, LayerNormConfig, Linear, VarBuilder};
4
5pub struct GeometricReasoningOriginalImpl {
6 c_s: usize,
7 v_heads: usize,
8 num_vector_messages: usize,
9 mask_and_zero_frameless: bool,
10 s_norm: LayerNorm,
11 proj: Linear,
12 out_proj: Linear,
13 distance_scale_per_head: Tensor,
14 rotation_scale_per_head: Tensor,
15}
16
17impl GeometricReasoningOriginalImpl {
18 // pub fn new(
19 // c_s: i64,
20 // v_heads: i64,
21 // num_vector_messages: i64,
22 // mask_and_zero_frameless: bool,
23 // _divide_residual_by_depth: bool,
24 // bias: bool,
25 // device: &Device,
26 // ) -> Result<Self> {
27 // let dim_proj = 4 * v_heads * 3 + v_heads * 3 * num_vector_messages;
28 // let channels_out = v_heads * 3 * num_vector_messages;
29
30 // Ok(Self {
31 // c_s,
32 // v_heads,
33 // num_vector_messages,
34 // mask_and_zero_frameless,
35 // s_norm: LayerNorm::new(c_s, bias)?,
36 // proj: Linear::new(c_s, dim_proj, bias)?,
37 // out_proj: Linear::new(channels_out, c_s, bias)?,
38 // distance_scale_per_head: Tensor::zeros((v_heads,), device)?,
39 // rotation_scale_per_head: Tensor::zeros((v_heads,), device)?,
40 // })
41 // }
42 pub fn load(vb: VarBuilder, config: &ESMCConfig) -> Result<Self> {
43 let ESMCConfig {
44 d_model,
45 v_head_transformer,
46 mask_and_zero_frameless,
47 ..
48 } = config;
49
50 let num_vector_messages = 1usize;
51
52 // todo: this is a hidden param. Needs to be fixed
53 let v_heads = v_head_transformer.unwrap_or(128);
54
55 let dim_proj = 4 * v_heads * 3 + v_heads * 3 * num_vector_messages;
56 let channels_out = v_heads * 3 * num_vector_messages;
57
58 let ln_conf = LayerNormConfig::from(1e-5);
59 let s_norm = nn::layer_norm(*d_model, ln_conf, vb.pp("layer_norm"))?;
60
61 let proj = nn::linear(*d_model, dim_proj, vb.pp("linear1"))?;
62 let out_proj = nn::linear(channels_out, *d_model, vb.pp("outproj"))?;
63 let distance_scale_per_head = Tensor::zeros((v_heads,), vb.dtype(), vb.device())?;
64 let rotation_scale_per_head = Tensor::zeros((v_heads,), vb.dtype(), vb.device())?;
65
66 Ok(Self {
67 c_s: *d_model,
68 v_heads,
69 num_vector_messages,
70 mask_and_zero_frameless: *mask_and_zero_frameless,
71 s_norm,
72 proj,
73 out_proj,
74 distance_scale_per_head,
75 rotation_scale_per_head,
76 })
77 }
78
79 // pub fn forward(
80 // &self,
81 // s: &Tensor,
82 // affine: &Affine,
83 // affine_mask: &Tensor,
84 // sequence_id: Option<&Tensor>,
85 // chain_id: &Tensor,
86 // ) -> Result<Tensor> {
87 // let sequence_id = match sequence_id {
88 // Some(sid) => sid.clone(),
89 // None => Tensor::zeros_like(&s.slice(s.dims()? - 1, 0, 1)?)?,
90 // };
91
92 // let attn_bias = sequence_id.unsqueeze(-1)?.eq(&sequence_id.unsqueeze(-2)?)?;
93 // let attn_bias = attn_bias.unsqueeze(1)?.to_dtype(s.dtype())?;
94 // let attn_bias = attn_bias.masked_fill(
95 // &affine_mask.broadcast_left(3)?.logical_not()?,
96 // f32::NEG_INFINITY,
97 // )?;
98
99 // let chain_id_mask = chain_id.unsqueeze(1)?.ne(&chain_id.unsqueeze(2)?)?;
100 // let attn_bias = attn_bias.masked_fill(&chain_id_mask.unsqueeze(1)?, f32::NEG_INFINITY)?;
101
102 // let ns = self.s_norm.forward(s)?;
103 // let proj_out = self.proj.forward(&ns)?;
104
105 // let (vec_rot, vec_dist) = proj_out.split_at(
106 // -1,
107 // &[
108 // self.v_heads * 2 * 3 + self.v_heads * 3 * self.num_vector_messages,
109 // self.v_heads * 2 * 3,
110 // ],
111 // )?;
112
113 // let vec_rot = rearrange(&vec_rot, "... (h c) -> ... h c", &[("c", 3)])?;
114 // let rot_out = affine.rot.broadcast_right(1)?.apply(&vec_rot)?;
115 // let (query_rot, key_rot, value) = rot_out.split_at(
116 // -2,
117 // &[
118 // self.v_heads,
119 // self.v_heads,
120 // self.v_heads * self.num_vector_messages,
121 // ],
122 // )?;
123
124 // let vec_dist = rearrange(&vec_dist, "... (h c) -> ... h c", &[("c", 3)])?;
125 // let (query_dist, key_dist) = affine.broadcast_right(1)?.apply(&vec_dist)?.chunk(2, -2)?;
126
127 // let query_dist = rearrange(&query_dist, "b s h d -> b h s 1 d", &[])?;
128 // let key_dist = rearrange(&key_dist, "b s h d -> b h 1 s d", &[])?;
129 // let query_rot = rearrange(&query_rot, "b s h d -> b h s d", &[])?;
130 // let key_rot = rearrange(&key_rot, "b s h d -> b h d s", &[])?;
131 // let value = rearrange(
132 // &value,
133 // "b s (h m) d -> b h s (m d)",
134 // &[("m", self.num_vector_messages)],
135 // )?;
136
137 // let distance_term = query_dist.sub(&key_dist)?.norm_dim(-1, true)?.div(SQRT_3)?;
138 // let rotation_term = query_rot.matmul(&key_rot)?.div(SQRT_3)?;
139
140 // let distance_term_weight =
141 // rearrange(&self.distance_scale_per_head.softplus()?, "h -> h 1 1", &[])?;
142 // let rotation_term_weight =
143 // rearrange(&self.rotation_scale_per_head.softplus()?, "h -> h 1 1", &[])?;
144
145 // let mut attn_weight = rotation_term
146 // .mul(&rotation_term_weight)?
147 // .sub(&distance_term.mul(&distance_term_weight)?)?;
148
149 // if let Some(bias) = attn_bias {
150 // let s_q = attn_weight.size(2)?;
151 // let s_k = attn_weight.size(3)?;
152 // let _s_q = (bias.size(2)? - s_q).max(0);
153 // let _s_k = (bias.size(3)? - s_k).max(0);
154 // let bias = bias.slice(_s_q..bias.size(2)?, _s_k..bias.size(3)?)?;
155 // attn_weight = attn_weight.add(&bias)?;
156 // }
157
158 // let attn_weight = attn_weight.softmax(-1)?;
159 // let mut attn_out = attn_weight.matmul(&value)?;
160
161 // attn_out = affine.rot.broadcast_right(1)?.invert()?.apply(&rearrange(
162 // &attn_out,
163 // "b h s (m d) -> b s (h m) d",
164 // &[("m", self.num_vector_messages)],
165 // )?)?;
166
167 // let mut attn_out = rearrange(
168 // &attn_out,
169 // "b s (h m) d -> b s (h m d)",
170 // &[("m", self.num_vector_messages)],
171 // )?;
172
173 // if self.mask_and_zero_frameless {
174 // attn_out =
175 // attn_out.masked_fill(&affine_mask.broadcast_right(1)?.logical_not()?, 0.0)?;
176 // }
177
178 // self.out_proj.forward(&attn_out)
179 // }
180}