import logging
import os
import shutil

import torch
import wandb

import utils
from data.flowers.main_raw import load_flowers_data
from evaluation.images.classification import NNClassifier
from evaluation.images.fid_score import calc_fid
from evaluation.images.precision_recall import compute_precision_recall
from evaluation.images.visualizer import Visualizer
from hyperparams.load import get_config
from mhvae_vasco.evaluator.evaluator import Evaluator
from mhvae_vasco.evaluator.images.caption_to_image import CaptionToImageSampler, CaptionToImagePlotter
from mhvae_vasco.evaluator.images.embedder import EmbedderVasco
from mhvae_vasco.evaluator.images.fid_score import FidImageGeneratorMhvae
from mhvae_vasco.evaluator.images.image_to_image import ImageToImageSampler, ImageToImagePlotter
from mhvae_vasco.evaluator.images.variances import VarianceCalculatorMhvae

logger = logging.getLogger('custom')
config = get_config()


class ImageEvaluator(Evaluator):

    @torch.no_grad()
    def evaluate(self, model, epoch):
        with utils.Timer(f'Evaluate split "{self.split}":',
                         event_frequency='medium'):
            self._prepare_evaluation_run(epoch)
            model.eval()
            self.caption_to_image(model)
            self.image_to_image(model)

            self.fid_score_wrapper(model, epoch)
            data = self.get_crossmodal_data(model)
            self.visualizer(data)
            for mod in ['x1', 'x2']:
                self.variances(model, epoch, mod)
                self.nn_classification(data, epoch, mod)
                self.precision_recall(data, epoch, mod)

            utils.shell_command_for_download(self.artifact_dir, name='results')

    def caption_to_image(self, model):
        with utils.Timer(
                f'Generating caption-to-image plots ({self.split}):'
        ):
            dataset, _ = load_flowers_data(mode=self.split, average=False)
            sampler = CaptionToImageSampler(
                model, dataset, self.artifact_dir, self.device
            )
            captions, images = sampler.get_data()
            plotter = CaptionToImagePlotter(self.artifact_dir, self.split)
            plotter.run(captions, images)

    def image_to_image(self, model):
        with utils.Timer(
                f'Generating image-to-image plots ({self.split}):'
        ):
            dataset, _ = load_flowers_data(mode=self.split, average=False)
            sampler = ImageToImageSampler(
                model, dataset, self.artifact_dir, self.device
            )
            cond_images, gen_images = sampler.run()
            plotter = ImageToImagePlotter(self.artifact_dir, self.split)
            plotter.run(cond_images, gen_images)

    def fid_score_wrapper(self, model, epoch):
        with utils.Timer(f'Frechet Inception Distances ({self.split})'):
            dataset, _ = load_flowers_data(mode=self.split, average=True)
            k1 = 'fid/'
            k2 = f'x1_mode_layer_None'  # Compatibility with hierarchical models
            v = self._fid_score(model, dataset)
            log = {k1: {k2: v}}
            self._update_quant_results(log, epoch)
            wandb.log(log, step=epoch)

    def _fid_score(self, model, dataset):
        dirs = dict(
            p=os.path.join(
                config.dirs['data'], 'flowers_images/jpg_64_fid', f'{self.split}'),
            q=os.path.join(self.artifact_dir, 'tmp'))
        generator = FidImageGeneratorMhvae(
            model, dataset, self.device, dirs['q'])
        generator.run()
        fids = calc_fid(dirs, self.device)
        shutil.rmtree(dirs['q'])
        return fids

    def variances(self, model, epoch, mod):
        with utils.Timer(f'Sample variances '
                         f'({mod}; {self.split})'):
            dataset, _ = load_flowers_data(mode=self.split)
            calc = VarianceCalculatorMhvae(model, dataset, self.device)
            k1 = f'variances_{mod}/'
            k2 = f'{mod}_mode_layer_None'  # Compatibility with hierarchical models
            v = calc.run(mod=mod)
            log = {k1: {k2: v}}
            self._update_quant_results(log, epoch)
            wandb.log(log, step=epoch)

    def get_crossmodal_data(self, model):
        emb = EmbedderVasco(model, self.device, self.split)
        data = emb.get_data()
        return data

    def visualizer(self, data):
        with utils.Timer(f'Visualizations (image overview; {self.split})'):
            dataset, _ = load_flowers_data(self.split, average=True)
            vis = Visualizer(dataset, self.artifact_dir)
            vis.run(data)

    def nn_classification(self, data, epoch, mod):
        with utils.Timer(f'Nearest-neighbor classification '
                         f'({mod}; {self.split})'):
            dataset, _ = load_flowers_data(mode=self.split, average=True)
            classifier = NNClassifier(dataset)
            k1 = f'nn_classification_{mod}/'
            k2 = f'{mod}_mode_layer_None'  # Compatibility with hierarchical models
            v = classifier.run(data, mod=mod)
            log = {k1: {k2: v}}
            self._update_quant_results(log, epoch)
            wandb.log(log, step=epoch)

    def precision_recall(self, data, epoch, mod):
        with utils.Timer(f'Precision recall '
                         f'({mod}; {self.split})'):
            k1 = f'precision_recall_{mod}/'
            k2 = f'{mod}_mode_layer_None'  # Compatibility with hierarchical models
            v = compute_precision_recall(data, mod=mod)
            log = {k1: {k2: v}}
            self._update_quant_results(log, epoch)
            wandb.log(log, step=epoch)
