import logging
import os

import torch
from matplotlib import pyplot as plt

import utils

logger = logging.getLogger('custom')


class CaptionVaeSampler:
    def __init__(self, model, dataset, device, save_dir):
        self.model = model
        self.dataset = dataset
        self.device = device

        self.save_dirs = {
            'output': os.path.join(save_dir, 'qual_results', 'prior_to_caption'),
            'data': os.path.join(save_dir, 'data')
        }
        for v in self.save_dirs.values():
            os.makedirs(v, exist_ok=True)

    def run(self):
        caption_features = self._get_data()
        self._make_figure()
        captions = self._plot_captions(caption_features)
        self._save_fig()
        self._save_data(captions)

    def _make_figure(self):
        # portrait: 8.27, 11.69
        fig = plt.figure(constrained_layout=True, figsize=(8.27, 11.69))
        self.gs = fig.add_gridspec(1, 2,
                                   top=1.0, left=0.00, right=1.0, bottom=0,
                                   wspace=0, hspace=0)
        self.fig = fig

    def _save_fig(self):
        save_path = os.path.join(self.save_dirs['output'], f'prior_to_captions.pdf')
        plt.savefig(save_path, format='pdf', dpi=100, transparent=True,
                    bbox_inches='tight')
        plt.close()

    def _plot_captions(self, features, border=0.2):
        captions = []
        ax = self.fig.add_subplot(self.gs[1])
        sample_idx, caption_idx = self._find_nearest_neighbor_caption(features)
        caption_paths = self.dataset.s['caption_paths'][sample_idx]
        i = 0
        for path, idx in zip(caption_paths, caption_idx):
            with open(path, 'r') as f:
                cur_captions = f.read().splitlines()
            r = i / caption_paths.size  # relative location, from 0 to 1
            f = (1 - border) / (1 + border)  # factor to squeeze into reduced interval
            cpt = utils.prepare_string_for_latex(cur_captions[idx])
            captions.append(cpt)
            ax.text(x=0,
                    y=border + r * f,
                    s=cpt,
                    fontsize=12,
                    transform=ax.transAxes,
                    horizontalalignment='left',
                    verticalalignment='center')
            i += 1
        ax.set_xticks([]), ax.set_yticks([])
        ax.set_axis_off()
        return captions

    def _find_nearest_neighbor_caption(self, features):
        """
        Note that each image is paired with ten captions
        :return:
            sample_idx: "outer" index
            caption_idx: index within the ten captions
        """
        n_captions = 10
        msg = 'code assumes ten captions per sample'
        assert self.dataset.x[1].size(1) == n_captions, msg
        # N1 x n_captions x D x N2
        d = utils.get_distance(s=self.dataset.x[1],  # N1 x n_captions x D
                               q=features,  # N2 x D
                               device=self.device)
        d = d.view(-1, d.size(-1))
        idx = d.argmin(0)
        # Decipher location in original array
        sample_idx = torch.div(idx, n_captions, rounding_mode='floor')
        caption_idx = idx % n_captions
        return sample_idx, caption_idx

    def _get_data(self):
        samples = self._sample_prior()
        caption_features = self._generate_caption_features(samples)
        return caption_features

    def _save_data(self, captions):
        data = {'captions': captions}
        dst = os.path.join(self.save_dirs['data'], 'prior_to_caption.pt')
        torch.save(data, dst)

    def _sample_prior(self):
        unconditional_prior = self.model.sample_from_prior(k=50)
        samples = unconditional_prior['samples']
        return samples

    def _generate_caption_features(self, top_level_samples):
        vae = self.model.vaes[1]
        ancestral_samples = vae.generate(top_level_samples)
        caption_features = ancestral_samples[0]['dist'].mean
        return caption_features
