import os
import time
import torch
import random
import pyspiel
import numpy as np
from absl import app
import os.path as osp
from absl import flags
from open_spiel.python import rl_environment
from open_spiel.python.algorithms import exploitability
from open_spiel.python.algorithms import policy_aggregator
from open_spiel.python.algorithms.psro_v2 import strategy_selectors
import mb_psro.mb_psro as psro_v2
import mb_psro.mb_psro_rl_oracle as rl_oracle
import mb_psro.torch_rl_policy as rl_policy
from network.env_model import DynamicModel

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

FLAGS = flags.FLAGS
flags.DEFINE_integer("seed", 2, "Seed.")
flags.DEFINE_string("game_name", "leduc_poker", "Game name.")
flags.DEFINE_integer("n_players", 3, "The number of players.")
flags.DEFINE_integer("game_length", 10, "Game Length")
flags.DEFINE_integer("gpsro_iterations", 30, "Number of training steps for GPSRO.")

# PSRO related
flags.DEFINE_integer("number_training_episodes", int(1e3), "Number training episodes per RL policy.")
flags.DEFINE_string("meta_strategy_method", "alpharank", "Name of meta strategy computation method.")
flags.DEFINE_integer("number_policies_selected", 1, "Number of new strategies trained at each PSRO iteration.")
flags.DEFINE_integer("sims_per_entry", 1000, "Number of simulations to estimate elements of the game outcome matrix.")
flags.DEFINE_integer("prd_iterations", 50000, "Number of training steps for PRD.")
flags.DEFINE_bool("symmetric_game", False, "Whether to consider the current game as a symmetric game.")

# Rectify options
flags.DEFINE_string("rectifier", "", "Which rectifier to use. Choices are ''(No filtering),'rectified' for rectified.")
flags.DEFINE_string("training_strategy_selector", "probabilistic",
                    "Which strategy selector to use. Choices are "
                    " - 'top_k_probabilities': select top `number_policies_selected` strategies. "
                    " - 'probabilistic': Randomly samples `number_policies_selected` strategies with probability "
                    "equal to their selection probabilities. "
                    " - 'uniform': Uniformly sample `number_policies_selected` strategies. "
                    " - 'rectified': Select every non-zero-selection-probability strategy available to each player.")

# General (RL) agent parameters
flags.DEFINE_integer("hidden_layer_size", 256, "Hidden layer size")
flags.DEFINE_integer("n_hidden_layers", 4, "# of hidden layers")
flags.DEFINE_float("dqn_learning_rate", 1e-2, "DQN learning rate.")
flags.DEFINE_integer("batch_size", 32, "Batch size")
flags.DEFINE_float("sigma", 0.0, "Policy copy noise (Gaussian Dropout term).")
flags.DEFINE_string("optimizer_str", "adam", "'adam' or 'sgd'")
flags.DEFINE_integer("learn_every", 10, "Learn every [X] steps.")
flags.DEFINE_float("self_play_proportion", 0.0, "Self play proportion")
flags.DEFINE_integer("update_target_network_every", 1000, "Update target network every [X] steps")

flags.DEFINE_bool("use_round", True, "env model whether use round to normalize state")
flags.DEFINE_bool("use_deterministic", True, "Re-init value net on each CFR iter")
flags.DEFINE_string("env_model_location", "mix_offline_dataset_trained_env_model/", "location of env model")
flags.DEFINE_string("env_model_file", "game_leduc_poker_players_3_hidden_layer_128_buffer_{}_lr_0.05_train_epoch_{}_batch_size_128.pkl", "location of env model")
flags.DEFINE_integer("replay_buffer", 50000, "env model replay buffer")
flags.DEFINE_integer("train_epoch", 10000, "env model replay buffer")

flags.DEFINE_string("device", "cuda", "device")
FLAGS.device = "cuda" if torch.cuda.is_available() else "cpu"


def get_policy_result_dir(proportion):
    result_dir = "mb_method_results/mb_psro_train_policy/" + FLAGS.game_name + '_' + str(FLAGS.n_players) + "_players/train_data_{}".format(FLAGS.replay_buffer)
    if not osp.exists(result_dir):
        os.makedirs(result_dir)

    if not FLAGS.use_round:
        result_name = "seed_{}".format(FLAGS.seed) + "_not_use_round_policy_env_train_epoch_" + str(FLAGS.train_epoch) + "_number_training_episodes_" + str(FLAGS.number_training_episodes) + \
                      "_proportion_" + str(proportion) + ".pkl"
    else:
        result_name = "seed_{}".format(FLAGS.seed) + "_policy_env_train_epoch_" + str(FLAGS.train_epoch) + "_number_training_episodes_" + str(FLAGS.number_training_episodes) +"_proportion_" + str(proportion) + ".pkl"
    return osp.join(result_dir, result_name)


def get_nash_conv_result_dir():
    result_dir = "mb_method_results/mb_psro_train_nash_conv/" + FLAGS.game_name + '_' + str(FLAGS.n_players) + "_players/train_data_{}".format(FLAGS.replay_buffer)
    if not osp.exists(result_dir):
        os.makedirs(result_dir)

    if not FLAGS.use_round:
        result_name = "seed_{}".format(FLAGS.seed) + "_not_use_round_train_epoch_" + str(FLAGS.train_epoch) + "_number_training_episodes_" + str(FLAGS.number_training_episodes) + ".pkl"
    else:
        result_name = "seed_{}".format(FLAGS.seed) + "_train_epoch_" + str(FLAGS.train_epoch) + "_number_training_episodes_" + str(FLAGS.number_training_episodes) + ".pkl"
    return osp.join(result_dir, result_name)


def init_dqn_responder(env, env_model):
    state_representation_size = env.observation_spec()["info_state"][0]
    num_actions = env.action_spec()["num_actions"]

    agent_class = rl_policy.DQNPolicy
    agent_kwargs = {
        "state_representation_size": state_representation_size,
        "num_actions": num_actions,
        "hidden_layers_sizes": [FLAGS.hidden_layer_size] * FLAGS.n_hidden_layers,
        "batch_size": FLAGS.batch_size,
        "learning_rate": FLAGS.dqn_learning_rate,
        "update_target_network_every": FLAGS.update_target_network_every,
        "learn_every": FLAGS.learn_every,
        "optimizer_str": FLAGS.optimizer_str
    }
    oracle = rl_oracle.RLOracle(
        env,
        env_model,
        FLAGS.game_length,
        agent_class,
        agent_kwargs,
        number_training_episodes=FLAGS.number_training_episodes,
        self_play_proportion=FLAGS.self_play_proportion,
        sigma=FLAGS.sigma)

    agents = [agent_class(env, player_id, **agent_kwargs) for player_id in range(FLAGS.n_players)]

    for agent in agents:
        agent.freeze()
    return oracle, agents


def gpsro_looper(env, oracle, agents):
    """Initializes and executes the GPSRO training loop."""
    print("Game : {}".format(FLAGS.game_name))
    print("Seed: {}".format(FLAGS.seed))

    sample_from_marginals = True
    training_strategy_selector = FLAGS.training_strategy_selector or strategy_selectors.probabilistic_strategy_selector

    g_psro_solver = psro_v2.PSROSolver(
        env.game,
        oracle,
        initial_policies=agents,
        training_strategy_selector=training_strategy_selector,
        rectifier=FLAGS.rectifier,
        sims_per_entry=FLAGS.sims_per_entry,
        number_policies_selected=FLAGS.number_policies_selected,
        meta_strategy_method=FLAGS.meta_strategy_method,
        prd_iterations=FLAGS.prd_iterations,
        prd_gamma=1e-10,
        sample_from_marginals=sample_from_marginals,
        symmetric_game=FLAGS.symmetric_game)

    start_time = time.time()
    for gpsro_iteration in range(FLAGS.gpsro_iterations):
        print("Iteration : {}".format(gpsro_iteration))
        print("Time so far: {}".format(time.time() - start_time))

        g_psro_solver.iteration()

    return g_psro_solver


def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True


def main(argv):
    if len(argv) > 1:
        raise app.UsageError("Too many command-line arguments.")

    setup_seed(FLAGS.seed)

    # load liar's dice game
    # game = pyspiel.load_game(FLAGS.game_name, {"players": FLAGS.n_players, "numdice": FLAGS.numdice, "dice_sides": 6})
    # load poker game
    game = pyspiel.load_game(FLAGS.game_name, {"players": FLAGS.n_players})
    # load phantom ttt game
    # game = pyspiel.load_game(FLAGS.game_name, {"obstype": "reveal-nothing"})    
    state = game.new_initial_state()
    chance_action = len(state.chance_outcomes())
    num_actions = max(game.num_distinct_actions(), chance_action)

    env = rl_environment.Environment(game)
    env.seed(FLAGS.seed)
    ex_list = []
    for proportion in range(7, 11):
        model_location = FLAGS.env_model_location + FLAGS.game_name + "_" + str(FLAGS.n_players) + "_players/" + "env_model_proportion_" \
                         + str(proportion) + "/" + FLAGS.env_model_file.format(FLAGS.replay_buffer, FLAGS.train_epoch)
        trained_model = torch.load(model_location)
        env_args = {"state_length": env.observation_spec()["info_state"][0],
                    "env_action_number": num_actions,
                    "legal_action_number": game.num_distinct_actions(),
                    "player_number": FLAGS.n_players,
                    "game_length": FLAGS.game_length,
                    "env_model": trained_model,
                    "use_round": FLAGS.use_round,
                    "use_deterministic_env_model": FLAGS.use_deterministic,
                    "device": FLAGS.device}

        env_model = DynamicModel(**env_args)

        # Initialize oracle and agents
        oracle, agents = init_dqn_responder(env, env_model)
        g_psro_solver = gpsro_looper(env, oracle, agents)
        # evaluate
        meta_probabilities = g_psro_solver.get_meta_strategies()
        policies = g_psro_solver.get_policies()
        # compute nash conv
        aggregator = policy_aggregator.PolicyAggregator(env.game)
        aggr_policies = aggregator.aggregate(range(FLAGS.n_players), policies, meta_probabilities)
        exploitabilities, expl_per_player = exploitability.nash_conv(env.game, aggr_policies,
                                                                     return_only_nash_conv=False)
        # save exploitability
        ex_list.append(exploitabilities)
        best_policy = [[p[i]._policy for i in range(len(p))] for p in policies]
        best_policy.append(meta_probabilities)

        print("conv:", ex_list)
        torch.save(ex_list, get_nash_conv_result_dir())
        torch.save(best_policy, get_policy_result_dir(proportion))


if __name__ == "__main__":
    app.run(main)
