import torch.nn as nn
import torch
import torch.nn.functional as F
import torchvision
from transformers import AutoTokenizer, BertForSequenceClassification

# Immge Encoder
class ImageClf(nn.Module):
    def __init__(self):
        super(ImageClf, self).__init__()
        self.img_encoder = torchvision.models.resnet152(pretrained=True)
        self.img_encoder = nn.Sequential(*list(self.img_encoder.children())[:-1])

    def forward(self, x):
        x = self.img_encoder(x)
        x = torch.flatten(x, start_dim=1)   # feature [B, C] [16, 6144]
        return x


# Text Encoder
class BertClf(nn.Module):
    def __init__(self):
        super(BertClf, self).__init__()
        self.model = BertForSequenceClassification.from_pretrained("textattack/bert-base-uncased-yelp-polarity", output_hidden_states=True)

    def forward(self, x):
        x = self.model(**x)    # feature [B, C] [16, 768]
        hidden_states = x.hidden_states # include all(13) layers' feature 
        cls_embedding = hidden_states[-1][:, 0, :]
        return cls_embedding
    

class MultimodalLateFusionClf(nn.Module):
    def __init__(self, num_classes):
        super(MultimodalLateFusionClf, self).__init__()
        self.ifc = nn.Linear(2048, num_classes)
        self.ien = ImageClf()
        self.tfc = nn.Linear(768, num_classes)
        self.ten = BertClf()


    def forward(self, img, txt):
        img_out = self.ien(img)
        img_out = self.ifc(img_out)
        txt_out = self.ten(txt)
        txt_out = self.tfc(txt_out)
        
        txt_energy = torch.log(torch.sum(torch.exp(txt_out), dim=1))
        img_energy = torch.log(torch.sum(torch.exp(img_out), dim=1))

        txt_conf = txt_energy / 10
        img_conf = img_energy / 10
        txt_conf = torch.reshape(txt_conf, (-1, 1))
        img_conf = torch.reshape(img_conf, (-1, 1))

        txt_img_out = (txt_out * txt_conf.detach() + img_out * img_conf.detach())

        return txt_img_out, txt_out, img_out, txt_conf, img_conf

