import os
import time
import copy
import torch
import random
import pyspiel
import numpy as np
from absl import app
import os.path as osp
from absl import flags
import tensorflow.compat.v1 as tf
from open_spiel.python import rl_environment
from open_spiel.python.algorithms.psro_v2 import strategy_selectors
import sys
sys.path.append("/home/kangjie/lsx/offline_learning/Dataset_Generator")
from generate_dataset.expert_dataset_generator.utils import rl_oracle, prso_oracle as psro_v2, \
    torch_rl_policy as rl_policy

os.environ["CUDA_VISIBLE_DEVICES"] = "6"

FLAGS = flags.FLAGS
flags.DEFINE_integer("seed", 1, "Seed.")
# flags.DEFINE_string("game_name", "leduc_poker", "Game name.")
flags.DEFINE_string("game_name", "phantom_ttt", "Name of the game")
flags.DEFINE_integer("n_players", 2, "The number of players.")
flags.DEFINE_integer("gpsro_iterations", 30, "Number of training steps for GPSRO.")

# PSRO related
flags.DEFINE_integer("number_training_episodes", int(1e5), "Number training episodes per RL policy.")
flags.DEFINE_string("meta_strategy_method", "alpharank", "Name of meta strategy computation method.")
flags.DEFINE_integer("number_policies_selected", 1, "Number of new strategies trained at each PSRO iteration.")
flags.DEFINE_integer("sims_per_entry", 1000, "Number of simulations to estimate elements of the game outcome matrix.")
flags.DEFINE_integer("prd_iterations", 50000, "Number of training steps for PRD.")
flags.DEFINE_bool("symmetric_game", False, "Whether to consider the current game as a symmetric game.")

# Rectify options
flags.DEFINE_string("rectifier", "", "Which rectifier to use. Choices are ''(No filtering),'rectified' for rectified.")
flags.DEFINE_string("training_strategy_selector", "probabilistic",
                    "Which strategy selector to use. Choices are "
                    " - 'top_k_probabilities': select top `number_policies_selected` strategies. "
                    " - 'probabilistic': Randomly samples `number_policies_selected` strategies with probability "
                    "equal to their selection probabilities. "
                    " - 'uniform': Uniformly sample `number_policies_selected` strategies. "
                    " - 'rectified': Select every non-zero-selection-probability strategy available to each player.")

# General (RL) agent parameters
flags.DEFINE_integer("hidden_layer_size", 256, "Hidden layer size")
flags.DEFINE_integer("n_hidden_layers", 4, "# of hidden layers")
flags.DEFINE_float("dqn_learning_rate", 1e-2, "DQN learning rate.")
flags.DEFINE_integer("batch_size", 32, "Batch size")
flags.DEFINE_float("sigma", 0.0, "Policy copy noise (Gaussian Dropout term).")
flags.DEFINE_string("optimizer_str", "adam", "'adam' or 'sgd'")
flags.DEFINE_integer("learn_every", 10, "Learn every [X] steps.")
flags.DEFINE_float("self_play_proportion", 0.0, "Self play proportion")
flags.DEFINE_integer("update_target_network_every", 1000, "Update target network every [X] steps")

flags.DEFINE_string("device", "cuda", "device")
FLAGS.device = "cuda" if torch.cuda.is_available() else "cpu"


def get_result_dir(iteration, ex):
    result_dir = "expert_policy/psro"
    sub_folder = FLAGS.game_name + "_" + str(FLAGS.n_players) + "_players"
    result_dir = osp.join(result_dir, sub_folder)
    if not osp.exists(result_dir):
        os.makedirs(result_dir)

    result_name = "train_{}_iterations_nash_conv_{}.pth".format(iteration, ex)
    return osp.join(result_dir, result_name)


def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    tf.random.set_random_seed(seed)
    torch.backends.cudnn.deterministic = True


# initialize DQN agent
def init_dqn_responder(env):
    state_representation_size = env.observation_spec()["info_state"][0]
    num_actions = env.action_spec()["num_actions"]

    agent_class = rl_policy.DQNPolicy
    agent_kwargs = {
        "state_representation_size": state_representation_size,
        "num_actions": num_actions,
        "hidden_layers_sizes": [FLAGS.hidden_layer_size] * FLAGS.n_hidden_layers,
        "batch_size": FLAGS.batch_size,
        "learning_rate": FLAGS.dqn_learning_rate,
        "update_target_network_every": FLAGS.update_target_network_every,
        "learn_every": FLAGS.learn_every,
        "optimizer_str": FLAGS.optimizer_str
    }
    oracle = rl_oracle.RLOracle(
        env,
        agent_class,
        agent_kwargs,
        number_policies_sampled=1000,
        number_training_episodes=FLAGS.number_training_episodes,
        self_play_proportion=FLAGS.self_play_proportion,
        sigma=FLAGS.sigma)

    agents = [agent_class(env, player_id, **agent_kwargs) for player_id in range(FLAGS.n_players)]

    for agent in agents:
        agent.freeze()
    return oracle, agents


def gpsro_looper(env, oracle, agents, game):
    """Initializes and executes the GPSRO training loop."""
    print("Game : {}".format(FLAGS.game_name))
    print("Seed: {}".format(FLAGS.seed))

    sample_from_marginals = True
    training_strategy_selector = FLAGS.training_strategy_selector or strategy_selectors.probabilistic_strategy_selector

    # define psro solver
    g_psro_solver = psro_v2.PSROSolver(env.game,
                                       oracle,
                                       initial_policies=agents,
                                       training_strategy_selector=training_strategy_selector,
                                       rectifier=FLAGS.rectifier,
                                       sims_per_entry=FLAGS.sims_per_entry,
                                       number_policies_selected=FLAGS.number_policies_selected,
                                       meta_strategy_method=FLAGS.meta_strategy_method,
                                       prd_iterations=FLAGS.prd_iterations,
                                       prd_gamma=1e-10,
                                       sample_from_marginals=sample_from_marginals,
                                       symmetric_game=FLAGS.symmetric_game)

    # run psro iterations
    start_time = time.time()
    best_exploi = 100
    best_policy_list = []
    best_meta_probabilities = []
    for gpsro_iteration in range(FLAGS.gpsro_iterations):
        print("Iteration : {}".format(gpsro_iteration))
        print("Time so far: {}".format(time.time() - start_time))

        best_utility = g_psro_solver.iteration()
        print(sum(best_utility))
        exploitabilities = sum(best_utility)

        # if best_exploi >= abs(exploitabilities):
        if (gpsro_iteration + 1) % 10 == 0:
            best_exploi = exploitabilities
            meta_probabilities = g_psro_solver.get_meta_strategies()
            policies = g_psro_solver.get_policies()
            best_policy_list = [[p[i]._policy for i in range(len(p))] for p in policies]
            best_meta_probabilities = copy.deepcopy(meta_probabilities)
            torch.save([best_policy_list, best_meta_probabilities], get_result_dir(gpsro_iteration, best_exploi))

# # compute exploitability for policies
#     # policy_list = []
#     # for i, p in enumerate(policies):
#     #     policy_list.append(psro_policy([p[i]._policy for i in range(len(p))], meta_probabilities[i], env.action_spec()["num_actions"]))
#
#     # aggregator = policy_aggregator.PolicyAggregator(env.game)
#     # aggr_policies = aggregator.aggregate(range(FLAGS.n_players), policies, meta_probabilities)
#     # aggr_policies = tabular_policy_from_callable(game, policy_list, device=FLAGS.device)
#     #
#     # exploitabilities, _ = exploitability.nash_conv(env.game, aggr_policies, return_only_nash_conv=False)
#     # print(exploitabilities)
#     # if best_exploi > exploitabilities:
#     #     # save policy
#     #     best_exploi = exploitabilities
#         best_policy_list = [[p[i]._policy for i in range(len(p))] for p in policies]
#         best_meta_probabilities = copy.deepcopy(meta_probabilities)
#
#     torch.save([best_policy_list, best_meta_probabilities], get_result_dir(FLAGS.gpsro_iterations, best_exploi))


def main(argv):
    if len(argv) > 1:
        raise app.UsageError("Too many command-line arguments.")

    setup_seed(FLAGS.seed)

    # load game
    # game = pyspiel.load_game(FLAGS.game_name, {"players": FLAGS.n_players})
    game = pyspiel.load_game(FLAGS.game_name, {"obstype": "reveal-nothing"})
    # game = pyspiel.load_game(FLAGS.game_name, {"obstype": "reveal-numturns"})

    env = rl_environment.Environment(game)
    env.seed(FLAGS.seed)

    # Initialize oracle and agents
    oracle, agents = init_dqn_responder(env)
    gpsro_looper(env, oracle, agents, game)


if __name__ == "__main__":
    app.run(main)
