import logging

import torch

import utils

logger = logging.getLogger('custom')
lme = utils.log_mean_exp


class Objective:

    def __init__(self, args):
        # regularization
        self.crossmodal_regularization = vars(args).get('crossmodal_regularization', 0.0)

        # scaling of reconstruction likelihoods
        self.rec_factors = [vars(args).get('x1_rec_factor'),
                            vars(args).get('x2_rec_factor')]

        # optimization
        self.likelihood_estimator = None

    def __call__(self, model, data, beta=1):
        diagnostics = {}
        x, _ = data['inp']
        output = data['output']

        loss, cur_diag = self._compute_loss(model, x, output, beta)
        diagnostics = utils.update(diagnostics, cur_diag)
        cur_diag = self._compute_diagnostics(loss, model, x, output, beta)
        diagnostics = utils.update(diagnostics, cur_diag)

        return loss, diagnostics

    def _compute_loss(self, model, x, output, beta):
        diagnostics = utils.rec_defaultdict()
        loss = []
        msg = 'loss must be scalar'

        cur_loss = self._compute_nelbo(model, x, output, beta)
        assert len(cur_loss.size()) == 0, msg
        diagnostics['loss'] = {'total': cur_loss.item(), 'beta': beta}
        loss.append(cur_loss)

        cur_loss = self._compute_regularization(x, output, beta)
        if torch.is_tensor(cur_loss):
            diagnostics['loss']['reg_loss'] = cur_loss.item()
            assert len(cur_loss.size()) == 0, msg
            loss.append(cur_loss)

        loss = torch.stack(loss).sum(0)

        return loss, diagnostics

    def _compute_nelbo(self, *args, **kwargs):
        raise NotImplementedError('Define in subclass.')

    def _compute_diagnostics(self, *args, **kwargs):
        raise NotImplementedError('Define in subclass.')

    @staticmethod
    def _get_loss_from_likelihood(likelihood):
        """
        :param likelihood: positive likelihood
        :return: loss
        """
        loss = -likelihood.mean()  # mean over N and K outside log
        return loss

    def _compute_regularization(self, x, output, beta):
        """ Regularization for semantic alignment of top-level unimodal
        posteriors
        :return: loss of shape (1,)
        """
        losses = []

        # Reconstruct x2 from x1
        if self.crossmodal_regularization > 0:
            estimator = self.likelihood_estimator(
                regularization={'beta': beta,
                                'rec_factors': self.rec_factors})
            loss = estimator.crossmodal_lik_regularization(x, output)
            loss *= self.crossmodal_regularization
            loss *= -1
            losses.append(loss)

        # Postprocessing
        if losses:
            loss = torch.stack(losses).sum(0)
            loss = loss.mean(-1)  # mean over N outside log, (K,)
            loss = loss.mean(0)
        else:
            # no regularization loss
            loss = None

        return loss
