from models import register_model
import torch.nn as nn
import torch
from models.base_model import BaseModel
from models.transformer_encoder_input import TransformerEncoderInput

@register_model("deep_nn_baseline")
class DebugModel(BaseModel):
    def __init__(self):
        super(DebugModel, self).__init__()

    def forward(self, inputs, src_key_mask=None, positions=None, rep_from_layer=-1):
        """
        ## These are purely to make it compatible with the transformer models forward pass

        src_key_mask: not used
        positions: not used
        rep_from_layer: not used
        """
        flattened_input = inputs.flatten(start_dim=1)
        #flattened_input = inputs.mean(axis=1) #Take the mean instead of concat
        if hasattr(self, "batch_norm"):
            flattened_input = self.batch_norm(flattened_input)
        out = self.linear_out(flattened_input)
        out = self.act_0(out)
        out = self.linear_out_1(out)
        out = self.act_1(out)
        out = self.linear_out_2(out)
        out = self.act_2(out)
        out = self.linear_out_3(out)
        out = self.act_3(out)
        out = self.linear_out_4(out)
        return out

    def build_model(self, cfg):
        self.cfg = cfg
        if cfg.get("batch_norm", False):
            self.batch_norm = nn.BatchNorm1d(cfg.input_dim)
        self.linear_out = nn.Linear(in_features=cfg.input_dim, out_features=512)
        self.act_0 = nn.GELU()
        self.linear_out_1 = nn.Linear(in_features=512, out_features=512)
        self.act_1 = nn.GELU()
        self.linear_out_2 = nn.Linear(in_features=512, out_features=512)
        self.act_2 = nn.GELU()
        self.linear_out_3 = nn.Linear(in_features=512, out_features=512)
        self.act_3 = nn.GELU()
        self.linear_out_4 = nn.Linear(in_features=512, out_features=1)


