import torch

from evaluation import sum_over_dims


class LikelihoodEstimator:

    def __init__(self, beta, eval=False):
        self.eval = eval
        self.beta = beta

    def get_total_elbo(self, model, x, output):
        elbo_joint = self.get_single_elbo(model, x, output, g_mod='joint')
        elbo_x1 = self.get_single_elbo(model, x, output, g_mod='x1')
        elbo_x2 = self.get_single_elbo(model, x, output, g_mod='x2')
        elbo = elbo_joint + elbo_x1 + elbo_x2
        return elbo

    def get_single_elbo(self, model, x, output, g_mod):
        """ ELBO using importance samples from
        either q(g|x_{1:2}), q(g|x_1), or q(g|x_2) """
        # Reconstruction factors do not depend on g
        rec = self.reconstruction(x, output)
        # KL-regularization factors both depend on g
        kl_z = self.kl_z(output, g_mod)
        kl_g = self.kl_g(output, model, g_mod)
        likelihood = rec - kl_z - kl_g
        return likelihood

    def kl_z(self, output, g_mod):
        """
        KL between q(z_i|x_i) and p(z_i|g) given z_i ~ q(z_i|x_i)
        :param g_mod: conditioning modality for importance distribution over g
            'joint': q(g|x_{1:2})
            'x1': q(g|x_1)
            'x2': q(g|x_2)
        """
        kl = []
        for m in ['x1', 'x2']:
            cur_kl = self._compute_kl(
                samples=output[m][m]['z']['posterior']['samples'],
                q=output[m][m]['z']['posterior']['dist'],
                p=output[m][g_mod]['z']['prior']['dist'])
            kl.append(cur_kl)
        kl = torch.stack(kl).sum(0)
        kl *= self.beta['z']
        return kl

    def kl_g(self, output, model, g_mod):
        """
        KL between q(g|x) and p(g)
        :param g_mod: conditioning modality for importance distribution
            'joint': q(g|x_{1:2})
            'x1': q(g|x_1)
            'x2': q(g|x_2)
        """
        kl = self._compute_kl(
            samples=output[g_mod][g_mod]['g']['posterior']['samples'],
            q=output[g_mod][g_mod]['g']['posterior']['dist'],
            p=model.pg(*model.pg_params))
        kl *= self.beta['g']
        return kl

    @staticmethod
    def _compute_kl(samples, q, p):
        p = p.log_prob(samples)
        q = q.log_prob(samples)
        kl = q - p  # K x N x D
        kl = torch.sum(kl, dim=-1)  # sum over dimensions
        return kl

    @staticmethod
    def _compute_rec_gaussian(x, rec_dist):
        """
        :param x: ground truth, N x D
        :param rec_dist: learned distribution/samples, K x N x D
        :return: log likelihood of true samples under reconstruction
        distribution (nats per dimension)
        """
        # Maximize likelihood of true samples under Gaussian distribution
        r = rec_dist.log_prob(x)  # K x N x D
        r = r.sum(dim=-1)  # K x N
        return r

    @staticmethod
    def _compute_rec_sigmoid(target, rec_samples):
        """
        Minimize the binary cross-entropy between sigmoid-activated
        learned tensors and the ground truth.
        (as in the HMVAE and the other baselines)
        :param target: true samples
        :param rec_samples: reconstructed samples
        """
        bce = torch.nn.BCELoss(reduction='none')
        r = bce(rec_samples, target)
        r = sum_over_dims(r)
        # Negate to be compliant with other estimates (the higher the better)
        r *= -1
        return r
