from __future__ import absolute_import, division, print_function
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import ray, logging, tensorflow as tf, numpy as np, sys

from absl import app

from algorithms.abstract import ZeroWorker
from algorithms.alpha_zero import AlphaZeroEvaluator, AlphaZeroReplayBuffer
from algorithms.alpha_zero.worker import AlphaZeroWorker
from algorithms.mu_zero.game_history import MuZeroGameHistory
from algorithms.mu_zero.worker import make_muzero_worker
from nannon.params import init_game_and_params
from algorithms.utils.types import SpielGame
from algorithms.utils.params import Params
from algorithms.services import eval, train, play, backup
from algorithms.mu_zero import MuZeroEvaluator, MuZeroReplayBuffer
from typing import List, Set

logger = tf.get_logger()
logger.setLevel(logging.ERROR)
logging.getLogger('tensorflow').disabled = True
np.set_printoptions(suppress=True)
np.seterr(divide='ignore')


def make_workers(num_workers: int, num_cpus: int, num_gpus: float, game: SpielGame, params: Params,) -> List[ZeroWorker]:
    """
    Makes ray remote workers.
    Args:
        num_workers: `int`, how many workers to create
        num_cpus: `int`, how many cpus per ray worker
        num_gpus: `float`, how many gpu fractions per ray worker
        game: `SpielGame`, an openspiel game object
        params: `Params`, a named tuple with all game-specific parameters
    Returns:
        Either a list of MuZeroWorker or AlphZeroWorker objects
    Raises:
        ValueError if invalid algorithm is passed
    """
    if params.algorithm == 'mu_zero':
        return [make_muzero_worker(id_,
                                   game,
                                   params,
                                   num_cpus=num_cpus,
                                   num_gpus=num_gpus)
                for id_ in range(num_workers)]
    elif params.algorithm == 'alpha_zero':
        return [AlphaZeroWorker.remote(id_, game, params) for id_ in range(num_workers)]
    else:
        raise ValueError('Invalid algorithm!')


def analyze(game: SpielGame, params: Params) -> None:
    """
    Function for analyzing results of MuZero training. Prints information from nodes who did
    not select the bellman action.
    """
    worker = make_muzero_worker(0, game, params)
    print('made worker')
    games = ray.get(worker.self_play.remote(5, with_nodes=True))  # type: List[MuZeroGameHistory]

    num_wrong = 0
    for g in games:
        for i, item in enumerate(g.items):
            if item.node.bellman_action and item.node.bellman_action > -1:
                zero_action = item.action
                zero_node = item.node
                if zero_action != item.node.bellman_action:
                    print('----')
                    print(zero_node)
                    print(zero_node.state_string)
                    print('priors:', list(zip(zero_node.child_actions, zero_node.child_priors)))
                    print('bell:', item.node.bellman_action)
                    print('zero:', zero_action)
                    num_wrong += 1
    print('num_wrong:', num_wrong)


@ray.remote
def test_nodes() -> Set[str]:
    """
    Function for testing whether all members of the cluster are accessible to Ray.
    Each element of the set is an IP address of a node in the cluster.
    """
    import time
    time.sleep(.01)
    return ray.services.get_node_ip_address()


def main(_):
    """
    The overall algorithm is broken into three services, each of which is meant to be run
    on its own separate cluster. This includes a training service, a self-play service,
    and an evaluation service. The `main` method should be called from a docker container on the head node
    of each cluster, with different arguments passed to direct which service will be created.

    See algorithms/utils/params.py for a class of the params named tuple by type.
    See nannon/params.py for an example of how the arguments can be parsed into a params object.

    Two algorithms are supported: Nondeterministic MuZero and Nondeterministic AlphaZero.
    """
    game, params = init_game_and_params()
    import os;
    if params.local:
        print('starting local ray')
        ray.init(memory=5.3e8)
    else:
        print('starting ray in cloud')
        ray.init(address="auto")
    # print('IP addresses in use:', set(ray.get([test_nodes.remote() for _ in range(1000)])))

    if params.algorithm == 'mu_zero':
        evaluator = MuZeroEvaluator(params)
        replay_buffer = MuZeroReplayBuffer(params.buffer_capacity)
    elif params.algorithm == 'alpha_zero':
        evaluator = AlphaZeroEvaluator(params)
        replay_buffer = AlphaZeroReplayBuffer(params.buffer_capacity)
    elif params.algorithm == 'analyze':
        analyze(game, params)
        return
    else:
        raise ValueError('Invalid algorithm')
    if params.service == 'train':
        workers = make_workers(params.num_train_workers,
                               params.num_train_cpus,
                               params.num_train_gpus,
                               game,
                               params)
        service = train.TrainService(evaluator, workers, replay_buffer, params)
    elif params.service == 'eval':
        workers = make_workers(params.num_eval_workers,
                               params.num_eval_cpus,
                               params.num_eval_gpus,
                               game,
                               params)
        service = eval.EvaluationService(evaluator, workers, params)
    elif params.service == 'play':
        workers = make_workers(params.num_play_workers,
                               params.num_play_cpus,
                               params.num_play_gpus,
                               game,
                               params)
        service = play.SelfPlayService(evaluator, workers, params)
    elif params.service == 'backup':
        ray.shutdown()
        service = backup.BackupService(replay_buffer, params)
    else:
        raise ValueError('Invalid service')
    try:
        service.start()
    except KeyboardInterrupt:
        ray.shutdown()
        print('finished with early cancellation.')
        sys.exit()


if __name__ == "__main__":
    app.run(main)
