ferritin_plms/esmc/layers/rotary.rs
1use crate::esmc::models::esmc::ESMCConfig;
2use candle_core::{D, Device, Result, Tensor};
3use candle_nn::VarBuilder;
4
5// NOTE: This implementation is based on LLaMA 2's rotary embeddings
6// fn rotate_half(x: &Tensor, interleaved: bool) -> Result<Tensor> {
7// if !interleaved {
8// let (x1, x2) = x.chunk(2, -1)?;
9// let neg_x2 = x2.neg();
10// Tensor::cat(&[&neg_x2, &x1], -1)
11// } else {
12// let x1 = x.index_select_along_dim(x.ndim() - 1, 0, 2)?;
13// let x2 = x.index_select_along_dim(x.ndim() - 1, 1, 2)?;
14// let neg_x2 = x2.neg();
15// let stacked = Tensor::stack(&[&neg_x2, &x1], -1)?;
16// stacked.flatten_from(-2)
17// }
18// }
19
20// fn apply_rotary_emb(x: &Tensor, cos: &Tensor, sin: &Tensor, interleaved: bool) -> Result<Tensor> {
21// let ro_dim = cos.dim(1)? * 2;
22// let (d1, d2, d3, d4) = x.dims4()?;
23// assert!(ro_dim <= d4);
24
25// let seqlen = d2;
26// let cos = cos.narrow(0, 0, seqlen)?;
27// let sin = sin.narrow(0, 0, seqlen)?;
28
29// let cos = cos.unsqueeze(1)?.repeat((1, 1, 2))?;
30// let sin = sin.unsqueeze(1)?.repeat((1, 1, 2))?;
31
32// let x_rot = x.narrow(-1, 0, ro_dim)?;
33// let x_pass = x.narrow(-1, ro_dim, d4 - ro_dim)?;
34
35// let x_rotated = rotate_half(&x_rot, interleaved)?;
36// let x_rot_out = (x_rot * &cos)? + (x_rotated * &sin)?;
37
38// Tensor::cat(&[&x_rot_out, &x_pass], -1)
39// }
40
41#[allow(dead_code)]
42pub struct RotaryEmbedding {
43 dim: usize,
44 base: f64,
45 interleaved: bool,
46 // scale_base: Option<f64>,
47 scaling_factor: f64,
48 seq_len_cached: usize,
49 cos_cached: Option<Tensor>,
50 sin_cached: Option<Tensor>,
51 cos_k_cached: Option<Tensor>,
52 sin_k_cached: Option<Tensor>,
53 inv_freq: Tensor,
54 scale: Option<Tensor>,
55}
56
57impl RotaryEmbedding {
58 // pub fn new(
59 // dim: usize,
60 // device: &Device,
61 // base: f64,
62 // interleaved: bool,
63 // scale_base: Option<f64>,
64 // scaling_factor: f64,
65 // ) -> Result<Self> {
66 // // self,
67 // // dim: int,
68 // // base=10000.0,
69 // // interleaved=False,
70 // // scale_base=None,
71 // // scaling_factor=1.0,
72 // // pos_idx_in_fp32=True,
73 // // device=None,
74
75 // let inv_freq = Self::compute_inv_freq(dim, base, device)?;
76
77 // let scale = if let Some(scale_base) = scale_base {
78 // let arange = Tensor::arange(0., dim as f64, 2., device)?;
79 // let scale = (arange + 0.4 * dim as f64) / (1.4 * dim as f64);
80 // Some(scale)
81 // } else {
82 // None
83 // };
84
85 // Ok(Self {
86 // dim,
87 // base,
88 // interleaved,
89 // scale_base,
90 // scaling_factor,
91 // seq_len_cached: 0,
92 // cos_cached: None,
93 // sin_cached: None,
94 // cos_k_cached: None,
95 // sin_k_cached: None,
96 // inv_freq,
97 // scale,
98 // })
99 // }
100 pub fn load(vb: VarBuilder, config: &ESMCConfig) -> Result<Self> {
101 let ESMCConfig {
102 d_model, n_heads, ..
103 } = config;
104
105 let rotary_dims = d_model / n_heads;
106 let base = 10000.0;
107 let device = vb.device();
108 let interleaved = false;
109 let scaling_factor = 1.0;
110 // scale_base=None,
111 // scaling_factor=1.0,
112 // pos_idx_in_fp32=True,
113
114 let inv_freq = Self::compute_inv_freq(rotary_dims, base, device)?;
115 // Build scale tensor in F32; candle operator overloads only accept f64 scalars.
116 let arange = Tensor::arange(0u32, (rotary_dims / 2) as u32, device)?
117 .to_dtype(candle_core::DType::F32)?
118 * 2.0f64;
119 let scale = {
120 let numerator = (&arange? + (0.4 * rotary_dims as f64))?;
121 let denominator = 1.4 * rotary_dims as f64;
122 numerator / denominator
123 };
124
125 Ok(Self {
126 dim: rotary_dims,
127 base,
128 interleaved,
129 // scale_base,
130 scaling_factor,
131 seq_len_cached: 0,
132 cos_cached: None,
133 sin_cached: None,
134 cos_k_cached: None,
135 sin_k_cached: None,
136 inv_freq,
137 scale: Some(scale?),
138 })
139 }
140
141 /// Compute cos/sin for `seqlen` positions using `self.inv_freq`.
142 /// Returns `(cos, sin)` each of shape `(seqlen, d_head)`.
143 fn compute_cos_sin(&self, seqlen: usize) -> Result<(Tensor, Tensor)> {
144 let device = self.inv_freq.device();
145 // positions: (seqlen,)
146 let t = Tensor::arange(0u32, seqlen as u32, device)?.to_dtype(candle_core::DType::F32)?;
147 // freqs: (seqlen, d_head/2)
148 let freqs = t.unsqueeze(1)?.matmul(&self.inv_freq.unsqueeze(0)?)?;
149 // repeat to (seqlen, d_head) by cat([freqs, freqs], dim=1)
150 let emb = Tensor::cat(&[&freqs, &freqs], 1)?;
151 Ok((emb.cos()?, emb.sin()?))
152 }
153
154 /// Apply rotary embeddings to query and key tensors.
155 /// q, k: `(B, n_heads, L, d_head)`
156 /// Returns rotated `(q, k)` of the same shapes.
157 pub fn forward(&self, q: &Tensor, k: &Tensor) -> Result<(Tensor, Tensor)> {
158 let seqlen = q.dim(2)?;
159 let (cos, sin) = self.compute_cos_sin(seqlen)?;
160 // cos/sin: (seqlen, d_head) → broadcast to (1, 1, seqlen, d_head)
161 let cos = cos.unsqueeze(0)?.unsqueeze(0)?;
162 let sin = sin.unsqueeze(0)?.unsqueeze(0)?;
163 let q_rot = Self::apply_rotary(q, &cos, &sin)?;
164 let k_rot = Self::apply_rotary(k, &cos, &sin)?;
165 Ok((q_rot, k_rot))
166 }
167
168 /// Rotate x by cos/sin: x_rot = x * cos + rotate_half(x) * sin
169 /// For non-interleaved: rotate_half([x1, x2]) = [-x2, x1]
170 fn apply_rotary(x: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
171 let d = x.dim(D::Minus1)?;
172 let half = d / 2;
173 let x1 = x.narrow(D::Minus1, 0, half)?;
174 let x2 = x.narrow(D::Minus1, half, half)?;
175 let x_rotated = Tensor::cat(&[&x2.neg()?, &x1], D::Minus1)?;
176 x.broadcast_mul(cos)? + x_rotated.broadcast_mul(sin)?
177 }
178
179 fn compute_inv_freq(rotary_dims: usize, base: f64, device: &Device) -> Result<Tensor> {
180 // Emit f32 values so inv_freq stays F32 and avoids dtype
181 // mismatches when matmul'd with the F32 position tensor.
182 Tensor::from_iter(
183 (0..rotary_dims)
184 .step_by(2)
185 .map(|i| i as f32 / rotary_dims as f32)
186 .map(|theta| base.powf(-theta as f64) as f32),
187 device,
188 )
189 }
190
191 // fn update_cos_sin_cache(&mut self, seqlen: usize) -> Result<()> {
192 // if seqlen > self.seq_len_cached || self.cos_cached.is_none() {
193 // self.seq_len_cached = seqlen;
194
195 // let t = (Tensor::arange(0., seqlen as f64, 1., self.inv_freq.device())?)
196 // / self.scaling_factor;
197 // let freqs = t.outer(&self.inv_freq)?;
198
199 // if self.scale.is_none() {
200 // self.cos_cached = Some(freqs.cos()?);
201 // self.sin_cached = Some(freqs.sin()?);
202 // } else {
203 // let scale = self.scale.as_ref().unwrap();
204 // let power = ((Tensor::arange(0., seqlen as f64, 1., scale.device())?
205 // - (seqlen / 2) as f64)
206 // / self.scale_base.unwrap())?;
207 // let scale = scale.pow(&power.unsqueeze(-1)?)?;
208
209 // let cos = freqs.cos()?;
210 // let sin = freqs.sin()?;
211
212 // self.cos_cached = Some((&cos * &scale)?);
213 // self.sin_cached = Some((&sin * &scale)?);
214 // self.cos_k_cached = Some((&cos / &scale)?);
215 // self.sin_k_cached = Some((&sin / &scale)?);
216 // }
217 // }
218 // Ok(())
219 // }
220
221 // pub fn forward(
222 // &mut self,
223 // q: &Tensor,
224 // k: &Tensor,
225 // seqlen_offset: usize,
226 // ) -> Result<(Tensor, Tensor)> {
227 // let seqlen = q.dim(1)? + seqlen_offset;
228 // self.update_cos_sin_cache(seqlen)?;
229
230 // if self.scale.is_none() {
231 // let cos = self
232 // .cos_cached
233 // .as_ref()
234 // .unwrap()
235 // .narrow(0, seqlen_offset, q.dim(1)?)?;
236 // let sin = self
237 // .sin_cached
238 // .as_ref()
239 // .unwrap()
240 // .narrow(0, seqlen_offset, q.dim(1)?)?;
241
242 // let q_out = apply_rotary_emb(q, &cos, &sin, self.interleaved)?;
243 // let k_out = apply_rotary_emb(k, &cos, &sin, self.interleaved)?;
244
245 // Ok((q_out, k_out))
246 // } else {
247 // panic!("Scaled rotary embeddings not implemented");
248 // }
249 // }
250}