ferritin_plms/esm/layers/
attention.rs

1use crate::esm::layers::rotary::RotaryEmbedding;
2use crate::esm::models::esmc::ESMCConfig;
3use candle_core::{Module, Result};
4use candle_nn::{self as nn, LayerNormConfig, VarBuilder};
5// use scaled_dot_product_attention;
6
7pub struct MultiHeadAttention {
8    d_model: usize,
9    n_heads: usize,
10    d_head: usize,
11    layernorm_qkv: nn::Sequential,
12    out_proj: nn::Linear,
13    q_ln: Box<dyn Module>,
14    k_ln: Box<dyn Module>,
15    rotary: RotaryEmbedding,
16}
17
18impl MultiHeadAttention {
19    // pub fn new(d_model: usize, n_heads: usize, bias: bool, qk_layernorm: bool) -> Result<Self> {
20    //     let d_head = d_model / n_heads;
21
22    //     let layernorm = nn::LayerNorm::new(d_model)?;
23    //     let linear = nn::linear(d_model, d_model * 3, bias)?;
24    //     let layernorm_qkv = nn::seq().add(layernorm).add(linear);
25
26    //     let out_proj = nn::linear(d_model, d_model, bias)?;
27
28    //     let (q_ln, k_ln): (Box<dyn Module>, Box<dyn Module>) = if qk_layernorm {
29    //         (
30    //             Box::new(nn::LayerNorm::new(d_model)?),
31    //             Box::new(nn::LayerNorm::new(d_model)?),
32    //         )
33    //     } else {
34    //         (Box::new(nn::Identity), Box::new(nn::Identity))
35    //     };
36
37    //     Ok(Self {
38    //         d_model,
39    //         n_heads,
40    //         d_head,
41    //         layernorm_qkv,
42    //         out_proj,
43    //         q_ln,
44    //         k_ln,
45    //         rotary: RotaryEmbedding::new(d_model / n_heads)?,
46    //     })
47    // }
48    pub fn load(vb: VarBuilder, config: &ESMCConfig) -> Result<Self> {
49        let ESMCConfig {
50            d_model, n_heads, ..
51        } = config;
52
53        let d_head = d_model / n_heads;
54        // let ln_conf = LayerNormConfig::from(1e-5);
55        let ln_conf = LayerNormConfig {
56            eps: 1e-5,
57            remove_mean: true,
58            affine: false,
59        };
60        let layernorm = nn::layer_norm(*d_model, ln_conf, vb.pp("layernorm_qkv.0"))?;
61        let linear = nn::linear_no_bias(*d_model, d_model * 3, vb.pp("layernorm_qkv.1"))?;
62        let layernorm_qkv = nn::seq().add(layernorm).add(linear);
63        let out_proj = nn::linear_no_bias(*d_model, *d_model, vb.pp("out_proj"))?;
64        // note: only handling the True case for the moment
65        // let  qk_layernorm = true
66        let q_ln = Box::new(nn::layer_norm(*d_model, ln_conf, vb.pp("q_ln"))?);
67        let k_ln = Box::new(nn::layer_norm(*d_model, ln_conf, vb.pp("k_ln"))?);
68
69        let rotary = RotaryEmbedding::load(vb.pp("rotary"), config)?;
70
71        Ok(Self {
72            d_model: *d_model,
73            n_heads: *n_heads,
74            d_head,
75            layernorm_qkv,
76            out_proj,
77            q_ln,
78            k_ln,
79            rotary,
80        })
81    }
82
83    // fn apply_rotary(&self, q: &Tensor, k: &Tensor) -> Result<(Tensor, Tensor)> {
84    //     let q = q.reshape((-1, self.n_heads, self.d_head))?;
85    //     let k = k.reshape((-1, self.n_heads, self.d_head))?;
86    //     let (q, k) = self.rotary.forward(&q, &k)?;
87    //     let q = q.flatten_from(1)?;
88    //     let k = k.flatten_from(1)?;
89    //     Ok((q, k))
90    // }
91
92    // pub fn forward(&self, x: &Tensor, seq_id: Option<&Tensor>) -> Result<Tensor> {
93    //     let qkv = self.layernorm_qkv.forward(x)?;
94    //     let chunks = qkv.chunk(3, -1)?;
95    //     let (query, key, value) = (&chunks[0], &chunks[1], &chunks[2]);
96
97    //     let query = self.q_ln.forward(query)?;
98    //     let key = self.k_ln.forward(key)?;
99    //     let (query, key) = self.apply_rotary(&query, &key)?;
100
101    //     let query = query.reshape((query.dims()[0], self.n_heads, -1, self.d_head))?;
102    //     let key = key.reshape((key.dims()[0], self.n_heads, -1, self.d_head))?;
103    //     let value = value.reshape((value.dims()[0], self.n_heads, -1, self.d_head))?;
104
105    //     let context = if let Some(seq_id) = seq_id {
106    //         let mask = seq_id.unsqueeze(-1)?.eq(&seq_id.unsqueeze(-2)?)?;
107    //         let mask = mask.unsqueeze(1)?;
108    //         scaled_dot_product_attention(&query, &key, &value, Some(&mask))?
109    //     } else {
110    //         scaled_dot_product_attention(&query, &key, &value, None)?
111    //     };
112
113    //     let context = context.flatten_from(2)?;
114    //     self.out_proj.forward(&context)
115    // }
116}