import logging
import os
from pathlib import Path

import torch

import utils
from evaluation.evaluator import Evaluator
from evaluation.features.unconditional_sampling import UnconditionalSampler, UnconditionalSamplesPlotter

logger = logging.getLogger('custom')


class FeatureEvaluator(Evaluator):
    """ Evaluator for models that maximize the likelihood of feature vectors. """

    def __init__(self, config, **kwargs):
        super().__init__(config=config, **kwargs)
        self.result_flags = config.result_flags_ft_data

    @torch.no_grad()
    def evaluate(self, model, epoch):
        model.eval()
        super().evaluate(model, epoch)
        self._sample_generation_from_prior(model)

    def _sample_generation_from_prior(self, model):
        c1 = self.result_flags.get('sample_generation_from_prior')
        c2 = model.vaes[0].num_levels == 1
        c3 = model.vaes[1].num_levels == 1
        if any([not c1, c2, c3]):
            return

        with utils.Timer('Generate samples from prior'):
            save_paths = {'vis': os.path.join(self.artifact_dir, 'prior'),
                          'data': os.path.join(self.artifact_dir, 'data')}
            sampler = UnconditionalSampler(model, save_paths['data'])
            samples = sampler.get_samples()
            plotter = UnconditionalSamplesPlotter(save_paths['vis'], self.split, self.device)
            plotter.make_plot(samples)
            utils.shell_command_for_download(Path(save_paths['vis']).parent)
