import logging
import os

import torch
from matplotlib import gridspec
from matplotlib import pyplot as plt

import utils

logger = logging.getLogger('custom')


class Visualizer:

    def __init__(self, dataset, save_dir):
        self.dataset = dataset
        self.save_dir = os.path.join(
            save_dir, 'qual_results', 'visualizations'
        )
        os.makedirs(self.save_dir, exist_ok=True)

    def run(self, data, mode_layer=None):
        self.plot_images_by_class(
            self.dataset.get_image_tensors(), self.dataset.s['y'], name='x1')
        name = f'x1|x2'
        if mode_layer:
            name += f'-mode-layer-{mode_layer}'
        self.plot_images_by_class(
            data['x1|x2'], self.dataset.s['y'],
            name=name
        )

    def plot_images_by_class(self, x: torch.Tensor, y: torch.Tensor, name=None):
        """ Create image grid, where rows indicate classes. """
        class2img = {}
        labels = [v.item() for v in y.unique()]
        for cur_y in labels:
            idx = (cur_y == y)
            class2img[cur_y] = x[idx]

        # A4 landscape: 11.69, 8.27
        nrows, ncols = len(labels), 50
        fig = plt.figure(figsize=(11.69, 8.27 * nrows / ncols))
        gs = gridspec.GridSpec(
            nrows=nrows, ncols=ncols, wspace=0.0, hspace=0.0, top=0.90,
            bottom=0.00, left=0.0, right=1.0, figure=fig)

        # Populate grid with images
        for i, (cur_y, row) in enumerate(class2img.items()):
            for j in range(min(ncols, len(row))):
                ax = fig.add_subplot(gs[i, j])
                ax.axis('off')
                if j == 0:
                    ax.set_title(f'class {cur_y}', size=5)
                img = utils.to_np(row[j].permute(1, 2, 0))
                ax.imshow(img)

        fig.suptitle(name, size=20, y=1.0)
        save_path = os.path.join(self.save_dir,
                                 rf'images_by_classes_overview_{name}.jpg')
        plt.savefig(save_path,
                    dpi=700,
                    format='jpg',
                    pil_kwargs={'quality': 90},
                    transparent=True,
                    bbox_inches='tight')
        plt.close()
