import torch.nn as nn
import torch
import torch.nn.functional as F
import torchvision
from transformers import AutoTokenizer, BertForSequenceClassification

# Image Encoder
class ImageClf(nn.Module):
    def __init__(self):
        super(ImageClf, self).__init__()
        self.img_encoder = torchvision.models.resnet152(pretrained=True)
        self.img_encoder = nn.Sequential(*list(self.img_encoder.children())[:-1])

    def forward(self, x):
        x = self.img_encoder(x)
        x = torch.flatten(x, start_dim=1)   # feature [B, C] [16, 6144]
        return x


# Text Encoder
class BertClf(nn.Module):
    def __init__(self):
        super(BertClf, self).__init__()
        self.model = BertForSequenceClassification.from_pretrained("textattack/bert-base-uncased-yelp-polarity", output_hidden_states=True)

    def forward(self, x):
        x = self.model(**x)    # feature [B, C] [16, 768]
        hidden_states = x.hidden_states # include all(13) layers' feature 
        cls_embedding = hidden_states[-1][:, 0, :]
        return cls_embedding
    
class PMRInFood101(nn.Module):

    def __init__(self, num_classes):
        super(GMFInFood101, self).__init__()
        self.t_enc = BertClf()
        self.i_enc = ImageClf()
        self.fusion = PMR([2048,768])
        self.fc = nn.Linear(2048+768+768, num_classes)
        self.nn = nn.BatchNorm1d(2048+768+768)
    
    def get_pmr(self):
        return self.fusion
    
    def forward(self, txt, img):
        fi = self.i_enc(img)
        ft = self.t_enc(txt)
        x1_inv, x2_inv, x1_spec, x2_spec, x1_re, x2_re = self.fusion(fi, ft)
        output = torch.concat((((x1_inv+x2_inv)/2, x1_spec, x2_spec)), dim=-1)
        output = self.nn(output)
        output = self.fc(output)
        return output, fi, ft, x1_re, x2_re
    

class PMR(nn.Module):

    def __init__(self, dims, multiple=4, boundary=0.5):
        super(GMF, self).__init__()
        self.modalA = ElementSplit(dims[0], min(dims), multiple, boundary)
        self.modalB = ElementSplit(dims[1], min(dims), multiple, boundary)
        self.P_reconA = nn.Linear(dims[0] + min(dims), dims[0])
        self.P_reconB = nn.Linear(dims[1] + min(dims), dims[1])

    def forward(self, x1, x2):
        x1_inv, x1_spec = self.modalA(x1)
        x2_inv, x2_spec = self.modalB(x2)
        x1 = torch.concat([x2_inv, x1_spec], dim=1)
        x2 = torch.concat([x1_inv, x2_spec], dim=1)
        x1_re = self.P_reconA(x1)
        x2_re = self.P_reconB(x2)
        return x1_inv, x2_inv, x1_spec, x2_spec, x1_re, x2_re
    # x1, x2 = batch
    # x1_inv, x2_inv, x1_spec, x2_spec, x1_re, x2_re = GMF(x1, x2)
    # recon(xi,xi_re)
    # recon.backward()
    # mainloss.backward()
    
class ElementSplit(nn.Module):

    def __init__(self, dim, min_len, multiple=4, boundary=0.5):
        super(ElementSplit, self).__init__()
        if boundary > 1.:
            self.boundary = boundary
        else:
            self.boundary = int(boundary * multiple * dim)
        self.dislen = int(dim*multiple)
        self.dim = dim
        self.P_dis = nn.Linear(dim, multiple * dim)
        self.P_con_inv = nn.Linear(self.boundary, min_len)
        self.P_cov_spec = nn.Linear(self.dislen - self.boundary, dim)

    def forward(self, x):
        x = self.P_dis(x)
        x_inv = self.P_con_inv(x[:, :self.boundary])
        x_spec = self.P_cov_spec(x[:, self.boundary:self.dislen])
        return x_inv, x_spec


class ReconstructionLoss(nn.Module):
    def __init__(self):
    	# if not mse loss
        super(ReconstructionLoss, self).__init__()

    def forward(self, x_recon, x_original):
        cos_sim = F.cosine_similarity(x_recon, x_original, dim=-1)
    
        loss = 1 - cos_sim.mean()
    
        return loss

