import os
from typing import List

import matplotlib
import pandas as pd
import pylab
import scipy
import seaborn as sns
from matplotlib import pyplot as plt
from matplotlib.ticker import LogLocator, AutoMinorLocator, ScalarFormatter

from scripts.evaluation.plotting.configs import names
from scripts.evaluation.plotting.configs import styles
from scripts.evaluation.plotting.configs.names import env_names



def plot(data: pd.DataFrame, env: str, methods: List[str],
         figure_width: int = 10,
         figure_height: int = 5,
         contexts: List[str] | None = None,
         metric: str = "full_rollout_mean_mse",
         mode: str = "show",
         unique_name: str = None,
         formatter_magnitude: int = -4,
         use_major_formatter: bool = False,
         use_minor_formatter: bool = False,
         y_top_limit: float | None = None,
         y_bottom_limit: float | None = None,
         legend_order: List[int] | None = None,
         label_fontsize: int = 19,
         tick_fontsize: int = 16,
         linewidth: int = 3,
         boarder_linewidth: int = 3,
         border_color='darkgray',
         show_legend: bool = True,
         overwrite: bool = False):
    # filter data to only contain the specified env, methods and contexts
    data = data[data["env"] == env]
    data = data[data["method"].isin(methods)]

    if contexts is not None:
        data = data[data["context"].isin(contexts)]
    # sort alphanumerically by context
    data = data.sort_values(by="context")
    # call seaborn
    sns.set(rc={'figure.figsize': (figure_width, figure_height)})
    sns.set_theme()
    sns.set_style("whitegrid")
    plt.rcParams.update({"ytick.left": True})
    sns.lineplot(data=data, x="context", y=metric,
                 hue="method",
                 style="method",
                 palette=styles.method_colors,
                 markers=styles.method_markers,
                 markersize=10,
                 dashes=False,
                 # errorbar=("pi", 75),
                 errorbar="ci",
                 # n_boot=1000,
                 estimator=lambda scores: scipy.stats.trim_mean(scores, proportiontocut=0.25, axis=None),  # IQM
                 linewidth=linewidth,
                 )

    ax = plt.gca()

    # ticks and labels appearance
    x_labels = data["context"].unique()
    # only get numbers from context
    x_labels = ["".join([c for c in x if c.isdigit()]) for x in x_labels]
    # remove trailing zeros
    x_labels = [str(int(x)) for x in x_labels]

    plt.xticks(ticks=range(len(x_labels)), labels=x_labels)
    plt.xlabel("Anchor Time Step", fontdict={'size': label_fontsize})
    plt.ylabel(names.metric_names[metric], fontdict={'size': label_fontsize})

    # Set the fontsize  and colors for the numbers on the ticks and the offset text.
    ax.tick_params(axis='both', which='major', labelsize=tick_fontsize, colors=border_color, labelcolor="black")
    ax.tick_params(axis='both', which='minor', labelsize=tick_fontsize, colors=border_color, labelcolor="black")
    ax.yaxis.get_offset_text().set_fontsize(tick_fontsize)

    # Ticks number formatting and y scale
    plt.yscale('log')
    fmt = MagnitudeFormatter(formatter_magnitude)
    if use_major_formatter:
        ax.yaxis.set_major_formatter(fmt)
    if use_minor_formatter:
        ax.yaxis.set_minor_formatter(fmt)

    # y limits
    if y_top_limit is not None:
        plt.ylim(top=y_top_limit)
    if y_bottom_limit is not None:
        plt.ylim(bottom=y_bottom_limit)

    # boarder
    plt.gca().spines['bottom'].set_linewidth(boarder_linewidth)
    plt.gca().spines['left'].set_linewidth(boarder_linewidth)
    plt.gca().spines['top'].set_linewidth(boarder_linewidth)
    plt.gca().spines['right'].set_linewidth(boarder_linewidth)
    plt.gca().spines['right'].set_color(border_color)
    plt.gca().spines['top'].set_color(border_color)
    plt.gca().spines['bottom'].set_color(border_color)
    plt.gca().spines['left'].set_color(border_color)

    # title
    plt.title(env_names[env], fontdict={'size': label_fontsize})

    # legend
    if show_legend:
        # Get the handles and labels of the current axes
        if legend_order is not None:
            handles, labels = plt.gca().get_legend_handles_labels()
            legend = plt.legend([handles[idx] for idx in legend_order], [labels[idx] for idx in legend_order])
        else:
            legend = plt.legend()

        # Translate the labels using the dictionary
        for text in legend.get_texts():
            original_label = text.get_text()
            translated_label = names.method_names.get(original_label, original_label)
            text.set_text(translated_label)
            text.set_fontsize(tick_fontsize)
        # Set the legend title
        legend.set_title('Methods')
        legend.get_title().set_fontsize(label_fontsize)
    else:
        ax.get_legend().remove()

    # save or show
    if mode == "save":
        if unique_name is not None:
            file_name = unique_name
        else:
            method_filename = "__".join(sorted(methods))
            file_name = f"{metric}___{method_filename}"
        out_path = f"output/figures/{env}/{file_name}.pdf"
        os.makedirs(os.path.dirname(out_path), exist_ok=True)
        # check if file exists and only write if overwrite is true
        if os.path.isfile(out_path) and not overwrite:
            print(f"File {out_path} already exists. Skipping...")
        else:
            # set tight layout explicit in savefig
            plt.savefig(out_path, bbox_inches='tight', pad_inches=0, )
            print(f"Saved figure to {out_path}")
    elif mode == "show":
        plt.show()
    else:
        raise ValueError(f"Unknown mode {mode}")


### PAPER PLOTS ###

def quantitative_1():
    legend_order = [1, 0, 2]
    figure_width = 12
    figure_height = 5
    label_fontsize = 30
    tick_fontsize = 20
    for metric in ["full_rollout_mean_mse", "10_steps_mean_mse"]:
        plot(df, env="deformable_plate_v2", methods=["mgn", "mgn_task_prop", "cnp_mp"],
             metric=metric,
             contexts=None,
             figure_width=figure_width,
             figure_height=figure_height,
             mode="save",
             unique_name=f"quantitative_1_{metric}_dpv2",
             legend_order=legend_order,
             use_major_formatter=True,
             use_minor_formatter=False,
             formatter_magnitude=-4,
             label_fontsize=label_fontsize,
             y_top_limit=1.01e-3,
             tick_fontsize=tick_fontsize,
             show_legend=False,
             overwrite=True)
        plt.close()
        plot(df, env="planar_bending",
             methods=["mgn", "mgn_task_prop", "cnp_mp"],
             metric=metric,
             contexts=None,
             figure_width=figure_width,
             figure_height=figure_height,
             mode="save",
             unique_name=f"quantitative_1_{metric}_pb",
             legend_order=legend_order,
             use_major_formatter=True,
             use_minor_formatter=False,
             formatter_magnitude=-6,
             label_fontsize=label_fontsize,
             tick_fontsize=tick_fontsize,
             show_legend=False,
             overwrite=True)
        plt.close()


def quantitative_2():
    # select 3 out of the 4 plots for main paper
    figure_width = 10
    figure_height = 7
    label_fontsize = 30
    tick_fontsize = 20
    for metric in ["full_rollout_mean_mse", "10_steps_mean_mse"]:
        plot(df, env="tissue_manipulation", methods=["mgn", "mgn_task_prop", "cnp_mp"],
             metric=metric,
             figure_width=figure_width + 0.25,
             figure_height=figure_height,
             contexts=None,
             mode="save",
             unique_name=f"quantitative_2_{metric}_tm",
             use_major_formatter=True,
             use_minor_formatter=True if metric == "full_rollout_mean_mse" else False,
             formatter_magnitude=-5 if metric == "full_rollout_mean_mse" else -6,
             show_legend=False,
             label_fontsize=label_fontsize,
             tick_fontsize=tick_fontsize,
             overwrite=True)
        plt.close()
        plot(df, env="mofmat", methods=["mgn", "mgn_task_prop", "cnp_mp"],
             metric=metric,
             figure_width=figure_width,
             figure_height=figure_height,
             contexts=["mesh_context_010", "mesh_context_015", "mesh_context_020", "mesh_context_025", "mesh_context_030"],
             mode="save",
             unique_name=f"quantitative_2_{metric}_mofmat",
             use_major_formatter=True,
             use_minor_formatter=True,
             formatter_magnitude=-3,
             show_legend=False,
             legend_order=[2, 1, 0],
             label_fontsize=label_fontsize,
             tick_fontsize=tick_fontsize,
             overwrite=True)
        plt.close()
        plot(df, env="teddy_fall_nopc", methods=["mgn", "mgn_task_prop", "cnp_mp"],
             metric=metric,
             figure_width=figure_width,
             figure_height=figure_height,
             contexts=["mesh_context_010", "mesh_context_015", "mesh_context_020", "mesh_context_025", "mesh_context_030"],
             mode="save",
             unique_name=f"quantitative_2_{metric}_tf",
             use_major_formatter=True,
             use_minor_formatter=True,
             formatter_magnitude=-4 if metric == "full_rollout_mean_mse" else -6,
             label_fontsize=label_fontsize,
             tick_fontsize=tick_fontsize,
             show_legend=False,
             overwrite=True)
        plt.close()


def method_ablation():
    figure_width = 12
    figure_height = 5
    label_fontsize = 28
    tick_fontsize = 17
    plot(df, env="deformable_plate_v2", methods=["cnp_mp", "mgn_mp", "cnp", "np", "np_mp"],
         metric="full_rollout_mean_mse",
         contexts=None,
         figure_width=figure_width,
         figure_height=figure_height,
         mode="save",
         unique_name="method_ablation_dpv2",
         label_fontsize=label_fontsize,
         tick_fontsize=tick_fontsize,
         use_major_formatter=True,
         use_minor_formatter=False,
         legend_order=[1, 4, 2, 3, 0],
         formatter_magnitude=-4,
         show_legend=False,
         overwrite=True)
    plt.close()
    plot(df, env="planar_bending", methods=["cnp_mp", "mgn_mp", "cnp", "np", "np_mp"],
         metric="full_rollout_mean_mse",
         contexts=None,
         figure_width=figure_width,
         figure_height=figure_height,
         mode="save",
         unique_name="method_ablation_pb",
         legend_order=[1, 4, 2, 3, 0],
         label_fontsize=label_fontsize,
         tick_fontsize=tick_fontsize,
         use_major_formatter=True,
         use_minor_formatter=False,
         formatter_magnitude=-6,
         show_legend=False,
         overwrite=True,
         )
    plt.close()


def context_ablation():
    figure_width = 12
    figure_height = 5
    label_fontsize = 28
    tick_fontsize = 17
    plot(df, env="deformable_plate_v2", methods=["cnp_mp", "abl_cnp_mp_max_max", "abl_cnp_mp_max_trafo", "abl_cnp_mp_null_trafo"],
         metric="full_rollout_mean_mse",
         contexts=None,
         figure_width=figure_width,
         figure_height=figure_height,
         mode="save",
         unique_name="context_ablation_dpv2",
         label_fontsize=label_fontsize,
         tick_fontsize=tick_fontsize,
         legend_order=[3, 0, 1, 2],
         use_major_formatter=True,
         use_minor_formatter=True,
         formatter_magnitude=-5,
         show_legend=False,
         overwrite=True)
    plt.close()
    plot(df, env="planar_bending", methods=["cnp_mp", "abl_cnp_mp_max_max", "abl_cnp_mp_max_trafo", "abl_cnp_mp_null_trafo"],
         metric="full_rollout_mean_mse",
         contexts=None,
         figure_width=figure_width,
         figure_height=figure_height,
         mode="save",
         unique_name="context_ablation_pb",
         label_fontsize=label_fontsize,
         tick_fontsize=tick_fontsize,
         use_major_formatter=True,
         use_minor_formatter=False,
         formatter_magnitude=-7,
         show_legend=False,
         overwrite=True,
         )
    plt.close()


def all_methods_figure(env="tissue_manipulation", mode="show"):
    for metric in ["full_rollout_mean_mse", "10_steps_mean_mse"]:
        plot(df, env=env, methods=["mgn", "ltsgns_mp", "ltsgns_mp_old", "mgn_task_prop", "cnp", "cnp_mp", "np", "np_mp", "mgn_mp", "ltsgns_step"],
             metric=metric,
             contexts=None,
             mode=mode,
             overwrite=True)
        # close
        plt.close()


if __name__ == "__main__":
    # load dataframe
    df = pd.read_csv(f"output/evaluation_dataframes/iclr_data_v1.10.csv", index_col=0)

    quantitative_1()
    quantitative_2()
    context_ablation()
    method_ablation()
    
