import numpy as np, time, joblib
from abc import ABC, abstractmethod
from algorithms.abstract import Evaluator, ZeroBot, Gymnasium, ReplayBuffer
from algorithms.abstract.game_history import GameHistory
from algorithms.utils.types import SpielGame, LossValues, ModelWeights, DynamicsTestLog
from algorithms.utils.params import Params
from typing import Tuple, List


class ZeroWorker(ABC):
    def __init__(self,
                 i: int,
                 game: SpielGame,
                 evaluator: Evaluator,
                 bots: Tuple[ZeroBot, ZeroBot, ZeroBot],
                 gymnasium: Gymnasium,
                 params: Params):
        self._rng = np.random.RandomState(int(time.time()) // (i + 1))
        self._id = i
        self._game = game
        self._k = params.k
        self._self_play_agent = params.self_play_agent
        self._zero_bot, self._bellman_bot, self._random_bot = bots
        self._evaluator = evaluator
        self._gymnasium = gymnasium
        self._replay_buffer = None  # type: ReplayBuffer

    def set_weights(self, model_weights: ModelWeights):
        # if self._id == 0:
        #     print('worker #0 received weights hash:', joblib.hash(model_weights))
        self._evaluator.set_weights(model_weights)
        return 'ok'

    def set_replay_buffer(self, replay_buffer) -> None:
        self._replay_buffer = replay_buffer

    def update_replay_buffer(self, game_histories: List[GameHistory]):
        for gh in game_histories:
            self._replay_buffer.add(gh)

    def get_weights(self) -> ModelWeights:
        return self._evaluator.get_weights()

    def train(self, num_epoch_iters: int, batch_size: int, lr: float):
        import os
        os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
        import tensorflow as tf
        optimizer = tf.keras.optimizers.Adam(learning_rate=lr)
        all_losses = []  # type: List[LossValues]
        for _ in range(num_epoch_iters):
            train_data = self._replay_buffer.sample(batch_size, k=self._k)
            grads, losses = self._evaluator.update(train_data, k=self._k, optimizer=optimizer)
            all_losses.append(losses)
        return self.get_weights(), all_losses

    def get_grads(self, batch_size: int):
        train_data = self._replay_buffer.sample(batch_size, k=self._k)
        grads, losses = self._evaluator.update(train_data, k=self._k, optimizer=None)
        return grads, losses

    def self_play(self, num_self_play_games=5000, agent='zero', with_nodes=False):
        """Uses the current state of the net with MCTS to play full games against.

    Args:
        num_self_play_games: the number of self-play games to play using the
            current net and MCTS.
        agent: which agent is to play: the zero agent, the random agent, or bellman agent
        with_nodes: `bool`, whether to return the search nodes along with the game history

    """
        if agent == 'zero':
            bot = self._zero_bot
        elif agent == 'random':
            bot = self._random_bot
        elif agent == 'bellman':
            bot = self._bellman_bot
        else:
            raise ValueError('Invalid agent passed to self-play.')

        game_histories = []
        for i in range(num_self_play_games):
            game_history = self._gymnasium.self_play_single(bot, with_nodes=with_nodes)
            game_histories.append(game_history)
        return game_histories

    def evaluate(self, num_eval_skill: int, num_eval_dyn: int) -> Tuple[int, int, DynamicsTestLog]:
        rand_wins = self.eval_skill(num_eval_skill, self._random_bot)
        bell_wins = self.eval_skill(num_eval_skill, self._bellman_bot)
        dyn_results = self.eval_dynamics(num_eval_dyn)
        return rand_wins, bell_wins, dyn_results

    def eval_skill(self, num_evals: int, adversary_bot: ZeroBot) -> int:
        wins_first, _ = self.evaluate_bots([self._zero_bot, adversary_bot],
                                           num_evaluations=num_evals)
        _, wins_second = self.evaluate_bots([adversary_bot, self._zero_bot],
                                            num_evaluations=num_evals)
        return wins_first + wins_second

    def evaluate_bots(self, bots: List[ZeroBot], num_evaluations: int) -> Tuple[int, int]:
        """Plays bots against each other, returns terminal utility for each bot."""
        wins = 0
        losses = 0
        for _ in range(num_evaluations):
            state = self._game.new_initial_state()
            while not state.is_terminal():
                if state.is_chance_node():
                    outcomes, probs = zip(*state.chance_outcomes())
                    action = self._rng.choice(outcomes, p=probs)
                    state.apply_action(action)
                else:
                    current_player = state.current_player()
                    legal_actions = state.legal_actions()
                    if len(legal_actions) == 1:
                        state.apply_action(legal_actions[0])
                    else:
                        action = bots[current_player].step(state)
                        state.apply_action(action)
            returns = state.returns()
            if int(returns[0]) == 1:
                wins += 1
            else:
                losses += 1
        return wins, losses

    @abstractmethod
    def eval_dynamics(self, num_iterations: int) -> DynamicsTestLog:
        raise NotImplementedError



