import argparse
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
from collections import defaultdict
from exp_bandit_core import get_game, run, save_learning_results, load_learning_results
from ex_ante_amd.simple_game import simple_single_item_matching, SimpleGame


result_path = Path("results")
result_path.mkdir(exist_ok=True)
figure_path = Path("figs")
figure_path.mkdir(exist_ok=True)


if __name__ == "__main__":

    parser = argparse.ArgumentParser(description='RMSE in BB and IR')
    parser.add_argument("-p", "--players", type=int)
    parser.add_argument("-t", "--types", type=int)
    parser.add_argument("-e", "--epsilon", type=float)
    parser.add_argument("-d", "--delta", type=float, default=0.1)
    parser.add_argument('--plot', action='store_true')
    args = parser.parse_args()

    n_players = args.players
    n_types = args.types
    delta = args.delta

    n_seeds = 10

    if args.epsilon:
        epsilon = args.epsilon    

        # Run
        results = list()
        for seed in range(n_seeds):
            rng=np.random.RandomState(seed)

            game = SimpleGame(
                n_players=n_players,
                n_types=n_types,
                type_space=np.arange(-n_types, n_types + 1),
                single_item_matching=simple_single_item_matching,
                rng=rng
            )
    
            BB = defaultdict(list)
            IR = defaultdict(list)
            BB["exact"] += [game.get_budget_balance()]
            IR["exact"] += [game.get_individual_rationality(p_idx) for p_idx in range(n_players)]
            BB["learn"] += [game.get_budget_balance(epsilon=epsilon, delta=delta)]
            IR["learn"] += [game.get_individual_rationality(p_idx, epsilon=epsilon, delta=delta) for p_idx in range(n_players)]

            save_learning_results(game, n_players, n_types, epsilon, delta, seed)

            RMSE_BB = np.sqrt(np.mean((np.array(BB["exact"]) - np.array(BB["learn"]))**2))
            RMSE_IR = np.sqrt(np.mean((np.array(IR["exact"]) - np.array(IR["learn"]))**2))
            results.append([game._n_evaluations, game.get_n_evaluated(), RMSE_BB, RMSE_IR])
    
        results = np.array(results)

        filename = f"results_rmse_{n_players}_{n_types}_{epsilon}.npy"
        np.save(result_path.joinpath(filename), results)

        
    if args.plot:
        
        plt.rcParams["font.size"] = 16
        cmap = plt.get_cmap("rainbow")
        
        epsilons = [1.0, 0.5, 0.4, 0.3, 0.25, 0.2, 0.15]

        results = dict()
        for epsilon in epsilons:
            filename = f"results_rmse_{n_players}_{n_types}_{epsilon}.npy"
            file_path = result_path.joinpath(filename)
            results[epsilon] = np.load(file_path)

        fig = plt.figure(figsize=(4,3))
        ax = plt.axes()
        for i, epsilon in enumerate(epsilons):
            plt.scatter(
                results[epsilon][:, 0],
                results[epsilon][:, 2],
                marker=".",
                color=cmap(i / (len(epsilons) - 1)),
                label=f"{epsilons[i]}"
            )
        #ax.set_xlim([0, max([max(v[:,0]) for v in results.values()])*1.01])
        ax.set_ylim([0, 0.43])
        ticks = [0, 0.1, 0.2, 0.3, 0.4]
        ax.set_yticks(ticks)
        ax.set_yticklabels(ticks)
        ax.set_xlabel("Sample size")
        ax.set_ylabel("RMSE in BB")
        ax.set_xscale("log")
        filename = f"BB_rmse_{n_players}_{n_types}.pdf"
        fig.savefig(figure_path.joinpath(filename), bbox_inches="tight")

        fig = plt.figure(figsize=(4,3))
        ax = plt.axes()
        for i, epsilon in enumerate(epsilons):
            plt.scatter(
                results[epsilon][:, 0],
                results[epsilon][:, 3],
                marker=".",
                color=cmap(i / (len(epsilons) - 1)),
                label=f"{epsilons[i]}"
            )
        #ax.set_xlim([0, max([max(v[:,0]) for v in results.values()])*1.01])
        ax.set_ylim([0, 0.43])
        ticks = [0, 0.1, 0.2, 0.3, 0.4]
        ax.set_yticks(ticks)
        ax.set_yticklabels(ticks)
        ax.set_xlabel("Sample size")
        ax.set_ylabel("RMSE in IR")
        ax.set_xscale("log")
        filename = f"IR_rmse_{n_players}_{n_types}.pdf"
        fig.savefig(figure_path.joinpath(filename), bbox_inches="tight")
