ferritin_plms/esm/layers/
regression_head.rs1use crate::esm::models::esmc::ESMCConfig;
2use candle_core::Tensor;
3use candle_nn::{self as nn, LayerNormConfig, Module, Sequential, VarBuilder};
4
5pub struct RegressionHead {
6 model: Sequential,
7}
8
9impl RegressionHead {
10 pub fn load(vb: VarBuilder, config: &ESMCConfig) -> candle_core::Result<Self> {
23 let ESMCConfig {
24 d_model,
25 regression_head_output_dim,
26 regression_head_hidden_dim,
27 ..
28 } = config;
29
30 let linear1 = nn::linear(*d_model, *regression_head_hidden_dim, vb.pp("0"))?;
31 let gelu = candle_nn::Activation::Gelu;
32 let ln_conf = LayerNormConfig::from(1e-5);
33 let norm = nn::layer_norm(*regression_head_hidden_dim, ln_conf, vb.pp("2"))?;
34 let linear2 = nn::linear(
35 *regression_head_hidden_dim,
36 *regression_head_output_dim,
37 vb.pp("3"),
38 )?;
39
40 let model = nn::seq().add(linear1).add(gelu).add(norm).add(linear2);
41
42 Ok(Self { model })
43 }
44}
45
46impl Module for RegressionHead {
47 fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
48 self.model.forward(x)
49 }
50}