import torch

import utils
from methods.estimator import LikelihoodEstimator

lme = utils.log_mean_exp


class PoeLikelihoodEstimator(LikelihoodEstimator):

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def get_training_likelihood(self, **kwargs):
        likelihood = {}
        likelihood.update(self._unimodal_likelihood_wrapper(**kwargs))
        likelihood.update(self._joint_likelihood(**kwargs))
        return likelihood

    def get_evaluation_likelihood(self, **kwargs):
        self.eval = True
        msg = 'Do not use regularization during evaluation.'
        assert self.regularization is False, msg

        likelihood = {}

        likelihood.update(self.get_training_likelihood(**kwargs))
        likelihood.update(self._crossmodal_likelihood(
            likelihood=likelihood, **kwargs))
        likelihood['sum'] = likelihood['x1'] \
                            + likelihood['x2'] \
                            + likelihood['joint'] \
                            + likelihood['x1|x2'] \
                            + likelihood['x2|x1']

        return likelihood

    def _joint_likelihood(self, model, x, output):
        """ Likelihood of joint model.
        :param model: model which contains prior parameters and llik-scaling
        :param x: input
        """
        likelihood = {}
        kls = []

        # Top-level KL regularization
        p = model.pg(*model.pg_params)
        q = output['posterior']['joint']['joint']['dist']
        samples = output['posterior']['joint']['joint']['samples']
        kl = self._compute_kl(p=p, q=q, samples=samples)
        kls.append(kl)

        # Kl regularization for unimodal lower levels
        for m in model.modalities:
            kl = self._compute_hierarchical_kls(
                prior=output['prior'][m]['joint'][:-1],
                posterior=output['posterior'][m]['joint'][:-1])
            if torch.is_tensor(kl):
                kl = torch.stack(kl).sum(0)  # sum over hierarchical levels
                kls.append(kl)

        # Total KL
        kl = torch.stack(kls).sum(0)

        # Reconstruct every modality from shared posterior
        recs = []
        for i, m in enumerate(model.modalities):
            recs.append(self._compute_reconstruction_likelihood(
                x=x[i],
                reconstruction=output['reconstruction'][m]['joint']))
        if self.rec_factors is not None:
            recs = [recs[i] * f for i, f in enumerate(self.rec_factors)]
        rec = torch.stack(recs).sum(0)

        # Postprocessing
        if self.beta is not None:
            kl *= self.beta
        lik = rec - kl
        if self.eval:
            # outside of eval-mode, this postprocessing step is done in class for the objective
            lik = lme(lik).mean(-1).item()
        likelihood['joint'] = lik

        return likelihood

    def crossmodal_lik_regularization(self, x, output):
        """ Regularization: Reconstruct x2 from x1 to semantically align
         top-level modality experts. """
        rec = self._compute_reconstruction_likelihood(
            x=x[1],
            reconstruction=output['ancestral_samples']['x2']['x1'][0])

        rec_factor = self.rec_factors[1] if self.rec_factors else None
        if rec_factor:
            rec *= rec_factor

        return rec
