import logging

import torch

import utils
from data.flowers.main_ft import load_flowers_ft_data
from evaluation.features.caption_to_image import CaptionToImageSampler, CaptionToImagePlotter
from evaluation.features.evaluator import FeatureEvaluator
from evaluation.features.image_to_caption import ImageToCaptionSampler, ImageToCaptionPlotter

logger = logging.getLogger('custom')


class FlowersFtEvaluator(FeatureEvaluator):
    """ Evaluator for models operate on images. """

    @torch.no_grad()
    def evaluate(self, model, epoch):
        with utils.Timer(
                f'Evaluate split "{self.split}" on Flowers dataset:',
                event_frequency='medium'
        ):
            model.eval()
            self._prepare_evaluation_run(epoch)
            super().evaluate(model, epoch)
            self._captions_to_images(model)
            self._images_to_captions(model)

    def _captions_to_images(self, model):
        if not self.result_flags.get('captions_to_images'):
            return

        with utils.Timer(f'Caption-to-image generation'):
            # Load all sentences for each image (not just their average) for nearest-neighbor lookup
            dataset = self._load_data(average=False)

            # Caption to image features
            sampler = CaptionToImageSampler(
                model, self.artifact_dir, self.device, dataset)
            captions, image_features = sampler.get_samples()
            plotter = CaptionToImagePlotter(
                self.artifact_dir, self.device, dataset)
            plotter.make_plot(captions, image_features)

    def _images_to_captions(self, model):
        if not self.result_flags.get('images_to_captions'):
            return

        with utils.Timer(f'Image-to-caption generation'):
            # Load all sentences for each image (not just their average) for nearest-neighbor lookup
            dataset = self._load_data(average=False)

            sampler = ImageToCaptionSampler(model, dataset, self.device)
            image_paths, caption_features = sampler.get_samples()
            plotter = ImageToCaptionPlotter(
                self.artifact_dir, self.device, dataset)
            plotter.make_plot(image_paths, caption_features)

    def _load_data(self, average):
        dataset, _ = load_flowers_ft_data(mode=self.split, average=average)
        return dataset
