import pickle
import numpy as np
from pathlib import Path
from bandit.best_reward import SuccessiveEliminationBME
from ex_ante_amd.simple_game import simple_single_item_matching, SimpleGame, SimpleMAB


result_path = Path("results")
result_path.mkdir(exist_ok=True)


def save_learning_results(game, n_players, n_types, epsilon, delta, seed, IR=False):

    if IR:
        filename = f"learn_{n_players}_{n_types}_{epsilon}_{delta}_{seed}_IR.npy"
    else:
        filename = f"learn_{n_players}_{n_types}_{epsilon}_{delta}_{seed}.npy"
    file_path = result_path.joinpath(filename)
    
    x = np.hstack([game._estimated_pivots, game._n_evaluations])
    np.save(file_path, x)


def load_learning_results(game, n_players, n_types, epsilon, delta, seed, IR=False):

    if IR:
        filename = f"learn_{n_players}_{n_types}_{epsilon}_{delta}_{seed}_IR.npy"
    else:
        filename = f"learn_{n_players}_{n_types}_{epsilon}_{delta}_{seed}.npy"
    file_path = result_path.joinpath(filename)

    if file_path.exists():
        x = np.load(file_path)
        game._estimated_pivots = x[:-1]
        game._n_evaluations = int(x[-1])

    return game


def get_game(n_players, n_types, single_item_matching=simple_single_item_matching, seed=42, precompute=True):

    filename = f"bandit_{n_players}_{n_types}_{seed}_game.pkl"
    game_path = result_path.joinpath(filename)
    
    if game_path.exists() and precompute:

        print("loading", game_path, end="...")
        with open(game_path, mode='br') as fi:
            game = pickle.load(fi)
        print("successful")

    else:
    
        print("creating", game_path)
    
        # Prepare simple game
        game = SimpleGame(
            n_players=n_players,
            n_types=n_types,
            type_space=np.arange(-n_types, n_types + 1),
            single_item_matching=single_item_matching,
            rng=np.random.RandomState(seed)
        )

        if precompute:
            # Prepare all values for analysis
            game.compute_all_values()
            for player_idx in range(n_players):
                for type_idx in range(n_types):
                    game._get_conditional_value(player_idx, type_idx) 
    
            with open(game_path, mode='wb') as fo:
                pickle.dump(game, fo)

    return game


def run(game, n_players, n_types, player_idx, epsilon, delta, seed=42, rerun=False):

    filename = f"bandit_{n_players}_{n_types}_{player_idx}_{epsilon}_{delta}_{seed}_alg.pkl"
    alg_path = result_path.joinpath(filename)

    if alg_path.exists() and not rerun:

        with open(alg_path, mode='br') as fi:
            alg = pickle.load(fi)

    else:
    
        # Prepare game
        # game = get_game(n_players, n_types, simple_single_item_matching, precompute=False)

        # Prepare bandit environment and algorithm
        rng=np.random.RandomState(seed)
        env = SimpleMAB(game, player_idx, rng)

        max_value = simple_single_item_matching(game._type_space)
        alg = SuccessiveEliminationBME(epsilon, delta, support=[0, max_value])

        # Run bandit algorithm
        alg.run(env, detail=True)

        with open(alg_path, mode='wb') as fo:
            pickle.dump(alg, fo)

    return alg
