import os
import time
import torch
import random
import pyspiel
import logging
import itertools
import numpy as np
import os.path as osp
import mb_jpsro.mb_jpsro_dqn as jpsro
from network.env_model import DynamicModel
from open_spiel.python import rl_environment


GAMES = ("kuhn_poker_3p", "trade_comm_2p_2i", "sheriff_2p_gabriele")
BR_SELECTIONS = ("all", "all_novel", "random", "random_novel", "largest_gap")
INIT_POLICIES = ("uniform", "random_deterministic")
UPDATE_PLAYERS_STRATEGY = ("all", "cycle", "random")
BRS = ("cce", "ce")
META_SOLVERS = ("uni", "undominated_uni", "rj", "undominated_rj", "rd",  "undominated_rd", "prd", "alpharank", "mgce",
                "min_epsilon_mgce", "approx_mgce", "rmwce", "mwce", "rvce", "mgcce", "min_epsilon_mgcce", "approx_mgcce",
                "rmwcce", "mwcce", "rvcce")

SWEEP_KWARGS = [dict(game_name=game,
                     iterations=iterations,
                     policy_init=policy_init,
                     update_players_strategy=update_players_strategy,
                     target_equilibrium=target_equilibrium,
                     br_selection=br_selection,
                     train_meta_solver=train_meta_solver,
                     # eval_meta_solver=eval_meta_solver,
                     ignore_repeats=ignore_repeats)
                for (iterations,
                     game,
                     policy_init,
                     update_players_strategy,
                     target_equilibrium,
                     br_selection,
                     train_meta_solver,
                     # eval_meta_solver,
                     ignore_repeats) in itertools.product([10],
                                                          ["leduc_poker_3p"],
                                                          ["dqn"],
                                                          ["all"],
                                                          ["cce"],
                                                          ["largest_gap"],
                                                          ["mgcce"],
                                                          # ["mwce"],
                                                          [True])]


def get_policy_result_dir(seed, game_name, data_size, proportion):
    result_dir = "mb_method_results/mb_jpsro_train_policy/" + game_name + "train_data_{}".format(data_size)
    if not osp.exists(result_dir):
        os.makedirs(result_dir)

    result_name = "seed_{}".format(seed) + "_policy_proportion_" + str(proportion) + ".pkl"
    return osp.join(result_dir, result_name)


def get_nash_conv_result_dir(seed, game_name, data_size):
    result_dir = "mb_method_results/mb_jpsro_train_nash_conv/" + game_name + "train_data_{}".format(data_size)
    if not osp.exists(result_dir):
        os.makedirs(result_dir)

    result_name = "seed_{}".format(seed) + "_conv_results_7.pkl"
    return osp.join(result_dir, result_name)


def get_game(game_name):
    """Returns the game."""
    if game_name == "kuhn_poker_3p":
        game_name = "kuhn_poker"
        game_kwargs = {"players": int(3)}
    elif game_name == "leduc_poker_3p":
        game_name = "leduc_poker"
        game_kwargs = {"players": int(3)}
    elif game_name == "trade_comm_2p_2i":
        game_name = "trade_comm"
        game_kwargs = {"num_items": int(2)}
    elif game_name == "sheriff_2p_gabriele":
        game_name = "sheriff"
        game_kwargs = {
            "item_penalty": float(1.0),
            "item_value": float(5.0),
            "max_bribe": int(2),
            "max_items": int(10),
            "num_rounds": int(2),
            "sheriff_penalty": float(1.0),
        }
    else:
        raise ValueError("Unrecognised game: %s" % game_name)
    return pyspiel.load_game_as_turn_based(game_name, game_kwargs)


def setup_seed(seed):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True


def test_jpsro_cce(kwargs):
    seed = 1

    game = get_game(kwargs["game_name"])
    device = "cuda" if torch.cuda.is_available() else "cpu"

    env = rl_environment.Environment(game)
    env.seed(seed)
    setup_seed(seed)

    state_representation_size = env.observation_spec()["info_state"][0]
    num_actions = env.action_spec()["num_actions"]
    number_train = int(1e4)

    agent_kwargs = {
        "state_representation_size": state_representation_size,
        "num_actions": num_actions,
        "hidden_layers_sizes": [256] * 2,
        "batch_size": 32,
        "learning_rate": 1e-2,
        "update_target_network_every": 1000,
        "learn_every": 10,
        "optimizer_str": "adam"
    }
    kwargs["agent_kwargs"] = agent_kwargs
    kwargs["number_train"] = number_train

    state = game.new_initial_state()
    chance_action = len(state.chance_outcomes())
    num_actions = max(game.num_distinct_actions(), chance_action)

    env_model_location = "mix_offline_dataset_trained_env_model/"
    game_name = "leduc_poker_3_players/"
    data_size = 10000
    env_model_file = "game_leduc_poker_players_3_hidden_layer_128_buffer_10000_lr_0.05_train_epoch_10000_batch_size_128.pkl"
    nash_conv_list = []
    for proportion in range(7, 11):
        logging.info("proportion: %d", proportion)
        model_location = env_model_location + game_name + "env_model_proportion_" + str(proportion) + "/" + env_model_file
        trained_model = torch.load(model_location, map_location=torch.device('cpu'))

        env_args = {"env": env,
                    "state_length": env.observation_spec()["info_state"][0],
                    "env_action_number": num_actions,
                    "legal_action_number": game.num_distinct_actions(),
                    "player_number": env.num_players,
                    "game_length": 10,
                    "env_model": trained_model,
                    "use_round": True,
                    "use_deterministic_env_model": True,
                    "device": device}
        env_model = DynamicModel(**env_args)

        final_conv, best_policy = jpsro.run_loop(env_model=env_model, **kwargs)
        
        torch.save(best_policy, get_policy_result_dir(seed, game_name, data_size, proportion))
        nash_conv_list.append(final_conv)
        torch.save(nash_conv_list, get_nash_conv_result_dir(seed, game_name, data_size))


if __name__ == "__main__":
    current_time = time.strftime("%H:%M:%S_%Y-%m-%d", time.localtime(time.time()))
    logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(filename)s:%(lineno)d - %(message)s',
                        handlers=[logging.FileHandler(f"./log_files/runtime_log_{current_time}.log", mode="a"), 
                        logging.StreamHandler()])

    test_jpsro_cce(SWEEP_KWARGS[0])
