import logging
import os
import shutil

import torch
import wandb

import utils
from evaluation.evaluator import Evaluator
from evaluation.images.caption_to_image import CaptionToImageSampler, CaptionToImagePlotter
from evaluation.images.classification import NNClassifier
from evaluation.images.embedder import Embedder
from evaluation.images.fid_score import FidImageGenerator, calc_fid
from evaluation.images.image_to_caption import ImageToCaptionSampler, ImageToCaptionPlotter
from evaluation.images.precision_recall import compute_precision_recall
from evaluation.images.prior_to_captions import CaptionVaeSampler
from evaluation.images.prior_to_image import PriorToImagePlotter
from evaluation.images.variances import VarianceCalculator
from evaluation.images.visualizer import Visualizer
from hyperparams.load import get_config

logger = logging.getLogger('custom')
config = get_config()


class ImageEvaluator(Evaluator):

    def __init__(self, config, **kwargs):
        super().__init__(config=config, **kwargs)
        self.result_flags = config.result_flags_image_data

    @torch.no_grad()
    def evaluate(self, model, epoch):
        with utils.Timer(f'Evaluate split "{self.split}":',
                         event_frequency='medium'):
            self._prepare_evaluation_run(epoch)
            model.eval()
            super().evaluate(model, epoch)
            qual_results_dir = os.path.join(self.artifact_dir, 'qual_results')
            os.makedirs(qual_results_dir, exist_ok=True)

            self.prior_to_images(model)
            self.images_to_captions(model)
            self.captions_to_images(model)
            self.prior_to_captions(model)

            # Specify over which variables to sample the mean
            mode_config = {}
            for i, mod in enumerate(['x1', 'x2']):
                num_levels = model.vaes[i].num_levels
                mode_config[mod] = [None] + [v for v in range(1, num_levels)]
            # Perform actual analysis
            for mod, mode_layers in mode_config.items():
                for mode_layer in mode_layers:
                    kwargs = {'mod': mod, 'mode_layer': mode_layer}
                    self.fid_score_wrapper(model, epoch, **kwargs)
                    self.variances(model, epoch, **kwargs)
                    data = self.get_crossmodal_data(model, mode_layer=mode_layer)
                    self.visualizer(data, **kwargs)
                    self.nn_classification(data, epoch, **kwargs)
                    self.precision_recall(data, epoch, **kwargs)

            utils.shell_command_for_download(qual_results_dir, 'qualitative results')

    def prior_to_images(self, model):
        c1 = self.result_flags.get('prior_to_images')
        if not c1:
            return

        with utils.Timer(f'Unconditional image generation ({self.split})'):
            save_paths = {'images': os.path.join(self.artifact_dir, 'qual_results', 'prior_to_image'),
                          'data': os.path.join(self.artifact_dir, 'data')}
            plot_builder = PriorToImagePlotter(model, save_paths)
            plot_builder.build_plots()

    def images_to_captions(self, model):
        c1 = self.result_flags.get('images_to_captions')
        if not c1:
            return

        with utils.Timer(f'Image-to-caption generation ({self.split})'):
            # Load all sentences for each image (not just their average) for nearest-neighbor lookup
            dataset = self._load_data(average=False)
            sampler = ImageToCaptionSampler(
                model, dataset, self.artifact_dir, self.device)
            images, caption_features = sampler.run()
            plotter = ImageToCaptionPlotter(
                dataset, self.artifact_dir, self.split, self.device)
            plotter.run(images, caption_features)

    def captions_to_images(self, model):
        c1 = self.result_flags.get('captions_to_images')
        if not c1:
            return

        with utils.Timer(f'Caption-to-image generation ({self.split})'):
            # Do not use averaged captions here (so we can visualize single captions)
            dataset = self._load_data(average=False)
            sampler = CaptionToImageSampler(
                model, self.device, self.artifact_dir, dataset)
            captions, images = sampler.run()
            plotter = CaptionToImagePlotter(dataset, self.artifact_dir, self.split)
            plotter.run(captions, images)

    def prior_to_captions(self, model):
        """ p(x2|g) with g ~ p(g) """
        if not self.result_flags.get('prior_to_captions'):
            return

        with utils.Timer(f'Unconditional caption generation ({self.split})'):
            # Load all sentences for each image (not just their average) for nearest-neighbor lookup
            dataset = self._load_data(average=False)
            sampler = CaptionVaeSampler(model,
                                        dataset,
                                        self.device,
                                        self.artifact_dir)
            sampler.run()

    def fid_score_wrapper(self, model, epoch, mod=None, mode_layer=None):
        c1 = self.result_flags.get('fid_score')
        c2 = mod == 'x1'
        if not (c1 and c2):
            return
        with utils.Timer(
                f'Frechet Inception Distances '
                f'({mod}; mode_layer: {mode_layer}; {self.split})'
        ):
            dataset = self._load_data(average=True)
            k1 = 'fid/'
            k2 = f'{mod}_mode_layer_{mode_layer}'
            v = self._fid_score(model, dataset, mode_layer)
            log = {k1: {k2: v}}
            self._update_quant_results(log, epoch)
            if not mode_layer:
                wandb.log(log, step=epoch)

    def _fid_score(self, model, dataset, mode_layer):
        dirs = dict(
            p=os.path.join(
                config.dirs['data'], 'flowers_images/jpg_64_fid', f'{self.split}'),
            q=os.path.join(self.artifact_dir, 'tmp'))
        generator = FidImageGenerator(model, dataset, self.device, dirs['q'])
        generator.run(mode_layer=mode_layer)
        fid = calc_fid(dirs, self.device)
        shutil.rmtree(dirs['q'])
        return fid

    def variances(self, model, epoch, mod=None, mode_layer=None):
        if not self.result_flags.get('variances'):
            return

        with utils.Timer(
                f'Sample variances '
                f'({mod}; mode_layer: {mode_layer}; {self.split})'
        ):
            dataset = self._load_data(average=True)
            calc = VarianceCalculator(model, dataset, self.device)
            k1 = f'variances_{mod}/'
            k2 = f'{mod}_mode_layer_{mode_layer}'
            v = calc.run(mod=mod, mode_layer=mode_layer)
            log = {k1: {k2: v}}
            self._update_quant_results(log, epoch)
            if not mode_layer:
                wandb.log(log, step=epoch)

    def get_crossmodal_data(self, model, **kwargs):
        emb = Embedder(model, self.device, self.split)
        data = emb.get_data(**kwargs)
        return data

    def visualizer(self, data, mod=None, mode_layer=None):
        c1 = self.result_flags.get('visualizations')
        c2 = mod == 'x1'
        if not (c1 and c2):
            return

        with utils.Timer(
                f'Image overview '
                f'(mode_layer: {mode_layer}; {self.split})'
        ):
            dataset = self._load_data(average=True)
            vis = Visualizer(dataset, self.artifact_dir)
            vis.run(data, mode_layer=mode_layer)

    def nn_classification(self, data, epoch, mod=None, mode_layer=None):
        if not self.result_flags.get('nn_classification'):
            return

        with utils.Timer(
                f'Nearest-neighbor classification '
                f'({mod}; mode_layer: {mode_layer}; {self.split})'
        ):
            dataset = self._load_data(average=True)
            calc = NNClassifier(dataset)
            k1 = f'nn_classification_{mod}/'
            k2 = f'{mod}_mode_layer_{mode_layer}'
            v = calc.run(data, mod=mod)
            log = {k1: {k2: v}}
            self._update_quant_results(log, epoch)
            if not mode_layer:
                wandb.log(log, step=epoch)

    def precision_recall(self, data, epoch, mod=None, mode_layer=None):
        if not self.result_flags.get('precision_recall'):
            return

        with utils.Timer(
                f'Precision recall '
                f'({mod}; mode_layer: {mode_layer}; {self.split})'
        ):

            k1 = f'precision_recall_{mod}/'
            k2 = f'{mod}_mode_layer_{mode_layer}'
            v = compute_precision_recall(data, mod=mod)
            log = {k1: {k2: v}}
            self._update_quant_results(log, epoch)
            if not mode_layer:
                wandb.log(log, step=epoch)
