import logging
import os
from typing import List

import matplotlib.pyplot as plt
import torch

import utils
from data.cub.main_ft import load_cub_ft_data
from evaluation.features.misc import make_grid_from_paths
from hyperparams.load import get_config

logger = logging.getLogger('custom')
config = get_config()


class UnconditionalSampler:
    def __init__(self, model, save_path):
        self.model = model
        self.save_path = save_path

    def get_samples(self, n=2, k=100) -> List[torch.tensor]:
        """
        :param n: samples from prior
        :param k: samples from lower-level variable
        :return: samples for both modalities, each with shape N x K x D
            N: samples from p(g)
            K: samples from p(z_1|g)
            D: Dimension of p(x_1|z_1)
        """
        prior_samples = self.sample_prior(k=n, temp=0.5)
        model_samples = self.generate_modalities(prior_samples, k=k)
        return model_samples

    def sample_prior(self, k=100, temp=0.5):
        unconditional_prior = self.model.sample_from_prior(k=k, temp=temp)
        samples = unconditional_prior['samples']
        return samples

    def generate_modalities(self, samples, **kwargs):
        fx1, fx2 = self._generate_modalities(samples, **kwargs)
        gens = [fx1, fx2]
        if config.save_generated_data:
            os.makedirs(self.save_path, exist_ok=True)
            torch.save(gens,
                       os.path.join(self.save_path, f'prior_to_image.pt'))
        return gens

    def _generate_modalities(self, cur_samples, k=10):
        assert len(self.model.vaes) == 2
        features = []
        for vae in self.model.vaes:
            ancestral_samples = vae.generate(cur_samples, k=k)
            gens = ancestral_samples[0]['samples']
            gens = gens.permute(1, 0, -1)
            features.append(gens)
        return features


class UnconditionalSamplesPlotter:
    def __init__(self, save_path, split, device):
        self.save_path = save_path
        self.device = device
        self.fig = None
        self.gs = None

        # Load all sentences for each image (not just their average) for nearest-neighbor lookup
        self.dataset, _ = load_cub_ft_data(mode=split, average=False)

        os.makedirs(save_path, exist_ok=True)

    def make_plot(self, samples):
        for i, (x1, x2) in enumerate(zip(*samples)):
            self._make_figure()
            self._plot_images(x1)
            self._plot_captions(x2)
            self._close_figure(i)

    def _make_figure(self):
        # size is landscape A4
        fig = plt.figure(constrained_layout=True, figsize=(11.69, 8.27))
        self.gs = fig.add_gridspec(1, 2,
                                   top=1.0, left=0.00, right=1.0, bottom=0,
                                   wspace=0, hspace=0)
        self.fig = fig

    def _plot_captions(self, samples, border=0.2):
        ax = self.fig.add_subplot(self.gs[1])
        # Reduce samples to avoid crowded figure
        samples = samples[:30]
        sample_idx, caption_idx = self._find_nearest_neighbor_caption(samples)
        caption_paths = self.dataset.s['caption_paths'][sample_idx]
        i = 0
        for path, idx in zip(caption_paths, caption_idx):
            with open(path, 'r') as f:
                captions = f.read().splitlines()
            r = i / caption_paths.size  # relative location, from 0 to 1
            f = (1 - border) / (1 + border)  # factor to squeeze into reduced interval
            ax.text(x=0,
                    y=border + r * f,
                    s=utils.prepare_string_for_latex(captions[idx]),
                    fontsize=6,
                    transform=ax.transAxes,
                    horizontalalignment='left',
                    verticalalignment='center')
            i += 1
        ax.set_xticks([]), ax.set_yticks([])
        ax.set_axis_off()

    def _find_nearest_neighbor_caption(self, features):
        n_captions = 10
        msg = 'code assumes ten captions per sample'
        assert self.dataset.x[1].size(1) == n_captions, msg
        d = utils.get_distance(s=self.dataset.x[1],
                               q=features,
                               device=self.device)
        d = d.view(-1, d.size(-1))
        idx = d.argmin(0)
        # Decipher location in original array
        sample_idx = torch.div(idx, n_captions, rounding_mode='floor')
        caption_idx = idx % n_captions
        return sample_idx, caption_idx

    def _plot_images(self, samples):
        d = utils.get_distance(s=self.dataset.x[0],
                               q=samples,
                               device=self.device)
        idx = d.argmin(0)
        paths = self.dataset.s['image_paths'][idx]
        grid, _ = make_grid_from_paths(paths=paths)

        im = grid.numpy().transpose((1, 2, 0))
        ax = self.fig.add_subplot(self.gs[0])
        ax.imshow(im)
        ax.set_xticks([]), ax.set_yticks([])
        ax.set_axis_off()

    def _close_figure(self, i):
        save_path = os.path.join(self.save_path, f'prior_samples_{i}.png')
        plt.savefig(save_path, format='png', dpi=500, transparent=True,
                    bbox_inches='tight')
        plt.close()
