import torch
import torch.nn as nn

from .bert import BertEncoder, BertClf
from .image import ImageEncoder, ImageClf
from torch.nn import functional as F

class IntermediateFusionClf(nn.Module):
    def __init__(self, args):
        super(IntermediateFusionClf, self).__init__()

        self.args = args

        self.text_encoder = BertEncoder(args)
        self.image_encoder = ImageEncoder(args)

        # self.txtclf = BertClf(args)
        # self.imgclf= ImageClf(args)
        self.text_projection_head = nn.Linear(args.hidden_sz, 1024)

        self.image_projection_head = nn.Linear(6144, 1024)
        
        self.classification_layer = nn.Linear(1024, args.n_classes)
        
        self.encoder_txt = nn.Sequential(nn.Linear(1024, 2048),
                                     nn.ReLU(inplace=True),
                                     nn.Linear(2048, 2048))  

        self.encoder_img = nn.Sequential(nn.Linear(1024, 2048),
                                     nn.ReLU(inplace=True),
                                     nn.Linear(2048, 2048)) 

        self.encoder = nn.Sequential(nn.Linear(1024, 2048),
                                     nn.ReLU(inplace=True),
                                     nn.Linear(2048, 2048))  

    def encode(self, x):
        """
        x : [batch_size, 1024]
        """
        x = self.encoder(x)
        return x[:,:1024], F.softplus(x[:,1024:]-5, beta=1)

    def txt_encode(self, x):
        """
        x : [batch_size, 1024]
        """
        x = self.encoder_txt(x)
        return x[:,:1024], F.softplus(x[:,1024:]-5, beta=1)

    def img_encode(self, x):
        """
        x : [batch_size, 1024]
        """
        x = self.encoder_img(x)
        return x[:,:1024], F.softplus(x[:,1024:]-5, beta=1)
    
    def reparameterise(self, mu, std):
        """
        mu : [batch_size,z_dim]
        std : [batch_size,z_dim]        
        """        
        # get epsilon from standard normal
        eps = torch.randn_like(std)
        return mu + std * eps
    
    def forward(self, txt, mask, segment, img):
        
        # print(txt.shape)
        # print(img.shape)
        
        txt_feature = self.text_encoder(txt, mask, segment)  # bs * 768
        img_feature = self.image_encoder(img)  # bs * 6144
        img_feature = torch.flatten(img_feature, start_dim=1)

        txt_feature_d = self.text_projection_head(txt_feature)  # bs * 1024
        img_feature_d = self.image_projection_head(img_feature)  # bs * 1024

        mu_txt, std_txt = self.txt_encode(txt_feature_d)
        z_txt = self.reparameterise(mu_txt, std_txt)
        output_txt =  self.classification_layer(z_txt)

        mu_img, std_img = self.img_encode(img_feature_d)
        z_img = self.reparameterise(mu_img, std_img)
        output_img =  self.classification_layer(z_img)

        # print(txt_feature.shape)
        # print(img_feature.shape)
        

        # txt_out = self.txtclf(txt, mask, segment)
        # img_out = self.imgclf(img)
        
        txt_out = self.classification_layer(txt_feature_d)
        img_out = self.classification_layer(img_feature_d)

        txt_energy = torch.log(torch.sum(torch.exp(txt_out), dim=1))
        img_energy = torch.log(torch.sum(torch.exp(img_out), dim=1))

        txt_conf = txt_energy / 10
        img_conf = img_energy / 10
        txt_conf = torch.reshape(txt_conf, (-1, 1))
        img_conf = torch.reshape(img_conf, (-1, 1))

        if self.args.df:
            # txt_img_out = (txt_out * txt_conf.detach() + img_out * img_conf.detach())
            # txt_img_out = self.classification_layer(txt_feature_d * txt_conf.detach() + img_feature_d * img_conf.detach())
            multimodal_feature = txt_feature_d * txt_conf.detach() + img_feature_d * img_conf.detach()
            txt_img_out = self.classification_layer(multimodal_feature)
            
            mu, std = self.encode(multimodal_feature)
            z = self.reparameterise(mu, std)
            output = self.classification_layer(z)

        else:
            txt_conf.detach()
            img_conf.detach()
            # txt_img_out = 0.5 * txt_out + 0.5 * img_out
            txt_img_out = self.classification_layer(txt_feature_d * 0.5 + img_feature_d * 0.5)


        return txt_img_out, txt_out, img_out, txt_conf, img_conf, txt_feature, img_feature, \
            [output_txt, mu_txt, std_txt], [output_img, mu_img, std_img], [output, mu, std]
