"""
Likelihoods that are independent from parameterization of shared posterior.
"""

import torch

import utils
from evaluation import sum_over_dims

lme = utils.log_mean_exp


class LikelihoodEstimator:
    """ Assumes Mixture of Experts parameterization of shared posterior. """

    def __init__(self, eval=False, beta=1):
        self.eval = eval
        self.beta = beta

    def _kl_disentangled_variable(self, model, output, mod):
        """
        KL between q(z_i|x_i) and p(z_i)
        """
        kl = self.compute_kl(
            samples=output['posterior'][mod][mod]['z']['samples'],
            q=output['posterior'][mod][mod]['z']['dist'],
            p=model.pz1(*model.pz1_params) if mod == 'x1'
            else model.pz2(*model.pz2_params))
        return kl

    @staticmethod
    def compute_kl(samples, q, p):
        p = p.log_prob(samples)
        q = q.log_prob(samples)
        kl = q - p  # K x N x D
        kl = torch.sum(kl, dim=-1)  # sum over dimensions
        return kl

    @staticmethod
    def compute_rec_gaussian(x, rec_dist):
        """
        :param x: ground truth, N x D
        :param rec_dist: learned distribution/samples, K x N x D
        :return: log likelihood of true samples under reconstruction
        distribution (nats per dimension)
        """
        # Maximize likelihood of true samples under Gaussian distribution
        r = rec_dist.log_prob(x)  # K x N x D
        r = r.sum(dim=-1)  # K x N
        return r

    @staticmethod
    def compute_rec_sigmoid(target, rec_samples):
        """
        Minimize the binary cross-entropy between sigmoid-activated
        learned tensors and the ground truth.
        (as in the HMVAE and the other baselines)
        :param target: true samples
        :param rec_samples: reconstructed samples
        """
        bce = torch.nn.BCELoss(reduction='none')
        r = bce(rec_samples, target)
        r = sum_over_dims(r)
        # Negate to be compliant with other estimates (the higher the better)
        r *= -1
        return r
