from .modules import BayesCap_MLP
from .loss import TempCombLoss
import torch
from torch import nn


class ProbVLM(nn.Module):
    def __init__(self, emb_dim=512, *args):
        super(ProbVLM, self).__init__()
        self.img_BayesCap = BayesCap_MLP(
            inp_dim=emb_dim, out_dim=emb_dim, hid_dim=1024,
            num_layers=3, p_drop=0.1)
        self.txt_BayesCap = BayesCap_MLP(
            inp_dim=emb_dim, out_dim=emb_dim, hid_dim=1024,
            num_layers=3, p_drop=0.1)
        self.Cri = TempCombLoss()
        self.num_fw = 15

    def forward(self, i_features, t_features):
        # print('dbg', i_features.shape, t_features.shape)
        i_features = self.normalize(i_features)
        t_features = self.normalize(t_features)
        img_mu, img_1alpha, img_beta = self.img_BayesCap(i_features)
        txt_mu, txt_1alpha, txt_beta = self.txt_BayesCap(t_features)

        ret_i = (img_mu, img_1alpha, img_beta, i_features)
        ret_t = (txt_mu, txt_1alpha, txt_beta, t_features)
        return ret_i, ret_t

    def loss(self, z_i_prime, z_t_prime):
        T1, T2 = 1.0, 5e-2
        cross_modal_lambda = 1e-4
        # unpack the outputs
        img_mu, img_1alpha, img_beta, xfI = z_i_prime
        txt_mu, txt_1alpha, txt_beta, xfT = z_t_prime

        loss_i = self.Cri(img_mu, img_1alpha, img_beta, xfI, T1=T1, T2=T2)
        loss_t = self.Cri(txt_mu, txt_1alpha, txt_beta, xfT, T1=T1, T2=T2)
        #cross modal terms
        loss_i4t = self.Cri(img_mu, img_1alpha, img_beta, xfT, T1=T1, T2=T2)
        loss_t4i = self.Cri(txt_mu, txt_1alpha, txt_beta, xfI, T1=T1, T2=T2)
        loss = loss_i + loss_t + cross_modal_lambda*(loss_i4t + loss_t4i)

        return loss

    def adapt_text(self, z_t):
        self.train()
        z_t = self.normalize(z_t)
        mus = []
        for _ in range(self.num_fw):
            mu, _, _ = self.txt_BayesCap(z_t)
            mus.append(mu)
        mus = torch.stack(mus, dim=0)
        mu_mean = mus.mean(dim=0)
        uncer = (mus - mu_mean.unsqueeze(0)).pow(2).mean(dim=0).mean(dim=-1)
        return mu_mean, uncer

    def adapt_image(self, z_i):
        self.train()
        z_i = self.normalize(z_i)
        mus = []
        for _ in range(self.num_fw):
            mu, _, _ = self.img_BayesCap(z_i)
            mus.append(mu)
        mus = torch.stack(mus, dim=0)
        mu_mean = mus.mean(dim=0)
        uncer = (mus - mu_mean.unsqueeze(0)).pow(2).mean(dim=0).mean(dim=-1)
        return mu_mean, uncer

    @staticmethod
    def normalize(z):
        return z / z.norm(dim=-1, keepdim=True)

    @staticmethod
    def _get_GGuncer(x_alpha, x_beta):
        a = 1/(x_alpha + 1e-5)
        a = torch.clip(a, min=1e-4, max=5)
        b = x_beta + 0.1
        b = torch.clip(b, min=0.1, max=5)
        u = (a**2)*torch.exp(torch.lgamma(3/b))/torch.exp(torch.lgamma(1.0/b))
        return u