from models import register_model
import torch.nn as nn
import torch
from models.base_model import BaseModel
from models.transformer_encoder_input import TransformerEncoderInput

@register_model("linear_model")
class LinearModel(BaseModel):
    def __init__(self):
        super(LinearModel, self).__init__()

    def forward(self, inputs, src_key_mask=None, positions=None, rep_from_layer=-1):
        """
        ## These are purely to make it compatible with the transformer models forward pass

        src_key_mask: not used
        positions: not used
        rep_from_layer: not used
        """
        ## Sensor dropout
        if self.training: 
            if inputs.size(1) > 2:  
                apply_mask = torch.rand(inputs.size(1)) < self.cfg.get("sensor_dropout", 0) ## true if need to perform dropout
                inputs[:, apply_mask, :] = torch.zeros(inputs[:, apply_mask, :].size(), device=inputs.device)

        flattened_input = inputs.flatten(start_dim=1)
        if hasattr(self, "batch_norm"):
            flattened_input = self.batch_norm(flattened_input)
        out = self.dropout(flattened_input)
        #flattened_input = inputs.mean(axis=1) #Take the mean instead of concat
        out = self.linear_out(out)
        return out

    def build_model(self, cfg):
        self.cfg = cfg
        if cfg.get("batch_norm", False):
            self.batch_norm = nn.BatchNorm1d(cfg.input_dim)
        self.dropout = nn.Dropout(p=cfg.get("dropout", 0))
        self.linear_out = nn.Linear(in_features=cfg.input_dim, out_features=cfg.get("output_dim", 1))

