import argparse
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
from exp_bandit_core import get_game, run


figure_path = Path("figs")
figure_path.mkdir(exist_ok=True)


def plot(n_players, n_types, player_idx, epsilon, delta):

    game = get_game(n_players, n_types)

    alg = run(game, n_players, n_types, player_idx, epsilon, delta)    

    plt.rcParams["font.size"] = 16
    fig = plt.figure(figsize=(4,3))
    ax = plt.axes()
    ax.set_xscale("log")
    
    max_x = max([v[-1] for v in alg._n.values()])
    y = [game._get_conditional_value(player_idx, t_idx) for t_idx in alg._sample]
    
    rank = np.argsort(np.argsort(y))
    color = list()
    for r in rank:
        if r == 0:
            color.append("r")
        else:
            color.append(str(r / len(y)))
    
    max_y = -np.inf
    min_y = np.inf
    for t_idx in alg._sample:
        cumsum = np.cumsum(alg._sample[t_idx])
        ave = -cumsum / np.arange(1, len(cumsum) + 1)
        print(t_idx, min(ave), max(ave))
        ax.semilogx(alg._n[t_idx], ave, label=t_idx, color=color[t_idx])
        ax.semilogx([1, max_x], [y[t_idx]]*2, linestyle=":", color=color[t_idx])
        max_y = max([max_y, max(ave)])
        min_y = min([min_y, min(ave)])
    ax.set_ylim([0.99 * min_y, 1.01 * max_y])
    ax.set_xlim([1, max_x])
    ax.set_xlabel("Sample size")
    ax.set_ylabel("Estimated value")
    filename = f"bandit_{n_players}_{n_types}_{player_idx}_{epsilon}_{delta}.pdf"
    fig.savefig(figure_path.joinpath(filename), bbox_inches="tight")


if __name__ == "__main__":

    parser = argparse.ArgumentParser(description='best arm identification vs. best mean-reward estimation')
    parser.add_argument("-p", "--players", type=int)
    parser.add_argument("-t", "--types", type=int)
    parser.add_argument("-e", '--epsilon', type=float, default=0.5)
    parser.add_argument("-d", "--delta", type=float, default=0.1)
    parser.add_argument('--run', action='store_true')
    parser.add_argument('--plot', action='store_true')
    args = parser.parse_args()

    for p_idx in range(args.players):
        
        if args.run:
            game = get_game(args.players, args.types)
            run(game, args.players, args.types, p_idx, args.epsilon, args.delta)

        if args.plot:
            plot(args.players, args.types, p_idx, args.epsilon, args.delta)
