ferritin_plms/esmc/layers/attention.rs
1use crate::esmc::layers::rotary::RotaryEmbedding;
2use crate::esmc::models::esmc::ESMCConfig;
3use candle_core::{Module, Result, Tensor};
4use candle_nn::{self as nn, LayerNorm, LayerNormConfig, VarBuilder};
5// use scaled_dot_product_attention;
6
7pub struct MultiHeadAttention {
8 d_model: usize,
9 n_heads: usize,
10 d_head: usize,
11 layernorm_qkv: nn::Sequential,
12 out_proj: nn::Linear,
13 q_ln: Box<dyn Module>,
14 k_ln: Box<dyn Module>,
15 rotary: RotaryEmbedding,
16}
17
18impl MultiHeadAttention {
19 // pub fn new(d_model: usize, n_heads: usize, bias: bool, qk_layernorm: bool) -> Result<Self> {
20 // let d_head = d_model / n_heads;
21
22 // let layernorm = nn::LayerNorm::new(d_model)?;
23 // let linear = nn::linear(d_model, d_model * 3, bias)?;
24 // let layernorm_qkv = nn::seq().add(layernorm).add(linear);
25
26 // let out_proj = nn::linear(d_model, d_model, bias)?;
27
28 // let (q_ln, k_ln): (Box<dyn Module>, Box<dyn Module>) = if qk_layernorm {
29 // (
30 // Box::new(nn::LayerNorm::new(d_model)?),
31 // Box::new(nn::LayerNorm::new(d_model)?),
32 // )
33 // } else {
34 // (Box::new(nn::Identity), Box::new(nn::Identity))
35 // };
36
37 // Ok(Self {
38 // d_model,
39 // n_heads,
40 // d_head,
41 // layernorm_qkv,
42 // out_proj,
43 // q_ln,
44 // k_ln,
45 // rotary: RotaryEmbedding::new(d_model / n_heads)?,
46 // })
47 // }
48 pub fn load(vb: VarBuilder, config: &ESMCConfig) -> Result<Self> {
49 let ESMCConfig {
50 d_model, n_heads, ..
51 } = config;
52
53 let d_head = d_model / n_heads;
54
55 // layernorm_qkv.0 has both weight and bias in the checkpoint.
56 let layernorm = nn::layer_norm(
57 *d_model,
58 LayerNormConfig::from(1e-5),
59 vb.pp("layernorm_qkv.0"),
60 )?;
61 let linear = nn::linear_no_bias(*d_model, d_model * 3, vb.pp("layernorm_qkv.1"))?;
62 let layernorm_qkv = nn::seq().add(layernorm).add(linear);
63 let out_proj = nn::linear_no_bias(*d_model, *d_model, vb.pp("out_proj"))?;
64
65 // q_ln / k_ln have weight but no bias in the checkpoint — use new_no_bias.
66 let q_ln: Box<dyn Module> = {
67 let w = vb.pp("q_ln").get((*d_model,), "weight")?;
68 Box::new(LayerNorm::new_no_bias(w, 1e-5))
69 };
70 let k_ln: Box<dyn Module> = {
71 let w = vb.pp("k_ln").get((*d_model,), "weight")?;
72 Box::new(LayerNorm::new_no_bias(w, 1e-5))
73 };
74
75 let rotary = RotaryEmbedding::load(vb.pp("rotary"), config)?;
76
77 Ok(Self {
78 d_model: *d_model,
79 n_heads: *n_heads,
80 d_head,
81 layernorm_qkv,
82 out_proj,
83 q_ln,
84 k_ln,
85 rotary,
86 })
87 }
88
89 pub fn forward(&self, x: &Tensor, sequence_id: Option<&Tensor>) -> Result<Tensor> {
90 let (b, l, _) = x.dims3()?;
91 // QKV projection: (B, L, d_model) → (B, L, 3*d_model) → split
92 let qkv = self.layernorm_qkv.forward(x)?;
93 let chunks = qkv.chunk(3, candle_core::D::Minus1)?;
94 let (q, k, v) = (&chunks[0], &chunks[1], &chunks[2]);
95
96 // Per-head layer norms
97 let q = self.q_ln.forward(q)?;
98 let k = self.k_ln.forward(k)?;
99
100 // Reshape to (B, n_heads, L, d_head) for rotary + SDPA
101 let q = q
102 .reshape((b, l, self.n_heads, self.d_head))?
103 .transpose(1, 2)?;
104 let k = k
105 .reshape((b, l, self.n_heads, self.d_head))?
106 .transpose(1, 2)?;
107 let v = v
108 .reshape((b, l, self.n_heads, self.d_head))?
109 .transpose(1, 2)?;
110
111 // Apply rotary positional embeddings
112 let (q, k) = self.rotary.forward(&q, &k)?;
113
114 // Scaled dot-product attention.
115 // k must be made contiguous after transpose — Metal (and some CPU paths) require
116 // contiguous tensors for batched matmul.
117 let scale = (self.d_head as f64).sqrt().recip();
118 let k_t = k
119 .transpose(candle_core::D::Minus1, candle_core::D::Minus2)?
120 .contiguous()?;
121 let attn = (q.contiguous()?.matmul(&k_t)? * scale)?;
122
123 // Optional key-padding mask from sequence_id (True = real token, False = pad)
124 let attn = if let Some(seq_id) = sequence_id {
125 // seq_id: (B, L) bool-like; build (B, 1, 1, L) mask so padded keys are masked out
126 let mask = seq_id
127 .unsqueeze(1)?
128 .unsqueeze(1)?
129 .broadcast_as(attn.shape())?;
130 let neg_inf = (Tensor::ones_like(&attn)? * f64::NEG_INFINITY)?;
131 mask.where_cond(&attn, &neg_inf)?
132 } else {
133 attn
134 };
135
136 let attn = candle_nn::ops::softmax(&attn, candle_core::D::Minus1)?;
137
138 // Weighted sum over values, reshape back to (B, L, d_model)
139 let context = attn.matmul(&v.contiguous()?)?; // (B, n_heads, L, d_head)
140 let context = context
141 .transpose(1, 2)?
142 .contiguous()?
143 .reshape((b, l, self.d_model))?;
144
145 self.out_proj.forward(&context)
146 }
147
148 // fn apply_rotary(&self, q: &Tensor, k: &Tensor) -> Result<(Tensor, Tensor)> {
149 // let q = q.reshape((-1, self.n_heads, self.d_head))?;
150 // let k = k.reshape((-1, self.n_heads, self.d_head))?;
151 // let (q, k) = self.rotary.forward(&q, &k)?;
152 // let q = q.flatten_from(1)?;
153 // let k = k.flatten_from(1)?;
154 // Ok((q, k))
155 // }
156
157 // pub fn forward(&self, x: &Tensor, seq_id: Option<&Tensor>) -> Result<Tensor> {
158 // let qkv = self.layernorm_qkv.forward(x)?;
159 // let chunks = qkv.chunk(3, -1)?;
160 // let (query, key, value) = (&chunks[0], &chunks[1], &chunks[2]);
161
162 // let query = self.q_ln.forward(query)?;
163 // let key = self.k_ln.forward(key)?;
164 // let (query, key) = self.apply_rotary(&query, &key)?;
165
166 // let query = query.reshape((query.dims()[0], self.n_heads, -1, self.d_head))?;
167 // let key = key.reshape((key.dims()[0], self.n_heads, -1, self.d_head))?;
168 // let value = value.reshape((value.dims()[0], self.n_heads, -1, self.d_head))?;
169
170 // let context = if let Some(seq_id) = seq_id {
171 // let mask = seq_id.unsqueeze(-1)?.eq(&seq_id.unsqueeze(-2)?)?;
172 // let mask = mask.unsqueeze(1)?;
173 // scaled_dot_product_attention(&query, &key, &value, Some(&mask))?
174 // } else {
175 // scaled_dot_product_attention(&query, &key, &value, None)?
176 // };
177
178 // let context = context.flatten_from(2)?;
179 // self.out_proj.forward(&context)
180 // }
181}