import torch, torch.nn as nn, torch.nn.functional as F
import pretrainedmodels as ptm


def create_mlp(n_inputs, n_outputs, width, depth, dropout = 0.2):
    sizes = [n_inputs] + [width]*(depth - 1) + [n_outputs]
    layers = []
    for c, i in enumerate(sizes[:-1]):
        layers.append(nn.Linear(sizes[c], sizes[c+1]))
        if c != len(sizes) - 2:
            layers.append(nn.Dropout(dropout))
            layers.append(nn.ReLU())
    
    return nn.Sequential(*layers)


"""============================================================="""
class MultiModalNetwork(torch.nn.Module):
    def __init__(self, opt, base_model_1, base_model_2):
        super(MultiModalNetwork, self).__init__()
        self.pars  = opt
        self.model = nn.ModuleList([base_model_1, base_model_2])
        self.name = opt.arch + '_multimodal'

        concat_emb_dim = base_model_1.enc_out_dim + base_model_2.enc_out_dim
        
        if opt.fusion_depth == 0:
            self.model.fusion1 = nn.Sequential()
            self.model.last_linear = torch.nn.Linear(concat_emb_dim, opt.embed_dim)
        else:
            self.model.fusion1 = create_mlp(concat_emb_dim , opt.multimodal_fc_dim, opt.multimodal_fc_dim, 
                        opt.fusion_depth)
            # Linear layer for embedding 
            self.model.last_linear = torch.nn.Linear(opt.multimodal_fc_dim, opt.embed_dim)

        if 'frozen' in opt.arch:
            for module in filter(lambda m: type(m) == nn.BatchNorm2d, self.model.modules()):
                module.eval()
                module.train = lambda _: None

        self.out_adjust = None
        self.enc_out_dim = opt.embed_dim

    def forward(self, x, input_ids, token_type_ids, attention_mask, **kwargs):
        _, (enc_out_1, no_avg_feat_1) = self.model[0](x) # we skip base_model.last_linear

        _, (enc_out_2, no_avg_feat_2) = self.model[1](input_ids = input_ids, token_type_ids = token_type_ids, 
            attention_mask = attention_mask)
        
        # Fusion layer 
        x = F.relu(self.model.fusion1(torch.cat((enc_out_1, enc_out_2), dim = -1)))
        x = self.model.last_linear(x)

        if 'normalize' in self.pars.arch:
            x = torch.nn.functional.normalize(x, dim=-1)
        if self.out_adjust and not self.train:
            x = self.out_adjust(x)

        return x, ((enc_out_1, enc_out_2), 
                (no_avg_feat_1.view(self.pars.bs, -1), no_avg_feat_2)
        )