import os
import argparse
import pickle
import numpy as np
import matplotlib.pyplot as plt


def aggregate_result(base_dir, pkl_name="results_dict.pkl", reg_key="cum_regret", time_key="time_steps"):
    all_cum_regrets = list()
    all_times = list()
    for seed in sorted(os.listdir(base_dir)):
        path = os.path.join(base_dir, seed, pkl_name)

        if not os.path.exists(path):  # results_dict cannot be found
            # return None, None, None, None
            continue
        
        with open(path, "rb") as f:
            results_dict = pickle.load(f)
        all_cum_regrets.append([0] + results_dict[reg_key][1:])
        all_times.append([0] + results_dict[time_key])

    if len(all_cum_regrets) == 0:
        return None, None, None, None

    all_cum_regrets = np.array(all_cum_regrets)
    all_times = np.array(all_times)
    
    return all_cum_regrets.mean(axis=0), all_cum_regrets.std(axis=0) / np.sqrt(len(all_cum_regrets)), all_times.mean(axis=0), all_times.std(axis=0) / np.sqrt(len(all_cum_regrets))

def step_time_plots(base_dirs, label_names, title, pkl_name="results_dict.pkl", reg_key="cum_regret", time_key="time_steps", figsize=(10, 4)):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
    
    for base_dir, label in zip(base_dirs, label_names):
        mean, std, time_mean, time_std = aggregate_result(base_dir, pkl_name=pkl_name, reg_key=reg_key, time_key=time_key)
        if mean is None:  # results_dict cannot be found
            print(f"{base_dir} does not have a results pickle file.")

            # empty plot lines to ensure the colour matches for subsequent iterations
            ax1.plot([], [])
            ax1.fill_between([], [], [])
            ax2.plot([], [])
            ax2.fill_between([], [], [])

            continue

        progress = np.linspace(0, 1, num=len(time_mean)) * 100
        # progress = np.arange(len(time_mean))
        
        ax1.plot(mean, label=label, alpha=0.8)
        ax1.fill_between(np.arange(len(mean)), mean - 1.96 * std, mean + 1.96 * std, alpha=0.2)
        ax1.set_xlabel("Number of steps")
        ax1.set_ylabel("Cumulative regret")

        ax2.plot(time_mean, progress, label=label, alpha=0.8)
        ax2.fill_betweenx(progress, time_mean - 1.96 * time_std, time_mean + 1.96 * time_std, alpha=0.2)
        ax2.set_xlabel("Time (s)")
        ax2.set_ylabel("Test progress (%)")
    
    plt.suptitle(title)
    ax1.legend()
    ax2.legend()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("base_dirs", help="List of directories where results are stored (separated by commas)", type=str)
    parser.add_argument("labels", help="Labels of each base dir (separated by commas)", type=str)
    parser.add_argument("title", help="Title for entire figure", type=str)
    parser.add_argument("save_dir", help="Path to save output figure", type=str)
    parser.add_argument("--pkl_name", help="Name of pickle file", default="results_dict.pkl", type=str)
    parser.add_argument("--reg_key", help="Dictionary key for regret", default="cum_regret", type=str)
    parser.add_argument("--time_key", help="Dictionary key for time", default="time_steps", type=str)

    args = parser.parse_args()

    base_dirs = args.base_dirs.split(",")
    labels = args.labels.split(",")

    step_time_plots(base_dirs, labels, args.title, 
                    pkl_name=args.pkl_name, reg_key=args.reg_key, time_key=args.time_key)
    plt.savefig(args.save_dir, dpi=100)