ferritin_plms/esm/layers/
regression_head.rs

1use crate::esm::models::esmc::ESMCConfig;
2use candle_core::Tensor;
3use candle_nn::{self as nn, LayerNormConfig, Module, Sequential, VarBuilder};
4
5pub struct RegressionHead {
6    model: Sequential,
7}
8
9impl RegressionHead {
10    // pub fn new(d_model: usize, output_dim: usize, hidden_dim: Option<usize>) -> candle_core::Result<Self> {
11    //     let hidden_dim = hidden_dim.unwrap_or(d_model);
12
13    //     let model = Sequential::new(vec![
14    //         Linear::new(d_model as usize, hidden_dim as usize)?.into(),
15    //         candle_nn::Activation::Gelu.into(),
16    //         candle_nn::LayerNorm::new(vec![hidden_dim])?.into(),
17    //         Linear::new(hidden_dim as usize, output_dim as usize)?.into(),
18    //     ]);
19
20    //     Ok(Self { model })
21    // }
22    pub fn load(vb: VarBuilder, config: &ESMCConfig) -> candle_core::Result<Self> {
23        let ESMCConfig {
24            d_model,
25            regression_head_output_dim,
26            regression_head_hidden_dim,
27            ..
28        } = config;
29
30        let linear1 = nn::linear(*d_model, *regression_head_hidden_dim, vb.pp("0"))?;
31        let gelu = candle_nn::Activation::Gelu;
32        let ln_conf = LayerNormConfig::from(1e-5);
33        let norm = nn::layer_norm(*regression_head_hidden_dim, ln_conf, vb.pp("2"))?;
34        let linear2 = nn::linear(
35            *regression_head_hidden_dim,
36            *regression_head_output_dim,
37            vb.pp("3"),
38        )?;
39
40        let model = nn::seq().add(linear1).add(gelu).add(norm).add(linear2);
41
42        Ok(Self { model })
43    }
44}
45
46impl Module for RegressionHead {
47    fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
48        self.model.forward(x)
49    }
50}