from __future__ import absolute_import, division, print_function
import pyspiel
from typing import Tuple
from algorithms.utils.types import SpielGame
from algorithms.utils.params import Params
from absl import flags

from nannon.extractor import NannonExtractor

FLAGS = flags.FLAGS
flags.DEFINE_string("algorithm", 'mu_zero', 'Algorithm to be used, can be "mu_zero" or "alpha_zero"')
flags.DEFINE_integer("num_rounds", 100, "The number of rounds of self-play followed by neural net training.")
flags.DEFINE_string("self_play_agent", 'zero', 'Agent passed to produce self-play games')
flags.DEFINE_integer("num_simulations", 100, "The number of simulations to play per MCTS.")

flags.DEFINE_string("hidden_layer_activation", "selu", "Activation function for hidden layers")
flags.DEFINE_string("hidden_layer_initializer", "lecun_normal", "Initialization method for hidden layer weights")
flags.DEFINE_string("hidden_state_activation", "selu", "Activation function for hidden state output")
flags.DEFINE_string("hidden_state_initializer", "lecun_normal", "Initialization method for hidden state weights")
flags.DEFINE_boolean("use_batchnorm", False, "Whether to use batch normalization while training")
flags.DEFINE_boolean("scale_gradient", False, "Whether to scale gradient and shrink hidden states in training")

flags.DEFINE_string("load_weights_name", "None", 'Whether to load pretrained model weights.')
flags.DEFINE_boolean("save_weights", False, 'Whether to persist model weights while training')
flags.DEFINE_string("net_type", "dense", "The type of network to use. Can be either 'dense' or 'conv'.")
flags.DEFINE_string("device", "cpu", "Device to evaluate neural nets on. Can be 'cpu', 'tpu' or 'gpu'.")
flags.DEFINE_integer("num_hidden_weights", 256, "Number of hidden weights per layer")
flags.DEFINE_integer("num_hidden_layers", 2, "Number of layers in the neural net")

flags.DEFINE_integer("num_points", 6, "Number of points on the Nannon board")
flags.DEFINE_integer("num_chex", 3, "Number of checkers per Nannon player")
flags.DEFINE_integer("num_die", 6, "Number of sides on the die for Nannon")
flags.DEFINE_integer("k", 6, "Number of lookahead for MuZero")

flags.DEFINE_string("service", "train", 'Which cluster operation to perform.')
flags.DEFINE_boolean("local", False, 'Whether we are running locally or on the cloud.')
flags.DEFINE_boolean("delete_queues", True, 'Whether to delete the queues')

flags.DEFINE_string("train_mode", "parallel", 'Whether to use workers in training')
flags.DEFINE_integer("num_epochs", 1, "The number of passes over the replay buffer done during training.")
flags.DEFINE_integer("num_train_workers", 1, "number of workers for parallelization")
flags.DEFINE_float('num_train_cpus', 1, 'Number of CPUs assigned per worker during network updates')
flags.DEFINE_float('num_train_gpus', 1, 'Number of GPUs assigned per worker')
flags.DEFINE_integer("batch_size", 512, "Number of transitions to sample at each learning step.")
flags.DEFINE_float("learning_rate", 1e-3, "Learning rate for optimizer.")
flags.DEFINE_float("l2_regularization", 1e-4, "L2 regularization term.")

flags.DEFINE_integer("num_play_workers", 300, "number of workers for parallelization")
flags.DEFINE_float('num_play_cpus', 1, 'Number of CPUs assigned per worker during self-play')
flags.DEFINE_float('num_play_gpus', 0, 'Number of GPUs assigned per worker during self-play')
flags.DEFINE_integer("num_self_play", 300, "The number of self-play games to play in a round.")
flags.DEFINE_integer("buffer_capacity", 4000, "The size of the replay buffer.")

flags.DEFINE_integer("num_eval_workers", 500, "number of workers for parallelization")
flags.DEFINE_float('num_eval_cpus', 1, 'Number of CPUs assigned per worker during self-play')
flags.DEFINE_float('num_eval_gpus', 0, 'Number of GPUs assigned per worker during self-play')
flags.DEFINE_integer("num_eval_dyn", 500, "Number of rounds of dynamics tests to perform")
flags.DEFINE_integer("num_eval_skill", 500, "Number of rounds of skill tests to perform")

flags.DEFINE_string('id', '', 'Unique id associated with this training run')


def init_game_and_params() -> Tuple[SpielGame, Params]:
    """
    Organizes absl flags, turning them into a single namedtuple that can be serialized and sent to ray workers.
    Creates the openspiel game object as well.
    Network architecture, service-specific arguments, AMQP queue names, output filenames are defined here.
    Game-specific flags (num_points, num_chex, etc.) are used to create general-purpose params.
    Args: None
    Returns:
        An openspiel game object and params namedtuple.
    """
    spiel_parameters = {'n_points': pyspiel.GameParameter(FLAGS.num_points),
                        'n_chex': pyspiel.GameParameter(FLAGS.num_chex),
                        'n_die': pyspiel.GameParameter(FLAGS.num_die)}
    game = pyspiel.load_game('nannon', spiel_parameters)  # type: SpielGame
    extractor = NannonExtractor(game, FLAGS.num_points, FLAGS.num_chex, FLAGS.num_die,
                                game.num_distinct_actions(), FLAGS.num_points + 1)
    rep_input_shape = (4, 1, FLAGS.num_points + 2)
    rep_output_shape = rep_input_shape
    pred_input_shape = rep_input_shape
    dyn_input_shape = (5, 1, FLAGS.num_points + 2)
    dyn_output_shape = rep_output_shape
    flat_input_size = 4 * (FLAGS.num_points + 2)

    train_queue_name = 'train_in_{}_{}_{}_{}'.format(FLAGS.num_points, FLAGS.num_chex, FLAGS.num_die, FLAGS.id)
    play_queue_name = 'play_in_{}_{}_{}_{}'.format(FLAGS.num_points, FLAGS.num_chex, FLAGS.num_die, FLAGS.id)
    eval_queue_name = 'eval_in_{}_{}_{}_{}'.format(FLAGS.num_points, FLAGS.num_chex, FLAGS.num_die, FLAGS.id)
    backup_weights_queue = 'backup_weights_{}_{}_{}_{}'.format(FLAGS.num_points, FLAGS.num_chex, FLAGS.num_die, FLAGS.id)
    backup_buffer_queue = 'backup_buffer_{}_{}_{}_{}'.format(FLAGS.num_points, FLAGS.num_chex, FLAGS.num_die, FLAGS.id)
    backup_log_queue = 'backup_log_{}_{}_{}_{}'.format(FLAGS.num_points, FLAGS.num_chex, FLAGS.num_die, FLAGS.id)
    log_name = 'logs_{}_{}_{}_{}.csv'.format(FLAGS.num_points, FLAGS.num_chex, FLAGS.num_die, FLAGS.id)
    models_name = 'models_{}_{}_{}_{}.p'.format(FLAGS.num_points, FLAGS.num_chex, FLAGS.num_die, FLAGS.id)

    params = Params(num_points=FLAGS.num_points,
                    num_chex=FLAGS.num_chex,
                    num_die=FLAGS.num_die,
                    image_width=FLAGS.num_points + 2,
                    num_actions=game.num_distinct_actions(),
                    num_players=game.num_players() + 2,  # plus chance + terminal
                    chance_player_id=2,
                    terminal_player_id=3,
                    pass_action=FLAGS.num_points + 1,
                    rep_input_shape=rep_input_shape,
                    rep_output_shape=rep_output_shape,
                    pred_input_shape=pred_input_shape,
                    dyn_input_shape=dyn_input_shape,
                    dyn_output_shape=dyn_output_shape,
                    num_hidden_weights=FLAGS.num_hidden_weights,
                    num_hidden_layers=FLAGS.num_hidden_layers,
                    hidden_layer_activation=FLAGS.hidden_layer_activation,
                    hidden_layer_initializer=FLAGS.hidden_layer_initializer,
                    hidden_state_activation=FLAGS.hidden_state_activation,
                    hidden_state_initializer=FLAGS.hidden_state_initializer,
                    use_batchnorm=FLAGS.use_batchnorm,
                    flat_input_size=flat_input_size,
                    load_weights_name=FLAGS.load_weights_name,
                    learning_rate=FLAGS.learning_rate,
                    l2_regularization=FLAGS.l2_regularization,
                    num_simulations=FLAGS.num_simulations,
                    device=FLAGS.device,
                    k=FLAGS.k,
                    train_mode=FLAGS.train_mode,
                    num_epochs=FLAGS.num_epochs,
                    batch_size=FLAGS.batch_size,
                    num_rounds=FLAGS.num_rounds,
                    num_self_play=FLAGS.num_self_play,
                    self_play_agent=FLAGS.self_play_agent,
                    num_eval_skill=FLAGS.num_eval_skill,
                    num_eval_dyn=FLAGS.num_eval_dyn,
                    train_queue_name=train_queue_name,
                    play_queue_name=play_queue_name,
                    eval_queue_name=eval_queue_name,
                    backup_weights_queue=backup_weights_queue,
                    backup_buffer_queue=backup_buffer_queue,
                    backup_log_queue=backup_log_queue,
                    log_name=log_name,
                    models_name=models_name,
                    algorithm=FLAGS.algorithm,
                    local=FLAGS.local,
                    buffer_capacity=FLAGS.buffer_capacity,
                    service=FLAGS.service,
                    num_train_workers=FLAGS.num_train_workers,
                    num_play_workers=FLAGS.num_play_workers,
                    num_eval_workers=FLAGS.num_eval_workers,
                    extractor=extractor,
                    delete_queues=FLAGS.delete_queues,
                    num_train_cpus=FLAGS.num_train_cpus,
                    num_train_gpus=FLAGS.num_train_gpus,
                    num_play_cpus=FLAGS.num_play_cpus,
                    num_play_gpus=FLAGS.num_play_gpus,
                    num_eval_cpus=FLAGS.num_eval_cpus,
                    num_eval_gpus=FLAGS.num_eval_gpus,
                    scale_gradient=FLAGS.scale_gradient)
    return game, params
