# Base CMVAE class definition
from itertools import combinations
import torch
from utils import get_mean
from .cmvae import CMVAE


class CHolderplus(CMVAE):
    """
    CHolderplus model definition. Multimodal VAE with clustering in the latent space.
    """
    def bhattacharyya_coefficient(self, mu1, logvar1, mu2, logvar2, eps=1e-12):
        """Compute Bhattacharyya coefficient BC(q1,q2) for diagonal Normal(q1,q2)."""
        v1 = torch.exp(logvar1) + eps
        v2 = torch.exp(logvar2) + eps
        vbar = 0.5 * (v1 + v2)
        diff2 = (mu1 - mu2) ** 2

        term1 = -0.125 * (diff2 / (vbar + eps)).sum(-1)
        term2 = -0.5 * torch.log(vbar + eps).sum(-1)
        term3 = 0.25 * (logvar1.sum(-1) + logvar2.sum(-1))
        log_bc = term1 + term2 + term3

        return torch.exp(log_bc)
    
    def _build_pairwise(self, qu_xs, uss, K):
        num_modalities = len(self.vaes)
        pairwise = {}

        for mi, mj in combinations(range(num_modalities), 2):
            dim_w = self.params.latent_dim_w
            dim_z = self.params.latent_dim_z

            if not getattr(self.params, "use_disen", False):
                mu_i, std_i = qu_xs[mi].loc, qu_xs[mi].scale
                mu_j, std_j = qu_xs[mj].loc, qu_xs[mj].scale
                _, mu_i_z = torch.split(mu_i, [dim_w, dim_z], dim=-1)
                _, std_i_z = torch.split(std_i, [dim_w, dim_z], dim=-1)
                _, mu_j_z = torch.split(mu_j, [dim_w, dim_z], dim=-1)
                _, std_j_z = torch.split(std_j, [dim_w, dim_z], dim=-1)
            else:
                mu_i_z, std_i_z = qu_xs[mi]["z"].loc, qu_xs[mi]["z"].scale
                mu_j_z, std_j_z = qu_xs[mj]["z"].loc, qu_xs[mj]["z"].scale

            vi = std_i_z.pow(2).clamp_min(1e-6)
            vj = std_j_z.pow(2).clamp_min(1e-6)

            inv_sum = 1.0 / vi + 1.0 / vj
            mu_ij = (mu_i_z / vi + mu_j_z / vj) / inv_sum
            std_ij = torch.sqrt(2.0 * (1.0 / inv_sum)).clamp_min(1e-6)

            qz_ij = self.post_dist(mu_ij, std_ij)
            z_ij = qz_ij.rsample(torch.Size([K]))

            bc_ij = self.bhattacharyya_coefficient(
                mu_i_z, vi.log(), mu_j_z, vj.log()
            )

            pair_rec = {}
            for d, vae in enumerate(self.vaes):
                if d == mi:
                    latents_w, _ = torch.split(uss[mi], [self.params.latent_dim_w, self.params.latent_dim_z], dim=-1)
                elif d == mj:
                    latents_w, _ = torch.split(uss[mj], [self.params.latent_dim_w, self.params.latent_dim_z], dim=-1)
                else:
                    pw = vae.pw(*vae.pw_params_aux)
                    latents_w = pw.rsample(torch.Size([z_ij.size()[0], z_ij.size()[1]])).squeeze(2)
                    if not self.params.no_cuda and torch.cuda.is_available():
                        latents_w = latents_w.cuda()

                us_combined = torch.cat((latents_w, z_ij), dim=-1)
                pair_rec[d] = vae.px_u(*vae.dec(us_combined))

            pairwise[(mi, mj)] = {
                "z_ij": z_ij,
                "q_ij": qz_ij,
                "bc_ij": bc_ij,
                "recon": pair_rec,
            }

        return pairwise

    def forward(self, x, K=1):
        qu_xs, px_us, uss = super(CHolderplus, self).forward(x, K)
        pairwise = self._build_pairwise(qu_xs, uss, K)
        return qu_xs, px_us, uss, pairwise
    
    def self_and_cross_modal_generation_forward(self, x, K=1):
        qu_xs, px_us, uss = super(CHolderplus, self).self_and_cross_modal_generation_forward(x, K)
        pairwise = self._build_pairwise(qu_xs, uss, K)
        return qu_xs, px_us, uss, pairwise

    def self_and_cross_modal_generation(self, data):
        with torch.no_grad():
            _, px_us, _, _ = self.self_and_cross_modal_generation_forward(data)
            recons = [[get_mean(px_u) for px_u in r] for r in px_us]
        return recons
