import sys
import os
import torch
import torch.nn as nn
import torch.nn.functional as F

sys.path.append('../ecog-multimodal/vil_embeds/SLIP')
from models import SIMCLR_VITB16

import transformers

class FusionModel(nn.Module):
    def __init__(self, image_encoder, text_encoder, output_dim):
        super().__init__()
        self.image_encoder = image_encoder
        self.text_encoder = text_encoder
        
        self.dropout = nn.Dropout(p = 0.1)
        self.linear1 = nn.Linear(2 * 768, output_dim)
        self.layer_norm = nn.LayerNorm(768)

    def forward(self, images, input_ids):
        with torch.no_grad():
            text_output = self.text_encoder(input_ids)
            image_reps = self.image_encoder.encode_image(images)

        text_reps = text_output.pooler_output
        assert len(text_reps.shape) == 2, 'Text representations must be only 2-dimensional'
        assert len(image_reps.shape) == 2, 'Image representations must be only 2-dimensional'

        fused_output = torch.cat([text_reps, image_reps], dim = -1)
        fused_output = self.dropout(fused_output)
        fused_output = self.layer_norm(self.linear1(fused_output))
        return fused_output

class NLVRModel(nn.Module):
    #From LXMERT paper
    def __init__(self, fusion_encoder):
        super().__init__()
        self.fusion_encoder = fusion_encoder
        self.linear1 = nn.Linear(2*768, 768)
        self.gelu = nn.GELU()
        self.layer_norm = nn.LayerNorm(768)
        self.linear2 = nn.Linear(768, 2)

    def forward(self, image1, image2, sent):
        mul_embed1 = self.fusion_encoder(image1, sent)
        mul_embed2 = self.fusion_encoder(image2, sent)
        z_0 = self.linear1(torch.cat([mul_embed1, mul_embed2], dim = -1))
        z_1 = self.layer_norm(self.gelu(z_0))
        return self.linear2(z_1)
    
def load_fusion_model(args, text_model_str):
    text_encoder = transformers.AutoModel.from_pretrained(text_model_str)
    text_encoder = text_encoder.eval()
    weights_path = '../ecog-multimodal/vil_embeds/pretrained_models'
    image_encoder = SIMCLR_VITB16()
    model_state_dict = torch.load(os.path.join(weights_path, 'simclr_base_25ep.pt'))['state_dict']
    for key in list(model_state_dict.keys()):
        model_state_dict[key.replace('module.', '')] = model_state_dict.pop(key)
    image_encoder.load_state_dict(model_state_dict)
    image_encoder = image_encoder.eval()
    if torch.cuda.is_available():
        text_encoder = text_encoder.cuda()
        image_encoder = image_encoder.cuda()

    fusion_model = FusionModel(image_encoder, text_encoder, args.output_dim)
    return fusion_model

def load_nlvr_model(args, text_model_str):
    nlvr_model = NLVRModel(load_fusion_model(args, text_model_str))
    return nlvr_model
