Skip to main content

ferritin_plms/esmc/layers/
rotary.rs

1use crate::esmc::models::esmc::ESMCConfig;
2use candle_core::{D, Device, Result, Tensor};
3use candle_nn::VarBuilder;
4
5// NOTE: This implementation is based on LLaMA 2's rotary embeddings
6// fn rotate_half(x: &Tensor, interleaved: bool) -> Result<Tensor> {
7//     if !interleaved {
8//         let (x1, x2) = x.chunk(2, -1)?;
9//         let neg_x2 = x2.neg();
10//         Tensor::cat(&[&neg_x2, &x1], -1)
11//     } else {
12//         let x1 = x.index_select_along_dim(x.ndim() - 1, 0, 2)?;
13//         let x2 = x.index_select_along_dim(x.ndim() - 1, 1, 2)?;
14//         let neg_x2 = x2.neg();
15//         let stacked = Tensor::stack(&[&neg_x2, &x1], -1)?;
16//         stacked.flatten_from(-2)
17//     }
18// }
19
20// fn apply_rotary_emb(x: &Tensor, cos: &Tensor, sin: &Tensor, interleaved: bool) -> Result<Tensor> {
21//     let ro_dim = cos.dim(1)? * 2;
22//     let (d1, d2, d3, d4) = x.dims4()?;
23//     assert!(ro_dim <= d4);
24
25//     let seqlen = d2;
26//     let cos = cos.narrow(0, 0, seqlen)?;
27//     let sin = sin.narrow(0, 0, seqlen)?;
28
29//     let cos = cos.unsqueeze(1)?.repeat((1, 1, 2))?;
30//     let sin = sin.unsqueeze(1)?.repeat((1, 1, 2))?;
31
32//     let x_rot = x.narrow(-1, 0, ro_dim)?;
33//     let x_pass = x.narrow(-1, ro_dim, d4 - ro_dim)?;
34
35//     let x_rotated = rotate_half(&x_rot, interleaved)?;
36//     let x_rot_out = (x_rot * &cos)? + (x_rotated * &sin)?;
37
38//     Tensor::cat(&[&x_rot_out, &x_pass], -1)
39// }
40
41#[allow(dead_code)]
42pub struct RotaryEmbedding {
43    dim: usize,
44    base: f64,
45    interleaved: bool,
46    // scale_base: Option<f64>,
47    scaling_factor: f64,
48    seq_len_cached: usize,
49    cos_cached: Option<Tensor>,
50    sin_cached: Option<Tensor>,
51    cos_k_cached: Option<Tensor>,
52    sin_k_cached: Option<Tensor>,
53    inv_freq: Tensor,
54    scale: Option<Tensor>,
55}
56
57impl RotaryEmbedding {
58    // pub fn new(
59    //     dim: usize,
60    //     device: &Device,
61    //     base: f64,
62    //     interleaved: bool,
63    //     scale_base: Option<f64>,
64    //     scaling_factor: f64,
65    // ) -> Result<Self> {
66    //     // self,
67    //     // dim: int,
68    //     // base=10000.0,
69    //     // interleaved=False,
70    //     // scale_base=None,
71    //     // scaling_factor=1.0,
72    //     // pos_idx_in_fp32=True,
73    //     // device=None,
74
75    //     let inv_freq = Self::compute_inv_freq(dim, base, device)?;
76
77    //     let scale = if let Some(scale_base) = scale_base {
78    //         let arange = Tensor::arange(0., dim as f64, 2., device)?;
79    //         let scale = (arange + 0.4 * dim as f64) / (1.4 * dim as f64);
80    //         Some(scale)
81    //     } else {
82    //         None
83    //     };
84
85    //     Ok(Self {
86    //         dim,
87    //         base,
88    //         interleaved,
89    //         scale_base,
90    //         scaling_factor,
91    //         seq_len_cached: 0,
92    //         cos_cached: None,
93    //         sin_cached: None,
94    //         cos_k_cached: None,
95    //         sin_k_cached: None,
96    //         inv_freq,
97    //         scale,
98    //     })
99    // }
100    pub fn load(vb: VarBuilder, config: &ESMCConfig) -> Result<Self> {
101        let ESMCConfig {
102            d_model, n_heads, ..
103        } = config;
104
105        let rotary_dims = d_model / n_heads;
106        let base = 10000.0;
107        let device = vb.device();
108        let interleaved = false;
109        let scaling_factor = 1.0;
110        // scale_base=None,
111        // scaling_factor=1.0,
112        // pos_idx_in_fp32=True,
113
114        let inv_freq = Self::compute_inv_freq(rotary_dims, base, device)?;
115        // Build scale tensor in F32; candle operator overloads only accept f64 scalars.
116        let arange = Tensor::arange(0u32, (rotary_dims / 2) as u32, device)?
117            .to_dtype(candle_core::DType::F32)?
118            * 2.0f64;
119        let scale = {
120            let numerator = (&arange? + (0.4 * rotary_dims as f64))?;
121            let denominator = 1.4 * rotary_dims as f64;
122            numerator / denominator
123        };
124
125        Ok(Self {
126            dim: rotary_dims,
127            base,
128            interleaved,
129            // scale_base,
130            scaling_factor,
131            seq_len_cached: 0,
132            cos_cached: None,
133            sin_cached: None,
134            cos_k_cached: None,
135            sin_k_cached: None,
136            inv_freq,
137            scale: Some(scale?),
138        })
139    }
140
141    /// Compute cos/sin for `seqlen` positions using `self.inv_freq`.
142    /// Returns `(cos, sin)` each of shape `(seqlen, d_head)`.
143    fn compute_cos_sin(&self, seqlen: usize) -> Result<(Tensor, Tensor)> {
144        let device = self.inv_freq.device();
145        // positions: (seqlen,)
146        let t = Tensor::arange(0u32, seqlen as u32, device)?.to_dtype(candle_core::DType::F32)?;
147        // freqs: (seqlen, d_head/2)
148        let freqs = t.unsqueeze(1)?.matmul(&self.inv_freq.unsqueeze(0)?)?;
149        // repeat to (seqlen, d_head) by cat([freqs, freqs], dim=1)
150        let emb = Tensor::cat(&[&freqs, &freqs], 1)?;
151        Ok((emb.cos()?, emb.sin()?))
152    }
153
154    /// Apply rotary embeddings to query and key tensors.
155    /// q, k: `(B, n_heads, L, d_head)`
156    /// Returns rotated `(q, k)` of the same shapes.
157    pub fn forward(&self, q: &Tensor, k: &Tensor) -> Result<(Tensor, Tensor)> {
158        let seqlen = q.dim(2)?;
159        let (cos, sin) = self.compute_cos_sin(seqlen)?;
160        // cos/sin: (seqlen, d_head) → broadcast to (1, 1, seqlen, d_head)
161        let cos = cos.unsqueeze(0)?.unsqueeze(0)?;
162        let sin = sin.unsqueeze(0)?.unsqueeze(0)?;
163        let q_rot = Self::apply_rotary(q, &cos, &sin)?;
164        let k_rot = Self::apply_rotary(k, &cos, &sin)?;
165        Ok((q_rot, k_rot))
166    }
167
168    /// Rotate x by cos/sin: x_rot = x * cos + rotate_half(x) * sin
169    /// For non-interleaved: rotate_half([x1, x2]) = [-x2, x1]
170    fn apply_rotary(x: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
171        let d = x.dim(D::Minus1)?;
172        let half = d / 2;
173        let x1 = x.narrow(D::Minus1, 0, half)?;
174        let x2 = x.narrow(D::Minus1, half, half)?;
175        let x_rotated = Tensor::cat(&[&x2.neg()?, &x1], D::Minus1)?;
176        x.broadcast_mul(cos)? + x_rotated.broadcast_mul(sin)?
177    }
178
179    fn compute_inv_freq(rotary_dims: usize, base: f64, device: &Device) -> Result<Tensor> {
180        // Emit f32 values so inv_freq stays F32 and avoids dtype
181        // mismatches when matmul'd with the F32 position tensor.
182        Tensor::from_iter(
183            (0..rotary_dims)
184                .step_by(2)
185                .map(|i| i as f32 / rotary_dims as f32)
186                .map(|theta| base.powf(-theta as f64) as f32),
187            device,
188        )
189    }
190
191    // fn update_cos_sin_cache(&mut self, seqlen: usize) -> Result<()> {
192    //     if seqlen > self.seq_len_cached || self.cos_cached.is_none() {
193    //         self.seq_len_cached = seqlen;
194
195    //         let t = (Tensor::arange(0., seqlen as f64, 1., self.inv_freq.device())?)
196    //             / self.scaling_factor;
197    //         let freqs = t.outer(&self.inv_freq)?;
198
199    //         if self.scale.is_none() {
200    //             self.cos_cached = Some(freqs.cos()?);
201    //             self.sin_cached = Some(freqs.sin()?);
202    //         } else {
203    //             let scale = self.scale.as_ref().unwrap();
204    //             let power = ((Tensor::arange(0., seqlen as f64, 1., scale.device())?
205    //                 - (seqlen / 2) as f64)
206    //                 / self.scale_base.unwrap())?;
207    //             let scale = scale.pow(&power.unsqueeze(-1)?)?;
208
209    //             let cos = freqs.cos()?;
210    //             let sin = freqs.sin()?;
211
212    //             self.cos_cached = Some((&cos * &scale)?);
213    //             self.sin_cached = Some((&sin * &scale)?);
214    //             self.cos_k_cached = Some((&cos / &scale)?);
215    //             self.sin_k_cached = Some((&sin / &scale)?);
216    //         }
217    //     }
218    //     Ok(())
219    // }
220
221    // pub fn forward(
222    //     &mut self,
223    //     q: &Tensor,
224    //     k: &Tensor,
225    //     seqlen_offset: usize,
226    // ) -> Result<(Tensor, Tensor)> {
227    //     let seqlen = q.dim(1)? + seqlen_offset;
228    //     self.update_cos_sin_cache(seqlen)?;
229
230    //     if self.scale.is_none() {
231    //         let cos = self
232    //             .cos_cached
233    //             .as_ref()
234    //             .unwrap()
235    //             .narrow(0, seqlen_offset, q.dim(1)?)?;
236    //         let sin = self
237    //             .sin_cached
238    //             .as_ref()
239    //             .unwrap()
240    //             .narrow(0, seqlen_offset, q.dim(1)?)?;
241
242    //         let q_out = apply_rotary_emb(q, &cos, &sin, self.interleaved)?;
243    //         let k_out = apply_rotary_emb(k, &cos, &sin, self.interleaved)?;
244
245    //         Ok((q_out, k_out))
246    //     } else {
247    //         panic!("Scaled rotary embeddings not implemented");
248    //     }
249    // }
250}