import os
import pickle
import numpy as np
import matplotlib.pyplot as plt

from copy import deepcopy

def plot_regret(base_dir, label, ax):
    all_regret = list()
    for seed in sorted(os.listdir(base_dir)):
        path = os.path.join(base_dir, seed, "results_dict.pkl")

        if not os.path.exists(path):  # results_dict cannot be found
            print(f"{path} not found")
            continue
        
        with open(path, "rb") as f:
            results_dict = pickle.load(f)
        all_regret.append([0] + results_dict["cum_regret"][1:])

    if len(all_regret) == 0:
        ax.plot([], [])
        ax.fill_between([], [], [])
        return

    all_regret = np.array(all_regret)

    mean_reg = all_regret.mean(axis=0)
    std_reg = all_regret.std(axis=0) / np.sqrt(len(all_regret))

    print(base_dir, (len(mean_reg) - mean_reg[-1]) * 100 / len(mean_reg))

    ax.plot(mean_reg, label=label, alpha=0.8)
    ax.fill_between(np.arange(len(mean_reg)), mean_reg - 1.96 * std_reg, mean_reg + 1.96 * std_reg, alpha=0.2)
    ax.set_xlabel("Number of steps")
    ax.set_ylabel("Cumulative regret")

def plot_time(base_dir, label, ax):
    all_times = list()
    for seed in sorted(os.listdir(base_dir)):
        path = os.path.join(base_dir, seed, "results_dict.pkl")

        if not os.path.exists(path):  # results_dict cannot be found
            print(f"{path} not found")
            continue
        
        with open(path, "rb") as f:
            results_dict = pickle.load(f)
        all_times.append([0] + results_dict["time_steps"])

    if len(all_times) == 0:
        ax.plot([], [])
        ax.fill_between([], [], [])
        return

    all_times = np.array(all_times)

    mean_times = all_times.mean(axis=0)
    std_times = all_times.std(axis=0) / np.sqrt(len(all_times))

    progress = np.linspace(0, 1, num=len(mean_times)) * 100

    ax.plot(mean_times, progress, label=label, alpha=0.8)
    ax.fill_betweenx(progress, mean_times - 1.96 * std_times, mean_times + 1.96 * std_times, alpha=0.2)
    ax.set_xlabel("Time (s)")
    ax.set_ylabel("Test progress (%)")


def separate_results(base_dir, datasets, alg_list, label_list, kind, subplot_format, figsize=(9, 5), titles=None):
    _, axes = plt.subplots(*subplot_format, figsize=figsize)

    for i, dataset in enumerate(datasets):
        ax = axes[i // 2, i % 2] if subplot_format != (1, 1) else axes
        for alg, label in zip(alg_list, label_list):
            path = os.path.join(base_dir, f"{dataset}_{alg}")

            if kind == "regret":
                plot_regret(path, label, ax)
            else:
                plot_time(path, label, ax)

        if titles is None:
            ax.set_title(dataset)
        else:
            ax.set_title(titles[i])
        ax.legend()
    plt.tight_layout()

def get_improvement(base_dir, datasets, alg_list, c3_index=0):

    result = {alg: 0 for alg in alg_list}
    dataset_acc = {d: dict() for d in datasets}

    for alg in alg_list:
        overall_acc = list()
        for i, dataset in enumerate(datasets):
        
            dir_path = os.path.join(base_dir, f"{dataset}_{alg}")
            all_regret = list()

            for seed in sorted(os.listdir(dir_path)):
                path = os.path.join(dir_path, seed, "results_dict.pkl")

                if not os.path.exists(path):  # results_dict cannot be found
                    print(f"{path} not found")
                    continue
                
                with open(path, "rb") as f:
                    results_dict = pickle.load(f)
                all_regret.append([0] + results_dict["cum_regret"][1:])

            if len(all_regret) == 0:
                continue

            all_regret = np.array(all_regret)

            mean_reg = all_regret.mean(axis=0)
            acc = (len(mean_reg) - mean_reg[-1]) * 100 / len(mean_reg)

            overall_acc.append(acc)
            dataset_acc[dataset][alg] = acc

        result[alg] = sum(overall_acc) / len(overall_acc) if len(overall_acc) > 0 else 0
    
    improvement = 0
    for dataset in datasets:
        obj_copy = deepcopy(dataset_acc[dataset])
        c3_acc = obj_copy.pop("c3")
        max_acc = np.max(list(obj_copy.values()))

        improvement += c3_acc - max_acc
    improvement /= len(datasets)
    print(f"Improvement of C3 over the best/next best algorithm, averaged over datasets is {improvement}")
    
    return result, dataset_acc

if __name__ == "__main__":
    # base_dir = "/scratch/ssd004/scratch/wmloh/ProjectionBeta/old_results_2"

    # base_dir = "results"
    # separate_results(base_dir, ["shuttle", "magic", "cover", "mnist"],
    #                             ["c3", "linucb", "neuralucb", "lts", "neuralts", "squarecb"], 
    #                             ["$C_3$", "LinUCB", "NeuralUCB", "LinTS", "NeuralTS", "SquareCB"], 
    #                             "regret", (2, 2),
    #                             titles=["shuttle", "MagicTelescope", "covertype", "MNIST"])
    # plt.savefig(os.path.join("figures", "bandit_regret_new.png"), dpi=200)

    # separate_results("results", ["shuttle", "magic", "covertype", "mnist"],
    #                             ["c3", "lucb", "nucb", "lts", "nts"], 
    #                             ["$C_3$", "LinUCB", "NeuralUCB", "LinTS", "NeuralTS"], 
    #                             "time", (2, 2),
    #                             titles=["shuttle", "MagicTelescope", "covertype", "MNIST"])
    # plt.savefig(os.path.join("figures", "bandit_time.png"), dpi=100)

    
    # base_dir = "results"
    # separate_results(base_dir, ["mind"],
    #                             ["c3", "bayeslr", "tt_small", "tt_med"], 
    #                             ["$C_3$", "BayesLR", "Two Tower (small)", "Two Tower (large)"], 
    #                             "regret", (1, 1), titles=["MIND"], figsize=(4, 4))
    # plt.savefig(os.path.join("figures", "MIND_regret_fixed.png"), dpi=100)

    print(get_improvement("results", ["shuttle", "magic", "cover", "mnist"],
                                ["c3", "linucb", "neuralucb", "lts", "neuralts", "squarecb"]))