#!/usr/bin/env python3.7

import argparse
from pathlib import Path
from typing import Callable, Dict

import numpy as np
import seaborn as sns

from utils import colors

import matplotlib
from matplotlib.figure import Figure
matplotlib.use("Agg")
import matplotlib.pyplot as plt


CALL_DICT: Dict[str, Callable] = {}


def histogram(refs: np.ndarray, news: np.ndarray, ratios: np.ndarray,
              metric: str, c: int, ref_folder: str, new_folder: str,
              nbins: int = 100) -> Figure:
    fig = plt.figure(figsize=(14, 9))
    ax = fig.gca()
    ax.set_xlim([ratios.min(), ratios.max()])
    ax.set_xlabel(metric)
    ax.set_ylabel(f"Ratio between {Path(ref_folder).name} and {Path(new_folder).name}")
    ax.grid(True, axis='y')
    ax.set_title(f"{metric} histograms")

    bins = np.linspace(ratios.min(), ratios.max(), nbins)

    ax.hist(ratios, bins, alpha=0.8, label=f"{Path(ref_folder).name}/{Path(new_folder).name}-{c}",
            color=colors[0])
    ax.legend()

    fig.tight_layout()

    return fig


CALL_DICT['histogram'] = histogram


def scatter(refs: np.ndarray, news: np.ndarray, ratios: np.ndarray,
            metric: str, c: int, ref_folder: str, new_folder: str,
            nbins: int = 100) -> Figure:
    xs: np.ndarray = np.linspace(news.min(), news.max(), 10000)

    fig = plt.figure(figsize=(14, 9))
    ax = fig.gca()
    ax.set_xlim([0, news.max() * 1.1])
    ax.set_ylim([0, refs.max() * 1.1])
    ax.grid(True, axis='both')

    ax.scatter(news, refs, s=2, alpha=.7)
    max_max: float = max(refs.max(), news.max())
    ax.plot([0, max_max], [0, max_max],
            linewidth=.5, label="y=x", color=colors[0])

    # # Regressions
    polys_degrees = [1, 2, 3]
    for i, p in enumerate(polys_degrees):
        z = np.polyfit(news, refs, p)

        nice_z = [f"{v: 5.03e}" for v in z]
        str_z = " ".join(nice_z)
        padded_z = f"{str_z:>{(max(polys_degrees) + 1) * 11}}"

        poly = np.poly1d(z)

        preds = poly(news)
        err = np.abs(preds - refs)

        color = colors[i + 1]

        ax.plot(xs, poly(xs),
                linewidth=.75,
                label=f"y=polyfit_z(x), " +
                      f"z={padded_z}, " +
                      f"err_std={err.std():.2f}, " +
                      f"err_mean={err.mean():.2f}, ",
                color=color)
        lower_bound, upper_bound = 0.9 * poly(xs), 1.1 * poly(xs)
        ax.plot(xs, lower_bound, xs, upper_bound,
                linewidth=.5,
                linestyle='--',
                color=color)
        ax.fill_between(xs, lower_bound, upper_bound, color=color, alpha=.3)

    ax.set_title(f"Scatter between {Path(ref_folder).name} and {Path(new_folder).name}")
    ax.set_xlabel("News")
    ax.set_ylabel("Preds")
    ax.legend(prop={'family': 'monospace', 'size': 9})

    fig.tight_layout()

    return fig


CALL_DICT['scatter'] = scatter


def kde(refs: np.ndarray, news: np.ndarray, ratios: np.ndarray,
        metric: str, c: int, ref_folder: str, new_folder: str,
        nbins: int = 100) -> Figure:
    fig = plt.figure(figsize=(14, 9))
    ax = fig.gca()
    ax.set_xlim([0, news.max() * 1.1])
    ax.set_ylim([0, refs.max() * 1.1])
    ax.grid(True, axis='both')
    # ax.set_aspect("equal")

    sns.kdeplot(news, refs,
                cmap="Blues", shade=True, shade_lowest=False, gridsize=100)
    ax.scatter(news, refs, s=1.5, alpha=.3)
    # max_max: float = max(refs.max(), news.max())
    # ax.plot([0, max_max], [0, max_max],
    #         linewidth=.5, label="y=x", color=colors[0])

    ax.set_title(f"Scatter between {Path(ref_folder).name} and {Path(new_folder).name}")
    ax.set_xlabel("News")
    ax.set_ylabel("Preds")
    # ax.legend(prop={'family': 'monospace', 'size': 9})

    fig.tight_layout()

    return fig


CALL_DICT['kde'] = kde


def get_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description='Plot the stats, with different options or types of plots.')
    parser.add_argument('--plot_type', type=str, required=True, choices=CALL_DICT.keys())
    parser.add_argument('--new_folder', type=str, required=True)
    parser.add_argument('--ref_folder', type=str, required=True)
    parser.add_argument('--save_dest', type=str, required=True, help="The file where to save the plot")
    parser.add_argument('--metric', type=str, required=True)
    parser.add_argument('--num_classes', type=int, required=True)
    parser.add_argument('--source', type=str, required=True)

    parser.add_argument("--nbins", type=int, default=100)

    parser.add_argument("--debug", action="store_true", help="Dummy for compatibility")

    args = parser.parse_args()

    print(args)

    return args


def main() -> None:
    args = get_args()

    npzfile = np.load(args.source)

    refs = npzfile["refs"]
    news = npzfile["news"]
    ratios = npzfile["ratios"]

    res_fig = CALL_DICT[args.plot_type](refs, news, ratios,
                                        metric=args.metric, c=args.num_classes, ref_folder=args.ref_folder,
                                        new_folder=args.new_folder, nbins=args.nbins)

    dest_path = Path(args.save_dest)
    dest_path.parent.mkdir(parents=True, exist_ok=True)
    res_fig.savefig(str(dest_path))


if __name__ == '__main__':
    main()
