"""
Likelihoods that dependent on the shared posterior formulation,
which is a mixture of experts here.
"""

import torch

import evaluation
import utils
from disentanglement_vae.likelihoods.likelihoods import LikelihoodEstimator

lme = utils.log_mean_exp


class MoeLikelihoodEstimator(LikelihoodEstimator):

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def get_joint_likelihood(self, model, x, output):
        """
        Joint likelihood is maximized during training.
        """
        kl_g = self._kl_shared_variable_joint_case(model, output)  # M x K x N
        kl_z1 = self._kl_disentangled_variable(model, output, mod='x1')
        kl_z2 = self._kl_disentangled_variable(model, output, mod='x2')
        kl_z = torch.stack([kl_z1, kl_z2])  # M x K x N
        kl = kl_g + kl_z  # M x K x N
        if self.beta is not None:
            kl *= self.beta
        rec = self._reconstruction_likelihood_joint_case(model, x, output)  # M x K x N
        likelihood = rec - kl  # M x K x N
        likelihood = self._postprocess_joint_likelihood(likelihood)
        return likelihood

    @torch.no_grad()
    def get_evaluation_likelihood(self, model, x, output):
        self.eval = True
        likelihood = {'joint': self.get_joint_likelihood(model, x, output)}
        return likelihood

    def _reconstruction_likelihood_joint_case(self, model, x, output):
        """
        Every modality expert must reconstruct every modality
        """
        recs = []
        for c in model.modalities:
            inner_recs = []
            for i, t in enumerate(model.modalities):
                if t == 'x1':
                    rec = self.compute_rec_sigmoid(
                        target=x[i],
                        rec_samples=output['reconstruction'][t][c]['samples'])
                elif t == 'x2':
                    # Gaussian likelihood
                    rec = self.compute_rec_gaussian(
                        x=x[i],
                        rec_dist=output['reconstruction'][t][c]['dist'])
                else:
                    raise ValueError
                inner_recs.append(rec)
            # sum reconstructions for all modalities given same expert
            recs.append(torch.stack(inner_recs).sum(0))
        rec = torch.stack(recs)
        return rec

    @staticmethod
    def _kl_shared_variable_joint_case(model, output):
        """
        KL between q(g|x_1, x_2) and p(g)
        """
        kls = []
        for m in model.modalities:
            g = output['posterior'][m][m]['g']['samples']  # K x N x D
            p = model.pg(*model.pg_params).log_prob(g)
            q = [output['posterior'][v][v]['g']['dist'] for v in model.modalities]
            q = [v.log_prob(g) for v in q]  # K x N x D
            # modality experts must be inside log
            q = lme(torch.stack(q), dim=0)
            kl = q - p
            kls.append(evaluation.reduce_kl(kl))
        kls = torch.stack(kls)  # M x K x N
        return kls

    def _postprocess_joint_likelihood(self, likelihood):
        if self.eval:
            # Omit half of the importance samples
            # See:
            #   Shi et al., Appendix A
            #   Appendix C of our paper, which portrays our likelihood estimation
            # Workaround to allow compatibility with other methods
            T = int(likelihood.size(1) / likelihood.size(0))  # T = K / M
            likelihood = likelihood[:, 0:T, ...]  # M x T x N

            # Concatenate modality- and importance-sampling dimension
            likelihood = likelihood.contiguous().view(-1, likelihood.size(-1))  # M*T x N

            # 1. Mean over importance samples inside logarithm
            # 2. Mean over batch dimension outside logarithm
            likelihood = lme(likelihood).mean(-1).item()
        else:
            # Mean over modalities, K, and N outside logarithm
            likelihood = likelihood.mean()
        return likelihood
