ferritin_plms/esm/layers/attention.rs
1use crate::esm::layers::rotary::RotaryEmbedding;
2use crate::esm::models::esmc::ESMCConfig;
3use candle_core::{Module, Result};
4use candle_nn::{self as nn, LayerNormConfig, VarBuilder};
5// use scaled_dot_product_attention;
6
7pub struct MultiHeadAttention {
8 d_model: usize,
9 n_heads: usize,
10 d_head: usize,
11 layernorm_qkv: nn::Sequential,
12 out_proj: nn::Linear,
13 q_ln: Box<dyn Module>,
14 k_ln: Box<dyn Module>,
15 rotary: RotaryEmbedding,
16}
17
18impl MultiHeadAttention {
19 // pub fn new(d_model: usize, n_heads: usize, bias: bool, qk_layernorm: bool) -> Result<Self> {
20 // let d_head = d_model / n_heads;
21
22 // let layernorm = nn::LayerNorm::new(d_model)?;
23 // let linear = nn::linear(d_model, d_model * 3, bias)?;
24 // let layernorm_qkv = nn::seq().add(layernorm).add(linear);
25
26 // let out_proj = nn::linear(d_model, d_model, bias)?;
27
28 // let (q_ln, k_ln): (Box<dyn Module>, Box<dyn Module>) = if qk_layernorm {
29 // (
30 // Box::new(nn::LayerNorm::new(d_model)?),
31 // Box::new(nn::LayerNorm::new(d_model)?),
32 // )
33 // } else {
34 // (Box::new(nn::Identity), Box::new(nn::Identity))
35 // };
36
37 // Ok(Self {
38 // d_model,
39 // n_heads,
40 // d_head,
41 // layernorm_qkv,
42 // out_proj,
43 // q_ln,
44 // k_ln,
45 // rotary: RotaryEmbedding::new(d_model / n_heads)?,
46 // })
47 // }
48 pub fn load(vb: VarBuilder, config: &ESMCConfig) -> Result<Self> {
49 let ESMCConfig {
50 d_model, n_heads, ..
51 } = config;
52
53 let d_head = d_model / n_heads;
54 // let ln_conf = LayerNormConfig::from(1e-5);
55 let ln_conf = LayerNormConfig {
56 eps: 1e-5,
57 remove_mean: true,
58 affine: false,
59 };
60 let layernorm = nn::layer_norm(*d_model, ln_conf, vb.pp("layernorm_qkv.0"))?;
61 let linear = nn::linear_no_bias(*d_model, d_model * 3, vb.pp("layernorm_qkv.1"))?;
62 let layernorm_qkv = nn::seq().add(layernorm).add(linear);
63 let out_proj = nn::linear_no_bias(*d_model, *d_model, vb.pp("out_proj"))?;
64 // note: only handling the True case for the moment
65 // let qk_layernorm = true
66 let q_ln = Box::new(nn::layer_norm(*d_model, ln_conf, vb.pp("q_ln"))?);
67 let k_ln = Box::new(nn::layer_norm(*d_model, ln_conf, vb.pp("k_ln"))?);
68
69 let rotary = RotaryEmbedding::load(vb.pp("rotary"), config)?;
70
71 Ok(Self {
72 d_model: *d_model,
73 n_heads: *n_heads,
74 d_head,
75 layernorm_qkv,
76 out_proj,
77 q_ln,
78 k_ln,
79 rotary,
80 })
81 }
82
83 // fn apply_rotary(&self, q: &Tensor, k: &Tensor) -> Result<(Tensor, Tensor)> {
84 // let q = q.reshape((-1, self.n_heads, self.d_head))?;
85 // let k = k.reshape((-1, self.n_heads, self.d_head))?;
86 // let (q, k) = self.rotary.forward(&q, &k)?;
87 // let q = q.flatten_from(1)?;
88 // let k = k.flatten_from(1)?;
89 // Ok((q, k))
90 // }
91
92 // pub fn forward(&self, x: &Tensor, seq_id: Option<&Tensor>) -> Result<Tensor> {
93 // let qkv = self.layernorm_qkv.forward(x)?;
94 // let chunks = qkv.chunk(3, -1)?;
95 // let (query, key, value) = (&chunks[0], &chunks[1], &chunks[2]);
96
97 // let query = self.q_ln.forward(query)?;
98 // let key = self.k_ln.forward(key)?;
99 // let (query, key) = self.apply_rotary(&query, &key)?;
100
101 // let query = query.reshape((query.dims()[0], self.n_heads, -1, self.d_head))?;
102 // let key = key.reshape((key.dims()[0], self.n_heads, -1, self.d_head))?;
103 // let value = value.reshape((value.dims()[0], self.n_heads, -1, self.d_head))?;
104
105 // let context = if let Some(seq_id) = seq_id {
106 // let mask = seq_id.unsqueeze(-1)?.eq(&seq_id.unsqueeze(-2)?)?;
107 // let mask = mask.unsqueeze(1)?;
108 // scaled_dot_product_attention(&query, &key, &value, Some(&mask))?
109 // } else {
110 // scaled_dot_product_attention(&query, &key, &value, None)?
111 // };
112
113 // let context = context.flatten_from(2)?;
114 // self.out_proj.forward(&context)
115 // }
116}