import argparse
import pickle
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
from collections import defaultdict
from exp_bandit_core import get_game, save_learning_results, load_learning_results


result_path = Path("results")
result_path.mkdir(exist_ok=True)
figure_path = Path("figs")
figure_path.mkdir(exist_ok=True)


if __name__ == "__main__":

    parser = argparse.ArgumentParser(description='number of evaluations')
    parser.add_argument("-p", "--players", type=int)
    parser.add_argument("-t", "--types", type=int)
    parser.add_argument("-e", '--epsilon', type=float, default=0.25)
    parser.add_argument("-d", "--delta", type=float, default=0.1)
    parser.add_argument("--IR", action='store_true')
    args = parser.parse_args()

    n_players = args.players
    n_types = args.types
    epsilon = args.epsilon
    delta = args.delta

    # Run
    BB = defaultdict(list)
    IR = defaultdict(list)
    for seed in range(10):
        rng=np.random.RandomState(seed)
        game = get_game(n_players, n_types, seed=seed)
        #game = load_learning_results(game, n_players, n_types, epsilon, delta, seed, IR=args.IR)
    
        BB["exact"] += [game.get_budget_balance(IR=args.IR)]
        IR["exact"] += [game.get_individual_rationality(p_idx, IR=args.IR) for p_idx in range(n_players)]
        for rho in [0]:
            BB[rho] += [game.get_budget_balance(epsilon=epsilon, delta=delta, rho=rho, IR=args.IR)]
            IR[rho] += [game.get_individual_rationality(p_idx, epsilon=epsilon, delta=delta, rho=rho, IR=args.IR) for p_idx in range(n_players)]

        save_learning_results(game, n_players, n_types, epsilon, delta, seed, IR=args.IR)

    if args.IR:
        filename = f"BB_{n_players}_{n_types}_{epsilon}_{delta}_IRok.pkl"
    else:
        filename = f"BB_{n_players}_{n_types}_{epsilon}_{delta}_BBok.pkl"
    with open(result_path.joinpath(filename), mode='wb') as fo:
        pickle.dump(BB, fo)

    if args.IR:
        filename = f"IR_{n_players}_{n_types}_{epsilon}_{delta}_IRok.pkl"
    else:
        filename = f"IR_{n_players}_{n_types}_{epsilon}_{delta}_BBok.pkl"
    with open(result_path.joinpath(filename), mode='wb') as fo:
        pickle.dump(IR, fo)
    

    plt.rcParams["font.size"] = 16
    
    # Plot Budget Balance

    min_v = min([np.min(v) for v in BB.values()]) - 0.05
    max_v = max([np.max(v) for v in BB.values()]) + 0.05
    
    for rho in [0, 0.05]:
        fig = plt.figure(figsize=(3,3))
        ax = plt.axes()
        ax.plot([min_v, max_v], [min_v, max_v], linestyle="-", linewidth=1, color="gray")
        for exact, learn in zip(BB["exact"], BB[rho]):
            ax.scatter(exact, learn, marker=".", color="red")
        ax.set_xlim([min_v, max_v])
        ax.set_ylim([min_v, max_v])
        ticks = [-0.1, 0, 0.1]
        ax.set_xticks(ticks)
        ax.set_yticks(ticks)
        ax.set_xticklabels(ticks)
        ax.set_yticklabels(ticks)
        ax.set_xlabel("Exact computation")
        ax.set_ylabel("Proposed approach")
        if args.IR:
            filename = f"amd_bb_{args.players}_{args.types}_{args.epsilon}_{args.delta}_{rho}_IRok.pdf"
        else:
            filename = f"amd_bb_{args.players}_{args.types}_{args.epsilon}_{args.delta}_{rho}_BBok.pdf"
        fig.savefig(figure_path.joinpath(filename), bbox_inches="tight")

    # Plot Individual Rationality
    
    min_v = min([np.min(v) for v in IR.values()]) - 0.05
    max_v = max([np.max(v) for v in IR.values()]) + 0.05
    
    for rho in [0, 0.05]:
        fig = plt.figure(figsize=(3,3))
        ax = plt.axes()
        ax.plot([min_v, max_v], [min_v, max_v], linestyle="-", linewidth=1, color="gray")
        for exact, learn in zip(IR["exact"], IR[rho]):
            ax.scatter(exact, learn, marker=".", s=1, color="red")
        ticks = [0, 1, 2]
        ax.set_xticks(ticks)
        ax.set_yticks(ticks)
        ax.set_xticklabels(ticks)
        ax.set_yticklabels(ticks)
        ax.set_xlim([min_v, max_v])
        ax.set_ylim([min_v, max_v])
        ax.set_xlabel("Exact computation")
        ax.set_ylabel("Proposed approach")
        if args.IR:
            filename = f"amd_ir_{args.players}_{args.types}_{args.epsilon}_{args.delta}_{rho}_IRok.pdf"
        else:
            filename = f"amd_ir_{args.players}_{args.types}_{args.epsilon}_{args.delta}_{rho}_BBok.pdf"
        fig.savefig(figure_path.joinpath(filename), bbox_inches="tight")
