import os

import torch
from matplotlib import pyplot as plt

import utils
from hyperparams.load import get_config

config = get_config()


class ImageToImageSampler:
    n_classes = 5  # from how many classes to sample

    def __init__(self, model, dataset, save_dir, device):
        self.model = model
        self.dataset = dataset
        self.save_dir = save_dir
        self.device = device

    def run(self):
        cond_images = self._sample_conditioning_images()
        gen_images = self._generate_images(cond_images)
        self._save_data(cond_images, gen_images)
        return cond_images, gen_images

    def _save_data(self, cond_images, gen_images):
        if config.save_generated_data:
            save_dir = os.path.join(self.save_dir, 'data')
            os.makedirs(save_dir, exist_ok=True)
            data = {'cond_images': cond_images, 'gen_images': gen_images}
            torch.save(data,
                       os.path.join(save_dir, 'image_to_image.pt'))

    def _generate_images(self, cond_images, k=100):
        h1 = self.model.encoder['backbone_x1'](cond_images)
        _, z1 = self.model.encoder['encoder_z1'](h1, k)
        _, gen_images = self.model.decoder['z1_to_x1'](z1)
        gen_images = gen_images.permute(1, 0, 2, 3, 4)  # N x K x 3 x 64 x 64
        return gen_images

    def _sample_conditioning_images(self):
        """ Select first image for each of the first five class. """
        images = []
        y = self.dataset.s['y'].unique()[:self.n_classes]
        for cur_y in y[:5]:
            idx = (self.dataset.s['y'] == cur_y)
            path = self.dataset.x[0][idx][0]
            x1 = self.dataset.open_image(path)
            images.append(x1)
        images = torch.stack(images)  # N x 3 x 64 x 64
        images = images.to(self.device)
        return images


class ImageToImagePlotter:
    def __init__(self, save_path, split):
        self.save_path = os.path.join(save_path, 'qual_results', 'i2i')
        os.makedirs(self.save_path, exist_ok=True)
        self.split = split

    def run(self, cond_images, gen_images):
        # For every class
        for idx, (cond, gen) in enumerate(zip(cond_images, gen_images)):
            self._make_fig()
            self._plot_cond_image(cond)
            self._plot_generated_images(gen)
            self._save_fig(idx)

    def _make_fig(self, n=9):
        fig = plt.figure(
            # Lower grid must be quadratic
            figsize=(10, 10 * (n + 1) / n))
        self.gs = fig.add_gridspec(nrows=2, ncols=1,
                                   top=1.0, left=0.00, right=1.0, bottom=0,
                                   wspace=0, hspace=0,
                                   height_ratios=[1, n])
        self.fig = fig
        plt.rc(r'text', usetex=True)

    def _plot_cond_image(self, image):
        subgrid = self.gs[0].subgridspec(nrows=1, ncols=3,
                                         width_ratios=[1, 1, 1.8],
                                         wspace=0.0, hspace=0.0)

        # Add text
        ax = self.fig.add_subplot(subgrid[0])
        ax.text(x=0.0, y=0.5,
                s=r'\textbf{Condition:}',
                fontsize=50,
                ha='left', va='center')
        ax.axis("off")

        # Add image
        ax = self.fig.add_subplot(subgrid[1])
        im = utils.to_np(image).transpose((1, 2, 0))
        ax.imshow(im, aspect="equal")
        ax.axis("off")

    def _plot_generated_images(self, images):
        n = 10
        subgrid = self.gs[1].subgridspec(nrows=n, ncols=n,
                                         wspace=0.0, hspace=0.0)
        for idx, cur_image in enumerate(images):
            im = utils.to_np(cur_image).transpose((1, 2, 0))
            row = idx // n
            col = idx % n
            ax = self.fig.add_subplot(subgrid[row, col])
            ax.axis("off")
            ax.imshow(im, aspect="equal")

    def _save_fig(self, index):
        plt.subplots_adjust(wspace=0, hspace=0)
        save_path = os.path.join(self.save_path, f'i2i_{index}_{self.split}.jpg')
        plt.savefig(save_path, format='jpg', dpi=150, transparent=True,
                    pil_kwargs={'quality': 90}, bbox_inches='tight')
        plt.close()
