import torch

import evaluation
import utils
from methods.estimator import LikelihoodEstimator

lme = utils.log_mean_exp


class MoeLikelihoodEstimator(LikelihoodEstimator):

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def get_training_likelihood(self, **kwargs):
        likelihood = {'joint': self._joint_likelihood(**kwargs)}
        return likelihood

    def get_evaluation_likelihood(self, **kwargs):
        self.eval = True
        likelihood = {}

        likelihood.update(self.get_training_likelihood(**kwargs))
        likelihood.update(self._unimodal_likelihood_wrapper(**kwargs))
        likelihood.update(self._crossmodal_likelihood(
            likelihood=likelihood, **kwargs))
        likelihood['sum'] = likelihood['x1'] \
                            + likelihood['x2'] \
                            + likelihood['joint'] \
                            + likelihood['x1|x2'] \
                            + likelihood['x2|x1']

        return likelihood

    def _joint_likelihood(self, model, x, output):
        """ Likelihood of joint model.
        :param model: model which contains prior parameters and llik-scaling
        :param x: input
        """
        kl = self._joint_kl(model, output)  # M x K x N
        if self.beta is not None:
            kl *= self.beta
        rec = self._joint_reconstruction(model, x, output)  # M x K x N

        likelihood = rec - kl  # M x K x N

        if self.eval:
            # Omit half of the importance samples
            # See:
            #   Shi et al., Appendix A
            #   Appendix C of our paper, which portrays our likelihood estimation
            # Workaround to allow compatibility with other methods
            T = int(likelihood.size(1) / likelihood.size(0))  # T = K / M
            likelihood = likelihood[:, 0:T, ...]  # M x T x N

            # Concatenate modality- and importance-sampling dimension
            likelihood = likelihood.contiguous().view(-1, likelihood.size(-1))  # M*T x N

            # 1. Mean over importance samples inside logarithm
            # 2. Mean over batch dimension outside logarithm
            likelihood = lme(likelihood).mean(-1).item()
        else:
            # Mean over modalities outside logarithm
            likelihood = likelihood.mean(0)
            # Other postprocessing steps are done in class for the objective

        return likelihood

    def _joint_kl(self, *args, **kwargs):
        """ Computes log q/p
        :return: shape M x K x N
        """
        top_kl = self._top_kl(*args, **kwargs)
        lower_kls = self._lower_kls(*args, **kwargs)

        if torch.is_tensor(lower_kls):
            # Add top-level and lower-lever KLs for each modality expert
            kls = top_kl + lower_kls
        else:
            kls = top_kl

        return kls

    def _lower_kls(self, model, output):
        """ Lower-level KL-regularization """
        # Sum all terms corresponding to one modality expert
        lower_kls = []
        for i in model.modalities:  # modality expert
            tmp = []
            for j in model.modalities:  # target modality
                kl = self._compute_hierarchical_kls(
                    prior=output['prior'][j][i][:-1],
                    posterior=output['posterior'][j][i][:-1])
                if kl is not None:
                    # sum over hierarchical levels
                    tmp.append(torch.stack(kl).sum(0))
            # sum over target modalities
            if tmp:
                lower_kls.append(torch.stack(tmp).sum(0))  # K x N
        if lower_kls:
            lower_kls = torch.stack(lower_kls)  # M x K x N
        return lower_kls

    @staticmethod
    def _top_kl(model, output):
        """
        Compute top-level KL separately, because the mixture formulation
        complicates things.
        """
        top_kl = []
        for m in model.modalities:
            g = output['posterior'][m][m][-1]['samples']  # K x N x D
            p = model.pg(*model.pg_params).log_prob(g)
            q = [output['posterior'][v][v][-1]['dist'] for v in model.modalities]
            q = [v.log_prob(g) for v in q]  # K x N x D
            # modality experts must be inside log
            q = lme(torch.stack(q), dim=0)
            kl = q - p
            top_kl.append(evaluation.reduce_kl(kl))
        top_kl = torch.stack(top_kl)  # M x K x N
        return top_kl

    def _joint_reconstruction(self, model, x, output):
        """
        :return: shape M x K x N
        """
        # Every modality expert must reconstruct every modality
        recs = []
        for c in model.modalities:
            inner_recs = []
            for i, t in enumerate(model.modalities):
                rec = self._compute_reconstruction_likelihood(
                    x=x[i],
                    reconstruction=output['reconstruction'][t][c])
                rec_factor = self.rec_factors[i] if self.rec_factors else None
                if rec_factor:
                    rec *= rec_factor
                inner_recs.append(rec)
            # sum reconstructions for all modalities given same expert
            recs.append(torch.stack(inner_recs).sum(0))
        rec = torch.stack(recs)
        return rec
