import sys
import shutil
import argparse
import logging
import pickle
import os
import warnings
from typing import Any, Dict, Iterable, List, Tuple

import matplotlib.figure
import torch
import matplotlib
import matplotlib.pyplot as plt

import empirical_error
import utils.logging
import utils.path_config
import utils.plotting

# https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
assert sys.version_info.minor >= 9, "relies on new argpase syntax"

FigAx = Tuple[matplotlib.figure.Figure, matplotlib.axes.Axes]

paths = utils.path_config.get_paths()

warnings.filterwarnings("error")

logging_level = 15

logger = logging.getLogger(__name__)
logger.setLevel(logging_level)

log_dir_str = paths["logs"]
standard_streamhandler = utils.logging.get_standard_streamhandler()
standard_filehandler = utils.logging.get_standard_filehandler(log_dir_str)

logger.addHandler(standard_streamhandler)
logger.addHandler(standard_filehandler)

subfigure_sizes = {
    "large": 5,
    "medium": 1.9,
    "small": 1.5
}


def _plot_kernel(mults: List[float],
                 means1: torch.Tensor,
                 means2: torch.Tensor,
                 stds1: torch.Tensor) -> FigAx:
    num_mults = len(mults)
    kwargs = {
        "color": "green",
        "alpha": .25
    }
    sz = subfigure_sizes["medium"]
    fig, axs = plt.subplots(1, 1, figsize=(sz, sz), squeeze=False)
    axs[0, 0].plot(mults, means1, linestyle='--', marker='o')
    axs[0, 0].plot(mults, means2, linestyle='-', color="k")
    axs[0, 0].fill_between(mults,
                           means1 - 1.96 * stds1,
                           means1 + 1.96 * stds1,
                           **kwargs)
    axs[0, 0].grid(axis="y")
    axs[0, 0].ticklabel_format(style='sci', axis='y', scilimits=(-1, 1))
    axs[0, 0].ticklabel_format(useMathText=True, axis='x')
    axs[0, 0].xaxis.set_ticks(torch.arange(1, num_mults + 1, 2))

    # https://stackoverflow.com/questions/11577665/change-x-axes-scale-in-matplotlib
    axs[0, 0].yaxis.major.formatter._useMathText = True
    axs[0, 0].xaxis.major.formatter._useMathText = True
    fig.tight_layout()
    return fig, axs


def plot_results(results_to_plot: Dict[str, Any]) -> Dict[str, FigAx]:
    num_results = len(results_to_plot)
    mults = results_to_plot[0]["mults"]

    terminal_train_evaluate_losses = torch.stack([_["evaluate_criterion_values"]["train"] for _ in results_to_plot])
    terminal_test_evaluate_losses = torch.stack([_["evaluate_criterion_values"]["test"] for _ in results_to_plot])

    terminal_train_optimize_losses = torch.stack([_["optimize_criterion_values"]["train"] for _ in results_to_plot])
    terminal_test_optimize_losses = torch.stack([_["optimize_criterion_values"]["test"] for _ in results_to_plot])

    test_evaluate_means = terminal_test_evaluate_losses.mean(0)
    test_optimize_means = terminal_test_optimize_losses.mean(0)

    train_evaluate_means = terminal_train_evaluate_losses.mean(0)
    train_optimize_means = terminal_train_optimize_losses.mean(0)

    test_evaluate_stds = terminal_test_evaluate_losses.std(0)
    test_optimize_stds = terminal_test_optimize_losses.std(0)

    plots = dict()

    means1 = test_evaluate_means
    means2 = train_evaluate_means
    stds1 = test_evaluate_stds
    fig, axs = _plot_kernel(mults, means1, means2, stds1)
    plots["terminal_error_by_width_evaluate"] = (fig, axs)

    means1 = test_optimize_means
    means2 = train_optimize_means
    stds1 = test_optimize_stds
    fig, axs = _plot_kernel(mults, means1, means2, stds1)
    plots["terminal_error_by_width_optimize"] = (fig, axs)
    return plots


def plot_experiment_results(results_filedir: str,
                            plots_filedir: str,
                            experiment_ident: str):
    pickle_filename = experiment_ident + ".pkl"
    pickle_fullfilename = os.path.join(results_filedir, pickle_filename)
    with open(pickle_fullfilename, 'rb') as handle:
        results_to_plot = pickle.load(handle)

    plots = plot_results(results_to_plot)
    fig_format = "pgf"

    for plot_ident, figax in plots.items():
        fig, axs = figax
        ident = f"{experiment_ident}"
        filepath = os.path.join(plots_filedir, plot_ident)
        logger.info(f"saving {ident}.{fig_format} to {filepath} [{filepath}/{ident}.{fig_format}]")
        # logger.info(f"saving {filepath}/{ident}.{fig_format}")
        utils.plotting.smart_save_fig(fig, ident, fig_format, filepath)
        plt.close(fig)


def get_valid_experiment_idents(results_filedir: str) -> List[str]:
    valid_experiment_idents = []
    listdir = os.listdir(results_filedir)
    # if False:
    #     for idx, filename in enumerate(listdir):
    #         print(filename)

    for idx, filename in enumerate(listdir):
        # idx = 1; filename = listdir[idx]
        identifier = filename.replace(".pkl", "")
        try:
            inverted_identifier = empirical_error.invert_identifier(identifier)
            is_okay = True
        except:
            print(f"Skipping {identifier}")
            is_okay = False
        if is_okay:
            valid_experiment_idents.append(identifier)
    return valid_experiment_idents


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--seed", type=int, default=11)
    parser.add_argument("--cuda_wanted", type=bool, default=True)
    parser.add_argument('--pgf', default=True, action=argparse.BooleanOptionalAction)

    parser.add_argument("--mode", type=str, default="", help="Swallow PyCharm args")
    parser.add_argument("--port", type=str, default="", help="Swallow PyCharm args")
    parser.add_argument("-f", type=str, default="", help="Swallow IPython arg")

    args = parser.parse_args()

    project_dir = paths["project"]
    results_filedir = paths["results"]
    if args.pgf:
        font_family = "serif"
        utils.plotting.initialise_pgf_plots("pdflatex",
                                            font_family)
    plots_filedir = paths["plots"]

    valid_experiment_idents = sorted(get_valid_experiment_idents(results_filedir))
    for experiment_ident in valid_experiment_idents:
        plot_experiment_results(results_filedir,
                                plots_filedir,
                                experiment_ident)
