ferritin_plms/esm/layers/rotary.rs
1use crate::esm::models::esmc::ESMCConfig;
2use candle_core::{Device, Result, Tensor};
3use candle_nn::VarBuilder;
4
5// NOTE: This implementation is based on LLaMA 2's rotary embeddings
6// fn rotate_half(x: &Tensor, interleaved: bool) -> Result<Tensor> {
7// if !interleaved {
8// let (x1, x2) = x.chunk(2, -1)?;
9// let neg_x2 = x2.neg();
10// Tensor::cat(&[&neg_x2, &x1], -1)
11// } else {
12// let x1 = x.index_select_along_dim(x.ndim() - 1, 0, 2)?;
13// let x2 = x.index_select_along_dim(x.ndim() - 1, 1, 2)?;
14// let neg_x2 = x2.neg();
15// let stacked = Tensor::stack(&[&neg_x2, &x1], -1)?;
16// stacked.flatten_from(-2)
17// }
18// }
19
20// fn apply_rotary_emb(x: &Tensor, cos: &Tensor, sin: &Tensor, interleaved: bool) -> Result<Tensor> {
21// let ro_dim = cos.dim(1)? * 2;
22// let (d1, d2, d3, d4) = x.dims4()?;
23// assert!(ro_dim <= d4);
24
25// let seqlen = d2;
26// let cos = cos.narrow(0, 0, seqlen)?;
27// let sin = sin.narrow(0, 0, seqlen)?;
28
29// let cos = cos.unsqueeze(1)?.repeat((1, 1, 2))?;
30// let sin = sin.unsqueeze(1)?.repeat((1, 1, 2))?;
31
32// let x_rot = x.narrow(-1, 0, ro_dim)?;
33// let x_pass = x.narrow(-1, ro_dim, d4 - ro_dim)?;
34
35// let x_rotated = rotate_half(&x_rot, interleaved)?;
36// let x_rot_out = (x_rot * &cos)? + (x_rotated * &sin)?;
37
38// Tensor::cat(&[&x_rot_out, &x_pass], -1)
39// }
40
41pub struct RotaryEmbedding {
42 dim: usize,
43 base: f64,
44 interleaved: bool,
45 // scale_base: Option<f64>,
46 scaling_factor: f64,
47 seq_len_cached: usize,
48 cos_cached: Option<Tensor>,
49 sin_cached: Option<Tensor>,
50 cos_k_cached: Option<Tensor>,
51 sin_k_cached: Option<Tensor>,
52 inv_freq: Tensor,
53 scale: Option<Tensor>,
54}
55
56impl RotaryEmbedding {
57 // pub fn new(
58 // dim: usize,
59 // device: &Device,
60 // base: f64,
61 // interleaved: bool,
62 // scale_base: Option<f64>,
63 // scaling_factor: f64,
64 // ) -> Result<Self> {
65 // // self,
66 // // dim: int,
67 // // base=10000.0,
68 // // interleaved=False,
69 // // scale_base=None,
70 // // scaling_factor=1.0,
71 // // pos_idx_in_fp32=True,
72 // // device=None,
73
74 // let inv_freq = Self::compute_inv_freq(dim, base, device)?;
75
76 // let scale = if let Some(scale_base) = scale_base {
77 // let arange = Tensor::arange(0., dim as f64, 2., device)?;
78 // let scale = (arange + 0.4 * dim as f64) / (1.4 * dim as f64);
79 // Some(scale)
80 // } else {
81 // None
82 // };
83
84 // Ok(Self {
85 // dim,
86 // base,
87 // interleaved,
88 // scale_base,
89 // scaling_factor,
90 // seq_len_cached: 0,
91 // cos_cached: None,
92 // sin_cached: None,
93 // cos_k_cached: None,
94 // sin_k_cached: None,
95 // inv_freq,
96 // scale,
97 // })
98 // }
99 pub fn load(vb: VarBuilder, config: &ESMCConfig) -> Result<Self> {
100 let ESMCConfig {
101 d_model, n_heads, ..
102 } = config;
103
104 let rotary_dims = d_model / n_heads;
105 let base = 10000.0;
106 let device = vb.device();
107 let interleaved = false;
108 let scaling_factor = 1.0;
109 // scale_base=None,
110 // scaling_factor=1.0,
111 // pos_idx_in_fp32=True,
112
113 let inv_freq = Self::compute_inv_freq(rotary_dims, base, device)?;
114 let arange = Tensor::arange(0., (rotary_dims as f64) / 2., device)? * 2.;
115 let scale = {
116 let numerator = (&arange? + (0.4 * rotary_dims as f64))?;
117 let denominator = 1.4 * rotary_dims as f64;
118 numerator / denominator
119 };
120
121 Ok(Self {
122 dim: rotary_dims,
123 base,
124 interleaved,
125 // scale_base,
126 scaling_factor,
127 seq_len_cached: 0,
128 cos_cached: None,
129 sin_cached: None,
130 cos_k_cached: None,
131 sin_k_cached: None,
132 inv_freq,
133 scale: Some(scale?),
134 })
135 }
136
137 fn compute_inv_freq(rotary_dims: usize, base: f64, device: &Device) -> Result<Tensor> {
138 Tensor::from_iter(
139 (0..rotary_dims)
140 .step_by(2)
141 .map(|i| i as f32 / rotary_dims as f32)
142 .map(|theta| base.powf(-theta as f64)),
143 device,
144 )
145 }
146
147 // fn update_cos_sin_cache(&mut self, seqlen: usize) -> Result<()> {
148 // if seqlen > self.seq_len_cached || self.cos_cached.is_none() {
149 // self.seq_len_cached = seqlen;
150
151 // let t = (Tensor::arange(0., seqlen as f64, 1., self.inv_freq.device())?)
152 // / self.scaling_factor;
153 // let freqs = t.outer(&self.inv_freq)?;
154
155 // if self.scale.is_none() {
156 // self.cos_cached = Some(freqs.cos()?);
157 // self.sin_cached = Some(freqs.sin()?);
158 // } else {
159 // let scale = self.scale.as_ref().unwrap();
160 // let power = ((Tensor::arange(0., seqlen as f64, 1., scale.device())?
161 // - (seqlen / 2) as f64)
162 // / self.scale_base.unwrap())?;
163 // let scale = scale.pow(&power.unsqueeze(-1)?)?;
164
165 // let cos = freqs.cos()?;
166 // let sin = freqs.sin()?;
167
168 // self.cos_cached = Some((&cos * &scale)?);
169 // self.sin_cached = Some((&sin * &scale)?);
170 // self.cos_k_cached = Some((&cos / &scale)?);
171 // self.sin_k_cached = Some((&sin / &scale)?);
172 // }
173 // }
174 // Ok(())
175 // }
176
177 // pub fn forward(
178 // &mut self,
179 // q: &Tensor,
180 // k: &Tensor,
181 // seqlen_offset: usize,
182 // ) -> Result<(Tensor, Tensor)> {
183 // let seqlen = q.dim(1)? + seqlen_offset;
184 // self.update_cos_sin_cache(seqlen)?;
185
186 // if self.scale.is_none() {
187 // let cos = self
188 // .cos_cached
189 // .as_ref()
190 // .unwrap()
191 // .narrow(0, seqlen_offset, q.dim(1)?)?;
192 // let sin = self
193 // .sin_cached
194 // .as_ref()
195 // .unwrap()
196 // .narrow(0, seqlen_offset, q.dim(1)?)?;
197
198 // let q_out = apply_rotary_emb(q, &cos, &sin, self.interleaved)?;
199 // let k_out = apply_rotary_emb(k, &cos, &sin, self.interleaved)?;
200
201 // Ok((q_out, k_out))
202 // } else {
203 // panic!("Scaled rotary embeddings not implemented");
204 // }
205 // }
206}