import torch

import utils
from methods.multimodal_vae_poe.estimator import PoeLikelihoodEstimator
from methods.objective import Objective

lme = utils.log_mean_exp


class PoeObjective(Objective):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.likelihood_estimator = PoeLikelihoodEstimator

    def _compute_nelbo(self, model, x, output, beta):
        """
        Computes NELBO for a Product of Experts VAE with subsampling.
        """
        estimator = self.likelihood_estimator(
            regularization={'beta': beta,
                            'rec_factors': self.rec_factors})
        lik = estimator.get_training_likelihood(x=x, model=model, output=output)
        likelihood = lik['x1'] + lik['x2'] + lik['joint']
        loss = self._get_loss_from_likelihood(likelihood)

        return loss

    @torch.no_grad()
    def _compute_diagnostics(self, loss, model, x, output, beta):
        """
        Compute loss without weighting factors (such as beta) and with
        non-changing weights (as opposed to training).
        """
        model.eval()
        diagnostics = utils.rec_defaultdict()
        estimator = self.likelihood_estimator()
        lik = estimator.get_training_likelihood(model=model, x=x, output=output)
        unweighted_loss = self._get_loss_from_likelihood(
            lik['x1'] + lik['x2'] + lik['joint'])
        diagnostics['loss'].update({
            f'lik_{k}': lme(v).mean(-1).item() for k, v in lik.items()})
        diagnostics['loss']['total_unweighted'] = unweighted_loss.item()
        return diagnostics
