import logging
import os

import matplotlib.pyplot as plt
import torch

import utils
from hyperparams.load import get_config

logger = logging.getLogger('custom')
config = get_config()


class ImageToCaptionSampler:
    def __init__(self, model, dataset, save_dir, device):
        self.model = model
        self.dataset = dataset
        self.device = device
        self.save_dir = save_dir

    def run(self):
        images, top_level_samples = self._infer_top_level_samples()
        caption_features = self._generate_caption_features(top_level_samples)
        self._save_data(images, caption_features)
        return images, caption_features

    def _save_data(self, images, caption_features):
        if config.save_generated_data:
            save_dir = os.path.join(self.save_dir, 'data')
            os.makedirs(save_dir, exist_ok=True)
            data = {'images': images, 'caption_features': caption_features}
            torch.save(data,
                       os.path.join(save_dir, 'image_to_caption.pt'))

    def _infer_top_level_samples(self):
        images = self._select_conditioning_images()
        top_level_samples = self._perform_bottom_up_pass(images)
        return images, top_level_samples

    def _perform_bottom_up_pass(self, image_features):
        """
        :return: shape samples: N x K x D
        """
        vae = self.model.vaes[0]
        # Sample variation over g
        _, posterior = vae.bottom_up(image_features, k=10)
        top_level_samples = posterior['samples'].transpose(0, 1)
        return top_level_samples

    def _select_conditioning_images(self, n_classes=5):
        # Select first image for each class
        y = self.dataset.s['y'].unique()
        images = []
        i = 0
        for cur_y in y[:n_classes]:
            idx = (self.dataset.s['y'] == cur_y)
            x1 = self.dataset.x[0][idx][i]
            x1 = self.dataset.open_image(x1)
            images.append(x1)
        images = torch.stack(images).to(self.device)
        return images

    def _generate_caption_features(self, top_level_samples):
        vae = self.model.vaes[1]
        mode_iterator = [None] + [v for v in range(1, vae.num_levels)]
        caption_features = {}
        for mode_layer in mode_iterator:
            ancestral_samples = vae.generate(top_level_samples,
                                             mode_layer=mode_layer)
            rec = ancestral_samples[0]['dist'].mean
            caption_features[mode_layer] = rec  # N x K x D
        return caption_features


class ImageToCaptionPlotter:
    def __init__(self, dataset, save_dir, split, device):
        self.dataset = dataset
        self.save_dir = os.path.join(save_dir, 'qual_results', 'image_to_caption')
        self.split = split
        self.device = device
        os.makedirs(self.save_dir, exist_ok=True)

    def run(self, images, caption_features):
        # For each mode_layer
        for mode_layer, cpt in caption_features.items():
            # For each conditioning image
            for index, cur_cpt in enumerate(cpt):
                self._make_figure()
                self._plot_image(images[index])
                self._plot_captions(cur_cpt)
                self._save(
                    name=f'i2c_{index}_{self.split}_mode_layer_{mode_layer}'
                )

    def _save(self, name, format='pdf'):
        save_path = os.path.join(self.save_dir, f'{name}.{format}')
        plt.savefig(save_path, format=format, dpi=145, transparent=True,
                    bbox_inches='tight')
        plt.close()

    def _make_figure(self):
        fig = plt.figure(figsize=(11.69, 8.27))
        plt.rc(r'text', usetex=True)
        self.gs = fig.add_gridspec(1, 2,
                                   top=1.0, left=0.00, right=1.0, bottom=0,
                                   wspace=0, hspace=0)
        self.fig = fig

    def _plot_image(self, image):
        ax = self.fig.add_subplot(self.gs[0])
        image = utils.to_np(image).transpose((1, 2, 0))
        ax.imshow(image)
        ax.set_xticks([]), ax.set_yticks([])
        ax.set_axis_off()

    def _plot_captions(self, features, border=0.2):
        ax = self.fig.add_subplot(self.gs[1])
        sample_idx, caption_idx = self._find_nearest_neighbor_caption(features)
        caption_paths = self.dataset.s['caption_paths'][sample_idx]
        i = 0
        for path, idx in zip(caption_paths, caption_idx):
            with open(path, 'r') as f:
                cur_captions = f.read().splitlines()
            r = i / caption_paths.size  # relative location, from 0 to 1
            f = (1 - border) / (1 + border)  # factor to squeeze into reduced interval
            string = utils.prepare_string_for_latex(cur_captions[idx])
            ax.text(x=0,
                    y=border + r * f,
                    s=string,
                    fontsize=12,
                    transform=ax.transAxes,
                    horizontalalignment='left',
                    verticalalignment='center')
            i += 1
        ax.set_xticks([]), ax.set_yticks([])
        ax.set_axis_off()

    def _find_nearest_neighbor_caption(self, features):
        n_captions = 10
        msg = 'code assumes ten captions per sample'
        assert self.dataset.x[1].size(1) == n_captions, msg
        # N1 x n_captions x D x N2
        d = utils.get_distance(s=self.dataset.x[1],  # N1 x n_captions x D
                               q=features,  # N2 x D
                               device=self.device)
        d = d.view(-1, d.size(-1))
        idx = d.argmin(0)
        # Decipher location in original array
        sample_idx = torch.div(idx, n_captions, rounding_mode='floor')
        caption_idx = idx % n_captions
        return sample_idx, caption_idx
