import pickle
import argparse
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
from collections import defaultdict
from bandit.best_reward import SuccessiveEliminationBME
from exp_bandit_core import run, get_game
from ex_ante_amd.simple_game import simple_single_item_matching, SimpleGame


result_path = Path("results")
result_path.mkdir(exist_ok=True)
figure_path = Path("figs")
figure_path.mkdir(exist_ok=True)


if __name__ == "__main__":

    parser = argparse.ArgumentParser(description='number of evaluations')
    parser.add_argument("-p", "--players", type=int)
    parser.add_argument("-t", "--types", type=int)
    parser.add_argument("-e", '--epsilon', type=float, default=0.25)
    parser.add_argument("-d", "--delta", type=float, default=0.1)
    parser.add_argument('--rotate', action='store_true')
    args = parser.parse_args()

    n = 4
    n_types_list = list()
    while n <= args.types:
        n_types_list.append(n)
        if args.rotate:
            n += 4
        else:
            n *= 2
    
    n = 4
    n_players_list = list()
    while n <= args.players:
        n_players_list.append(n)
        n *= 2

    n_evaluation = defaultdict(int)
    n_unique_eval = defaultdict(int)
    for n_types in n_types_list:
        type_space = np.arange(-n_types, n_types + 1)
        max_value = simple_single_item_matching(type_space)
        for n_players in n_players_list:

            seed = 42
            
            game = SimpleGame(
                n_players=n_players,
                n_types=n_types,
                type_space=type_space,
                single_item_matching=simple_single_item_matching,
                rng=np.random.RandomState(seed)
            )

            for p_idx in range(n_players):
                alg = run(game, n_players, n_types, p_idx, args.epsilon, args.delta, rerun=True)
                n_evaluation[(n_players, n_types)] += alg.get_total_n_sample()
            n_unique_eval[(n_players, n_types)] = game.get_n_evaluated()

            filename = f"bandit_{n_players}_{n_types}_{seed}_game.pkl"
            game_path = result_path.joinpath(filename)
            
            #with open(game_path, mode='wb') as fo:
            #    pickle.dump(game, fo)

    
    filename = f"n_evaluation_{args.players}_{args.types}_{args.epsilon}_{args.delta}_{args.rotate}.pkl"
    with open(result_path.joinpath(filename), mode='wb') as fo:
        pickle.dump(n_evaluation, fo)

    filename = f"n_unique_eval_{args.players}_{args.types}_{args.epsilon}_{args.delta}_{args.rotate}.pkl"
    with open(result_path.joinpath(filename), mode='wb') as fo:
        pickle.dump(n_unique_eval, fo)
    
    # Plot
    color = ["r", "g", "b", "k"]
    
    plt.rcParams["font.size"] = 16
    fig = plt.figure(figsize=(4,3))
    ax = plt.axes()
    if not args.rotate:
        ax.set_xscale("log")
    ax.set_yscale("log")
    if args.rotate:
        for i, n_players in enumerate(n_players_list):
            x = n_types_list
            y = [n_types**n_players for n_types in x]
            plt.plot(x, y, linestyle="--", marker=".", color=color[i])
            #z = [n_evaluation[(n_players, n_types)] for n_types in x]
            #plt.plot(x, z, linestyle="-", marker=".", color=color[i], label=f"{n_players} players")
            z = [n_unique_eval[(n_players, n_types)] for n_types in x]
            plt.plot(x, z, linestyle="-", marker=".", color=color[i], label=f"{n_players} players")
        plt.xlabel("Number of types")
        filename = f"bandit_trend_type_{args.players}_{args.types}_{args.epsilon}_{args.delta}.pdf"
    else:
        for i, n_types in enumerate(n_types_list):
            x = n_players_list
            y = [n_types**n_players for n_players in x]
            plt.plot(x, y, linestyle="--", marker=".", color=color[i])
            #z = [n_evaluation[(n_players, n_types)] for n_players in x]
            #plt.plot(x, z, linestyle="-", marker=".", color=color[i], label=f"{n_types} types")
            z = [n_unique_eval[(n_players, n_types)] for n_players in x]
            plt.plot(x, z, linestyle="-", marker=".", color=color[i], label=f"{n_types} types")
        plt.xlabel("Number of players")
        filename = f"bandit_trend_player_{args.players}_{args.types}_{args.epsilon}_{args.delta}.pdf"
    plt.ylabel("Number of evaluations")
    plt.legend()
    ax.set_xlim([min(x), max(x)])
    fig.savefig(figure_path.joinpath(filename), bbox_inches="tight")
