ferritin_plms/esm/layers/
rotary.rs

1use crate::esm::models::esmc::ESMCConfig;
2use candle_core::{Device, Result, Tensor};
3use candle_nn::VarBuilder;
4
5// NOTE: This implementation is based on LLaMA 2's rotary embeddings
6// fn rotate_half(x: &Tensor, interleaved: bool) -> Result<Tensor> {
7//     if !interleaved {
8//         let (x1, x2) = x.chunk(2, -1)?;
9//         let neg_x2 = x2.neg();
10//         Tensor::cat(&[&neg_x2, &x1], -1)
11//     } else {
12//         let x1 = x.index_select_along_dim(x.ndim() - 1, 0, 2)?;
13//         let x2 = x.index_select_along_dim(x.ndim() - 1, 1, 2)?;
14//         let neg_x2 = x2.neg();
15//         let stacked = Tensor::stack(&[&neg_x2, &x1], -1)?;
16//         stacked.flatten_from(-2)
17//     }
18// }
19
20// fn apply_rotary_emb(x: &Tensor, cos: &Tensor, sin: &Tensor, interleaved: bool) -> Result<Tensor> {
21//     let ro_dim = cos.dim(1)? * 2;
22//     let (d1, d2, d3, d4) = x.dims4()?;
23//     assert!(ro_dim <= d4);
24
25//     let seqlen = d2;
26//     let cos = cos.narrow(0, 0, seqlen)?;
27//     let sin = sin.narrow(0, 0, seqlen)?;
28
29//     let cos = cos.unsqueeze(1)?.repeat((1, 1, 2))?;
30//     let sin = sin.unsqueeze(1)?.repeat((1, 1, 2))?;
31
32//     let x_rot = x.narrow(-1, 0, ro_dim)?;
33//     let x_pass = x.narrow(-1, ro_dim, d4 - ro_dim)?;
34
35//     let x_rotated = rotate_half(&x_rot, interleaved)?;
36//     let x_rot_out = (x_rot * &cos)? + (x_rotated * &sin)?;
37
38//     Tensor::cat(&[&x_rot_out, &x_pass], -1)
39// }
40
41pub struct RotaryEmbedding {
42    dim: usize,
43    base: f64,
44    interleaved: bool,
45    // scale_base: Option<f64>,
46    scaling_factor: f64,
47    seq_len_cached: usize,
48    cos_cached: Option<Tensor>,
49    sin_cached: Option<Tensor>,
50    cos_k_cached: Option<Tensor>,
51    sin_k_cached: Option<Tensor>,
52    inv_freq: Tensor,
53    scale: Option<Tensor>,
54}
55
56impl RotaryEmbedding {
57    // pub fn new(
58    //     dim: usize,
59    //     device: &Device,
60    //     base: f64,
61    //     interleaved: bool,
62    //     scale_base: Option<f64>,
63    //     scaling_factor: f64,
64    // ) -> Result<Self> {
65    //     // self,
66    //     // dim: int,
67    //     // base=10000.0,
68    //     // interleaved=False,
69    //     // scale_base=None,
70    //     // scaling_factor=1.0,
71    //     // pos_idx_in_fp32=True,
72    //     // device=None,
73
74    //     let inv_freq = Self::compute_inv_freq(dim, base, device)?;
75
76    //     let scale = if let Some(scale_base) = scale_base {
77    //         let arange = Tensor::arange(0., dim as f64, 2., device)?;
78    //         let scale = (arange + 0.4 * dim as f64) / (1.4 * dim as f64);
79    //         Some(scale)
80    //     } else {
81    //         None
82    //     };
83
84    //     Ok(Self {
85    //         dim,
86    //         base,
87    //         interleaved,
88    //         scale_base,
89    //         scaling_factor,
90    //         seq_len_cached: 0,
91    //         cos_cached: None,
92    //         sin_cached: None,
93    //         cos_k_cached: None,
94    //         sin_k_cached: None,
95    //         inv_freq,
96    //         scale,
97    //     })
98    // }
99    pub fn load(vb: VarBuilder, config: &ESMCConfig) -> Result<Self> {
100        let ESMCConfig {
101            d_model, n_heads, ..
102        } = config;
103
104        let rotary_dims = d_model / n_heads;
105        let base = 10000.0;
106        let device = vb.device();
107        let interleaved = false;
108        let scaling_factor = 1.0;
109        // scale_base=None,
110        // scaling_factor=1.0,
111        // pos_idx_in_fp32=True,
112
113        let inv_freq = Self::compute_inv_freq(rotary_dims, base, device)?;
114        let arange = Tensor::arange(0., (rotary_dims as f64) / 2., device)? * 2.;
115        let scale = {
116            let numerator = (&arange? + (0.4 * rotary_dims as f64))?;
117            let denominator = 1.4 * rotary_dims as f64;
118            numerator / denominator
119        };
120
121        Ok(Self {
122            dim: rotary_dims,
123            base,
124            interleaved,
125            // scale_base,
126            scaling_factor,
127            seq_len_cached: 0,
128            cos_cached: None,
129            sin_cached: None,
130            cos_k_cached: None,
131            sin_k_cached: None,
132            inv_freq,
133            scale: Some(scale?),
134        })
135    }
136
137    fn compute_inv_freq(rotary_dims: usize, base: f64, device: &Device) -> Result<Tensor> {
138        Tensor::from_iter(
139            (0..rotary_dims)
140                .step_by(2)
141                .map(|i| i as f32 / rotary_dims as f32)
142                .map(|theta| base.powf(-theta as f64)),
143            device,
144        )
145    }
146
147    // fn update_cos_sin_cache(&mut self, seqlen: usize) -> Result<()> {
148    //     if seqlen > self.seq_len_cached || self.cos_cached.is_none() {
149    //         self.seq_len_cached = seqlen;
150
151    //         let t = (Tensor::arange(0., seqlen as f64, 1., self.inv_freq.device())?)
152    //             / self.scaling_factor;
153    //         let freqs = t.outer(&self.inv_freq)?;
154
155    //         if self.scale.is_none() {
156    //             self.cos_cached = Some(freqs.cos()?);
157    //             self.sin_cached = Some(freqs.sin()?);
158    //         } else {
159    //             let scale = self.scale.as_ref().unwrap();
160    //             let power = ((Tensor::arange(0., seqlen as f64, 1., scale.device())?
161    //                 - (seqlen / 2) as f64)
162    //                 / self.scale_base.unwrap())?;
163    //             let scale = scale.pow(&power.unsqueeze(-1)?)?;
164
165    //             let cos = freqs.cos()?;
166    //             let sin = freqs.sin()?;
167
168    //             self.cos_cached = Some((&cos * &scale)?);
169    //             self.sin_cached = Some((&sin * &scale)?);
170    //             self.cos_k_cached = Some((&cos / &scale)?);
171    //             self.sin_k_cached = Some((&sin / &scale)?);
172    //         }
173    //     }
174    //     Ok(())
175    // }
176
177    // pub fn forward(
178    //     &mut self,
179    //     q: &Tensor,
180    //     k: &Tensor,
181    //     seqlen_offset: usize,
182    // ) -> Result<(Tensor, Tensor)> {
183    //     let seqlen = q.dim(1)? + seqlen_offset;
184    //     self.update_cos_sin_cache(seqlen)?;
185
186    //     if self.scale.is_none() {
187    //         let cos = self
188    //             .cos_cached
189    //             .as_ref()
190    //             .unwrap()
191    //             .narrow(0, seqlen_offset, q.dim(1)?)?;
192    //         let sin = self
193    //             .sin_cached
194    //             .as_ref()
195    //             .unwrap()
196    //             .narrow(0, seqlen_offset, q.dim(1)?)?;
197
198    //         let q_out = apply_rotary_emb(q, &cos, &sin, self.interleaved)?;
199    //         let k_out = apply_rotary_emb(k, &cos, &sin, self.interleaved)?;
200
201    //         Ok((q_out, k_out))
202    //     } else {
203    //         panic!("Scaled rotary embeddings not implemented");
204    //     }
205    // }
206}