import logging
import os

import torch

import utils
from ..evaluator import SyntheticEvaluator
from ..scatter import distribution_scatters_wrapper

logger = logging.getLogger('custom')


class SyntheticMultimodalPoeEvaluator(SyntheticEvaluator):
    """ Evaluator for synthetic dataset. """

    @torch.no_grad()
    def evaluate(self, model, epoch):
        with utils.Timer(
                f'Evaluate split "{self.split}" for synthetic dataset:',
                event_frequency='medium'
        ):
            model.eval()
            self._prepare_evaluation_run(epoch)
            super().evaluate(model, epoch)
            self._scatter_plots(model)

    def _scatter_plots(self, model):
        c1 = model.multimodal
        c2 = self.result_flags.get('scatter_plots')
        if not all([c1, c2]):
            return

        logger.debug('\n====> Distribution Scatter:')
        dataset = self._load_data()
        output = self._get_forward_data(model, dataset, prior_samples=1000)
        inp = dataset[:]

        # workaround: 'joint' target allows seeing joint posterior over top-level
        # variable
        t = ['x1', 'x2', 'joint']
        c = ['x1', 'x2', 'joint']
        for cur_t in t:
            for cur_c in c:
                if cur_t == 'joint' and cur_c != 'joint':
                    continue
                mod = {'t': cur_t, 'c': cur_c}
                distribution_scatters_wrapper(
                    data={'inp': inp, 'output': output},
                    save_path=os.path.join(self.artifact_dir, f'scatter_{self.split}'),
                    color_dict=dataset.color_dict['all'],
                    mod=mod)
