import os
from collections import defaultdict

import matplotlib.pyplot as plt
import torch
import torchvision
from PIL import Image

import utils
from hyperparams.load import get_config

config = get_config()


class ImageToCaptionSampler:
    def __init__(self, model, dataset, device):
        self.model = model
        self.dataset = dataset
        self.device = device

    def get_samples(self):
        image_paths, top_level_samples = self._infer_top_level_samples()
        caption_features = self._generate_caption_features(top_level_samples)
        return image_paths, caption_features

    def _infer_top_level_samples(self):
        image_paths, image_features = self._select_conditioning_images()
        top_level_samples = self._perform_bottom_up_pass(image_features)
        return image_paths, top_level_samples

    def _perform_bottom_up_pass(self, image_features):
        top_level_samples = {}
        for temp in [0.5, 0.7, 1.0]:
            vae = self.model.vaes[0]
            if vae.num_levels == 1:
                # sample variation over g
                _, posterior = vae.bottom_up(image_features, k=10, temp=temp)
                # shape: images x K x D
                top_level_samples[temp] = posterior['samples'].transpose(0, 1)
            else:
                # sample identity over g and variation over unimodal variables
                _, posterior = vae.bottom_up(image_features, temp=temp)
                # shape: images x D
                top_level_samples[temp] = posterior['samples'].squeeze()
        return top_level_samples

    def _select_conditioning_images(self):
        # Select first image for each of the first five class
        y = self.dataset.s['y'].unique()[:5]
        image_features, image_paths = [], []
        i = 0
        for cur_y in y[:5]:
            idx = (self.dataset.s['y'] == cur_y)
            x1 = self.dataset.x[0][idx]
            paths = self.dataset.s['image_paths'][idx]
            image_features.append(x1[i])
            image_paths.append(paths[i])
        image_features = torch.stack(image_features).to(self.device)
        return image_paths, image_features

    def _generate_caption_features(self, top_level_samples):
        caption_features = {}
        for temp, cur_sample in top_level_samples.items():
            vae = self.model.vaes[1]
            ancestral_samples = vae.generate(cur_sample, k=10)
            rec = ancestral_samples[0]['dist'].mean
            if vae.num_levels == 2:
                # sampling is done in unimodal variable -> put K at right place
                rec = rec.transpose(1, 0)
            caption_features[temp] = rec  # caption features x K x D
        return caption_features


class ImageToCaptionPlotter:
    def __init__(self, save_path, device, dataset):
        self.save_paths = {
            'output': os.path.join(save_path, 'image_to_caption'),
            'data': os.path.join(save_path, 'data')}
        self.device = device
        self.dataset = dataset
        os.makedirs(self.save_paths['output'], exist_ok=True)

    def make_plot(self, image_paths, caption_features):
        images = defaultdict(list)
        captions = defaultdict(list)
        for temp, temp_feat in caption_features.items():
            for index, feat in enumerate(temp_feat):
                self._make_figure()
                im = self._plot_image(image_paths[index])
                c = self._plot_captions(feat)
                images[temp].append(im)
                captions[temp].append(c)
                save_path = os.path.join(self.save_paths['output'],
                                         f'image_to_caption_{index}_temp_{temp}.png')
                plt.savefig(save_path, format='png', dpi=500, transparent=True,
                            bbox_inches='tight')
                plt.close()
        utils.shell_command_for_download(self.save_paths['output'])
        self._save_data(images, captions)

    def _save_data(self, images, captions):
        if config.save_generated_data:
            data = {'images': images,
                    'captions': captions}
            os.makedirs(self.save_paths['data'], exist_ok=True)
            torch.save(
                data,
                os.path.join(self.save_paths['data'], 'images_to_captions.pt')
            )

    def _make_figure(self):
        # size is landscape A4
        fig = plt.figure(constrained_layout=True, figsize=(11.69, 8.27))
        self.gs = fig.add_gridspec(1, 2,
                                   top=1.0, left=0.00, right=1.0, bottom=0,
                                   wspace=0, hspace=0)
        self.fig = fig

    def _plot_image(self, *args, **kwargs):
        im = self._extract_image(*args, **kwargs)
        ax = self.fig.add_subplot(self.gs[0])
        ax.imshow(im)
        ax.set_xticks([]), ax.set_yticks([])
        ax.set_axis_off()
        return im

    @staticmethod
    def _extract_image(path, width=300):
        img = Image.open(path)
        w, h = img.width, img.height
        crop = torchvision.transforms.CenterCrop(min(w, h))
        trans = torchvision.transforms.ToTensor()
        resize = torchvision.transforms.Resize((width, width))
        image = trans(resize(crop(img)))
        im = image.numpy().transpose((1, 2, 0))
        return im

    def _plot_captions(self, features, border=0.2):
        captions = []
        ax = self.fig.add_subplot(self.gs[1])
        sample_idx, caption_idx = self._find_nearest_neighbor_caption(features)
        caption_paths = self.dataset.s['caption_paths'][sample_idx]
        i = 0
        for path, idx in zip(caption_paths, caption_idx):
            with open(path, 'r') as f:
                cur_captions = f.read().splitlines()
            r = i / caption_paths.size  # relative location, from 0 to 1
            f = (1 - border) / (1 + border)  # factor to squeeze into reduced interval
            string = utils.prepare_string_for_latex(cur_captions[idx])
            captions.append(string)
            ax.text(x=0,
                    y=border + r * f,
                    s=string,
                    fontsize=12,
                    transform=ax.transAxes,
                    horizontalalignment='left',
                    verticalalignment='center')
            i += 1
        ax.set_xticks([]), ax.set_yticks([])
        ax.set_axis_off()
        return captions

    def _find_nearest_neighbor_caption(self, features):
        """
        :param features:
        """
        n_captions = 10
        msg = 'code assumes ten captions per sample'
        assert self.dataset.x[1].size(1) == n_captions, msg
        # N1 x n_captions x D x N2
        d = utils.get_distance(s=self.dataset.x[1],  # N1 x n_captions x D
                               q=features,  # N2 x D
                               device=self.device)
        d = d.view(-1, d.size(-1))
        idx = d.argmin(0)
        # Decipher location in original array
        sample_idx = torch.div(idx, n_captions, rounding_mode='floor')
        caption_idx = idx % n_captions
        return sample_idx, caption_idx
