Skip to main content

ferritin_plms/esmc/layers/
attention.rs

1use crate::esmc::layers::rotary::RotaryEmbedding;
2use crate::esmc::models::esmc::ESMCConfig;
3use candle_core::{Module, Result, Tensor};
4use candle_nn::{self as nn, LayerNorm, LayerNormConfig, VarBuilder};
5// use scaled_dot_product_attention;
6
7pub struct MultiHeadAttention {
8    d_model: usize,
9    n_heads: usize,
10    d_head: usize,
11    layernorm_qkv: nn::Sequential,
12    out_proj: nn::Linear,
13    q_ln: Box<dyn Module>,
14    k_ln: Box<dyn Module>,
15    rotary: RotaryEmbedding,
16}
17
18impl MultiHeadAttention {
19    // pub fn new(d_model: usize, n_heads: usize, bias: bool, qk_layernorm: bool) -> Result<Self> {
20    //     let d_head = d_model / n_heads;
21
22    //     let layernorm = nn::LayerNorm::new(d_model)?;
23    //     let linear = nn::linear(d_model, d_model * 3, bias)?;
24    //     let layernorm_qkv = nn::seq().add(layernorm).add(linear);
25
26    //     let out_proj = nn::linear(d_model, d_model, bias)?;
27
28    //     let (q_ln, k_ln): (Box<dyn Module>, Box<dyn Module>) = if qk_layernorm {
29    //         (
30    //             Box::new(nn::LayerNorm::new(d_model)?),
31    //             Box::new(nn::LayerNorm::new(d_model)?),
32    //         )
33    //     } else {
34    //         (Box::new(nn::Identity), Box::new(nn::Identity))
35    //     };
36
37    //     Ok(Self {
38    //         d_model,
39    //         n_heads,
40    //         d_head,
41    //         layernorm_qkv,
42    //         out_proj,
43    //         q_ln,
44    //         k_ln,
45    //         rotary: RotaryEmbedding::new(d_model / n_heads)?,
46    //     })
47    // }
48    pub fn load(vb: VarBuilder, config: &ESMCConfig) -> Result<Self> {
49        let ESMCConfig {
50            d_model, n_heads, ..
51        } = config;
52
53        let d_head = d_model / n_heads;
54
55        // layernorm_qkv.0 has both weight and bias in the checkpoint.
56        let layernorm = nn::layer_norm(
57            *d_model,
58            LayerNormConfig::from(1e-5),
59            vb.pp("layernorm_qkv.0"),
60        )?;
61        let linear = nn::linear_no_bias(*d_model, d_model * 3, vb.pp("layernorm_qkv.1"))?;
62        let layernorm_qkv = nn::seq().add(layernorm).add(linear);
63        let out_proj = nn::linear_no_bias(*d_model, *d_model, vb.pp("out_proj"))?;
64
65        // q_ln / k_ln have weight but no bias in the checkpoint — use new_no_bias.
66        let q_ln: Box<dyn Module> = {
67            let w = vb.pp("q_ln").get((*d_model,), "weight")?;
68            Box::new(LayerNorm::new_no_bias(w, 1e-5))
69        };
70        let k_ln: Box<dyn Module> = {
71            let w = vb.pp("k_ln").get((*d_model,), "weight")?;
72            Box::new(LayerNorm::new_no_bias(w, 1e-5))
73        };
74
75        let rotary = RotaryEmbedding::load(vb.pp("rotary"), config)?;
76
77        Ok(Self {
78            d_model: *d_model,
79            n_heads: *n_heads,
80            d_head,
81            layernorm_qkv,
82            out_proj,
83            q_ln,
84            k_ln,
85            rotary,
86        })
87    }
88
89    pub fn forward(&self, x: &Tensor, sequence_id: Option<&Tensor>) -> Result<Tensor> {
90        let (b, l, _) = x.dims3()?;
91        // QKV projection: (B, L, d_model) → (B, L, 3*d_model) → split
92        let qkv = self.layernorm_qkv.forward(x)?;
93        let chunks = qkv.chunk(3, candle_core::D::Minus1)?;
94        let (q, k, v) = (&chunks[0], &chunks[1], &chunks[2]);
95
96        // Per-head layer norms
97        let q = self.q_ln.forward(q)?;
98        let k = self.k_ln.forward(k)?;
99
100        // Reshape to (B, n_heads, L, d_head) for rotary + SDPA
101        let q = q
102            .reshape((b, l, self.n_heads, self.d_head))?
103            .transpose(1, 2)?;
104        let k = k
105            .reshape((b, l, self.n_heads, self.d_head))?
106            .transpose(1, 2)?;
107        let v = v
108            .reshape((b, l, self.n_heads, self.d_head))?
109            .transpose(1, 2)?;
110
111        // Apply rotary positional embeddings
112        let (q, k) = self.rotary.forward(&q, &k)?;
113
114        // Scaled dot-product attention.
115        // k must be made contiguous after transpose — Metal (and some CPU paths) require
116        // contiguous tensors for batched matmul.
117        let scale = (self.d_head as f64).sqrt().recip();
118        let k_t = k
119            .transpose(candle_core::D::Minus1, candle_core::D::Minus2)?
120            .contiguous()?;
121        let attn = (q.contiguous()?.matmul(&k_t)? * scale)?;
122
123        // Optional key-padding mask from sequence_id (True = real token, False = pad)
124        let attn = if let Some(seq_id) = sequence_id {
125            // seq_id: (B, L) bool-like; build (B, 1, 1, L) mask so padded keys are masked out
126            let mask = seq_id
127                .unsqueeze(1)?
128                .unsqueeze(1)?
129                .broadcast_as(attn.shape())?;
130            let neg_inf = (Tensor::ones_like(&attn)? * f64::NEG_INFINITY)?;
131            mask.where_cond(&attn, &neg_inf)?
132        } else {
133            attn
134        };
135
136        let attn = candle_nn::ops::softmax(&attn, candle_core::D::Minus1)?;
137
138        // Weighted sum over values, reshape back to (B, L, d_model)
139        let context = attn.matmul(&v.contiguous()?)?; // (B, n_heads, L, d_head)
140        let context = context
141            .transpose(1, 2)?
142            .contiguous()?
143            .reshape((b, l, self.d_model))?;
144
145        self.out_proj.forward(&context)
146    }
147
148    // fn apply_rotary(&self, q: &Tensor, k: &Tensor) -> Result<(Tensor, Tensor)> {
149    //     let q = q.reshape((-1, self.n_heads, self.d_head))?;
150    //     let k = k.reshape((-1, self.n_heads, self.d_head))?;
151    //     let (q, k) = self.rotary.forward(&q, &k)?;
152    //     let q = q.flatten_from(1)?;
153    //     let k = k.flatten_from(1)?;
154    //     Ok((q, k))
155    // }
156
157    // pub fn forward(&self, x: &Tensor, seq_id: Option<&Tensor>) -> Result<Tensor> {
158    //     let qkv = self.layernorm_qkv.forward(x)?;
159    //     let chunks = qkv.chunk(3, -1)?;
160    //     let (query, key, value) = (&chunks[0], &chunks[1], &chunks[2]);
161
162    //     let query = self.q_ln.forward(query)?;
163    //     let key = self.k_ln.forward(key)?;
164    //     let (query, key) = self.apply_rotary(&query, &key)?;
165
166    //     let query = query.reshape((query.dims()[0], self.n_heads, -1, self.d_head))?;
167    //     let key = key.reshape((key.dims()[0], self.n_heads, -1, self.d_head))?;
168    //     let value = value.reshape((value.dims()[0], self.n_heads, -1, self.d_head))?;
169
170    //     let context = if let Some(seq_id) = seq_id {
171    //         let mask = seq_id.unsqueeze(-1)?.eq(&seq_id.unsqueeze(-2)?)?;
172    //         let mask = mask.unsqueeze(1)?;
173    //         scaled_dot_product_attention(&query, &key, &value, Some(&mask))?
174    //     } else {
175    //         scaled_dot_product_attention(&query, &key, &value, None)?
176    //     };
177
178    //     let context = context.flatten_from(2)?;
179    //     self.out_proj.forward(&context)
180    // }
181}