from pathlib import Path
from typing import Dict, List

import matplotlib as mpl
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import torch

import utils

usetex = True


class DistributionVisualizer:
    """ Visualizes pass of unimodal, crossmodal or unconditional samples
    through unimodal network.
    """

    def __init__(self,
                 n_levels: int,
                 color_dict,
                 mod: Dict[str, str],
                 scales: bool = True):
        """ Sets up shared characteristics.
        :param n_levels: number of hierarchical levels
        :param mod: target 't' or conditioning 'c' modality
        :param scales: whether to use scales
        """
        self.n_levels = n_levels
        self.mod = mod
        self.scales = scales

        self.fig, self.gs = None, None
        self.properties = _setup_shared_properties(color_dict)

    def create_plot(self,
                    x: List[torch.Tensor],
                    labels: torch.Tensor,
                    output: dict,
                    *args, **kwargs):
        """ Creates and populates figure.
        :param x: input data with entries representing modalities
        :param labels: can describe class or class clusters
        :param output: for example posteriors or reconstructions
        """
        self._build_figure()
        self._populate_grid(x, output, labels)
        self._create_titles()
        _save_fig(*args, **kwargs)

    def _build_figure(self):
        fig = plt.figure(figsize=(11.69, 8.27))  # A4-width landscape
        gs = gridspec.GridSpec(self.n_levels + 2, 8, figure=fig,
                               hspace=0.15)
        plt.rc(r'text', usetex=usetex)
        fig.subplots_adjust(top=0.92, left=0.04, right=1.0,
                            wspace=0, hspace=0,
                            # make space for legend below plots
                            bottom=0.05)
        self.fig, self.gs = fig, gs
        return fig

    @staticmethod
    def _create_titles():
        # titles over several subplots need to be hard-coded as there is no
        # in-build functionality
        plt.figtext(1 / 8,  # approx. start of first of eight columns
                    0.98, 'Posteriors', va="center", ha="center", size=15)
        plt.figtext(3 / 8, 0.98, 'Priors', va="center", ha="center",
                    size=15)
        plt.figtext(5 / 8, 0.98,
                    s='Ancestral Sampling\nFrom Top-Level Posterior',
                    va="center", ha="center", size=15)
        plt.figtext(7 / 8, 0.98,
                    s='Ancestral Sampling\nFrom Top-Level Prior',
                    va="center", ha="center", size=15)

    def _populate_grid(self, x, output, labels):
        posterior = output['posterior']
        prior = output['prior']
        reconstruction = output['reconstruction']
        ancestral_samples = output['ancestral_samples']

        # input
        plot_input = True
        if self.mod['t'] == 'x1':
            i = 0
        elif self.mod['t'] == 'x2':
            i = 1
        else:
            plot_input = False
        if plot_input:
            self._plot_input(x[i], labels)

        # posteriors and reconstruction
        q = posterior[self.mod['t']][self.mod['c']]
        if q:
            self._plot_posteriors(q, labels, columns=(0, 1))

        # reconstruction
        r = reconstruction[self.mod['t']][self.mod['c']]
        if r:
            self._plot_reconstructions(r, labels, columns=(0, 1))

        # priors from inference pass
        p = prior[self.mod['t']][self.mod['c']]
        if p:
            self._plot_priors(p, labels, columns=(2, 3))

        # ancestral samples
        ac = ancestral_samples[self.mod['t']][self.mod['c']]
        if ac:
            self._plot_priors(ac[1:], labels, columns=(4, 5))
            self._plot_reconstructions(ac[0], labels, columns=(4, 5))

        # ancestral_samples from top-level unconditional samples
        if self.mod['t'] != 'joint':
            self._plot_priors(ancestral_samples[self.mod['t']]['g'][1:], columns=(6, 7))
            self._plot_reconstructions(ancestral_samples[self.mod['t']]['g'][0],
                                       columns=(6, 7))

    def _plot_posteriors(self,
                         posterior,
                         labels=None,
                         columns=(0, 1)):
        for idx, level in zip(range(self.n_levels),
                              range(self.n_levels)[::-1]):
            # conditioning
            cond = f'|{self.mod["c"]}' if level == self.n_levels - 1 \
                else f'|{self.mod["c"]}, z_{level + 1}'

            # get posterior
            if type(posterior) != list:
                # the joint posterior only exists for the top-level
                q = posterior
            else:
                q = posterior[level]

            # means
            ax = self.fig.add_subplot(self.gs[idx, columns[0]])
            ax.set_title(r'$q(\bar{'
                         rf'z_{level}'
                         r'}'
                         rf'{cond})$')
            ax.set_ylabel(f'Level {level}', size=20)
            if q:
                utils.vis.scatter(ax, x=q['dist'].mean, y=labels, **self.properties)
            if not self.scales:
                ax.set_xticks([]), ax.set_yticks([])

            # samples
            ax = self.fig.add_subplot(self.gs[idx, columns[1]])
            ax.set_title(rf'$q(z_{level}{cond})$')
            if q:
                utils.vis.scatter(ax, x=q['samples'],
                                  y=labels, **self.properties)
            if not self.scales:
                ax.set_xticks([]), ax.set_yticks([])

    def _plot_priors(self,
                     prior,
                     labels=None,
                     columns=(0, 1)):
        for idx, level in zip(range(self.n_levels),
                              range(self.n_levels)[::-1]):

            if prior == [None]:
                # no hierarchy
                continue
            if prior[level] is None:
                # unconditional prior is already saved in datastructure from
                # acestral sampling
                continue
            cond = '' if level == self.n_levels + 1 else f'|z_{level + 1}'

            # means
            ax = self.fig.add_subplot(self.gs[idx, columns[0]])
            ax.set_title(r'$p(\bar{'
                         rf'z_{level}'
                         r'}'
                         rf'{cond})$')
            utils.vis.scatter(ax, x=prior[level]['dist'].mean, y=labels,
                              **self.properties)
            if not self.scales:
                ax.set_xticks([]), ax.set_yticks([])

            # samples
            ax = self.fig.add_subplot(self.gs[idx, columns[1]])
            ax.set_title(rf'$p(z_{level}{cond})$')
            utils.vis.scatter(ax, x=prior[level]['samples'], y=labels,
                              **self.properties)
            if not self.scales:
                ax.set_xticks([]), ax.set_yticks([])

    def _plot_reconstructions(self,
                              reconstruction,
                              labels=None,
                              columns=(0, 1)):
        if reconstruction['type'] != 'categorical':
            # means
            ax = self.fig.add_subplot(self.gs[self.n_levels, columns[0]])
            ax.set_title(r'$p(\bar{'
                         rf'{self.mod["t"]}'
                         r'}'
                         rf'|{self.mod["c"]})$')
            if 0 in columns:
                ax.set_ylabel(f'Reconstruction', size=20)
            utils.vis.scatter(ax, x=reconstruction['dist'].mean,
                              y=labels, **self.properties)
            if not self.scales:
                ax.set_xticks([]), ax.set_yticks([])

            # samples
            ax = self.fig.add_subplot(self.gs[self.n_levels, columns[1]])
            ax.set_title(rf'$p({self.mod["t"]}|{self.mod["c"]})$')
            utils.vis.scatter(ax, x=reconstruction['samples'],
                              y=labels, **self.properties)
            if not self.scales:
                ax.set_xticks([]), ax.set_yticks([])

    def _plot_input(self, x, y):
        if len(x.size()) > 1:
            # otherwise modality are labels
            ax = self.fig.add_subplot(self.gs[self.n_levels + 1, 0])
            ax.set_title(rf'$p({self.mod["t"]})$')
            ax.set_ylabel(f'Input', size=20)
            utils.vis.scatter(ax, x, y, **self.properties)
            if not self.scales:
                ax.set_xticks([]), ax.set_yticks([])


def distribution_scatters_wrapper(data: dict,
                                  mod: dict,
                                  save_path,
                                  **kwargs):
    save_path = save_path + f'_{mod["t"]}_{mod["c"]}'
    q = data['output']['posterior'][mod['t']][mod['t']]
    n_levels = len(q)
    x, s = data['inp']

    # color-code classes
    dv = DistributionVisualizer(n_levels, mod=mod, **kwargs)
    dv.create_plot(x=x,
                   labels=s['y'],
                   output=data['output'],
                   save_path=save_path)

    # color-code modes
    dv = DistributionVisualizer(n_levels, mod=mod, **kwargs)
    dv.create_plot(x=x,
                   labels=s['m'],
                   output=data['output'],
                   save_path=save_path + '_modes')


def _setup_shared_properties(color_dict):
    mpl.rcParams['lines.markersize'] = 5  # s
    mpl.rcParams['lines.linewidth'] = 0
    mpl.rcParams["scatter.edgecolors"] = None
    properties = {'color_dict': color_dict,
                  'marker': '.',  # necessary for proxy artists
                  'zorder': 0}

    return properties


def _save_fig(save_path: str):
    save_path = save_path + '.png'
    plt.savefig(save_path, format='png', dpi=500, transparent=True,
                bbox_inches='tight')
    utils.shell_command_for_download(save_path)
    plt.close()
