import os
from typing import List

import matplotlib
import pandas as pd
import pylab
import scipy
import seaborn as sns
from matplotlib import pyplot as plt
from matplotlib.ticker import LogLocator, AutoMinorLocator, ScalarFormatter

from scripts.evaluation.plotting.configs import names
from scripts.evaluation.plotting.configs import styles
from scripts.evaluation.plotting.configs.names import env_names



class MagnitudeFormatter(matplotlib.ticker.ScalarFormatter):
    def __init__(self, exponent=None):
        super().__init__()
        self._fixed_exponent = exponent

    def _set_order_of_magnitude(self):
        if self._fixed_exponent:
            self.orderOfMagnitude = self._fixed_exponent
        else:
            super()._set_order_of_magnitude()

    def _set_format(self):
        self.format = "%1.1f"


def plot(data: pd.DataFrame, env: str, methods: List[str],
         figure_width: int = 10,
         figure_height: int = 5,
         context_size: int = 2,
         mode: str = "show",
         formatter_magnitude: int = -4,
         use_major_formatter: bool = False,
         use_minor_formatter: bool = False,
         unique_name: str = None,
         y_top_limit: float | None = None,
         y_bottom_limit: float | None = None,
         legend_order: List[int] | None = None,
         label_fontsize: int = 19,
         tick_fontsize: int = 16,
         linewidth: int = 3,
         border_linewidth: int = 3,
         border_color='darkgray',
         show_legend: bool = True,
         overwrite: bool = False):
    # filter data to only contain the specified env, methods and context size
    data = data[data["env"] == env]
    data = data[data["method"].isin(methods)]
    data = data[data["context_size"] == context_size]
    # sort by time_step
    data = data.sort_values(by="time_step")
    # call seaborn
    sns.set_theme(rc={'figure.figsize': (figure_width, figure_height)})
    sns.set_style("whitegrid")
    plt.rcParams.update({"ytick.left": True})
    sns.lineplot(data=data, x="time_step", y="error_over_time",
                 hue="method",
                 style="method",
                 palette=styles.method_colors_reduced,
                 # markers=styles.method_markers_reduced,
                 # markersize=10,
                 dashes=False,
                 # errorbar=("pi", 75),
                 errorbar="ci",
                 # n_boot=1000,
                 estimator=lambda scores: scipy.stats.trim_mean(scores, proportiontocut=0.0, axis=None),  # IQM
                 linewidth=linewidth,
                 )

    ax = plt.gca()

    # ticks and labels appearance
    # x_labels = data["context"].unique()
    # # only get numbers from context
    # x_labels = ["".join([c for c in x if c.isdigit()]) for x in x_labels]
    # # remove trailing zeros
    # x_labels = [str(int(x)) for x in x_labels]

    # plt.xticks(ticks=range(len(x_labels)), labels=x_labels)
    plt.xlabel("Time Steps", fontdict={'size': label_fontsize})
    plt.ylabel("Error over Time", fontdict={'size': label_fontsize})

    # Set the fontsize  and colors for the numbers on the ticks and the offset text.
    ax.tick_params(axis='both', which='major', labelsize=tick_fontsize, colors=border_color, labelcolor="black")
    ax.tick_params(axis='both', which='minor', labelsize=tick_fontsize, colors=border_color, labelcolor="black")
    ax.yaxis.get_offset_text().set_fontsize(tick_fontsize)

    # Ticks number formatting and y scale
    # plt.yscale('log')
    fmt = MagnitudeFormatter(formatter_magnitude)
    if use_major_formatter:
        ax.yaxis.set_major_formatter(fmt)
    if use_minor_formatter:
        ax.yaxis.set_minor_formatter(fmt)

    # y limits
    if y_top_limit is not None:
        plt.ylim(top=y_top_limit)
    if y_bottom_limit is not None:
        plt.ylim(bottom=y_bottom_limit)

    # border
    plt.gca().spines['bottom'].set_linewidth(border_linewidth)
    plt.gca().spines['left'].set_linewidth(border_linewidth)
    plt.gca().spines['top'].set_linewidth(border_linewidth)
    plt.gca().spines['right'].set_linewidth(border_linewidth)
    plt.gca().spines['right'].set_color(border_color)
    plt.gca().spines['top'].set_color(border_color)
    plt.gca().spines['bottom'].set_color(border_color)
    plt.gca().spines['left'].set_color(border_color)

    # title
    plt.title(env_names[env], fontdict={'size': label_fontsize})

    # legend
    if show_legend:
        # Get the handles and labels of the current axes
        if legend_order is not None:
            handles, labels = plt.gca().get_legend_handles_labels()
            legend = plt.legend([handles[idx] for idx in legend_order], [labels[idx] for idx in legend_order])
        else:
            legend = plt.legend()

        # Translate the labels using the dictionary
        method_names = names.method_names_reduced
        for text in legend.get_texts():
            original_label = text.get_text()
            translated_label = method_names.get(original_label, original_label)
            text.set_text(translated_label)
            text.set_fontsize(tick_fontsize)
        # Set the legend title
        legend.set_title('Methods')
        legend.get_title().set_fontsize(label_fontsize)
    else:
        ax.get_legend().remove()

    # save or show
    if mode == "save":
        if unique_name is not None:
            file_name = unique_name
        else:
            method_filename = "__".join(sorted(methods))
            file_name = f"Context_size_{context_size}___{method_filename}"
        out_path = f"output/error_over_time_figures/{env}/{file_name}.pdf"
        os.makedirs(os.path.dirname(out_path), exist_ok=True)
        # check if file exists and only write if overwrite is true
        if os.path.isfile(out_path) and not overwrite:
            print(f"File {out_path} already exists. Skipping...")
        else:
            # set tight layout explicit in savefig
            plt.savefig(out_path, bbox_inches='tight', pad_inches=0, )
            print(f"Saved figure to {out_path}")
    elif mode == "show":
        plt.show()
    else:
        raise ValueError(f"Unknown mode {mode}")


def standard_error_over_time_figure(env="deformable_plate_v2", mode="show"):
    # for debugging and general simple plots
    # for metric in ["full_rollout_mean_mse", "10_steps_mean_mse", "20_steps_mean_mse", "30_steps_mean_mse", "50_steps_mean_mse", "80_steps_mean_mse", "100_steps_mean_mse", "120_steps_mean_mse", "150_steps_mean_mse"]:
    for context_size in sorted(df[df["env"] == env]["context_size"].unique()):
        plot(df, env=env, methods=['cnp_mp', 'mgn', 'mgn_task_prop', 'egno'],
             context_size=context_size,
             mode=mode,
             overwrite=True,
             show_legend=False)
        # close
        plt.close()


if __name__ == "__main__":
    # load dataframe
    # df = get_dataset(neurips=True)

    df = pd.read_csv(f"output/error_over_time_dataframes/error_over_time_data_v1.csv")


    standard_error_over_time_figure(env="deformable_plate_v2", mode="save")
    standard_error_over_time_figure(env="planar_bending", mode="save")
    standard_error_over_time_figure(env="planar_bending_oode", mode="save")
    standard_error_over_time_figure(env="tissue_manipulation", mode="save")
    standard_error_over_time_figure(env="teddy_fall_nopc", mode="save")
    standard_error_over_time_figure(env="mofmat", mode="save")
