import logging
import os

import matplotlib.pyplot as plt
import torch

import utils
from evaluation.images.caption_to_image import CaptionToImageSampler, CaptionToImagePlotter
from utils.visualization import process_caption

logger = logging.getLogger('custom')


class CaptionToImageSamplerMdvae(CaptionToImageSampler):

    def _generate_images(self, samples):
        assert len(samples.size()) == 3  # N x K x D
        tmp = []
        for x in samples:
            x = self.model.vaes['x1'].decode(x)
            tmp.append(x['samples'])
        images = torch.stack(tmp)
        return images

    def _perform_bottom_up_pass(self, caption_features, k=100):
        """
        :return: shape captions x K x D
        """
        output = self.model.vaes['x2'].encode(caption_features, k)
        g = output['g']['samples'].transpose(0, 1)  # N x K x D
        pz1 = self.model.pz1(*self.model.pz1_params)
        n = caption_features.size(0)
        z1 = pz1.sample((k, n,)).squeeze(-2)  # K x N x D
        z1 = z1.transpose(0, 1)  # N x K x D
        z = torch.cat((z1, g), dim=-1)
        return z


class CaptionToImagePlotterMdvae(CaptionToImagePlotter):
    def run(self, captions, images):
        # For each conditioning caption
        for idx, (cpt, cur_img) in enumerate(zip(captions, images)):
            grid = self._make_grid(cur_img)
            self._create_figure(grid, cpt, idx)

    def _create_figure(self, grid, caption, index, **kwargs):
        fig, ax = plt.subplots(figsize=(8.27, 8.27))
        ax.set_xticks([]), ax.set_yticks([])
        ax.set_axis_off()
        plt.title(process_caption(caption))
        im = utils.to_np(grid).transpose((1, 2, 0))
        ax.imshow(im)
        name = f'c2i_{index}_{self.split}.png'
        save_path = os.path.join(self.save_dir, name)
        plt.savefig(save_path, format='jpg', dpi=150,
                    pil_kwargs={'quality': 90}, transparent=True,
                    bbox_inches='tight')
        plt.close()
