import torch
from torch import distributions as D


# comparison with analytical posterior
def kldFull(mean_lin: torch.Tensor, Sigma_lin: torch.Tensor,
            mean_prop: torch.Tensor, Sigma_prop: torch.Tensor) -> torch.Tensor:
    ''' looped version for KLD between two MVNs with non-diagonal variance '''
    mask = (mean_lin != 0.)
    b = mean_prop.shape[0]
    losses = torch.zeros(b)
    for i in range(b):
        mask_i = mask[i] 
        mean_lin_i = mean_lin[i, mask_i]
        mean_i = mean_prop[i, mask_i]
        Sigma_lin_i = Sigma_lin[i, mask_i][..., mask_i]
        Sigma_i = Sigma_prop[i, mask_i][..., mask_i]
        post_lin = D.MultivariateNormal(mean_lin_i, Sigma_lin_i)
        post_prop = D.MultivariateNormal(mean_i, Sigma_i)
        losses[i] = D.kl.kl_divergence(post_lin, post_prop)
    return losses


def kldMarginal(mean_lin: torch.Tensor, var_lin: torch.Tensor,
                mean_prop: torch.Tensor, var_prop: torch.Tensor,
                eps: float = 1e-8) -> torch.Tensor:
    ''' vectorized version for KLD between two MVNs with diagonal variance '''
    mask = (mean_lin != 0).float()
    var_prop = var_prop + eps
    var_lin = var_lin + eps
    term1 = (mean_lin - mean_prop).square() / var_prop
    term2 = var_lin / var_prop
    term3 = var_prop.log() - var_lin.log()
    kl = 0.5 * (term1 + term2 + term3 - 1.) * mask
    return kl.sum(dim=-1)


def compareWithAnalytical(loc_ana: dict[str, torch.Tensor], sigma_error: torch.Tensor,
                          loc_prop: torch.Tensor, scale_prop: torch.Tensor,
                          marginal: bool = True) -> torch.Tensor:
    # prepare proposed posterior
    mean_prop, var_prop = loc_prop[..., :-1], scale_prop[..., :-1].square()
    # prepare analytical solution
    mean_lin, Sigma_lin = loc_ana['mu'], loc_ana['Sigma']
    # calculate KL Divergence
    if marginal:
        var_lin = torch.diagonal(Sigma_lin, dim1=-2, dim2=-1) * sigma_error.square().view(-1,1)
        kld = kldMarginal(mean_lin, var_lin, mean_prop, var_prop)
    else:
        Sigma_prop = torch.diag_embed(var_prop)
        Sigma_lin = Sigma_lin * sigma_error.view(-1,1,1)
        kld = kldFull(mean_lin, Sigma_lin, mean_prop, Sigma_prop)
    return kld