import numpy as np
import matplotlib.pyplot as plt
from bandit.mab import *
from bandit.best_arm import *
from bandit.best_reward import *
import argparse
import pickle
from pathlib import Path


def run(epsilon, delta, repeat=10, ns=[2**n for n in range(1, 11)], seed=42):

    result_path = Path("results")
    result_path.mkdir(exist_ok=True)

    filename_val = f"succelim_{epsilon}_{delta}_result_val.pkl"
    filename_num = f"succelim_{epsilon}_{delta}_result_num.pkl"

    result_num = {
        "Best Arm": defaultdict(list),
        "Best Reward": defaultdict(list)
    }
    result_val = {
        "Best Arm": defaultdict(list),
        "Best Reward": defaultdict(list)
    }
    
    for n_arms in ns:
        means = (np.arange(0, n_arms) + 0.5) / n_arms
        means = sorted(means, reverse=True)

        # Best Arm Identification
        rng = np.random.default_rng(seed)
        env = BernoulliMAB(means, rng=rng)
    
        bai = SuccessiveElimination(epsilon, delta)
        for _ in range(repeat):
            best_arm = bai.run(env)
            result_val["Best Arm"][n_arms].append(best_arm)
            result_num["Best Arm"][n_arms].append(bai.get_total_n_sample())

        # Best Mean-reward Estimation
        rng = np.random.default_rng(seed)
        env = BernoulliMAB(means, rng=rng)

        bme = SuccessiveEliminationBME(epsilon, delta)
        for _ in range(repeat):
            best_reward = bme.run(env)
            result_val["Best Reward"][n_arms].append(best_reward)
            result_num["Best Reward"][n_arms].append(bme.get_total_n_sample())

    # Save results
    with open(result_path.joinpath(filename_val), mode='wb') as fo:
      pickle.dump(result_val, fo)
    
    with open(result_path.joinpath(filename_num), mode='wb') as fo:
      pickle.dump(result_num, fo)


def plot(epsilon, delta, ns=[2**n for n in range(1, 11)]):

    figure_path = Path("figs")
    figure_path.mkdir(exist_ok=True)

    result_path = Path("results")
    filename_val = f"succelim_{epsilon}_{delta}_result_val.pkl"
    filename_num = f"succelim_{epsilon}_{delta}_result_num.pkl"

    with open(result_path.joinpath(filename_val), mode='br') as fi:
        result_val = pickle.load(fi)

    with open(result_path.joinpath(filename_num), mode='br') as fi:
        result_num = pickle.load(fi)

    linestyle = {
        "Best Arm": "--",
        "Best Reward": "-"
    }
    color = {
        "Best Arm": "k",
        "Best Reward": "r"
    }
    label = {
        "Best Arm": "Arm",
        "Best Reward": "Mean"
    }
    
    plt.rcParams["font.size"] = 16
    fig = plt.figure(figsize=(4,3))
    ax = plt.axes()
    ax.set_xscale("log")
    ax.set_yscale("log")
    for method in result_num:
        means = [np.mean(result_num[method][n]) for n in ns]
        stds = [np.std(result_num[method][n]) for n in ns]
        ax.errorbar(ns, means, yerr=stds, label=label[method], marker=".", color=color[method], linestyle=linestyle[method])
    ax.legend(loc=4)
    ax.set_xlim([min(ns), max(ns)])
    ax.set_ylim([100, 10**7])
    ax.set_xlabel("Number of arms")
    ax.set_ylabel("Total sample size")
    filename = f"succelim_{epsilon}_{delta}.pdf"
    fig.savefig(figure_path.joinpath(filename), bbox_inches="tight")


if __name__ == "__main__":

    parser = argparse.ArgumentParser(description='best arm identification vs. best mean-reward estimation')
    parser.add_argument("-e", '--epsilon', type=float)
    parser.add_argument("-d", "--delta", type=float)
    parser.add_argument('--run', action='store_true')
    parser.add_argument('--plot', action='store_true')
    args = parser.parse_args()

    if args.run:
        run(args.epsilon, args.delta)

    if args.plot:
        plot(args.epsilon, args.delta)
