import torch, torch.nn as nn

"""============================================================="""
class Network(torch.nn.Module):
    def __init__(self, opt):
        super(Network, self).__init__()

        self.pars  = opt
        self.name = opt.arch

        self.model = nn.GRU(opt.ts_input_size, opt.embed_dim, opt.gru_n_layers, batch_first = True, 
            dropout = 0.2, bidirectional = False)

        self.model.last_linear = nn.Linear(opt.embed_dim, opt.embed_dim)

        if 'frozen' in opt.arch:
            for module in filter(lambda m: type(m) == nn.BatchNorm2d, self.model.modules()):
                module.eval()
                module.train = lambda _: None
        
        self.enc_out_dim = opt.embed_dim
        self.out_adjust = None


    def forward(self, x, **kwargs):
        emb_out, _ = self.model(x)
        emb_out = emb_out[:, -1, :]

        x = self.model.last_linear(emb_out)

        if 'normalize' in self.pars.arch:
            x = torch.nn.functional.normalize(emb_out, dim=-1)
        if self.out_adjust and not self.train:
            x = self.out_adjust(x)

        return x, (emb_out, emb_out)
