"""Visualization utilities for diffeq solutions and models

TODO: this is a mess and too tailored to the Trainer class, make more general
"""
import glob
import numpy as np
import matplotlib.pyplot as plt

from pathlib import Path
from distutils.spawn import find_executable
from mpl_toolkits.axes_grid1 import make_axes_locatable

# ----------------------------- Add some ~class~ to the font -----------------------------
# NOTE: User needs a working system installation of LaTeX in order to produce plots with
# the font families. If no installation exists, default to the MathJax font in matplotlib.
#
# ... consult here https://matplotlib.org/stable/tutorials/text/usetex.html for more info
if find_executable('latex'):  # WARNING: this check needs more testing
    plt.rcParams.update({
        "text.usetex": True,
        "font.family": "sans-serif",
        "font.sans-serif": ["Helvetica"]
    })
# ----------------------------------------------------------------------------------------


def visualize_solver_heatmap(predictions, dataset, save_prefix, config):
    """Handles visualization of NNDE solver solution heatmaps
    """
    # unpack plotting parameters
    xlabel = dataset.xlabel  # y-axis text label
    ylabel = dataset.ylabel  # x-axis text label
    text_fs = config.get("text_fs", 30)  # text fontsize
    num_fs = config.get("num_fs", 30)  # numerical fontsize
    nticks = config.get("nticks", 5)  # number of ticks per axis
    title = config.get("title", "Heatmap")  # figure title
    padding = config.get("padding", 50)  # title padding
    rowcol_aspect = (dataset.domain_shape[0]/dataset.domain_shape[1])
    figsize = config.get("figsize") or (int(10*rowcol_aspect), 10)
    
    heatmap = predictions.reshape(dataset.domain_shape)

    # create plot based on dimensionality of data  
    plt.figure(figsize=figsize)
    plt.title(f"{title}", pad=padding, fontsize=text_fs)
    plt.xlabel(xlabel, fontsize=text_fs)
    plt.ylabel(ylabel, fontsize=text_fs)
    im = plt.imshow(
        heatmap, 
        cmap=config.get("cmap", "inferno"), 
        vmin=heatmap.min(), 
        vmax=heatmap.max(),
        origin='lower',
        aspect='equal' if heatmap.shape[0] == np.sqrt(np.prod(heatmap.shape)) else 'auto'
    )

    # adjust the ticks
    plt.xticks(
        np.linspace(0, dataset.domain_shape[1]-1, nticks, dtype=int),
        np.linspace(*dataset.domain_lims[1], nticks).round(3), 
        fontsize=num_fs
    )
    plt.yticks(
        np.linspace(0, dataset.domain_shape[0]-1, nticks, dtype=int),
        np.linspace(*dataset.domain_lims[0][::-1], nticks).round(3), 
        fontsize=num_fs
    )

    # colorbar stuff
    divider = make_axes_locatable(plt.gca())
    cax = divider.append_axes("right", size="5%", pad=0.05)
    cbar = plt.colorbar(im, cax=cax)
    cbar.ax.tick_params(labelsize=num_fs)

    # save the figure
    plt.tight_layout()
    plt.savefig(
        f"{save_prefix}.{config.get('format', 'png')}",
        format=config.get("format", "png"), 
        bbox_inches=config.get('bbox', None),
        dpi=config.get('dpi', 200) 
    )
    plt.close()


def visualize_solver_phase_space(predictions, dataset, save_prefix, config):
    """Visualize phase space of ODE solvers
    """
    # unpack plotting parameters
    title = config.get("title", "Phase Space")  # plot title
    xlabel = dataset.xlabel  # y-axis text label
    ylabel = dataset.ylabel  # x-axis text label
    text_fs = config.get("text_fs", 30)  # text fontsize
    num_fs = config.get("num_fs", 30)  # numerical fontsize
    coords = dataset.viz_coords
    figsize = config.get("figsize") or (10, 10)

    # create plot based on dimensionality of data  
    plt.figure(figsize=figsize) 
    plt.title(title, fontdict={'fontsize': text_fs})
    plt.xlabel(xlabel, fontsize=text_fs)
    plt.ylabel(ylabel, fontsize=text_fs)
    plt.plot(*dataset.sol_data[:, coords].T, 'k-', label='groundtruth')
    plt.plot(*predictions[:, coords].T, 'r--', label='prediction')
    plt.scatter(
        *dataset.initial_cond[coords], 
        marker='*', c='r', 
        s=config.get("pt_size", 10),
        label='initial condition'
    )
    plt.xticks(fontsize=num_fs)
    plt.yticks(fontsize=num_fs)
    plt.ticklabel_format(style='sci', axis='both', scilimits=(-2,2), useMathText=True)
    plt.legend()
    plt.tight_layout()

    # save the figure
    plt.savefig(
        f"{save_prefix}.{config.get('format', 'png')}",
        bbox_inches=config.get('bbox', None),
        format=config.get("format", "png"), 
        dpi=config.get('dpi', 200)
    )
    plt.close()


def visualize_residual_heatmap(residual, residual_pred, dataset, save_prefix, config):
    """Handles visualization of NNDE solver residual maps
    """
    # unpack plotting parameters
    xlabel = dataset.xlabel  # y-axis text label
    ylabel = dataset.ylabel  # x-axis text label
    text_fs = config.get("text_fs", 30)  # text fontsize
    num_fs = config.get("num_fs", 30)  # numerical fontsize
    nticks = config.get("nticks", 5)  # number of ticks per axis
    padding = config.get("padding", 50)  # title padding
    rowcol_aspect = (dataset.domain_shape[0]/dataset.domain_shape[1])
    figsize = config.get("figsize") or (int(10*rowcol_aspect), 10)
    
    for data, lbl in zip([residual, residual_pred], ["True Error", "Predicted Error"]):
        heatmap = data.reshape(dataset.domain_shape)

        # create plot based on dimensionality of data  
        plt.figure(figsize=figsize) 
        plt.title(f"{lbl.replace('-', ' ')}", pad=padding, fontsize=text_fs)
        plt.xlabel(xlabel, fontsize=text_fs)
        plt.ylabel(ylabel, fontsize=text_fs)
        im = plt.imshow(
            heatmap, 
            cmap=config.get("cmap", "inferno"), 
            vmin=heatmap.min(), 
            vmax=heatmap.max(),
            origin='lower',
            aspect='equal' if heatmap.shape[0] == np.sqrt(np.prod(heatmap.shape)) else 'auto'
        )

        # adjust the ticks
        plt.xticks(
            np.linspace(0, dataset.domain_shape[1]-1, nticks, dtype=int),
            np.linspace(*dataset.domain_lims[1], nticks).round(3), 
            fontsize=num_fs
        )
        plt.yticks(
            np.linspace(0, dataset.domain_shape[0]-1, nticks, dtype=int),
            np.linspace(*dataset.domain_lims[0][::-1], nticks).round(3), 
            fontsize=num_fs
        )

        # colorbar stuff
        divider = make_axes_locatable(plt.gca())
        cax = divider.append_axes("right", size="5%", pad=0.05)
        cbar = plt.colorbar(im, cax=cax)
        cbar.ax.tick_params(labelsize=num_fs)

        # save the figure
        plt.tight_layout()
        plt.savefig(
            f"{save_prefix}_{lbl}.{config.get('format', 'png')}",
            format=config.get("format", "png"), 
            bbox_inches=config.get('bbox', None),
            dpi=config.get('dpi', 200)
        )
        plt.close()


def visualize_residual_phase_space(residual, residual_pred, dataset, save_prefix, config):
    """Visualize squared error of ODE solvers in phase
    """
    # unpack plotting parameters
    title = config.get("title", "Phase Space")  # plot title
    xlabel = dataset.xlabel  # y-axis text label
    ylabel = dataset.ylabel  # x-axis text label
    text_fs = config.get("text_fs", 30)  # text fontsize
    num_fs = config.get("num_fs", 30)  # numerical fontsize
    figsize = config.get("figsize") or (10, 10)
    coords = dataset.viz_coords

    # create plot based on dimensionality of data  
    plt.figure(figsize=figsize) 
    plt.title(title, fontdict={'fontsize': text_fs})
    plt.xlabel(xlabel, fontsize=text_fs)
    plt.ylabel(ylabel, fontsize=text_fs)
    plt.plot(*residual[:, coords].T, 'k-', label='groundtruth')
    plt.plot(*residual_pred[:, coords].T, 'r--', label='prediction')
    plt.scatter(
        *residual[0, coords], 
        marker='*', c='r', 
        s=config.get("pt_size", 10),
        label='initial condition'
    )
    plt.xticks(fontsize=num_fs)
    plt.yticks(fontsize=num_fs)
    plt.ticklabel_format(style='sci', axis='both', scilimits=(-2,2), useMathText=True)
    plt.legend()
    plt.tight_layout()

    # save the figure
    plt.savefig(
        f"{save_prefix}_ResEstimation.{config.get('format', 'png')}",
        bbox_inches=config.get('bbox', None),
        format=config.get("format", "png"),
        dpi=config.get('dpi', 200)
    )
    plt.close()


def visualize_error_correction(enc, metrics_dir, save_dir, config):
    """Plot validation metrics across all orders and checkpoints

    NOTE: this function is intended for post-training analysis 
    -> does not take pure metrics data as arguments!
    """
    text_fs = config.figures.get('text_fs', 30)
    num_fs = config.figures.get('num_fs', 30)
    legend_fs = config.figures.get('legend_fs', 30)
    log_scale = config.figures.get("log_scale", 0)
    format = config.figures.get('format', 'png')
    figsize = config.get("figsize") or (10, 10)
    colormap = plt.get_cmap("tab10").colors
    Path(save_dir).mkdir(exist_ok=True)

    # load metrics data and sort into orders
    # store data as (k, v) = (order, tuple(data, (start, stop, skip)))
    metrics_files = glob.glob(f"{metrics_dir}/{enc}/*{config.model.get('prefix')}*")
    metrics_dict = {order: [] for order in range(config.get('orders', 0) + 1)}
    for file in metrics_files:
        stem = Path(file).stem  # extract info from the filename
        fields = stem.split('_')
        skip = int(fields[-1].split('-')[1])
        data = np.load(file)

        if fields[0] == config.model.get("prefix"):  # check if the prefix specifies the order
            save_prefix = stem
            stop = int(fields[-2].split('-')[1])
            metrics_dict[0].append((data, (0, stop, skip)))
        else:  # the order is specified in the prefix
            stop = int(fields[1].split('-')[-1])
            start = stop - int(fields[-2].split('-')[1])
            metrics_dict[int(fields[0].split('-')[1])].append((data, (start, stop, skip)))

    # failed to load any data (no base model metrics)
    if len(metrics_dict[0]) == 0:
        print(f"[VISUALIZATION] Failed to load metrics data for {enc}")
        return

    for metric, eq in {
        #### Below are the nice LaTeX version of the equations
        # "RMSE": r"$\sqrt{\langle (\Phi - \Phi_N)^2 \rangle}$",
        # "MAE": r"$\langle |\Phi - \Phi_N| \rangle$",
        # "Max-AE": r"$max\{|\Phi - \Phi_N|\}$",
        # "MSE": r"$\langle |\Phi - \Phi_N|^2 \rangle$",
        # "Relative Error": r"$\langle ||\Phi - \Phi_N||_2^2 \rangle / \langle ||\Phi||_2^2\rangle$" 
        #### Below are the equations consistent with the manuscript notation
        "RMSE": r"$\sqrt{|\delta z|^2}$",
        "MAE": r"$|\delta z|$",
        "Max-AE": r"$\max|\delta z|$",
        "MSE": r"$L$",
        "Relative Error": r"$||\delta z||_2^2 / ||\Phi||_2^2$"
    }.items():

        fig = plt.figure(figsize=figsize, facecolor='white')
        plt.xlabel(config.figures.get("xlabel", "Epochs"), fontsize=text_fs)
        plt.ylabel(eq, fontsize=text_fs)
        plt.xticks(fontsize=num_fs)
        plt.yticks(fontsize=num_fs)

        # TODO refine legend labels
        for order in metrics_dict:
            for i, (metrics_data, (start, stop, skip)) in enumerate(metrics_dict[order]):

                # extremely hacky, im very sorry for the ugliness. MSE is not recorded at intervals
                # it is recorded every iteration
                if metric == 'MSE':
                    skip = config.figures.get("mse_skip", 100)
                    metric_linedata = metrics_data[metric][::skip]
                    stop -= 1
                else:
                    metric_linedata = metrics_data[metric]
                    plt.scatter(np.arange(start, stop+1, skip),
                        metrics_data[metric], 
                        color=colormap[order]
                    )

                plt.plot(
                    np.arange(start, stop+1, skip),
                    metric_linedata, 
                    color=colormap[order], 
                    label=f"Order-{order}" if i == 0 else '_nolegend_'
                ) 
                if order > 0:
                    plt.axvline(start, linestyle=':', color=colormap[order], alpha=0.5)

        plt.xlim(0, stop)
        plt.ticklabel_format(axis="x", style="sci", scilimits=(0,0), useMathText=True)
        plt.gca().yaxis.offsetText.set_fontsize(num_fs)
        plt.gca().xaxis.offsetText.set_fontsize(num_fs)
        plt.legend(loc=config.figures.get('legend_loc', 'best'), fontsize=legend_fs)

        if log_scale > 0:
            plt.yscale('log')
        if log_scale > 1:
            plt.xscale('log')

        fig.savefig(
            f"{save_dir}/{save_prefix}_{metric.replace(' ', '-')}.{format}",
            bbox_inches=config.figures.get('bbox', None),
            format=format,
            transparent=False,
            dpi=config.figures.get('dpi', 200)
        )
        plt.close()


def visualize_metrics(metrics, save_prefix, config):
    """Plot given loss values from training
    """
    text_fs = config.figures.get('text_fs', 30)
    num_fs = config.figures.get('num_fs', 30)
    log_scale = config.figures.get("log_scale", 0)
    format = config.figures.get('format', 'png')
    figsize = config.get("figsize") or (10, 10)

    for metric, eq in {
        "RMSE": r"$\sqrt{\langle (\phi - \phi_N)^2 \rangle}$",
        "MAE": r"$\langle |\phi - \phi_N| \rangle$",
        "Relative Error": r"$\langle ||\phi - \phi_N||_2^2 \rangle / \langle ||\phi||_2^2\rangle$" 
    }.items():
        plt.figure(figsize=figsize, facecolor='white')
        plt.plot(np.arange(len(metrics[metric])), metrics[metric], label=metric)
        plt.title(eq, fontdict={'fontsize': text_fs})
        plt.xlabel(config.figures.get("xlabel", "Epochs"), fontsize=text_fs)
        plt.ylabel(metric, fontsize=text_fs)
        plt.xticks(fontsize=num_fs)
        plt.yticks(fontsize=num_fs)

        if log_scale > 0:
            plt.yscale('log')
        if log_scale > 1:
            plt.xscale('log')

        plt.savefig(
            f"{save_prefix}_{metric.replace(' ', '-')}.{format}",
            bbox_inches=config.figures.get('bbox', None),
            format=format,
            dpi=config.figures.get('dpi', 200),
            transparent=False
        )
        plt.close()
        