import os
import pickle
import numpy as np
import matplotlib.pyplot as plt

from copy import deepcopy

def plot_mind_cumregret(base_dir, alg_list, label_list, figsize=(9, 5), 
                         titles=None, conf=1):
    fig, ax = plt.subplots(figsize=figsize)

    all_mean_regrets = np.zeros((len(alg_list), 500)) + np.inf
    all_std_regrets = np.zeros((len(alg_list), 500)) + np.inf

    for i, (alg, label) in enumerate(zip(alg_list, label_list)):
        path = os.path.join(base_dir, f"mind_{alg}")

        all_regret = list()
        for seed in sorted(os.listdir(path)):
            pkl_path = os.path.join(path, seed, "results_dict.pkl")

            if not os.path.exists(path):  # results_dict cannot be found
                print(f"{path} not found")
                continue
            
            with open(pkl_path, "rb") as f:
                results_dict = pickle.load(f)
            all_regret.append(results_dict["cum_regret"][1:])

        all_regret = np.array(all_regret)

        mean_reg = all_regret.mean(axis=0)
        std_reg = all_regret.std(axis=0) / np.sqrt(len(all_regret))

        all_mean_regrets[i] = mean_reg
        all_std_regrets[i] = std_reg

        print(path, (len(mean_reg) - mean_reg[-1]) * 100 / len(mean_reg))

    diff_reg = np.concatenate((np.zeros((len(alg_list), 1)), np.diff(all_mean_regrets)), axis=-1)

    diff_reg -= diff_reg.min(axis=0).reshape(1, -1)

    # selector = diff_reg.min(axis=0) < 0.45
    # selector = diff_reg.mean(axis=0) < 0.45
    # selector = diff_reg[0] < 0.65
    # selector = diff_reg.max(axis=0) < 0.5

    # print(np.around(diff_reg[:, 200:210], 2))
    
    # mean_filter = all_mean_regrets[:, selector]
    # std_filter = all_std_regrets[:, selector]

    # mean_filter = np.cumsum(diff_reg[:, selector], axis=-1)
    # std_filter = all_std_regrets[:, selector]

    mean_filter = np.cumsum(diff_reg, axis=-1)
    std_filter = all_std_regrets

    print(mean_filter[:, -1])

    for i, (alg, label) in enumerate(zip(alg_list, label_list)):
        mean_reg = mean_filter[i]
        std_reg = std_filter[i]

        ax.plot(mean_reg, label=label, alpha=0.8)
        ax.fill_between(np.arange(len(mean_reg)), mean_reg - conf * std_reg, mean_reg + conf * std_reg, alpha=0.2)
        ax.set_xlabel("Number of steps")
        ax.set_ylabel("Cumulative regret")

    if titles is None:
        ax.set_title("MIND")
    else:
        ax.set_title(titles)
    ax.legend()
    # handles, labels = ax.get_legend_handles_labels()

    # fig.legend(handles, labels, loc='outside center right', bbox_to_anchor=bbox_to_anchor, frameon=False)
    # plt.tight_layout(rect=rect) 
    plt.tight_layout()

if __name__ == "__main__":
    base_dir = "/scratch/ssd004/scratch/wmloh/ProjectionBeta/old_results_2"
    plot_mind_cumregret(base_dir, ["c3", "bayeslr", "tt_small", "tt_med"], 
                                ["$C_3$", "BayesLR", "Two Tower (small)", "Two Tower (large)"], 
                                titles="MIND", figsize=(4, 3.5))
    plt.savefig(os.path.join("figures", "MIND_regret_fixed_new.png"), dpi=200, pad_inches=0, bbox_inches='tight')