ferritin_plms/amplify/
config.rs1use candle_nn::Activation;
2use serde::Deserialize;
3
4#[derive(Debug, Clone, Deserialize)]
5pub struct AMPLIFYConfig {
11 pub hidden_size: usize,
12 pub num_hidden_layers: usize,
13 pub num_attention_heads: usize,
14 pub intermediate_size: usize,
15 pub dropout_prob: f64,
16 pub embedding_init_range: f64,
17 pub decoder_init_range: f64,
18 pub rms_norm: bool,
19 pub norm_eps: f64,
20 pub hidden_act: Activation,
21 pub layer_norm_after_embedding: bool,
22 pub layer_norm_before_last_layer: bool,
23 pub vocab_size: usize,
24 pub ffn_bias: bool,
25 pub att_bias: bool,
26 pub pad_token_id: usize,
27 pub max_length: usize,
28}
29
30impl Default for AMPLIFYConfig {
31 fn default() -> Self {
32 AMPLIFYConfig::amp_120m()
33 }
34}
35impl AMPLIFYConfig {
36 pub fn amp_120m() -> Self {
37 Self {
38 hidden_size: 640,
39 num_hidden_layers: 24,
40 num_attention_heads: 10,
41 intermediate_size: 2560,
42 dropout_prob: 0.0,
43 embedding_init_range: 0.02,
44 decoder_init_range: 0.02,
45 rms_norm: true,
46 norm_eps: 1e-5,
47 hidden_act: Activation::Swiglu,
48 layer_norm_after_embedding: false,
49 layer_norm_before_last_layer: true,
50 vocab_size: 27,
51 ffn_bias: false,
52 att_bias: false,
53 pad_token_id: 0,
54 max_length: 2048,
55 }
56 }
57 pub fn amp_350m() -> Self {
58 Self {
59 hidden_size: 960,
60 num_hidden_layers: 32,
61 num_attention_heads: 15,
62 intermediate_size: 3840,
63 dropout_prob: 0.0,
64 embedding_init_range: 0.02,
65 decoder_init_range: 0.02,
66 rms_norm: true,
67 norm_eps: 1e-5,
68 hidden_act: Activation::Swiglu,
69 layer_norm_after_embedding: false,
70 layer_norm_before_last_layer: true,
71 vocab_size: 27,
72 ffn_bias: false,
73 att_bias: false,
74 pad_token_id: 0,
75 max_length: 2048,
76 }
77 }
78}