from copy import deepcopy
from typing import Dict
import pathlib

import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from pylab import rcParams
import seaborn as sns
import pandas as pd


# Global plotting settings
rcParams['figure.figsize'] = 12, 4
sns.set_style("whitegrid")
sns.set_context("notebook", font_scale=1.8)
# Important for paper compatibility
matplotlib.rcParams['font.family'] = "serif"
matplotlib.rcParams['ps.useafm'] = True
matplotlib.rcParams['pdf.use14corefonts'] = True
matplotlib.rcParams['text.usetex'] = True
legend_bool = "brief"


def plot_single_across_steps(df_input: pd.DataFrame, logdir: pathlib.Path, method: str):
    # General settings of figures
    rcParams['figure.figsize'] = 8, 6
    sns.set_style("whitegrid")
    sns.set_context("notebook", font_scale=1.6)
    # Important for paper compatibility
    matplotlib.rcParams['font.family'] = "serif"
    matplotlib.rcParams['ps.useafm'] = True
    matplotlib.rcParams['pdf.use14corefonts'] = True
    matplotlib.rcParams['text.usetex'] = True

    df = deepcopy(df_input)
    print(f"df: {df.head()}")

    for result_type in df["result name"].unique():
        df_select = df.loc[df["result name"] == result_type]
        # Iterate over tasks
        for task_id in df["task uid"].unique():
            df_select2 = df_select.loc[df_select["task uid"] == task_id]
            y_values = df_select2["result"].values
            # y_logscale = np.log(y_values + 1e-5)
            x_values = df_select2["step"].values

            for mode in ["regular", "log"]:

                result_type_print = result_type.replace("_", "-")
                data = pd.DataFrame(y_values, x_values, columns=[f"Task {str(task_id)}-{result_type_print}"])
                plt.figure()
                if mode == "log":
                    # plt.xscale('log')
                    plt.yscale('log')
                else:
                    plt.yscale('linear')
                g = sns.lineplot(data=data, palette="tab10", linewidth=2.5)
                fig = g.get_figure()
                print(f'Saving plot in logdir / {str(result_type_print).replace(" ","")}_single_plot.pdf')
                fig.savefig(logdir / f'task-{str(task_id)}-{str(result_type_print).replace(" ","")}_{mode}_single_plot.pdf',
                            bbox_inches='tight', pad_inches=0
                )
                plt.close(fig)
                plt.close()


def plot_plasticity(df_input: pd.DataFrame, logdir: pathlib.Path, method: str):
    """
    Plot plasticity of multiple experiments in different colors
    :param df_input:
    :param logdir:
    :param method:
    :return:
    """
    # TODO: select plasticity rows
    # TODO: plot area under the curve
    # TODO: add alpha
    # TODO:
    NotImplementedError("not implemented yet.")


def plot_task_1_2_joint(df_input: pd.DataFrame, logdir: pathlib.Path, method: str):
    # General settings of figures
    rcParams['figure.figsize'] = 8, 6
    sns.set_style("whitegrid")
    sns.set_context("notebook", font_scale=1.6)
    # Important for paper compatibility
    matplotlib.rcParams['font.family'] = "serif"
    matplotlib.rcParams['ps.useafm'] = True
    matplotlib.rcParams['pdf.use14corefonts'] = True
    matplotlib.rcParams['text.usetex'] = True

    df = deepcopy(df_input)
    print(f"df: {df.head()}")

    for result_type in df["result name"].unique():
        df_select = df.loc[df["result name"] == result_type]
        # Iterate over tasks
        plt.figure()
        for mode in ["regular", "log"]:
            x_switch_values = []
            x_values, y_values = None, None
            for task_id in sorted(df["task uid"].unique()):
                df_select2 = df_select.loc[df_select["task uid"] == task_id]
                if y_values is None:
                    y_values = df_select2["result"].values
                else:
                    y_values = np.concatenate((y_values, df_select2["result"].values), axis=0)

                if x_values is None:
                    x_values = df_select2["step"].values
                    x_switch_values.append(x_values[-1])
                else:
                    x_values = np.concatenate((x_values, df_select2["step"].values + x_switch_values[-1]), axis=0)
                    x_switch_values.append(x_values[-1])

            result_type_print = result_type.replace("_", "-")
            data = pd.DataFrame(y_values, x_values, columns=[f"{result_type_print}"])

            if mode == "log":
                plt.yscale('log')
            else:
                plt.yscale('linear')
            g = sns.lineplot(data=data, palette="tab10", linewidth=2.5)
            for x_switch in x_switch_values:
                    plt.axvline(x=x_switch, color='k', linestyle='--')
            fig = g.get_figure()
            print(f'Saving plot in logdir / {str(result_type_print).replace(" ","")}_joint_plot.pdf')
            fig.savefig(
                logdir / f'{str(result_type_print).replace(" ","")}_{mode}_joint_plot.pdf',
                bbox_inches='tight', pad_inches=0
                )
            plt.close(fig)
            plt.close()
