import logging
import os

import numpy as np
import torch
import torchvision
from matplotlib import pyplot as plt

import utils
from hyperparams.load import get_config
from utils.visualization import process_caption

logger = logging.getLogger('custom')
config = get_config()


class CaptionToImageSampler:
    def __init__(self, model, dataset, save_dir, device):
        self.model = model
        self.dataset = dataset
        self.save_dir = save_dir
        self.device = device

    def get_data(self):
        captions, caption_features = self._select_conditioning_captions()
        images = self._generate_images(caption_features)
        self._save_data(captions, images)
        return captions, images

    def _save_data(self, captions, images):
        if config.save_generated_data:
            save_dir = os.path.join(self.save_dir, 'data')
            os.makedirs(save_dir, exist_ok=True)
            data = {'captions': captions, 'images': images}
            torch.save(data,
                       os.path.join(save_dir, 'caption_to_image.pt'))

    def _generate_images(self, caption_features):
        h = self._infer_h(caption_features)
        images = self._infer_images(h)
        return images

    def _infer_h(self, caption_features):
        h2 = self.model.encoder['backbone_x2'](caption_features)
        bs = h2.size(0)
        h_zeros = torch.zeros((bs, 512,)).to(self.device)
        h = torch.cat([h_zeros, h2], dim=-1)
        return h

    def _infer_images(self, h, k=100):
        _, g = self.model.encoder['encoder_g'](h, k)
        _, z = self.model.decoder[f'g_to_z1'](g)
        _, images = self.model.decoder['z1_to_x1'](z)
        images = images.permute(1, 0, 2, 3, 4)  # N x K x 3 x 64 x 64
        return images

    def _select_conditioning_captions(self, i=0, j=0):
        """
        Note that there are 10 captions for each image
        :param i: select the i-th image-captions-pair for each class
        :param j: select the j-th caption for each image-captions-pair
        """
        # Select first caption for each of the classes
        caption_features, captions = [], []
        y = self.dataset.s['y'].unique()
        for cur_y in y[:30]:
            idx = (self.dataset.s['y'] == cur_y)
            x2 = self.dataset.x[1][idx]
            paths = self.dataset.s['caption_paths'][idx]
            caption_features.append(x2[i][j])
            with open(paths[i], 'r') as f:
                caption = f.read().splitlines()
            captions.append(caption[j])
        caption_features = torch.stack(caption_features).to(self.device)
        return captions, caption_features


class CaptionToImagePlotter:
    def __init__(self, save_path, split):
        self.save_path = os.path.join(save_path, 'qual_results', 'c2i')
        os.makedirs(self.save_path, exist_ok=True)
        self.split = split

    def run(self, captions, images):
        # For each conditioning caption
        for idx, (cpt, img) in enumerate(zip(captions, images)):
            grid = self._make_grid(img)
            self._create_figure(grid, cpt, idx)

    @staticmethod
    def _make_grid(im):
        grid = torchvision.utils.make_grid(
            tensor=im, nrow=int(np.sqrt(im.size(0))), padding=0)
        return grid

    def _create_figure(self, grid, caption, index):
        fig, ax = plt.subplots(figsize=(8.27, 8.27))
        ax.axis("off")
        caption = process_caption(caption, threshold=40)
        plt.title(caption, fontsize=20)
        im = utils.to_np(grid).transpose((1, 2, 0))
        ax.imshow(im)
        save_path = os.path.join(
            self.save_path, f'c2i_{index}_{self.split}.png')
        plt.savefig(save_path, format='jpg', dpi=150, transparent=True,
                    pil_kwargs={'quality': 90}, bbox_inches='tight')
        plt.close()
