import logging
import os

import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

import utils
from data.synthetic import load_synthetic_data
from evaluation import accumulate_distributions
from evaluation.evaluator import Evaluator
from utils.visualization import scatter

logger = logging.getLogger('custom')


class SyntheticEvaluator(Evaluator):
    """ Evaluates variational autoencoders from this package on the synthetic
    dataset. """

    def __init__(self, config, **kwargs):
        super().__init__(config=config, **kwargs)
        self.result_flags = config.result_flags_synthetic_data

    @torch.no_grad()
    def evaluate(self, model, epoch):
        model.eval()
        super().evaluate(model, epoch)
        self._visualize_input(model)

    def _visualize_input(self, model):
        if not self.result_flags.get('visualize_input'):
            return

        dataset = self._load_data()
        x, s, output = self._save_input_output_to_disk(model, dataset)
        _visualize_input(x=x, y=s['y'], save_dir=self.artifact_dir,
                         split=self.split)

    def _get_forward_data(self, model, dataset, prior_samples=2000, **kwargs):
        """
        Saves forward-data on CPU.
        """
        output = {}
        bs = 8192
        loader = DataLoader(dataset, bs, shuffle=False, pin_memory=True)

        # collect data iteratively to reduce memory load
        for idx, inp in tqdm(enumerate(loader), total=len(loader),
                             desc='Obtaining evaluation data'):
            inp = utils.to_device(inp, self.device)
            x, s = inp
            new_output = model.forward(x, eval=True, k=1, **kwargs)
            prior_gen = model.ancestral_sampling_from_prior(k=prior_samples, **kwargs)
            new_output['ancestral_samples'] = utils.update(
                new_output['ancestral_samples'], prior_gen)

            if idx == 0:
                # initialize
                output['posterior'] = new_output['posterior']
                output['prior'] = new_output['prior']
                output['reconstruction'] = new_output['reconstruction']
                output['ancestral_samples'] = new_output['ancestral_samples']
            else:
                # continuously collect output
                output['posterior'] = accumulate_distributions(
                    old=output['posterior'],
                    new=new_output['posterior'])
                output['prior'] = accumulate_distributions(
                    old=output['prior'],
                    new=new_output['prior'])
                output['reconstruction'] = accumulate_distributions(
                    old=output['reconstruction'],
                    new=new_output['reconstruction'])
                output['ancestral_samples'] = accumulate_distributions(
                    old=output['ancestral_samples'],
                    new=new_output['ancestral_samples'])

        return output

    def _save_input_output_to_disk(self, model, dataset):
        output = self._get_forward_data(model, dataset)
        inp = dataset[:]
        x, s = inp
        save_path = os.path.join(self.artifact_dir, 'input_output.pt')
        torch.save({'input': inp, 'output': output}, save_path)
        logger.debug('\n====> Saved experiment data (input, output) at:\n'
                     f'{save_path}\n')
        return x, s, output

    def _load_data(self, *args, **kwargs):
        dataset, _ = load_synthetic_data(mode=self.split)
        return dataset


def _visualize_input(x, y, save_dir, split):
    logger.debug('\n====> Visualizing input:')
    if len(y.unique()) < 4:
        color_dict = {0: '#FFA630', 1: '#588B8B', 2: '#3089ff'}
    else:
        color_dict = None
    properties = {'s': 150,
                  'linewidth': 0,
                  'alpha': 1.0,
                  'color_dict': color_dict}

    for idx, cur_x in enumerate(x):
        if len(cur_x.size()) > 1:
            # otherwise given modality are labels
            f, ax = plt.subplots(1, 1, figsize=(5, 5))
            scatter(ax, x=cur_x, y=y, **properties)
            ax.set_xticks([]), ax.set_yticks([])

            dir_ = os.path.join(save_dir, f'input_x{idx + 1}_{split}.png')
            plt.savefig(dir_, format='png', dpi=500, transparent=True,
                        bbox_inches='tight')
            utils.shell_command_for_download(dir_)
            plt.close()
