import os
from collections import defaultdict

import matplotlib.pyplot as plt
import torch

import utils
from evaluation.features.misc import make_grid_from_paths
from hyperparams.load import get_config

config = get_config()


class CaptionToImageSampler:
    def __init__(self, model, save_path, device, dataset):
        self.model = model
        self.save_data_path = os.path.join(save_path, 'data')
        self.device = device
        self.dataset = dataset

    def get_samples(self):
        captions, top_level_samples = self._infer_top_level_samples()
        image_features = self._generate_image_features(top_level_samples)
        return captions, image_features

    def _infer_top_level_samples(self):
        captions, caption_features = self._select_conditioning_captions()
        top_level_samples = self._perform_bottom_up_pass(caption_features)
        return captions, top_level_samples

    def _perform_bottom_up_pass(self, caption_features):
        top_level_samples = {}
        for temp in [0.5, 0.7, 1.0]:
            vae = self.model.vaes[1]
            if vae.num_levels == 1:
                # sample variation over g
                _, posterior = vae.bottom_up(caption_features, k=100, temp=temp)
                # shape: captions x K x D
                top_level_samples[temp] = posterior['samples'].transpose(0, 1)
            else:
                # sample identity over g and variation over unimodal variables
                _, posterior = vae.bottom_up(caption_features, temp=temp)
                # shape: captions x D
                top_level_samples[temp] = posterior['samples'].squeeze()
        return top_level_samples

    def _select_conditioning_captions(self):
        # Select first caption for each of the classes
        caption_features, captions = [], []
        # remove [:20] to capture all classes
        y = self.dataset.s['y'].unique()[:20]
        for cur_y in y:
            idx = (self.dataset.s['y'] == cur_y)
            x2 = self.dataset.x[1][idx]
            paths = self.dataset.s['caption_paths'][idx]
            cur_feat, cur_capt = self._get_caption(paths, x2, cur_y)
            caption_features.extend(cur_feat)
            captions.extend(cur_capt)
        caption_features = torch.stack(caption_features).to(self.device)
        return captions, caption_features

    def _get_caption(self, paths, x2, cur_y, extended=False):
        caption_feats, captions = [], []
        if cur_y in [1, 66, 72, 84] and extended:
            # flying birds in test set -> get more instances
            for i in range(50):
                cur_feat, cur_caption = self._get_cur_caption(paths, x2, i)
                caption_feats.append(cur_feat)
                captions.append(cur_caption)
        else:
            cur_feat, cur_caption = self._get_cur_caption(paths, x2, i=0)
            caption_feats.append(cur_feat)
            captions.append(cur_caption)
        return caption_feats, captions

    @staticmethod
    def _get_cur_caption(paths, feats, i):
        """
        :param i: which image/caption pair to use
        """
        # We use the first caption from each image/caption pair
        feat = feats[i][0]
        with open(paths[i], 'r') as f:
            caption = f.read().splitlines()
            caption = caption[0]
        return feat, caption

    def _generate_image_features(self, samples):
        image_features = {}
        for temp, cur_sample in samples.items():
            vae = self.model.vaes[0]
            ancestral_samples = vae.generate(cur_sample, k=64)
            rec = ancestral_samples[0]['dist'].mean
            if vae.num_levels == 2:
                # sampling is done in unimodal variable -> put K at right place
                rec = rec.transpose(1, 0)
            image_features[temp] = rec  # image features x K x D
        return image_features


class CaptionToImagePlotter:
    def __init__(self, save_path, device, dataset):
        self.save_paths = {
            'output': os.path.join(save_path, 'caption_to_image'),
            'data': os.path.join(save_path, 'data')
        }
        self.device = device
        self.dataset = dataset
        os.makedirs(self.save_paths['output'], exist_ok=True)

    def make_plot(self, captions, image_features):
        image_paths = defaultdict(list)
        for temp, temp_feat in image_features.items():
            for index, feat in enumerate(temp_feat):
                # for every class
                cur_paths = self._find_nearest_neighbor_images(feat)
                grid, _ = make_grid_from_paths(cur_paths)
                image_paths[temp].append(cur_paths)
                self._create_figure(grid, captions, temp, index)
        utils.shell_command_for_download(self.save_paths['output'])
        self._save_data(captions, image_paths)

    def _save_data(self, captions, image_paths):
        if config.save_generated_data:
            data = {'captions': captions,
                    'image_paths': image_paths}
            os.makedirs(self.save_paths['data'], exist_ok=True)
            torch.save(data,
                       os.path.join(self.save_paths['data'], 'captions_to_images.pt'))

    def _find_nearest_neighbor_images(self, image_features):
        d = utils.get_distance(s=self.dataset.x[0],
                               q=image_features,
                               device=self.device)
        idx = d.argmin(0)  # K image features belonging to one caption
        paths = self.dataset.s['image_paths'][idx]
        return paths

    def _create_figure(self, grid, captions, temp, index):
        fig, ax = plt.subplots(figsize=(8.27, 8.27))
        plt.title(captions[index])
        im = utils.to_np(grid).transpose((1, 2, 0))
        ax.imshow(im)
        save_path = os.path.join(self.save_paths['output'],
                                 f'c2i_{index}_temp_{temp}.png')
        plt.savefig(save_path,
                    format='png', dpi=500, transparent=True,
                    bbox_inches='tight')
        plt.close()
