import json
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import os


with open("data_info.json") as f:
    data_info = json.load(f)
case_study_list = ["cartpole", "glucose", "quadrotor", "quadruped"]

# algorithm_list =  ["model_based", "sac_base", "sac_autosafe_exp", "sac_autosafe_linear",
#                   "sac_autosafe_opt", "sac_lag", "sac_lam_exp", "sac_lam_linear",
#                   "sac_lam_opt", "sac_lyapunov", "sac_residual", "sac_simplex"]

# all algorithms

plot_list = ["training/episodic_return_ep"]
plot_name_list = ["return"]


plot_config = {
               "smooth_span": {"cartpole": 5, "glucose": 5, "quadrotor": 5, "quadruped": 5},
               # larger span for larger smoothing

               "algo_labels": {"model_based": "Safe Policy",
                               "sac_base": "SAC",
                               "sac_autosafe_exp": "AutoSafe (exp)",
                               "sac_autosafe_linear": "AutoSafe (linear)",
                               "sac_autosafe_opt": "AutoSafe",
                               "sac_lag": "Lagrangian",
                               "sac_lam_exp": "AdaLam (exp)",
                               "sac_lam_linear": "AdaLam (linear)",
                               "sac_lam_opt": "AdaLam",
                               "sac_lyapunov": "Lyapunov",
                               "sac_residual": "Residual",
                               "sac_simplex": "Simplex",
                               },

               "env_labels": {"cartpole": "Cartpole", "glucose": "Glucose", "quadrotor": "3D Quadrotor", "quadruped": "Quadruped Navigation"},

               "y_labels": {"training/episodic_return_ep": "Performance Return",
                           "training/termination": "Safety Violations"},

               "y_limits": {"training/episodic_return_ep":
                            {"cartpole": [0, 500],
                               # "glucose":  [-1000, -50],
                               "quadrotor":  [-10, 300],
                               "quadruped": [-100, 6000],
                             },
                            # "training/termination":
                            # {"cartpole": [0, 1],
                            #    "glucose":  [-1, 1.2],
                            #    "quadrotor":  [-1, 1],
                            #    "quadruped": [-1, 1],
                            # }
                            },
               "max_episode":{"cartpole": 200,
                               "glucose":  200,
                               "quadrotor":  1000,
                               "quadruped": 200},

               "n_grid":{"cartpole": 200,
                         "glucose":  200,
                         "quadrotor":  1000,
                         "quadruped": 200,
                         },


               # assign a unique color for each algo according to the color palette
               # "color_id":{"model_based": 0,
               #             "sac_base": 1,
               #             "sac_autosafe_exp": 2,
               #             "sac_autosafe_linear": 3,
               #             "sac_autosafe_opt": 4,
               #             "sac_lag": 5,
               #             "sac_lam_exp": 6,
               #             "sac_lam_linear": 7,
               #             "sac_lam_opt": 8,
               #             "sac_lyapunov": 9,
               #             "sac_residual": 10,
               #             "sac_simplex": 11},

               # "loc": {"cartpole": [0, 0],
               #         "glucose": [0, 1],
               #         "quadrotor": [0, 2],
               #         "quadruped": [1, 0],
               #         "Humanoid-v4": [1, 1]}
}

figure_name_list = ["performance_comparison"]

for figure_name in figure_name_list:
    if figure_name == 'performance_comparison':
        algorithm_list =  ["sac_autosafe_opt", "model_based"]
    else:
        exit(NotImplementedError)


    sns.set_style('whitegrid')
    # statistics = {"ib": {}, "dr": {}, "roundrobin": {}, "space": {}, "spdrl": {}, "mse": {}, "augobs": {}, "pearl": {}}

    fig, axes = plt.subplots(1, 4, figsize=(18, 4.5))
    for raw, (plot_name, data_tag) in enumerate(zip(plot_name_list, plot_list)):

        for j, case in enumerate(case_study_list):
            span = plot_config["smooth_span"][case]
            case_info = data_info[case]
            max_episode = plot_config["max_episode"][case]
            n_grid = plot_config["n_grid"][case]

            # n_steps = round(max_episode / n_grid)
            data_horizon = np.linspace(0, max_episode, n_grid)
            num_plot_algo = 0

            for m, algo in enumerate(algorithm_list):
                num_plot_algo += 1
                data_list = case_info[algo]
                df_list = []

                for data_name in data_list:
                    data_path = os.path.join(case, algo, data_name+".csv")
                    df = pd.read_csv(data_path)

                    saved_data_tag = data_tag
                    if algo == "model_based":
                        saved_data_tag = "evaluation/episodic_return"

                    filtered_df = df[df["tag"] == saved_data_tag]
                    average_df = pd.DataFrame({"value":filtered_df["value"]})
                    average_df = average_df.apply(lambda x: x.ewm(span=span).mean())  # first smooth the data using exponential weighted moving average
                    data = np.interp(data_horizon, filtered_df["step"], average_df["value"])
                    inter_data_frame = pd.DataFrame({"value": data})
                    df_list.append(inter_data_frame)

                if len(df_list) == 0:
                    continue

                if raw == 0:
                    data = {f"col_from_df{i + 1}": df_list[i]["value"] for i in range(len(df_list))}
                    data_frame = pd.DataFrame(data)
                    row_mean = data_frame.mean(axis=1)
                    row_std = data_frame.std(axis=1)
                    row_std_error = row_std / np.sqrt(len(data_frame.columns))
                    data_frame['Steps'] = data_horizon
                    data_frame['Mean'] = row_mean
                    data_frame['Std'] = row_std
                    data_frame['Std Error'] = row_std_error
                    data_frame['Mean-Std-error'] = row_mean - row_std_error
                    data_frame['Mean+Std-error'] = row_mean + row_std_error
                    color = sns.color_palette()[m]
                    if algo == "model_based":
                        line_plot = sns.lineplot(x='Steps', y='Mean', data=data_frame, ax=axes[j], label=plot_config["algo_labels"][algo],
                                                 legend=False, color=color, linewidth=2, linestyle='--')
                    else:
                        line_plot = sns.lineplot(x='Steps', y='Mean', data=data_frame, ax=axes[j], label=plot_config["algo_labels"][algo],
                                                 legend=False, color=color, linewidth=2,)

                    axes[j].fill_between(data_frame['Steps'], data_frame['Mean-Std-error'], data_frame['Mean+Std-error'],
                                            alpha=0.2, color=color)

                    axes[j].set_xlabel("Training Episodes", fontsize=17)
                    axes[j].set_ylabel(plot_config["y_labels"][data_tag], fontsize=17)
                    axes[j].set_title(plot_config["env_labels"][case], fontsize=15, fontweight="bold")
                    axes[j].tick_params(labelsize=13)
                    axes[j].set_xlim([0, max_episode])
                try:
                    y_lims = plot_config["y_limits"][data_tag][case]
                    axes[j].set_ylim([y_lims[0], y_lims[1]])
                except:
                    pass

    num_methods = len(algorithm_list)
    plt.tight_layout()
    # leg = axes[0].legend(loc="lower center",  bbox_to_anchor=(2.5, -1.70), fancybox=True, shadow=True, ncol=num_methods, fontsize=18)
    leg = axes[0].legend(fancybox=True, shadow=True, fontsize=16, facecolor='lightgray')

    # axes[1, 2].set_visible(False)
    for line in leg.get_lines():
        line.set_linewidth(4.0)
    plt.subplots_adjust(bottom=0.15)
    plt.savefig(f"{figure_name}.pdf", dpi=300)
