import torch

import utils
from methods.multimodal_vae_moe.estimator import MoeLikelihoodEstimator
from methods.objective import Objective

lme = utils.log_mean_exp


class MoeObjective(Objective):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.likelihood_estimator = MoeLikelihoodEstimator

    def _compute_nelbo(self, model, x, output, beta):
        estimator = self.likelihood_estimator(
            regularization={'beta': beta,
                            'rec_factors': self.rec_factors})
        likelihood = estimator.get_training_likelihood(
            x=x, model=model, output=output)
        loss = self._get_loss_from_likelihood(likelihood['joint'])

        return loss

    @torch.no_grad()
    def _compute_diagnostics(self, loss, model, x, output, beta):
        """
        Compute loss without weighting factors (such as beta) and with
        non-changing weights (as opposed to training).
        """
        model.eval()
        diagnostics = utils.rec_defaultdict()
        estimator = self.likelihood_estimator()
        lik = estimator.get_training_likelihood(model=model, x=x, output=output)
        unweighted_loss = self._get_loss_from_likelihood(lik['joint'])
        diagnostics['loss'].update({
            'lik_joint': lme(lik['joint']).mean(-1).item(),
            'total_unweighted': unweighted_loss.item()})
        return diagnostics
