import logging
import os
import shutil
from pathlib import Path

import torch
import wandb

import utils
from data.flowers.main_raw import load_flowers_data
from disentanglement_vae.evaluator.caption_to_image import CaptionToImageSamplerMdvae, CaptionToImagePlotterMdvae
from disentanglement_vae.evaluator.embedder import EmbedderMdvae
from disentanglement_vae.evaluator.fid_score import FidImageGeneratorMdvae
from disentanglement_vae.evaluator.image_to_image import ImageToImageSamplerMdvae, ImageToImagePlotterMdvae
from disentanglement_vae.evaluator.variances import VarianceCalculatorMdvae
from evaluation.images.classification import NNClassifier
from evaluation.images.fid_score import calc_fid
from evaluation.images.precision_recall import compute_precision_recall
from hyperparams.load import get_config

logger = logging.getLogger('custom')
config = get_config()


class Evaluator:
    def __init__(self, run_path, split, args, device, debug=False):
        """
        :param debug: True for runs solely used for debugging purposes.
        """
        # Extract run id from run path to comply with ID of the form f'{ID}_debug'
        self.run_id = Path(run_path).name
        self.split_dir = os.path.join(run_path, split)
        os.makedirs(self.split_dir, exist_ok=True)
        self.split = split
        self.args = args
        self.debug = debug
        self.device = device
        self.artifact_dir = None  # Define when calling evaluation method

    def _prepare_evaluation_run(self, epoch):
        self.set_artifact_dir(epoch)
        utils.init_wandb(
            run_id=self.args.run_id,
            project=f'hmvae_images',
            group=self.split,
            wandb_config=vars(self.args),
            name=f'{self.args.model}_{self.args.trial}',
            tags=['debug'] if self.debug else None,
        )

    def _update_quant_results(self, log, epoch):
        path = os.path.join(self.split_dir, 'quant_results.pt')
        data = torch.load(path) if os.path.isfile(path) else {}
        if epoch not in data:
            data[epoch] = log
        else:
            data[epoch] = utils.update(data[epoch], log)
        torch.save(data, path)

    def set_artifact_dir(self, epoch):
        """ Prepare dir that gathers experimental artifacts. """
        self.artifact_dir = os.path.join(self.split_dir, f'epoch_{epoch}')
        os.makedirs(self.artifact_dir, exist_ok=True)

    @torch.no_grad()
    def evaluate(self, model, epoch):
        with utils.Timer(f'Evaluate split "{self.split}":',
                         event_frequency='medium'):
            self._prepare_evaluation_run(epoch)
            model.eval()
            self.captions_to_images(model)
            self.image_to_image(model)

            self.fid_score_wrapper(model, epoch)
            data = self.get_crossmodal_data(model)
            for mod in ['x1', 'x2']:
                self.variances(model, epoch, mod)
                self.nn_classification(data, epoch, mod)
                self.precision_recall(data, epoch, mod)

            utils.shell_command_for_download(self.artifact_dir, name='results')

    def captions_to_images(self, model):
        with utils.Timer(f'Generating caption-to-image plots {(self.split)}:'):
            # Do not use averaged captions here (so we can visualize single captions)
            dataset, _ = load_flowers_data(mode=self.split, average=False)
            sampler = CaptionToImageSamplerMdvae(
                model, self.device, self.artifact_dir, dataset)
            captions, images = sampler.run()
            plotter = CaptionToImagePlotterMdvae(dataset, self.artifact_dir, self.split)
            plotter.run(captions, images)

    def image_to_image(self, model):
        with utils.Timer(
                f'Generating image-to-image plots {(self.split)}:'
        ):
            dataset, _ = load_flowers_data(mode=self.split, average=True)
            sampler = ImageToImageSamplerMdvae(
                model, dataset, self.artifact_dir, self.device
            )
            cond_images, gen_images = sampler.run()
            plotter = ImageToImagePlotterMdvae(self.artifact_dir, self.split)
            plotter.run(cond_images, gen_images)

    def fid_score_wrapper(self, model, epoch):
        with utils.Timer(f'Frechet Inception Distances ({self.split})'):
            dataset, _ = load_flowers_data(mode=self.split, average=True)
            k1 = 'fid/'
            k2 = f'x1_mode_layer_None'  # Compatibility with hierarchical models
            v = self._fid_score(model, dataset)
            log = {k1: {k2: v}}
            self._update_quant_results(log, epoch)
            wandb.log(log, step=epoch)

    def _fid_score(self, model, dataset):
        dirs = dict(
            p=os.path.join(
                config.dirs['data'], 'flowers_images/jpg_64_fid', f'{self.split}'),
            q=os.path.join(self.artifact_dir, 'tmp'))
        generator = FidImageGeneratorMdvae(
            model, dataset, self.device, dirs['q'])
        generator.run()
        fids = calc_fid(dirs, self.device)
        shutil.rmtree(dirs['q'])
        return fids

    def variances(self, model, epoch, mod):
        with utils.Timer(f'Sample variances '
                         f'({mod}; {self.split})'):
            dataset, _ = load_flowers_data(mode=self.split, average=True)
            calc = VarianceCalculatorMdvae(model, dataset, self.device)
            k1 = f'variances_{mod}/'
            k2 = f'{mod}_mode_layer_None'  # Compatibility with hierarchical models
            v = calc.run(mod=mod)
            log = {k1: {k2: v}}
            self._update_quant_results(log, epoch)
            wandb.log(log, step=epoch)

    def get_crossmodal_data(self, model):
        emb = EmbedderMdvae(model, self.device, self.split)
        data = emb.get_data()
        return data

    def nn_classification(self, data, epoch, mod):
        with utils.Timer(f'Nearest-neighbor classification '
                         f'({mod}; {self.split})'):
            dataset, _ = load_flowers_data(mode=self.split, average=True)
            classifier = NNClassifier(dataset)
            k1 = f'nn_classification_{mod}/'
            k2 = f'{mod}_mode_layer_None'  # Compatibility with hierarchical models
            v = classifier.run(data, mod=mod)
            log = {k1: {k2: v}}
            self._update_quant_results(log, epoch)
            wandb.log(log, step=epoch)

    def precision_recall(self, data, epoch, mod):
        with utils.Timer(f'Precision recall '
                         f'({mod}; {self.split})'):
            k1 = f'precision_recall_{mod}/'
            k2 = f'{mod}_mode_layer_None'  # Compatibility with hierarchical models
            v = compute_precision_recall(data, mod=mod)
            log = {k1: {k2: v}}
            self._update_quant_results(log, epoch)
            wandb.log(log, step=epoch)
