import json
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import os


with open("data_info.json") as f:
    data_info = json.load(f)

case_study_list = ["cartpole", "glucose", "quadrotor", "quadruped"]

# algorithm_list =  ["model_based", "sac_base", "sac_autosafe_exp", "sac_autosafe_linear",
#                   "sac_autosafe_opt", "sac_lag", "sac_lam_exp", "sac_lam_linear",
#                   "sac_lam_opt", "sac_lyapunov", "sac_residual", "sac_simplex"]

# all algorithms


plot_list = ["agent/log_pi", "agent/tem"]
plot_name_list = ["Log Prob.", "Evolution of T"]


plot_config = {
               "smooth_span": {"cartpole": 5, "glucose": 5, "quadrotor": 5, "quadruped": 5},
               # larger span for larger smoothing

               "algo_labels": {"model_based": "Safe Policy",
                               "sac_base": "SAC",
                               "sac_autosafe_exp": "AutoSafe (exp)",
                               "sac_autosafe_linear": "AutoSafe (linear)",
                               "sac_autosafe_opt": "AutoSafe",
                               "sac_lag": "Lagrangian",
                               "sac_lam_exp": "AdaLam (exp)",
                               "sac_lam_linear": "AdaLam (linear)",
                               "sac_lam_opt": "AdaLam",
                               "sac_lyapunov": "Lyapunov",
                               "sac_residual": "Residual",
                               "sac_simplex": "Simplex",
                               "sac_autosafe_opt_cor": "AutoSafe (with corr.)"
                               },

               "env_labels": {"cartpole": "Cartpole", "glucose": "Glucose", "quadrotor": "3D Quadrotor", "quadruped": "Quadruped Navigation"},

               "y_labels": {"agent/log_pi": "Log Prob.",
                           "agent/tem": "Temperature $T$"},

               "y_limits": {"training/episodic_return_ep":
                            {"cartpole": [0, 500],
                               # "glucose":  [-1000, -50],
                               "quadrotor":  [-10, 300],
                               "quadruped": [-100, 6000],
                             },
                            # "training/termination":
                            # {"cartpole": [0, 1],
                            #    "glucose":  [-1, 1.2],
                            #    "quadrotor":  [-1, 1],
                            #    "quadruped": [-1, 1],
                            # }
                            },
               "max_steps":{"cartpole": 200000,
                               "glucose":  100000,
                               "quadrotor":  1000000,
                               "quadruped": 500000},

               "n_grid":{"cartpole": 200,
                         "glucose":  200,
                         "quadrotor":  200,
                         "quadruped": 200,
                         },


               # assign a unique color for each algo according to the color palette
               # "color_id":{"model_based": 0,
               #             "sac_base": 1,
               #             "sac_autosafe_exp": 2,
               #             "sac_autosafe_linear": 3,
               #             "sac_autosafe_opt": 4,
               #             "sac_lag": 5,
               #             "sac_lam_exp": 6,
               #             "sac_lam_linear": 7,
               #             "sac_lam_opt": 8,
               #             "sac_lyapunov": 9,
               #             "sac_residual": 10,
               #             "sac_simplex": 11},

               # "loc": {"cartpole": [0, 0],
               #         "glucose": [0, 1],
               #         "quadrotor": [0, 2],
               #         "quadruped": [1, 0],
               #         "Humanoid-v4": [1, 1]}
}

figure_name_list = ["log_pi_tem"]

for figure_name in figure_name_list:
    if figure_name == "log_pi_tem":
        algorithm_list = ["sac_autosafe_opt", "sac_autosafe_opt_cor"]
    else:
        exit(NotImplementedError)


    sns.set_style('whitegrid')
    # statistics = {"ib": {}, "dr": {}, "roundrobin": {}, "space": {}, "spdrl": {}, "mse": {}, "augobs": {}, "pearl": {}}

    fig, axes = plt.subplots(2, 4, figsize=(18, 9))
    for raw, (plot_name, data_tag) in enumerate(zip(plot_name_list, plot_list)):

        for j, case in enumerate(case_study_list):
            span = plot_config["smooth_span"][case]
            case_info = data_info[case]
            max_episode = plot_config["max_steps"][case]
            n_grid = plot_config["n_grid"][case]

            # n_steps = round(max_episode / n_grid)
            data_horizon = np.linspace(0, max_episode, n_grid)
            num_plot_algo = 0
            for m, algo in enumerate(algorithm_list):
                if algo == 'model_based':
                    continue
                num_plot_algo += 1
                data_list = case_info[algo]
                df_list = []

                for data_name in data_list:
                    data_path = os.path.join(case, algo, data_name+".csv")
                    df = pd.read_csv(data_path)

                    saved_data_tag = data_tag
                    if (case == 'quadrotor' or case =='quadruped') and data_tag == 'training/termination':
                        saved_data_tag = 'training/termination_ep'

                    filtered_df = df[df["tag"] == saved_data_tag]
                    average_df = pd.DataFrame({"value":filtered_df["value"]})  # clip extreme values to avoid large spikes

                    if data_tag == 'agent/actor_loss':
                        average_df = pd.DataFrame({"value": -1 * filtered_df["value"]})  # clip extreme values to avoid large spikes

                    average_df = average_df.apply(lambda x: x.ewm(span=span).mean())  # first smooth the data using exponential weighted moving average
                    # if data_name == 'seed_2' and case == 'quadrotor' and algo == 'sac_autosafe_opt':
                    #     print("df\n", df)
                    #     print("filtered_df\n",filtered_df )
                    #     print("filtered_df_step\n", filtered_df["step"])
                    data = np.interp(data_horizon, filtered_df["step"], average_df["value"])
                    inter_data_frame = pd.DataFrame({"value": data})
                    df_list.append(inter_data_frame)

                if len(df_list) == 0:
                    continue

                data = {f"col_from_df{i + 1}": df_list[i]["value"] for i in range(len(df_list))}
                data_frame = pd.DataFrame(data)

                row_mean = data_frame.mean(axis=1)
                row_std = data_frame.std(axis=1)

                row_std_error = row_std / np.sqrt(len(data_frame.columns))
                data_frame['Steps'] = data_horizon
                data_frame['Mean'] = row_mean
                data_frame['Std'] = row_std
                data_frame['Std Error'] = row_std_error
                data_frame['Mean-Std-error'] = row_mean - row_std_error
                data_frame['Mean+Std-error'] = row_mean + row_std_error
                color = sns.color_palette()[m]
                line_plot = sns.lineplot(x='Steps', y='Mean', data=data_frame, ax=axes[raw, j], label=plot_config["algo_labels"][algo],
                                         legend=False, color=color, linewidth=2,)
                axes[raw, j].fill_between(data_frame['Steps'], data_frame['Mean-Std-error'], data_frame['Mean+Std-error'],
                                        alpha=0.2, color=color)

                axes[raw, j].set_xlabel("Training Steps", fontsize=17)
                axes[raw, j].ticklabel_format(style='sci', axis='x', scilimits=(0, 0))
                axes[raw, j].ticklabel_format(style='sci', axis='y', scilimits=(0, 0))
                axes[raw, j].set_ylabel(plot_config["y_labels"][data_tag], fontsize=17)
                axes[raw, j].set_title(plot_config["env_labels"][case], fontsize=15, fontweight="bold")
                axes[raw, j].tick_params(labelsize=13)
                axes[raw, j].set_xlim([0, max_episode])
                try:
                    y_lims = plot_config["y_limits"][data_tag][case]
                    axes[raw, j].set_ylim([y_lims[0], y_lims[1]])
                except:
                    pass

    num_methods = len(algorithm_list)
    plt.tight_layout()
    leg = axes[0, 0].legend(loc="lower center",  bbox_to_anchor=(2.5, -1.75), fancybox=True, shadow=True, ncol=num_methods, fontsize=18)

    # axes[1, 2].set_visible(False)
    for line in leg.get_lines():
        line.set_linewidth(4.0)
    plt.subplots_adjust(bottom=0.15)
    plt.savefig(f"{figure_name}.pdf", dpi=300)
