import logging
from typing import Optional, Dict

import torch

import evaluation
import utils

logger = logging.getLogger('custom')
lme = utils.log_mean_exp


class BaseEstimator:
    """ Estimates likelihoods given one or two modalities. """

    def __init__(self,
                 regularization: Optional[dict] = None):
        """
        :param regularization: regularization is only performed when this
        argument is not None
            - 'beta': int, scale KL-regularizations
            - 'rec_factor': List[int]: factor multiplied on reconstruction likelihoods
        """
        self.eval = False  # whether to compute evaluation likelihood
        if regularization:
            assert isinstance(regularization, dict)
            self.regularization = True
            self.beta = regularization.get('beta')
            self.rec_factors = regularization.get('rec_factors')
        else:
            self.regularization = False
            self.beta = None
            self.rec_factors = None

    @staticmethod
    def _compute_reconstruction_likelihood(x, reconstruction):
        """
        :param x: ground truth
        :param reconstruction: learned distribution/samples
        :return: log likelihood of true samples under reconstruction
        distribution (nats per dimension)
        """
        rtype = reconstruction['type']

        if rtype == 'normal' or rtype == 'laplace':
            # maximize likelihood of true samples under Gaussian or Laplace
            # distribution
            r = reconstruction['dist'].log_prob(x)
            r = evaluation.sum_over_dims(r)

        elif rtype == 'categorical':
            # maximize likelihood of true samples under Categorical
            # distribution
            r = reconstruction['dist'].log_prob(x)

        elif rtype == 'bernoulli':
            # maximize likelihood of true samples under Bernoulli distribution
            r = reconstruction['dist'].log_prob(x)

        elif rtype == 'sigmoid':
            # Minimize Binary Cross Entropy between reconstruction and ground
            # truth
            rec = reconstruction['samples']
            target = x

            # iterate over importance samples, as BCELoss cannot handle
            # different shapes
            bce = torch.nn.BCELoss(reduction='none')
            r = torch.stack([bce(s, target) for s in rec])

            # sum over dimensions
            r = evaluation.sum_over_dims(r)
            # negate as to be compliant with other estimates
            # (the higher the better)
            r *= -1

        else:
            raise ValueError(f'"{rtype}" is illegal reconstruction type.')

        return r

    def _compute_hierarchical_kls(self, prior, posterior, model=None):
        """ This function computes log q/p

        :param model: only necessary when computing KL divergence with
        unconditional prior
        :return: KL-divergence between posteriors and priors over all
        hierarchical latent spaces
        """
        kls = []

        for cur_post, cur_prior in zip(posterior, prior):
            q = cur_post['dist']
            samples = cur_post['samples']
            if cur_prior is None:
                # unconditional prior
                assert model is not None
                p = model.pg(*model.pg_params)
            else:
                p = cur_prior['dist']
            kl = self._compute_kl(p=p, q=q, samples=samples)
            kls.append(kl)

        if not kls:
            # no hierarchical vae
            kls = None

        return kls

    @staticmethod
    def _compute_kl(p, q, samples, closed_form=True):
        if all([closed_form,
                (type(q), type(p)) in torch.distributions.kl._KL_REGISTRY]):
            kl = torch.distributions.kl_divergence(q, p)
            if len(kl.size()) == 2:
                # add importance sampling dimension
                kl = kl.unsqueeze(0).repeat(samples.size(0), 1, 1)
        else:
            kl = evaluation.get_mc_estimate(p=p, q=q, samples=samples)
        kl = evaluation.reduce_kl(kl)
        return kl


class LikelihoodEstimator(BaseEstimator):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def _unimodal_likelihood_wrapper(self, model, output, x):
        likelihood = {}
        if hasattr(model, 'n_modalities'):
            # assume multimodal model
            assert model.n_modalities == 2

        for i, m in enumerate(model.modalities):
            rec_factor = self.rec_factors[i] if self.rec_factors else None

            lik = self._unimodal_likelihood(
                model=model,
                x=x[i],
                distributions={
                    'posterior': output['posterior'][m][m],
                    'prior': output['prior'][m][m],
                    'reconstruction': output['reconstruction'][m][m]},
                rec_factor=rec_factor)
            likelihood[m] = lik

        return likelihood

    def _unimodal_likelihood(self,
                             model,
                             x: torch.Tensor,
                             distributions: dict,
                             rec_factor: Optional[int] = None):
        """ Estimates unimodal likelihood p(x_m).
        :param model: model which contains prior parameters
        :param x: input
        :param distributions: posterior, prior and reconstruction
        :param rec_factor: factor with which to scale reconstruction loss
        """
        assert torch.is_tensor(x), 'Pass data for a single modality.'

        posterior = distributions['posterior']
        prior = distributions['prior']
        reconstruction = distributions['reconstruction']

        # kl regularization
        kl = self._compute_hierarchical_kls(prior=prior,
                                            posterior=posterior,
                                            model=model)
        if kl:
            # sum over hierarchical levels, because KL-terms are factorized in log-space
            kl = torch.stack(kl).sum(0)
        if self.beta is not None:
            kl *= self.beta

        # reconstruction
        rec = self._compute_reconstruction_likelihood(x, reconstruction)
        if all([rec_factor is not None,
                not self.eval]):
            rec *= rec_factor

        # postprocessing
        likelihood = rec - kl
        if self.eval:
            # outside of eval-mode, this postprocessing step is done in class for the objective
            likelihood = lme(likelihood).mean(-1).item()

        return likelihood

    def _crossmodal_likelihood(self, model, likelihood, **kwargs):
        msg = 'The code solely supports a setting with two modalities.'
        assert model.modalities == ['x1', 'x2'], msg
        assert self.eval, 'Only use this method during evaluation'
        likelihood = self._update_crossmodal_likelihood(
            xt='x1', xc='x2', likelihood=likelihood, **kwargs)
        likelihood = self._update_crossmodal_likelihood(
            xt='x2', xc='x1', likelihood=likelihood, **kwargs)
        return likelihood

    def _update_crossmodal_likelihood(self,
                                      xt: str, xc: str,
                                      likelihood: Dict[str, torch.Tensor],
                                      x, output):
        """
        :param xt: target modality
        :param xc: conditioning modality
        :return:
        """
        assert xt in ['x1', 'x2'], f'"{xt}" is not a legal modality.'
        i = 0 if xt == 'x1' else 1
        rec = self._compute_reconstruction_likelihood(
            x=x[i],
            reconstruction=output['ancestral_samples'][xt][xc][0])
        joint = likelihood[xc] + rec
        marginal = likelihood[xc]
        joint = lme(joint).mean(-1).item()
        assert isinstance(joint, float)
        assert isinstance(marginal, float)
        likelihood[f'{xt}|{xc}'] = joint - marginal
        return likelihood
