# objectives of choice
import torch
from numpy import prod
from utils import log_mean_exp, is_multidata
import torch.nn.functional as F
import math

eta = 1e-20

# helper to vectorise computation
def compute_microbatch_split(x, K):
    """ Checks if batch needs to be broken down further to fit in memory. """
    B = x[0].size(0) if is_multidata(x) else x.size(0)
    S = sum([1.0 / (K * prod(_x.size()[1:])) for _x in x]) if is_multidata(x) \
        else 1.0 / (K * prod(x.size()[1:]))
    S = int(1e8 * S)  # float heuristic for 12Gb cuda memory
    assert (S > 0), "Cannot fit individual data in memory, consider smaller K"
    return min(B, S)


# MULTIMODAL OBJECTIVES

def _cmvae_log_likelihood(model, x, K=1, detach=False):
    qu_xs, px_us, uss = model(x, K)
    qz_xs, qw_xs = [], []
    if detach:
        if not getattr(model.params, "use_disen", False):
            qu_xs_ = [vae.qu_x(*[p.detach() for p in vae.qu_x_params]) for vae in model.vaes]
            for r, qu_x in enumerate(qu_xs_):
                qu_x_r_mean, qu_x_r_scale = model.vaes[r].qu_x_params
                qw_x_mean, qz_x_mean = torch.split(qu_x_r_mean, [model.params.latent_dim_w, model.params.latent_dim_z], dim=-1)
                qw_x_scale, qz_x_scale = torch.split(qu_x_r_scale, [model.params.latent_dim_w, model.params.latent_dim_z], dim=-1)
                qw_x = model.vaes[r].qu_x(qw_x_mean, qw_x_scale)
                qz_x = model.vaes[r].qu_x(qz_x_mean, qz_x_scale)
                qz_xs.append(qz_x)
                qw_xs.append(qw_x)
        else:
            for r, qu_x in enumerate(qu_xs):
                qz_x = qu_x["z"]
                qw_x = qu_x["w"]
                qz_x = model.vaes[r].qu_x(qz_x.loc.detach(), qz_x.scale.detach())
                qw_x = model.vaes[r].qu_x(qw_x.loc.detach(), qw_x.scale.detach())
                qz_xs.append(qz_x)
                qw_xs.append(qw_x)
    else:
        for r, qu_x in enumerate(qu_xs):
            if not getattr(model.params, "use_disen", False):
                qu_x_r_mean, qu_x_r_scale = model.vaes[r].qu_x_params
                qw_x_mean, qz_x_mean = torch.split(qu_x_r_mean, [model.params.latent_dim_w, model.params.latent_dim_z], dim=-1)
                qw_x_scale, qz_x_scale = torch.split(qu_x_r_scale, [model.params.latent_dim_w, model.params.latent_dim_z], dim=-1)
                qw_x = model.vaes[r].qu_x(qw_x_mean, qw_x_scale)
                qz_x = model.vaes[r].qu_x(qz_x_mean, qz_x_scale)
            else:
                qz_x = qu_x["z"]
                qw_x = qu_x["w"]
            qz_xs.append(qz_x)
            qw_xs.append(qw_x)

    beta = getattr(model.params, "beta", 1.0)
    lws = []
    kl_z_terms = []
    for r, qz_x in enumerate(qz_xs):
        ws, zs = torch.split(uss[r], [model.params.latent_dim_w, model.params.latent_dim_z], dim=-1)

        pc = model.pc_params
        lpc = torch.log(pc)
        lpc = lpc.unsqueeze(1).repeat((1, zs.size()[1], 1))
        lpz_c_l = [model.pz(*model.pz_params(idx)).log_prob(zs).sum(-1) for idx in
                   range(model.params.latent_dim_c)]
        lpz_c = torch.stack(lpz_c_l, dim=-1)
        pc_z = F.softmax((lpc + lpz_c), dim=-1) + eta
        lpc_z = torch.log(pc_z)
        lpw = model.pw(*model.pw_params()).log_prob(ws).sum(-1)

        lqz_x = log_mean_exp(torch.stack([qz_x.log_prob(zs).sum(-1) for qz_x in qz_xs]))
        lqw_x = qw_xs[r].log_prob(ws).sum(-1)
        lpx_u = [px_u.log_prob(x[d]).view(*px_u.batch_shape[:2], -1)
                     .mul(model.vaes[d].llik_scaling).sum(-1)
                 for d, px_u in enumerate(px_us[r])]

        lpx_u = torch.stack(lpx_u).sum(0)
        kl_z = lqz_x - ((lpz_c + lpc - lpc_z) * pc_z).sum(-1)
        kl_z_terms.append(kl_z)
        lw = lpx_u + beta * (
            ((lpz_c + lpc - lpc_z) * pc_z).sum(-1) - lqz_x
        ) + beta * (lpw - lqw_x)
        lws.append(lw)

    return torch.stack(lws), torch.stack(uss)


def _cmvae_iwae(model, x, K=1):
    """IWAE estimate for log p_\theta(x) for multi-modal vae -- fully vectorised
    This version is the looser bound---with the average over modalities outside the log
    """
    lws, _ = _cmvae_log_likelihood(model, x, K, detach=False)
    return lws  # (n_modality * n_samples) x batch_size, batch_size

def cmvae_iwae(model, x, K=1):
    """Computes iwae estimate for log p_\theta(x) for multi-modal vae
    This version is the looser bound---with the average over modalities outside the log
    """
    S = compute_microbatch_split(x, K)
    x_split = zip(*[_x.split(S) for _x in x])
    lw = []
    for _x in x_split:
        lw_mb = _cmvae_iwae(model, _x, K)
        lw.append(lw_mb)
    lw = torch.cat(lw, 2)  # concat on batch
    loss = log_mean_exp(lw, dim=1).mean(0).sum()
    return loss


def _cmvae_dreg(model, x, K=1):
    lws, uss = _cmvae_log_likelihood(model, x, K, detach=True)
    return lws, uss

def cmvae_dreg(model, x, K=1):
    """Computes dreg estimate for log p_\theta(x) for multi-modal vae
    This version is the looser bound---with the average over modalities outside the log
    """
    S = compute_microbatch_split(x, K)
    x_split = zip(*[_x.split(S) for _x in x])
    lw = []
    uss = []
    for _x in x_split:
        lw_mb, uss_mb = _cmvae_dreg(model, _x, K)
        lw.append(lw_mb)
        uss.append(uss_mb)
    lw = torch.cat(lw, 2)  # concat on batch
    uss = torch.cat(uss, 2)  # concat on batch
    with torch.no_grad():
        grad_wt = (lw - torch.logsumexp(lw, 1, keepdim=True)).exp()
        if uss.requires_grad:
            uss.register_hook(lambda grad: grad_wt.unsqueeze(-1) * grad)
    loss = (grad_wt * lw).mean(0).sum()
    return loss


# MULTIMODAL OBJECTIVES FOR CHolderplus

def _cholderplus_log_likelihood(model, x, K=1, detach=False):
    qu_xs, px_us, uss, pairwise = model(x, K)

    n_mods_sample = len(model.vaes)
    BCs = {}
    eps = 1e-12
    if pairwise:
        for (mi, mj), d in pairwise.items():
            BCs[(mi, mj)] = d["bc_ij"]
    # build logZ (normalizer) using BCs, like in _compute_k_lws
    if BCs:
        bc_stack = torch.stack(list(BCs.values()), dim=0)   
        sumC = bc_stack.sum(0)                             
        denom = float(n_mods_sample) + 2.0 * sumC          
        logZ = (denom + 1e-12).log()                  
    else:
        # no pairwise components → Z = M
        logZ = torch.full_like(uss[0][0, :, 0], math.log(float(n_mods_sample)))
    log_w_single = {k: -logZ for k in range(n_mods_sample)}
    log_w_pairs  = {
        k: (math.log(2.0) - logZ + BCs[k].clamp_min(eps).log())
        for k in BCs
    }

    # build q(z|x_m) (shared part) for Holder mixture
    qz_xs, qw_xs = [], []
    for r, qu_x in enumerate(qu_xs):
        if not getattr(model.params, "use_disen", False):
            if detach:
                qu_x_r_mean, qu_x_r_scale = (p.detach() for p in model.vaes[r].qu_x_params)
            else:
                qu_x_r_mean, qu_x_r_scale = model.vaes[r].qu_x_params
            qw_x_mean, qz_x_mean = torch.split(qu_x_r_mean, [model.params.latent_dim_w, model.params.latent_dim_z], dim=-1)
            qw_x_scale, qz_x_scale = torch.split(qu_x_r_scale, [model.params.latent_dim_w, model.params.latent_dim_z], dim=-1)
            qw_x = model.vaes[r].qu_x(qw_x_mean, qw_x_scale)
            qz_x = model.vaes[r].qu_x(qz_x_mean, qz_x_scale)
        else:
            qz_x = qu_x["z"]
            qw_x = qu_x["w"]
            if detach:
                qz_x = model.vaes[r].qu_x(qz_x.loc.detach(), qz_x.scale.detach())
                qw_x = model.vaes[r].qu_x(qw_x.loc.detach(), qw_x.scale.detach())
        qz_xs.append(qz_x)
        qw_xs.append(qw_x)

    def log_q_holder_z(zs):
        # need qz_xs, so defined below before first call
        log_q_list = [qz_x.log_prob(zs).sum(-1) for qz_x in qz_xs]  
        log_q_stack = torch.stack(log_q_list, dim=0)              
        # singles
        term1 = torch.logsumexp(log_q_stack, dim=0)            
        # pairwise sqrt terms
        sqrt_terms = []
        for i in range(n_mods_sample):
            for j in range(i + 1, n_mods_sample):
                log_sqrt = 0.5 * (log_q_list[i] + log_q_list[j]) + math.log(2.0)
                sqrt_terms.append(log_sqrt)           
        if sqrt_terms:
            term2 = torch.logsumexp(torch.stack(sqrt_terms, dim=0), dim=0) 
            lq = torch.logaddexp(term1, term2)                  
        else:
            lq = term1
        return lq - logZ 

    # ------------------------
    # Unchanged CMVAE part
    # ------------------------
    beta = getattr(model.params, "beta", 1.0)
    lws = {}
    kl_z_terms = []
    for r, qz_x in enumerate(qz_xs):
        ws, zs = torch.split(uss[r], [model.params.latent_dim_w, model.params.latent_dim_z], dim=-1)

        pc = model.pc_params
        lpc = torch.log(pc)
        lpc = lpc.unsqueeze(1).repeat((1, zs.size()[1], 1))
        lpz_c_l = [model.pz(*model.pz_params(idx)).log_prob(zs).sum(-1) for idx in
                   range(model.params.latent_dim_c)]
        lpz_c = torch.stack(lpz_c_l, dim=-1)
        pc_z = F.softmax((lpc + lpz_c), dim=-1) + eta
        lpc_z = torch.log(pc_z)
        lpw = model.pw(*model.pw_params()).log_prob(ws).sum(-1)

        lqz_x = log_q_holder_z(zs)

        lqw_x = qw_xs[r].log_prob(ws).sum(-1)
        lpx_u = [px_u.log_prob(x[d]).view(*px_u.batch_shape[:2], -1)
                     .mul(model.vaes[d].llik_scaling).sum(-1)
                 for d, px_u in enumerate(px_us[r])]

        lpx_u = torch.stack(lpx_u).sum(0)
        kl_z = lqz_x - ((lpz_c + lpc - lpc_z) * pc_z).sum(-1)
        kl_z_terms.append(kl_z)
        lw = lpx_u + beta * (
            ((lpz_c + lpc - lpc_z) * pc_z).sum(-1) - lqz_x
        ) + beta * (lpw - lqw_x)
        lws[r] = lw

    # ------------------------
    # Pairwise Holder-plus part
    # ------------------------
    if pairwise:
        dim_w = model.params.latent_dim_w
        dim_z = model.params.latent_dim_z

        for (mi, mj), d in pairwise.items():
            z_ij = d["z_ij"]  

            # private codes from uss (unimodal posteriors)
            u_i = uss[mi].clone()
            u_j = uss[mj].clone()  
            d["u_i"] = u_i
            d["u_j"] = u_j
            # u_i, u_j = uss[mi], uss[mj]
            w_i, _ = torch.split(u_i, [dim_w, dim_z], dim=-1) 
            w_j, _ = torch.split(u_j, [dim_w, dim_z], dim=-1)   

            # priors
            pc = model.pc_params
            lpc = torch.log(pc)
            lpc = lpc.unsqueeze(1).repeat((1, z_ij.size()[1], 1))
            lpz_ij_c_l = [model.pz(*model.pz_params(idx)).log_prob(z_ij).sum(-1) for idx in
                    range(model.params.latent_dim_c)]
            lpz_ij_c = torch.stack(lpz_ij_c_l, dim=-1)
            pc_z = F.softmax((lpc + lpz_ij_c), dim=-1) + eta
            lpc_z = torch.log(pc_z)

            pw_dist_pair = model.pw(*model.pw_params())
            lpw_i  = pw_dist_pair.log_prob(w_i).sum(-1)
            lpw_j  = pw_dist_pair.log_prob(w_j).sum(-1)

            # Holder q_H(z_ij | X)
            lqz_ij = log_q_holder_z(z_ij)                               

            # q(w_i | x_i) and q(w_j | x_j)
            lqw_i = qw_xs[mi].log_prob(w_i).sum(-1)                     
            lqw_j = qw_xs[mj].log_prob(w_j).sum(-1)                     

            # likelihood with pairwise reconstructions
            lpx_ij_list = []
            for d_mod, px_ij in d["recon"].items():
                lpx_ij_mod = (
                    px_ij.log_prob(x[d_mod])
                    .view(*px_ij.batch_shape[:2], -1)
                    .mul(model.vaes[d_mod].llik_scaling)
                    .sum(-1)
                )                                                      
                lpx_ij_list.append(lpx_ij_mod)

            if lpx_ij_list:
                lpx_ij = torch.stack(lpx_ij_list, dim=0).sum(0)      
            else:
                lpx_ij = torch.zeros_like(lpz_ij_c[..., 0])

            kl_z = lqz_ij - ((lpz_ij_c + lpc - lpc_z) * pc_z).sum(-1)
            kl_z_terms.append(kl_z)
            lw_ij = lpx_ij + beta * (
                ((lpz_ij_c + lpc - lpc_z) * pc_z).sum(-1) - lqz_ij
            ) + beta * (lpw_i + lpw_j - lqw_i - lqw_j)
            lws[(mi, mj)] = lw_ij

    return lws, torch.stack(uss), log_w_single, log_w_pairs, pairwise


def _cholderplus_iwae(model, x, K=1):
    """IWAE estimate for log p_\theta(x) for multi-modal vae -- fully vectorised
    This version is the looser bound---with the average over modalities outside the log
    """
    lws, _, log_w_single, log_w_pairs, _ = _cholderplus_log_likelihood(model, x, K, detach=False)
    return lws, log_w_single, log_w_pairs

def cholderplus_iwae(model, x, K=1):
    """Computes iwae estimate for log p_\theta(x) for multi-modal vae
    This version is the looser bound---with the average over modalities outside the log
    """
    S = compute_microbatch_split(x, K)
    x_split = zip(*[_x.split(S) for _x in x])

    agg_lws = None
    agg_log_w_single = None
    agg_log_w_pairs = None
    # accumulate over microbatches
    for _x in x_split:
        lws_mb, log_w_single_mb, log_w_pairs_mb = _cholderplus_iwae(model, _x, K)

        if agg_lws is None:
            agg_lws = {k: v for k, v in lws_mb.items()}
            agg_log_w_single = {k: v for k, v in log_w_single_mb.items()}
            agg_log_w_pairs = {k: v for k, v in log_w_pairs_mb.items()}
        else:
            for k in agg_lws:
                agg_lws[k] = torch.cat([agg_lws[k], lws_mb[k]], dim=1)          # concat on batch
            for k in agg_log_w_single:
                agg_log_w_single[k] = torch.cat(
                    [agg_log_w_single[k], log_w_single_mb[k]], dim=0
                )
            for k in agg_log_w_pairs:
                agg_log_w_pairs[k] = torch.cat(
                    [agg_log_w_pairs[k], log_w_pairs_mb[k]], dim=0
                )

    lws = agg_lws
    log_w_single = agg_log_w_single
    log_w_pairs = agg_log_w_pairs or {}   # may be empty if no pairwise

    # compute IWAE over K per key, then weight outside K
    total = None
    for key, lw in lws.items():   
        f_key = log_mean_exp(lw, dim=0)  

        w_key = (
            log_w_pairs[key].exp() if isinstance(key, tuple)
            else log_w_single[key].exp()
        ) 

        contrib = w_key * f_key 
        total = contrib if total is None else total + contrib

    loss = total.sum()
    return loss


def _cholderplus_dreg(model, x, K=1):
    lws, uss, log_w_single, log_w_pairs, pairwise = _cholderplus_log_likelihood(model, x, K, detach=True)
    return lws, uss, log_w_single, log_w_pairs, pairwise

def cholderplus_dreg(model, x, K=1):
    """DReG estimate for log p_theta(x) with Holder component weights outside K."""
    S = compute_microbatch_split(x, K)
    x_split = zip(*[_x.split(S) for _x in x])

    total = 0.0
    for _x in x_split:
        lws_mb, uss_mb, log_w_single_mb, log_w_pairs_mb, pairwise_mb = _cholderplus_dreg(model, _x, K)

        # normalize over K per key 
        wk_mb = {}
        wkey_mb = {}
        for key, lw in lws_mb.items():
            log_w = log_w_pairs_mb[key] if isinstance(key, tuple) else log_w_single_mb[key]
            with torch.no_grad():
                wk_mb[key] = (lw - torch.logsumexp(lw, dim=0, keepdim=True)).exp()
                wkey_mb[key] = log_w.exp()

        # DReG hooks (per key, per microbatch) 
        M = len(model.vaes)

        # unimodal clones: each modality m only used by its unimodal key m
        for m in range(M):
            if uss_mb[m].requires_grad:
                w = wk_mb[m]
                uss_mb[m].register_hook(lambda grad, w=w: w.unsqueeze(-1) * grad)

        # pairwise latents: z_ij and the per-pair clones u_i/u_j
        if pairwise_mb:
            for (mi, mj), d in pairwise_mb.items():
                w = wk_mb[(mi, mj)]
                if d["z_ij"].requires_grad:
                    d["z_ij"].register_hook(lambda grad, w=w: w.unsqueeze(-1) * grad)
                if d["u_i"].requires_grad:
                    d["u_i"].register_hook(lambda grad, w=w: w.unsqueeze(-1) * grad)
                if d["u_j"].requires_grad:
                    d["u_j"].register_hook(lambda grad, w=w: w.unsqueeze(-1) * grad)

        mb_sum = 0.0
        for key, lw in lws_mb.items():
            comp = (lw * wk_mb[key]).sum(0)         
            mb_sum = mb_sum + wkey_mb[key] * comp      
        total = total + mb_sum.sum()

    return total
