import os
from typing import List

import matplotlib
import pandas as pd
import pylab
import scipy
import seaborn as sns
from matplotlib import pyplot as plt
from matplotlib.ticker import LogLocator, AutoMinorLocator, ScalarFormatter

from scripts.evaluation.plotting.configs import names
from scripts.evaluation.plotting.configs import styles
from scripts.evaluation.plotting.configs.names import env_names


def get_dataset(neurips: bool = True) -> pd.DataFrame:
    path = "output/evaluation_dataframes/"
    all_files = []
    for file in os.listdir(path):
        if neurips:
            if "neurips" in file:
                all_files.append(file)
        else:
            if "neurips" not in file:
                all_files.append(file)
    # find the highest version number
    version_numbers = sorted([file.split("_v")[-1][:-4] for file in all_files])[-1]
    if neurips:
        dataframe_name = f"neurips_data_v{version_numbers}.csv"
    else:
        dataframe_name = f"data_v{version_numbers}.csv"
    print(f"Loading Dataset {dataframe_name}")
    return pd.read_csv(f"output/evaluation_dataframes/{dataframe_name}", index_col=0)


class MagnitudeFormatter(matplotlib.ticker.ScalarFormatter):
    def __init__(self, exponent=None):
        super().__init__()
        self._fixed_exponent = exponent

    def _set_order_of_magnitude(self):
        if self._fixed_exponent:
            self.orderOfMagnitude = self._fixed_exponent
        else:
            super()._set_order_of_magnitude()

    def _set_format(self):
        self.format = "%1.1f"


def plot(data: pd.DataFrame, env: str, methods: List[str],
         figure_width: int = 10,
         figure_height: int = 5,
         contexts: List[str] | None = None,
         metric: str = "full_rollout_mean_mse",
         reduced: bool = False,
         mode: str = "show",
         unique_name: str = None,
         formatter_magnitude: int = -4,
         use_major_formatter: bool = False,
         use_minor_formatter: bool = False,
         y_top_limit: float | None = None,
         y_bottom_limit: float | None = None,
         legend_order: List[int] | None = None,
         label_fontsize: int = 19,
         tick_fontsize: int = 16,
         linewidth: int = 3,
         boarder_linewidth: int = 3,
         border_color='darkgray',
         show_legend: bool = True,
         overwrite: bool = False):
    # filter data to only contain the specified env, methods and contexts
    data = data[data["env"] == env]
    data = data[data["method"].isin(methods)]

    if contexts is not None:
        data = data[data["context"].isin(contexts)]
    # sort alphanumerically by context
    data = data.sort_values(by="context")
    # call seaborn
    sns.set(rc={'figure.figsize': (figure_width, figure_height)})
    sns.set_theme()
    sns.set_style("whitegrid")
    plt.rcParams.update({"ytick.left": True})
    sns.lineplot(data=data, x="context", y=metric,
                 hue="method",
                 style="method",
                 palette=styles.method_colors_reduced if reduced else styles.method_colors,
                 markers=styles.method_markers_reduced if reduced else styles.method_markers,
                 markersize=10,
                 dashes=False,
                 # errorbar=("pi", 75),
                 errorbar="ci",
                 # n_boot=1000,
                 estimator=lambda scores: scipy.stats.trim_mean(scores, proportiontocut=0.0, axis=None),  # IQM
                 linewidth=linewidth,
                 )

    ax = plt.gca()

    # ticks and labels appearance
    x_labels = data["context"].unique()
    # only get numbers from context
    x_labels = ["".join([c for c in x if c.isdigit()]) for x in x_labels]
    # remove trailing zeros
    x_labels = [str(int(x)) for x in x_labels]

    plt.xticks(ticks=range(len(x_labels)), labels=x_labels)
    plt.xlabel("Anchor Time Step", fontdict={'size': label_fontsize})
    plt.ylabel(names.metric_names[metric], fontdict={'size': label_fontsize})

    # Set the fontsize  and colors for the numbers on the ticks and the offset text.
    ax.tick_params(axis='both', which='major', labelsize=tick_fontsize, colors=border_color, labelcolor="black")
    ax.tick_params(axis='both', which='minor', labelsize=tick_fontsize, colors=border_color, labelcolor="black")
    ax.yaxis.get_offset_text().set_fontsize(tick_fontsize)

    # Ticks number formatting and y scale
    plt.yscale('log')
    fmt = MagnitudeFormatter(formatter_magnitude)
    if use_major_formatter:
        ax.yaxis.set_major_formatter(fmt)
    if use_minor_formatter:
        ax.yaxis.set_minor_formatter(fmt)

    # y limits
    if y_top_limit is not None:
        plt.ylim(top=y_top_limit)
    if y_bottom_limit is not None:
        plt.ylim(bottom=y_bottom_limit)

    # boarder
    plt.gca().spines['bottom'].set_linewidth(boarder_linewidth)
    plt.gca().spines['left'].set_linewidth(boarder_linewidth)
    plt.gca().spines['top'].set_linewidth(boarder_linewidth)
    plt.gca().spines['right'].set_linewidth(boarder_linewidth)
    plt.gca().spines['right'].set_color(border_color)
    plt.gca().spines['top'].set_color(border_color)
    plt.gca().spines['bottom'].set_color(border_color)
    plt.gca().spines['left'].set_color(border_color)

    # title
    plt.title(env_names[env], fontdict={'size': label_fontsize})

    # legend
    if show_legend:
        # Get the handles and labels of the current axes
        if legend_order is not None:
            handles, labels = plt.gca().get_legend_handles_labels()
            legend = plt.legend([handles[idx] for idx in legend_order], [labels[idx] for idx in legend_order])
        else:
            legend = plt.legend()

        # Translate the labels using the dictionary
        method_names = names.method_names_reduced if reduced else names.method_names
        for text in legend.get_texts():
            original_label = text.get_text()
            translated_label = method_names.get(original_label, original_label)
            text.set_text(translated_label)
            text.set_fontsize(tick_fontsize)
        # Set the legend title
        legend.set_title('Methods')
        legend.get_title().set_fontsize(label_fontsize)
    else:
        ax.get_legend().remove()

    # save or show
    if mode == "save":
        if unique_name is not None:
            file_name = unique_name
        else:
            method_filename = "__".join(sorted(methods))
            file_name = f"{metric}___{method_filename}"
        out_path = f"output/figures/{env}/{file_name}.pdf"
        os.makedirs(os.path.dirname(out_path), exist_ok=True)
        # check if file exists and only write if overwrite is true
        if os.path.isfile(out_path) and not overwrite:
            print(f"File {out_path} already exists. Skipping...")
        else:
            # set tight layout explicit in savefig
            plt.savefig(out_path, bbox_inches='tight', pad_inches=0, )
            print(f"Saved figure to {out_path}")
    elif mode == "show":
        plt.show()
    else:
        raise ValueError(f"Unknown mode {mode}")


### PAPER PLOTS ###

def quantitative_1(reduced=False, show_legend=False):
    # legend_order = [1, 0, 2, 3, 4, 5]
    figure_width = 10
    figure_height = 7
    label_fontsize = 30
    tick_fontsize = 20
    methods = ["mgn", "cnp_mp", "mgn_history", "cnp_mp_history"]
    dp_methods_reduced = ["mgn_history_task_prop", "mgn_history", "cnp_mp_history", "egno"]
    pb_methods_reduced = ["mgn_history_task_prop", "mgn_history", "cnp_mp_history", "egno"]
    pb_ood_methods_reduced = ["mgn_history_task_prop", "mgn_history", "cnp_mp", "egno"]
    all_methods = "_val" if not reduced else ""
    for metric in ["full_rollout_mean_mse", "10_steps_mean_mse"]:
        plot(df, env="deformable_plate_v2",
             methods=dp_methods_reduced if reduced else methods,
             metric=metric,
             contexts=None,
             figure_width=figure_width,
             figure_height=figure_height,
             mode="save",
             reduced=reduced,
             unique_name=f"quantitative_1_{metric}{all_methods}_dpv2_tmlr",
             use_major_formatter=True,
             use_minor_formatter=False,
             formatter_magnitude=-4,
             label_fontsize=label_fontsize,
             y_top_limit=1.01e-3,
             tick_fontsize=tick_fontsize,
             show_legend=show_legend,
             overwrite=True)
        plt.close()
        plot(df, env="planar_bending",
             methods=pb_methods_reduced if reduced else methods,
             metric=metric,
             contexts=None,
             figure_width=figure_width,
             figure_height=figure_height,
             mode="save",
             reduced=reduced,
             unique_name=f"quantitative_1_{metric}{all_methods}_pb_tmlr",
             use_major_formatter=True,
             use_minor_formatter=False,
             formatter_magnitude=-6,
             label_fontsize=label_fontsize,
             tick_fontsize=tick_fontsize,
             show_legend=show_legend,
             overwrite=True)
        plt.close()
        plot(df, env="planar_bending_oode",
             methods=pb_ood_methods_reduced if reduced else methods,
             metric=metric,
             contexts=None,
             figure_width=figure_width,
             figure_height=figure_height,
             mode="save",
             reduced=reduced,
             unique_name=f"quantitative_1_{metric}{all_methods}_pb_ood_tmlr",
             use_major_formatter=True,
             use_minor_formatter=False,
             formatter_magnitude=-6,
             label_fontsize=label_fontsize,
             tick_fontsize=tick_fontsize,
             show_legend=show_legend,
             overwrite=True)
        plt.close()


def quantitative_2(reduced=False, show_legend=False):
    # select 3 out of the 4 plots for main paper
    figure_width = 10
    figure_height = 7
    label_fontsize = 30
    tick_fontsize = 20
    methods = ["mgn", "cnp_mp", "mgn_history", "cnp_mp_history"]
    tm_methods_reduced = ["mgn_history_task_prop", "mgn_history", "cnp_mp", "egno"]
    mofmat_methods_reduced = ["mgn_task_prop", "mgn", "cnp_mp_history", "egno"]
    tf_methods_reduced = ["mgn_history_task_prop", "mgn_history", "cnp_mp_history", "egno"]
    all_methods = "_val" if not reduced else ""
    for metric in ["full_rollout_mean_mse", "10_steps_mean_mse"]:
        plot(df, env="tissue_manipulation",
             methods=tm_methods_reduced if reduced else methods,
             metric=metric,
             figure_width=figure_width + 0.25,
             figure_height=figure_height,
             contexts=["mesh_context_002", "mesh_context_003", "mesh_context_005", "mesh_context_010", "mesh_context_015"],
             mode="save",
             reduced=reduced,
             unique_name=f"quantitative_2_{metric}{all_methods}_tm_tmlr",
             use_major_formatter=True,
             use_minor_formatter=True if metric == "full_rollout_mean_mse" else False,
             formatter_magnitude=-5 if metric == "full_rollout_mean_mse" else -6,
             show_legend=show_legend,
             label_fontsize=label_fontsize,
             tick_fontsize=tick_fontsize,
             overwrite=True)
        plt.close()
        plot(df, env="mofmat", methods=mofmat_methods_reduced if reduced else methods,
             metric=metric,
             figure_width=figure_width,
             figure_height=figure_height,
             contexts=["mesh_context_010", "mesh_context_015", "mesh_context_020", "mesh_context_025", "mesh_context_030"],
             mode="save",
             reduced=reduced,
             unique_name=f"quantitative_2_{metric}{all_methods}_mofmat_tmlr",
             use_major_formatter=True,
             use_minor_formatter=True,
             formatter_magnitude=-3,
             show_legend=show_legend,
             label_fontsize=label_fontsize,
             tick_fontsize=tick_fontsize,
             overwrite=True)
        plt.close()
        plot(df, env="teddy_fall_nopc", methods=tf_methods_reduced if reduced else methods,
             metric=metric,
             figure_width=figure_width,
             figure_height=figure_height,
             contexts=["mesh_context_010", "mesh_context_015", "mesh_context_020", "mesh_context_025", "mesh_context_030"],
             mode="save",
             reduced=reduced,
             unique_name=f"quantitative_2_{metric}{all_methods}_tf_tmlr",
             use_major_formatter=True,
             use_minor_formatter=True,
             formatter_magnitude=-4 if metric == "full_rollout_mean_mse" else -6,
             label_fontsize=label_fontsize,
             tick_fontsize=tick_fontsize,
             show_legend=show_legend,
             overwrite=True)
        plt.close()

def pb_ood():
    legend_order = [1, 0, 2]
    figure_width = 12
    figure_height = 5
    label_fontsize = 30
    tick_fontsize = 20
    for metric in ["full_rollout_mean_mse", "10_steps_mean_mse"]:
        plot(df, env="planar_bending_oodi", methods=["mgn", "mgn_task_prop", "cnp_mp"],
             metric=metric,
             contexts=None,
             figure_width=figure_width,
             figure_height=figure_height,
             mode="save",
             unique_name=f"quantitative_1_{metric}_dpv2",
             legend_order=legend_order,
             use_major_formatter=True,
             use_minor_formatter=False,
             formatter_magnitude=-4,
             label_fontsize=label_fontsize,
             y_top_limit=1.01e-3,
             tick_fontsize=tick_fontsize,
             show_legend=True,
             overwrite=True)
        plt.close()
        plot(df, env="planar_bending_oode",
             methods=["mgn", "mgn_task_prop", "cnp_mp"],
             metric=metric,
             contexts=None,
             figure_width=figure_width,
             figure_height=figure_height,
             mode="save",
             unique_name=f"quantitative_1_{metric}_pb",
             legend_order=legend_order,
             use_major_formatter=True,
             use_minor_formatter=False,
             formatter_magnitude=-6,
             label_fontsize=label_fontsize,
             tick_fontsize=tick_fontsize,
             show_legend=True,
             overwrite=True)
        plt.close()

def method_ablation():
    figure_width = 12
    figure_height = 5
    label_fontsize = 28
    tick_fontsize = 17
    plot(df, env="deformable_plate_v2", methods=["cnp_mp", "mgn_mp", "cnp", "np", "np_mp"],
         metric="full_rollout_mean_mse",
         contexts=None,
         figure_width=figure_width,
         figure_height=figure_height,
         mode="save",
         unique_name="method_ablation_dpv2",
         label_fontsize=label_fontsize,
         tick_fontsize=tick_fontsize,
         use_major_formatter=True,
         use_minor_formatter=False,
         legend_order=[1, 4, 2, 3, 0],
         formatter_magnitude=-4,
         show_legend=False,
         overwrite=True)
    plt.close()
    plot(df, env="planar_bending", methods=["cnp_mp", "mgn_mp", "cnp", "np", "np_mp"],
         metric="full_rollout_mean_mse",
         contexts=None,
         figure_width=figure_width,
         figure_height=figure_height,
         mode="save",
         unique_name="method_ablation_pb",
         legend_order=[1, 4, 2, 3, 0],
         label_fontsize=label_fontsize,
         tick_fontsize=tick_fontsize,
         use_major_formatter=True,
         use_minor_formatter=False,
         formatter_magnitude=-6,
         show_legend=False,
         overwrite=True,
         )
    plt.close()


def context_ablation():
    figure_width = 12
    figure_height = 5
    label_fontsize = 28
    tick_fontsize = 17
    plot(df, env="deformable_plate_v2", methods=["cnp_mp", "abl_cnp_mp_max_max", "abl_cnp_mp_max_trafo", "abl_cnp_mp_null_trafo"],
         metric="full_rollout_mean_mse",
         contexts=None,
         figure_width=figure_width,
         figure_height=figure_height,
         mode="save",
         unique_name="context_ablation_dpv2",
         label_fontsize=label_fontsize,
         tick_fontsize=tick_fontsize,
         legend_order=[3, 0, 1, 2],
         use_major_formatter=True,
         use_minor_formatter=True,
         formatter_magnitude=-5,
         show_legend=False,
         overwrite=True)
    plt.close()
    plot(df, env="planar_bending", methods=["cnp_mp", "abl_cnp_mp_max_max", "abl_cnp_mp_max_trafo", "abl_cnp_mp_null_trafo"],
         metric="full_rollout_mean_mse",
         contexts=None,
         figure_width=figure_width,
         figure_height=figure_height,
         mode="save",
         unique_name="context_ablation_pb",
         label_fontsize=label_fontsize,
         tick_fontsize=tick_fontsize,
         use_major_formatter=True,
         use_minor_formatter=False,
         formatter_magnitude=-7,
         show_legend=False,
         overwrite=True,
         )
    plt.close()


def all_methods_figure(env="tissue_manipulation", mode="show"):
    # for debugging and general simple plots
    # for metric in ["full_rollout_mean_mse", "10_steps_mean_mse", "20_steps_mean_mse", "30_steps_mean_mse", "50_steps_mean_mse", "80_steps_mean_mse", "100_steps_mean_mse", "120_steps_mean_mse", "150_steps_mean_mse"]:
    for metric in ["full_rollout_mean_mse", "10_steps_mean_mse"]:
        plot(df, env=env, methods=["mgn", "mgn_task_prop", "mgn_history_task_prop", "cnp", "cnp_mp", "cnp_mp_history","mgn_history"],
             metric=metric,
             contexts=None,
             mode=mode,
             overwrite=True)
        # close
        plt.close()


if __name__ == "__main__":
    # load dataframe
    # df = get_dataset(neurips=True)

    df = pd.read_csv(f"output/evaluation_dataframes/data_v3.2.csv", index_col=0)

    # quantitative_1(reduced=True, show_legend=False)
    # quantitative_2(reduced=True, show_legend=False)
    # quantitative_1(reduced=False, show_legend=False)
    # quantitative_2(reduced=False, show_legend=False)
    # method_ablation()
    context_ablation()

    # all_methods_figure(env="deformable_plate_v2", mode="show")
    # all_methods_figure(env="tissue_manipulation", mode="show")
    # all_methods_figure(env="teddy_fall_nopc", mode="show")
    # all_methods_figure(env="planar_bending", mode="show")
    # all_methods_figure(env="planar_bending_oode", mode="show")
    # all_methods_figure(env="mofmat", mode="show")
