import os

import torch

import utils
from hyperparams.load import get_config
from utils.visualization.images import make_image_grid

config = get_config()


class PriorToImagePlotter:
    def __init__(self, model, save_paths):
        self.model = model
        self.save_paths = save_paths
        os.makedirs(self.save_paths['images'], exist_ok=True)

    def build_plots(self):
        samples = self.sample_prior()
        images = self.generate_images(samples)
        self.plot(images)

    def plot(self, images):
        """ 10x10 grid for p(x1|g) """
        for temp, outer_v in images.items():
            for mode, inner_v in outer_v.items():
                name = f'images_temp_{temp}_mode_{mode}'
                save_path = os.path.join(self.save_paths['images'], name)
                make_image_grid(inner_v, save_path)

    def sample_prior(self):
        samples = {}
        for temp in [0.5, 0.7, 1.0]:
            unconditional_prior = self.model.sample_from_prior(k=100, temp=temp)
            samples[temp] = unconditional_prior['samples']
        return samples

    def generate_images(self, samples):
        num_levels = self.model.vaes[0].num_levels
        mode_iterator = [None] + [v for v in range(1, num_levels)]
        images = utils.rec_defaultdict()
        for temp, cur_samples in samples.items():
            for mode_layer in mode_iterator:
                ancestral_samples = self._generate_images(
                    cur_samples, mode_layer)
                images[temp][mode_layer] = ancestral_samples[0]['samples']
        if config.save_generated_data:
            os.makedirs(self.save_paths['data'], exist_ok=True)
            torch.save(images,
                       os.path.join(self.save_paths['data'], f'prior_to_image.pt'))
        return images

    def _generate_images(self, cur_samples, mode_layer):
        ancestral_samples = self.model.vaes[0].generate(
            cur_samples, mode_layer=mode_layer)
        return ancestral_samples
