ferritin_plms/amplify/
config.rs

1use candle_nn::Activation;
2use serde::Deserialize;
3
4#[derive(Debug, Clone, Deserialize)]
5/// Configuration Struct for AMPLIFY
6///
7/// Currently only holds the weight params for
8/// those models found on GH: the 120M and 350M models.
9///
10pub struct AMPLIFYConfig {
11    pub hidden_size: usize,
12    pub num_hidden_layers: usize,
13    pub num_attention_heads: usize,
14    pub intermediate_size: usize,
15    pub dropout_prob: f64,
16    pub embedding_init_range: f64,
17    pub decoder_init_range: f64,
18    pub rms_norm: bool,
19    pub norm_eps: f64,
20    pub hidden_act: Activation,
21    pub layer_norm_after_embedding: bool,
22    pub layer_norm_before_last_layer: bool,
23    pub vocab_size: usize,
24    pub ffn_bias: bool,
25    pub att_bias: bool,
26    pub pad_token_id: usize,
27    pub max_length: usize,
28}
29
30impl Default for AMPLIFYConfig {
31    fn default() -> Self {
32        AMPLIFYConfig::amp_120m()
33    }
34}
35impl AMPLIFYConfig {
36    pub fn amp_120m() -> Self {
37        Self {
38            hidden_size: 640,
39            num_hidden_layers: 24,
40            num_attention_heads: 10,
41            intermediate_size: 2560,
42            dropout_prob: 0.0,
43            embedding_init_range: 0.02,
44            decoder_init_range: 0.02,
45            rms_norm: true,
46            norm_eps: 1e-5,
47            hidden_act: Activation::Swiglu,
48            layer_norm_after_embedding: false,
49            layer_norm_before_last_layer: true,
50            vocab_size: 27,
51            ffn_bias: false,
52            att_bias: false,
53            pad_token_id: 0,
54            max_length: 2048,
55        }
56    }
57    pub fn amp_350m() -> Self {
58        Self {
59            hidden_size: 960,
60            num_hidden_layers: 32,
61            num_attention_heads: 15,
62            intermediate_size: 3840,
63            dropout_prob: 0.0,
64            embedding_init_range: 0.02,
65            decoder_init_range: 0.02,
66            rms_norm: true,
67            norm_eps: 1e-5,
68            hidden_act: Activation::Swiglu,
69            layer_norm_after_embedding: false,
70            layer_norm_before_last_layer: true,
71            vocab_size: 27,
72            ffn_bias: false,
73            att_bias: false,
74            pad_token_id: 0,
75            max_length: 2048,
76        }
77    }
78}