import utils
from methods.mm_method import MultiModalMethod
from vae.layers import ProductOfExperts
from vae.layers.distributions import get_distribution
from .estimator import PoeLikelihoodEstimator


class Model(MultiModalMethod):
    def __init__(self, args):
        super().__init__(args)
        self.multimodal = True

        # prepare shared posterior: no scale regularization, as such is
        # already done over unimodal posteriors
        self.qg = get_distribution(dist_type=args.stoc_dist)
        self.experts = ProductOfExperts()

    def forward(self, x, eval=False, **kwargs):
        output, bu_tensors = self._infer_unimodal_experts(x, **kwargs)
        output = self._compute_product_wrapper(output, **kwargs)
        output = self._top_down_pass(output, bu_tensors)

        if eval:
            output = self._crossmodal_passes(output, bu_tensors=bu_tensors)
            output = self._crossmodal_generations(output)
        else:
            # for regularization where we reconstruct x2 from x1
            output = self._crossmodal_generation(output,
                                                 modalities={'c': 'x1',
                                                             't': 'x2'})

        return output

    def _infer_unimodal_experts(self, x, **kwargs):
        """
        Infer marginal posteriors q(g|x_i)
        """
        output = utils.rec_defaultdict()
        bu_tensors = []

        # infer unimodal experts
        for i, vae in enumerate(self.vaes):
            cur_bu_tensors, posterior, prior = vae.bottom_up_wrapper(
                x[i], **kwargs)
            m = self.modalities[i]
            output['posterior'][m][m] = posterior
            output['prior'][m][m] = prior
            bu_tensors.append(cur_bu_tensors)

        return output, bu_tensors

    def _compute_product_wrapper(self, output, **kwargs):
        """
        Compute product of experts while subsampling modality experts.
        """
        # joint expert
        experts = [output['posterior'][m][m][-1]['dist'] for m in self.modalities]
        q = self._compute_product(experts, **kwargs)
        # joint posterior solely exists for top-level
        output['posterior']['joint']['joint'] = q

        # unimodal experts
        msg = 'Product must be computed before including prior experts in ' \
              'marginal posteriors'
        assert 'joint' in output['posterior'].keys(), msg
        for m in self.modalities:
            q = output['posterior'][m][m][-1]['dist']
            q = self._compute_product(experts=[q], **kwargs)
            output['posterior'][m][m][-1] = q

        return output

    def _compute_product(self, experts: list, prior_expert=True, **kwargs):
        """
        Compute product over given modality experts
        :param experts: modality expert distributions
        :param prior_expert: whether to include a prior expert
        """
        mus, scales = [], []  # expert params

        # unimodal experts
        for e in experts:
            n = e.mean.size(0)
            mus.append(e.mean)
            scales.append(e.scale)

        if prior_expert:
            mu, scale = list(self.pg_params)
            mu, scale = mu.repeat(n, 1), scale.repeat(n, 1)
            mus.append(mu)
            scales.append(scale)

        # product over experts
        mu_poe, scale_poe = self.experts(mus=mus, scales=scales)
        distribution = self.qg([mu_poe, scale_poe], **kwargs)

        return distribution

    def _top_down_pass(self, output, bu_tensors):
        """
        Generate modalities from posteriors over shared latent space.
        """
        output = self._unimodal_top_down_pass(output, bu_tensors)

        output = self._joint_top_down_pass(
            output,
            shared_posterior=output['posterior']['joint']['joint'],
            bu_tensors=bu_tensors)

        return output

    def _unimodal_top_down_pass(self, output, bu_tensors):
        for i, vae in enumerate(self.vaes):
            m = self.modalities[i]

            # get conditioning samples
            samples = output['posterior'][m][m][-1]['samples']

            # propagate through top-down pass
            posterior, prior, reconstruction = vae.top_down(
                samples, bu_tensors=bu_tensors[i])

            # save values
            m = self.modalities[i]
            if posterior:
                # assume hierarchy
                output['posterior'][m][m] = posterior + output['posterior'][m][m]
                output['prior'][m][m] = prior + output['prior'][m][m]
            output['reconstruction'][m][m] = reconstruction

        return output

    def _joint_top_down_pass(self,
                             output: dict,
                             shared_posterior: dict,
                             bu_tensors: list):
        """ Pass shared posterior through every unimodal generative
        top-down pass.
        :param shared_posterior: samples from joint posterior
        :param bu_tensors: hidden states from unimodal bottom-up pass
        """
        for i, vae in enumerate(self.vaes):
            # get conditioning samples
            samples = shared_posterior['samples']

            # propagate through top-down pass
            posterior, prior, reconstruction = vae.top_down(
                samples, bu_tensors=bu_tensors[i])

            # save values
            m = self.modalities[i]
            # highest hierarchical level already defined
            output['posterior'][m]['joint'] = posterior + [None]
            output['prior'][m]['joint'] = prior + [None]
            output['reconstruction'][m]['joint'] = reconstruction

        return output

    def evaluate_likelihood(self, inp, output, **kwargs):
        x, _ = inp
        estimator = PoeLikelihoodEstimator()
        likelihood = estimator.get_evaluation_likelihood(
            model=self, x=x, output=output, **kwargs)
        return likelihood
