import json
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import os


with open("data_info.json") as f:
    data_info = json.load(f)
print(data_info["quadrotor"])
case_study_list = ["cartpole", "glucose", "quadrotor", "quadruped"]

# algorithm_list =  ["model_based", "sac_base", "sac_autosafe_exp", "sac_autosafe_linear",
#                   "sac_autosafe_opt", "sac_lag", "sac_lam_exp", "sac_lam_linear",
#                   "sac_lam_opt", "sac_lyapunov", "sac_residual", "sac_simplex"]

# all algorithms


plot_list = ["training/episodic_return_ep", "training/termination"]
plot_name_list = ["return", "safety violation"]


plot_config = {
               "smooth_span": {"cartpole": 5, "glucose": 5, "quadrotor": 5, "quadruped": 5},
               # larger span for larger smoothing

               "algo_labels": {"model_based": "Safe Policy",
                               "sac_base": "SAC",
                               "sac_autosafe_exp": "AutoSafe (exp)",
                               "sac_autosafe_linear": "AutoSafe (linear)",
                               "sac_autosafe_opt": "AutoSafe",
                               "sac_lag": "Lagrangian",
                               "sac_lam_exp": "AdaLam (exp)",
                               "sac_lam_linear": "AdaLam (linear)",
                               "sac_lam_opt": "AdaLam",
                               "sac_lyapunov": "Lyapunov",
                               "sac_residual": "Residual",
                               "sac_simplex": "Simplex",
                               "sac_autosafe_opt_cor": "AutoSafe (with corr.)"
                               },

               "env_labels": {"cartpole": "Cartpole", "glucose": "Glucose", "quadrotor": "3D Quadrotor", "quadruped": "Quadruped Navigation"},

               "y_labels": {"training/episodic_return_ep": "Performance Return",
                           "training/termination": "Safety Violations"},

               "y_limits": {"training/episodic_return_ep":
                            {"cartpole": [0, 500],
                               # "glucose":  [-1000, -50],
                               "quadrotor":  [-10, 300],
                               "quadruped": [-100, 6000],
                             },
                            # "training/termination":
                            # {"cartpole": [0, 1],
                            #    "glucose":  [-1, 1.2],
                            #    "quadrotor":  [-1, 1],
                            #    "quadruped": [-1, 1],
                            # }
                            },
               "max_episode":{"cartpole": 200,
                               "glucose":  200,
                               "quadrotor":  1000,
                               "quadruped": 180},

               "n_grid":{"cartpole": 200,
                         "glucose":  200,
                         "quadrotor":  1000,
                         "quadruped": 180,
                         },
               "text_loc": {"cartpole": [135, 20],
                           "glucose": [135, -2950],
                           "quadrotor": [20, 280],
                           "quadruped": [5, 5600]
               }

               # # color palette""


               # assign a unique color for each algo according to the color palette
               # "color_id":{"model_based": 0,
               #             "sac_base": 1,
               #             "sac_autosafe_exp": 2,
               #             "sac_autosafe_linear": 3,
               #             "sac_autosafe_opt": 4,
               #             "sac_lag": 5,
               #             "sac_lam_exp": 6,
               #             "sac_lam_linear": 7,
               #             "sac_lam_opt": 8,
               #             "sac_lyapunov": 9,
               #             "sac_residual": 10,
               #             "sac_simplex": 11},

               # "loc": {"cartpole": [0, 0],
               #         "glucose": [0, 1],
               #         "quadrotor": [0, 2],
               #         "quadruped": [1, 0],
               #         "Humanoid-v4": [1, 1]}
}

figure_name_list = ["all", "ab_autosafe", "ab_lam", "ab_cor" ]

for figure_name in figure_name_list:
    if figure_name == 'all':
        algorithm_list =  ["sac_autosafe_opt", "sac_simplex", "sac_lam_linear", "sac_residual", "sac_lyapunov", "sac_lag", "sac_base",]
    elif figure_name == 'ab_autosafe':
        # # ablation on autosafe
        algorithm_list =  ["sac_autosafe_opt", "sac_autosafe_exp", "sac_autosafe_linear" ]
    elif figure_name == 'ab_lam':
        # # ablation on adaptive lam
        algorithm_list =  ["sac_lam_opt", "sac_lam_exp", "sac_lam_linear",]
    elif figure_name == 'ab_cor':
        # # ablation on safety constraint
        algorithm_list =  ["sac_autosafe_opt", "sac_autosafe_opt_cor"]
    else:
        exit(NotImplementedError)


    sns.set_style('whitegrid')
    # statistics = {"ib": {}, "dr": {}, "roundrobin": {}, "space": {}, "spdrl": {}, "mse": {}, "augobs": {}, "pearl": {}}

    fig, axes = plt.subplots(2, 4, figsize=(18, 9))
    for raw, (plot_name, data_tag) in enumerate(zip(plot_name_list, plot_list)):

        for j, case in enumerate(case_study_list):
            span = plot_config["smooth_span"][case]
            case_info = data_info[case]
            max_episode = plot_config["max_episode"][case]
            n_grid = plot_config["n_grid"][case]

            # n_steps = round(max_episode / n_grid)
            data_horizon = np.linspace(0, max_episode, n_grid)
            num_plot_algo = 0
            for m, algo in enumerate(algorithm_list):

                num_plot_algo += 1
                data_list = case_info[algo]
                df_list = []
                algo_label = plot_config["algo_labels"][algo]

                if figure_name == 'all' and algo == 'sac_lam_linear':
                    algo_label = "AdaLam"

                for data_name in data_list:
                    data_path = os.path.join(case, algo, data_name+".csv")
                    df = pd.read_csv(data_path)

                    saved_data_tag = data_tag
                    if (case == 'quadrotor' or case =='quadruped') and data_tag == 'training/termination':
                        saved_data_tag = 'training/termination_ep'

                    if algo == "sac_autosafe_opt_cor":
                        if data_tag == "training/termination":
                            saved_data_tag = 'training/termination_ep'
                        elif data_tag == "training/episodic_return_ep":
                            saved_data_tag = 'training/episodic_return_ep'

                    filtered_df = df[df["tag"] == saved_data_tag]

                    average_df = pd.DataFrame({"value":filtered_df["value"]})
                    average_df = average_df.apply(lambda x: x.ewm(span=span).mean())  # first smooth the data using exponential weighted moving average
                    # if data_name == 'seed_2' and case == 'quadrotor' and algo == 'sac_autosafe_opt':
                    #     print("df\n", df)
                    #     print("filtered_df\n",filtered_df )
                    #     print("filtered_df_step\n", filtered_df["step"])
                    data = np.interp(data_horizon, filtered_df["step"], average_df["value"])
                    inter_data_frame = pd.DataFrame({"value": data})
                    df_list.append(inter_data_frame)

                if len(df_list) == 0:
                    continue

                if raw == 0:
                    data = {f"col_from_df{i + 1}": df_list[i]["value"] for i in range(len(df_list))}
                    data_frame = pd.DataFrame(data)
                    row_mean = data_frame.mean(axis=1)
                    row_std = data_frame.std(axis=1)

                    row_std_error = row_std / np.sqrt(len(data_frame.columns))
                    data_frame['Steps'] = data_horizon
                    data_frame['Mean'] = row_mean
                    data_frame['Std'] = row_std
                    data_frame['Std Error'] = row_std_error
                    data_frame['Mean-Std-error'] = row_mean - row_std_error
                    data_frame['Mean+Std-error'] = row_mean + row_std_error
                    color = sns.color_palette()[m]
                    line_plot = sns.lineplot(x='Steps', y='Mean', data=data_frame, ax=axes[raw, j], label=algo_label,
                                             legend=False, color=color, linewidth=2)
                    axes[raw, j].fill_between(data_frame['Steps'], data_frame['Mean-Std-error'], data_frame['Mean+Std-error'],
                                            alpha=0.2, color=color)

                    axes[raw, j].set_xlabel("Training Episodes", fontsize=17)
                    axes[raw, j].set_ylabel(plot_config["y_labels"][data_tag], fontsize=17)
                    axes[raw, j].set_title(plot_config["env_labels"][case], fontsize=15, fontweight="bold")
                    axes[raw, j].tick_params(labelsize=13)
                    axes[raw, j].set_xlim([0, max_episode])
                else:
                    data = {f"col_from_df{i + 1}": df_list[i]["value"] for i in range(len(df_list))}
                    data_frame = pd.DataFrame(data)
                    # print(data_frame.mean(axis=1))
                    row_mean = data_frame.mean(axis=1).values[-1]
                    row_std = data_frame.std(axis=1).values[-1]
                    row_std_error = row_std / np.sqrt(len(data_frame.columns))

                    color = sns.color_palette()[m]
                    axes[raw, j].errorbar(
                        [m], row_mean, yerr=row_std_error, fmt='o', color=color,
                        capsize=6, markersize=8, elinewidth=2, label=algo_label
                    )
                    axes[raw, j].set_xlabel("Methods", fontsize=17)
                    axes[raw, j].set_ylabel(plot_config["y_labels"][data_tag], fontsize=17)
                    axes[raw, j].set_title(plot_config["env_labels"][case], fontsize=15, fontweight="bold")
                    axes[raw, j].tick_params(labelsize=13)
                    axes[raw, j].set_xticks(np.array(range(num_plot_algo)), np.array(range(num_plot_algo))+1)
                try:
                    y_lims = plot_config["y_limits"][data_tag][case]
                    axes[raw, j].set_ylim([y_lims[0], y_lims[1]])
                except:
                    pass

    if figure_name == 'all':
    # add baseline performance
        for j, case in enumerate(case_study_list):
            span = plot_config["smooth_span"][case]
            case_info = data_info[case]
            max_episode = plot_config["max_episode"][case]
            n_grid = plot_config["n_grid"][case]
            data_tag = "evaluation/episodic_return"
            # n_steps = round(max_episode / n_grid)
            data_horizon = np.linspace(0, max_episode, n_grid)
            raw = 0
            for m, algo in enumerate(["model_based"]):
                data_list = case_info[algo]
                df_list = []
                for data_name in data_list:
                    data_path = os.path.join(case, algo, data_name+".csv")
                    df = pd.read_csv(data_path)
                    saved_data_tag = data_tag
                    if (case == 'quadrotor' or case =='quadruped') and data_tag == 'training/termination':
                        saved_data_tag = 'training/termination_ep'
                    filtered_df = df[df["tag"] == saved_data_tag]
                    average_df = pd.DataFrame({"value":filtered_df["value"]})
                    average_df = average_df.apply(lambda x: x.ewm(span=span).mean())  # first smooth the data using exponential weighted moving average
                    data = np.interp(data_horizon, filtered_df["step"], average_df["value"])
                    inter_data_frame = pd.DataFrame({"value": data})
                    df_list.append(inter_data_frame)

                if len(df_list) == 0:
                    continue
                data = {f"col_from_df{i + 1}": df_list[i]["value"] for i in range(len(df_list))}
                data_frame = pd.DataFrame(data)

                row_mean = data_frame.mean(axis=1)
                row_std = data_frame.std(axis=1)

                row_std_error = row_std / np.sqrt(len(data_frame.columns))
                data_frame['Steps'] = data_horizon
                data_frame['Mean'] = row_mean
                data_frame['Std'] = row_std
                data_frame['Std Error'] = row_std_error
                data_frame['Mean-Std-error'] = row_mean - row_std_error
                data_frame['Mean+Std-error'] = row_mean + row_std_error
                color = sns.color_palette()[-1]

                line_plot = sns.lineplot(x='Steps', y='Mean', data=data_frame, ax=axes[raw, j], label=plot_config["algo_labels"][algo], legend=False, color=color, linewidth=2, linestyle='--')
                axes[raw, j].text(
                    plot_config["text_loc"][case][0], plot_config["text_loc"][case][1],  # x, y coordinates
                    "Safe Policy Prior",  # text
                    fontsize=9, color=color,  # text style (color matches the line)
                    bbox=dict(facecolor='white', edgecolor=color, boxstyle='round, pad=0.3')
                )
                axes[raw, j].fill_between(data_frame['Steps'], data_frame['Mean-Std-error'], data_frame['Mean+Std-error'], alpha=0.2, color=color)

    num_methods = len(algorithm_list)
    plt.tight_layout()

    # leg = axes[0, 0].legend(loc="lower center",  bbox_to_anchor=(2.5, -1.70), fancybox=True, shadow=True, ncol=num_methods, fontsize=18)

    handles, labels = axes[0, 0].get_legend_handles_labels()
    filtered = [(h, l) for h, l in zip(handles, labels) if l != plot_config["algo_labels"]["model_based"]]

    # Unzip into separate lists
    handles_to_show, labels_to_show = zip(*filtered)
    leg = axes[0, 0].legend(handles_to_show, labels_to_show, loc="lower center",  bbox_to_anchor=(2.5, -1.70), fancybox=True, shadow=True, ncol=num_methods, fontsize=18)

    # axes[1, 2].set_visible(False)
    for line in leg.get_lines():
        line.set_linewidth(4.0)
    plt.subplots_adjust(bottom=0.15)
    plt.savefig(f"{figure_name}.pdf", dpi=300)
