import logging
import os

import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision

import utils
from hyperparams.load import get_config
from utils.visualization import process_caption

logger = logging.getLogger('custom')
config = get_config()


class CaptionToImageSampler:
    def __init__(self, model, device, save_dir, dataset):
        self.model = model
        self.device = device
        self.save_dir = save_dir
        self.dataset = dataset

    def run(self):
        captions, top_level_samples = self._infer_top_level_samples()
        images = self._generate_images(top_level_samples)
        self._save_data(captions, images)
        return captions, images

    def _save_data(self, captions, images):
        if config.save_generated_data:
            save_dir = os.path.join(self.save_dir, 'data')
            os.makedirs(save_dir, exist_ok=True)
            data = {'captions': captions, 'images': images}
            torch.save(data, os.path.join(save_dir, 'caption_to_image.pt'))

    def _generate_images(self, samples):
        vae = self.model.vaes[0]
        assert len(samples.size()) == 3  # n_classes x K x g_dim
        mode_iterator = [None] + [v for v in range(1, vae.num_levels)]
        images = {}
        for mode_layer in mode_iterator:
            tmp = []
            for cur_samples in samples:
                ancestral_samples = vae.generate(cur_samples,
                                                 mode_layer=mode_layer)
                cur_images = ancestral_samples[0]['samples']
                tmp.append(cur_images)
            images[mode_layer] = torch.stack(tmp)
        return images

    def _select_conditioning_captions(self, i=0, j=0, n_classes=30):
        """
        Note that there are 10 captions for each image
        :param i: select the i-th image-captions-pair for each class
        :param j: select the j-th caption for each image-captions-pair
        """
        # Select first caption for each of the classes
        caption_features, captions = [], []
        y = self.dataset.s['y'].unique()
        for cur_y in y[:n_classes]:
            idx = (self.dataset.s['y'] == cur_y)
            x2 = self.dataset.x[1][idx]
            paths = self.dataset.s['caption_paths'][idx]
            caption_features.append(x2[i][j])
            with open(paths[i], 'r') as f:
                caption = f.read().splitlines()
            captions.append(caption[j])
        caption_features = torch.stack(caption_features).to(self.device)
        return captions, caption_features

    def _infer_top_level_samples(self):
        captions, caption_features = self._select_conditioning_captions()
        top_level_samples = self._perform_bottom_up_pass(caption_features)
        return captions, top_level_samples

    def _perform_bottom_up_pass(self, caption_features):
        """
        :return: shape captions N x K x D
        """
        vae = self.model.vaes[1]
        _, posterior = vae.bottom_up(caption_features, k=100)
        top_level_samples = posterior['samples'].transpose(0, 1)
        return top_level_samples


class CaptionToImagePlotter:

    def __init__(self, dataset, save_dir, split):
        self.dataset = dataset
        self.save_dir = os.path.join(save_dir, 'qual_results', 'c2i')
        self.split = split
        os.makedirs(self.save_dir, exist_ok=True)

    def run(self, captions, images):
        # For each mode_layer
        for mode_layer, cur_images in images.items():
            # For each conditioning caption
            for idx, (cpt, cur_img) in enumerate(zip(captions, cur_images)):
                grid = self._make_grid(cur_img)
                self._create_figure(grid, cpt, idx, mode_layer)

    @staticmethod
    def _make_grid(im):
        grid = torchvision.utils.make_grid(
            tensor=im, nrow=int(np.sqrt(im.size(0))), padding=0)
        return grid

    def _create_figure(self, grid, caption, index, mode_layer):
        fig, ax = plt.subplots(figsize=(8.27, 8.27))
        ax.set_xticks([]), ax.set_yticks([])
        ax.set_axis_off()
        plt.title(process_caption(caption))
        im = utils.to_np(grid).transpose((1, 2, 0))
        ax.imshow(im)
        name = f'c2i_{index}_{self.split}_mode_layer_{mode_layer}.png'
        save_path = os.path.join(self.save_dir, name)
        plt.savefig(save_path, format='jpg', dpi=150,
                    pil_kwargs={'quality': 90}, transparent=True,
                    bbox_inches='tight')
        plt.close()
