from copy import deepcopy

import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from pylab import rcParams
import seaborn as sns
import pandas as pd
from pathlib import Path
import json

LOGS_DIR = Path("logs")
ANALYSIS_ROOT_DIR = Path("/home/idscadmin/path_learning") / "analysis"
ANALYSIS_LOGDIR = ANALYSIS_ROOT_DIR / LOGS_DIR
ANALYSIS_PLOTS = ANALYSIS_ROOT_DIR / "plots"


# Global plotting settings
rcParams['figure.figsize'] = 9, 8
sns.set_style("whitegrid")
sns.set_context("notebook", font_scale=1.6)
# Important for paper compatibility
matplotlib.rcParams['font.family'] = "serif"
matplotlib.rcParams['ps.useafm'] = True
matplotlib.rcParams['pdf.use14corefonts'] = True
matplotlib.rcParams['text.usetex'] = True
legend_bool = "brief"


def keys_to_analyze(config):
    with open(config, "r") as analysis_fp:
        analyis_config = json.load(analysis_fp)
    return analyis_config["methods"]


# TODO: Add plot that visualizes path


def plot_model_step(df, title, analysis_config, plot_type, model_step: int = 1):
    df = deepcopy(df)
    df_plot = df.loc[df["Model-task-step"] == str(model_step)]

    for analysis_mode in ["-train", "-test"]:
        for key in keys_to_analyze(analysis_config):
            key += analysis_mode
            file_name = str(ANALYSIS_PLOTS / title.replace(" ", "")) + f"_{key}.pdf"
            print(f"FILE SAVING: {file_name}")
            if plot_type == "barplot":
                plt.figure()
                g = sns.barplot(x="Task-step", y=key, hue="Model-type", data=df_plot,
                            palette="muted")
                # g.legend_.remove()
                plt.title(title)
                plt.ylabel(key.capitalize())
                plt.savefig(file_name, bbox_inches='tight', pad_inches=0)
                plt.close()
            elif plot_type == "scatterplot":
                plt.figure()
                g = sns.scatterplot(x="Task-step", y=key, hue="Model-type", data=df_plot,
                                palette="muted")
                # g.legend_.remove()
                plt.title(title)
                plt.ylabel(key.capitalize())
                plt.savefig(file_name, bbox_inches='tight', pad_inches=0)
                plt.close()
            elif plot_type == "swarmplot":
                plt.figure()
                g = sns.swarmplot(x="Task-step", y=key, hue="Experiment", data=df_plot)
                # g.legend_.remove()
                plt.title(title)
                plt.ylabel(key.capitalize())
                plt.savefig(file_name, bbox_inches='tight', pad_inches=0)
                plt.close()
            else:
                NotImplementedError("Pick an implemented and valid plot type")


def performance_stability_plot(df_input, title, analysis_config, plot_type=None):
    """
    Here we are merging the performance of model at task step 1 with the stability analysis
    of the previous model at task step 0.
    :param df:
    :param title:
    :param plot_type:
    :return:
    """
    df = deepcopy(df_input)
    df = df.loc[df["Model-type"] == "Trained"]

    # Selecting the right entries to merge
    df_performance = deepcopy(df.loc[df["Model-task-step"] == 2])
    df_performance = df_performance.loc[df_performance["Task-step"] == 2]
    df_performance = df_performance.sort_values(by=['Experiment', 'Model-type', 'Seed'], ascending=False)
    # df_performance = df_performance.sort_values(by='Model-type', ascending=False)
    # df_performance = df_performance.sort_values(by='Seed', ascending=False)

    print("-------------------")
    print(f"Example row1: {df_performance.iloc[0]}")

    # Which values of performance to override
    df_plot = df.loc[df["Model-task-step"] == 1]
    df_plot = df_plot.loc[df_plot["Task-step"] == 2]
    df_plot = df_plot.sort_values(by=['Experiment', 'Model-type', 'Seed'], ascending=False)
    # df_plot = df_plot.sort_values(by='Model-type', ascending=False)
    # df_plot = df_plot.sort_values(by='Seed', ascending=False)

    print("-------------------")
    print(f"Example row2: {df_plot.iloc[0]}")
    df_plot["task-loss-test"] = df_performance["task-loss-test"].values

    pairs = [("absolute-sensitivity-stability-train", "task-loss-test"),
             ("stability-train", "task-loss-test"),
             ("task-loss-stability-train", "task-loss-test"),
             ("absolute-gradient-stability-train", "task-loss-test"),
             ("absolute-gradient-stability-train", "task-loss-stability-train"),
             ("task-loss-sensitivity-train", "task-loss-test"),
             ("absolute-sensitivity-train", "task-loss-test"),
             ("input-sensitivity-train", "task-loss-test"),
             ("output-sensitivity-train", "task-loss-test"),
             ("input-output-sensitivity-train", "task-loss-test"),
             ("task-loss-sensitivity-train", "task-loss-test"),
             ("task-loss-sensitivity-train", "task-loss-stability-train"),
             ("task-loss-sensitivity-train", "absolute-gradient-stability-train"),
             ("task-loss-sensitivity-train", "stability-train")]
    for entry in pairs:
        # Initialize figure and ax
        fig, ax = plt.subplots()
        # Set the scale of the x-and y-axes
        ax.set(xscale="log", yscale="log")
        g = sns.scatterplot(x=entry[0], y=entry[1], hue="Experiment", data=df_plot,
                        palette="muted")
        # g.legend_.remove()
        plt.title(title)
        plt.xlabel(entry[0].replace("-", " "))
        plt.ylabel(entry[1].replace("-", " "))
        plt.savefig(str(ANALYSIS_PLOTS / title.replace(" ", "")) + f"_{entry[0]}_{entry[1]}.pdf",
                    bbox_inches='tight', pad_inches=0)
        plt.close()


def step_number_effect(df_input, title, analysis_config, plot_type=None):
    # Which values of performance to override
    df = deepcopy(df_input)
    df_plot = df.loc[df["Model-task-step"] == 1]
    df_plot = df_plot.loc[df_plot["Task-step"] != 1]
    df_plot = df_plot.loc[df_plot["Model-type"] == "Trained"]
    df_plot = df_plot.sort_values(by=['Experiment', 'Model-type', 'Seed'], ascending=False)

    to_plot = ["stability-train",
               "input-output-sensitivity-train",
               "task-loss-sensitivity-train",
               "task-loss-stability-train",
               "absolute-sensitivity-stability-train",
               "stability-test",
               "input-output-sensitivity-test",
               "task-loss-sensitivity-test",
               "task-loss-stability-test",
               "absolute-sensitivity-stability-test"
               ]

    for keyword in to_plot:

        fig, ax = plt.subplots()

        # Set the scale of the x-and y-axes
        # ax.set(yscale="log")
        g = sns.swarmplot(x="Task-step", y=keyword, hue="Model-type",
                            data=df_plot, palette="muted")
        # g.legend_.remove()
        plt.title(title)
        plt.xlabel("Task step")
        ylabel_kwd = keyword.replace("-", " ")
        plt.ylabel(ylabel_kwd)
        plt.savefig(str(ANALYSIS_PLOTS / title.replace(" ", "")) + f"_step_effect_{keyword}.pdf",
                    bbox_inches='tight', pad_inches=0)
        plt.close()


def plot_toy_dataset(dir, dataloader):
    import matplotlib
    import matplotlib.pyplot as plt
    from pylab import rcParams
    import seaborn as sns
    print(f"Saving synthetic dataset in directory: {str(dir)}")
    # Global plotting settings
    rcParams['figure.figsize'] = 8, 8
    sns.set_style("whitegrid")
    sns.set_context("talk", font_scale=1.8)
    # Important for paper compatibility
    matplotlib.rcParams['font.family'] = "serif"
    matplotlib.rcParams['ps.useafm'] = True
    matplotlib.rcParams['pdf.use14corefonts'] = True
    matplotlib.rcParams['text.usetex'] = True

    plt.figure()
    # Get data
    df = gather_toy_results(dataloader)
    sns.scatterplot(x="x", y="y", hue="Class", palette="muted", data=df)
    plt.title("Synthetic dataset")
    plt.xlabel("x-axis")
    plt.ylabel("y-axis")
    plt.savefig(dir / "synthetic_dataset.pdf",
                bbox_inches='tight', pad_inches=0)
    plt.close()


def gather_toy_results(dataloader):
    dataloader_plotting = deepcopy(dataloader)
    df = pd.DataFrame(columns=["x", "y", "Class"])
    count = 0
    for (data, target) in dataloader_plotting:
        data = data.cpu().numpy()
        target = target.cpu().numpy()

        for i in range(target.shape[0]):
            df.loc[count] = [data[i, 0],
                             data[i, 1],
                             str(int(target[i]))]
            count += 1
    return df


def discrete_vs_continuous(df, title, analysis_config, plot_type=None):
    df = deepcopy(df)

    plt.figure()
    # TODO: loop over unique entries of experiments in df
    # TODO: loop over unique entries of dataset
    for experiment_name in ["Blur N=2", "Blur N=4"]:
        index = experiment_name[-1]
        df_plot = df.loc[df["Model-task-step"] == index]
        df_plot = df_plot.loc[df_plot["Task-step"] == index]

        df_plot = df_plot.loc[df_plot["Dataset"] == "MNIST"]
        df_plot = df_plot.loc[df_plot["Experiment"] == experiment_name]

        if plot_type == "barplot":
            sns.barplot(x="Task-step", y="Task-loss", hue="Task-step", data=df_plot,
                        palette="muted")
    plt.title(title)
    plt.ylabel("Test loss")
    plt.savefig(str(ANALYSIS_PLOTS / title.replace(" ", "")) + "_discrete_vs_continuous.pdf",
                bbox_inches='tight', pad_inches=0)
    plt.close()


def vector_plot(df, title, analysis_config, plot_type=None):
    from sklearn.decomposition import PCA, TruncatedSVD
    from sklearn.manifold import TSNE
    # TODO: Plot weights vs. weight-derivatives
    # TODO: project to lower dimensional space
    # TODO: first iteration - take only first two weight dimensions and plot against each other
    for experiment_name in ["Toy-equal-batch-resnet-experiment-reverse", "Toy-equal-batchrelu-experiment-reverse"]:
        for task_step in ["1", "2"]:
            df_plot = deepcopy(df)
            print(f"Experiment: {experiment_name}, task step: {task_step}")
            df_plot = df_plot.loc[df_plot["Experiment"] == experiment_name]
            df_plot = df_plot.loc[df_plot["Task-step"] == task_step]
            df_plot = df_plot.loc[df_plot["Model-task-step"] == task_step]
            df_plot = df_plot.loc[df_plot["Seed"] == 101]
            print(f"df_plot: {df_plot['vector-plot-weights'].shape}")
            weights = np.array(df_plot["vector-plot-weights"].iloc[0])
            grads = np.array(df_plot["vector-plot-grads"].iloc[0])
            grads += 0.00001 * np.random.randn(grads.shape[0], grads.shape[1], )
            print(f"weights: {weights.shape}, grads: {grads.shape}")

            n_components = 2
            # The weights are evenly arranged, but the changes in gradients are important
            # Need to reduce dimensionality to plot
            # embedding = TSNE(n_components=2).fit(grads)
            embedding = TruncatedSVD(n_components=2, algorithm="randomized").fit(grads)
            # embedding = PCA(n_components=n_components, svd_solver='auto',
            #           whiten=False, tol=0.001).fit(grads)

            weights = embedding.transform(weights).reshape((25, 25, n_components))
            grads = embedding.transform(grads).reshape((25, 25, n_components))
            for i in range(grads.shape[0]):
                grads[i] = grads[i] / np.linalg.norm(grads[i])
            weight_task1 = np.array([[[1, 1, -1, -1]]])
            grads_task1 = weight_task1
            weight_task2 = np.array([[[-1, -1, 1, 1]]])
            grads_task2 = weight_task2

            fig, ax = plt.subplots()
            q = ax.quiver(weights[:, :, 0], weights[:, :, 1], grads[:, :, 0], grads[:, :, 1])
            q = ax.quiver(weight_task1[:, :, 0], weight_task1[:, :, 1], grads_task1[:, :, 0], grads_task1[:, :, 1],
                          color='r')
            q = ax.quiver(weight_task2[:, :, 0], weight_task2[:, :, 1], grads_task2[:, :, 0], grads_task2[:, :, 1],
                          color='g')
            # ax.quiverkey(q, X=0.3, Y=1.1, U=10,
            #              label='Quiver key, length = 10', labelpos='E')
            plt.savefig(ANALYSIS_PLOTS / str(experiment_name + "_quiver_" + task_step + ".pdf"),
                        bbox_inches='tight', pad_inches=0)
            plt.close()
