import torch
import torch.nn as nn

from .bert import BertEncoder, BertClf
from .image import ImageEncoder, ImageClf


class Ours(nn.Module):
    def __init__(self, args):
        super(Ours, self).__init__()

        self.args = args

        self.text_encoder = BertEncoder(args)
        self.image_encoder = ImageEncoder(args)

        # self.txtclf = BertClf(args)
        # self.imgclf= ImageClf(args)
        self.text_projection_head = nn.Linear(args.hidden_sz, 1024)
        self.image_projection_head = nn.Linear(6144, 1024)
        
        self.classification_layer = nn.Linear(1024, args.n_classes)


    def forward(self, txt, mask, segment, img):
        
        # print(txt.shape)
        # print(img.shape)
        
        txt_feature = self.text_encoder(txt, mask, segment)
        img_feature = self.image_encoder(img)
        img_feature = torch.flatten(img_feature, start_dim=1)
        
        # print(txt_feature.shape)
        # print(img_feature.shape)
        
        txt_feature = self.text_projection_head(txt_feature)
        img_feature = self.image_projection_head(img_feature)
        
        # print(txt_feature.shape)
        # print(img_feature.shape)
        

        # txt_out = self.txtclf(txt, mask, segment)
        # img_out = self.imgclf(img)
        
        txt_out = self.classification_layer(txt_feature)
        img_out = self.classification_layer(img_feature)

        txt_energy = torch.log(torch.sum(torch.exp(txt_out), dim=1))
        img_energy = torch.log(torch.sum(torch.exp(img_out), dim=1))

        txt_conf = txt_energy / 10
        img_conf = img_energy / 10
        txt_conf = torch.reshape(txt_conf, (-1, 1))
        img_conf = torch.reshape(img_conf, (-1, 1))

        if self.args.df:
            # txt_img_out = (txt_out * txt_conf.detach() + img_out * img_conf.detach())
            txt_img_out = self.classification_layer(txt_feature * txt_conf.detach() + img_feature * img_conf.detach())
        else:
            txt_conf.detach()
            img_conf.detach()
            # txt_img_out = 0.5 * txt_out + 0.5 * img_out
            txt_img_out = self.classification_layer(txt_feature * 0.5 + img_feature * 0.5)


        return txt_img_out, txt_out, img_out, txt_conf, img_conf