"""MuZero Bot implemented in TensorFlow."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
from algorithms.abstract import ZeroBot
from algorithms.abstract.gymnasium import Gymnasium
from algorithms.mu_zero import MuZeroBot
from algorithms.mu_zero.game_history import MuZeroGameHistory
from algorithms.mu_zero.types import MuZeroHistoryItem
from algorithms.mu_zero.node import MuZeroNode
from algorithms.utils.types import SpielGame, SpielState
from algorithms.utils.params import Params


class MuZeroGymnasium(Gymnasium):

    def __init__(self,
                 game: SpielGame,
                 bellman_bot: ZeroBot,
                 params: Params):
        """
        MuZero constructor.

        Args:
          game: a pyspiel.Game object
          bot: a MuZeroBot object.
          params: a Params named tuple
        """

        Gymnasium.__init__(self, game, bellman_bot, params)
        self._image_width = params.image_width
        self._num_players = params.num_players
        self._chance_player_id = params.chance_player_id
        self._pass_action = params.pass_action

    def self_play_single(self, bot: MuZeroBot, with_nodes: bool = False) -> MuZeroGameHistory:
        state = self._game.new_initial_state()  # type: SpielState
        final_target = self._helper.final_target()
        final_action_image = self._helper.final_action_image()

        game_history = MuZeroGameHistory(final_target, final_action_image)
        game_has_started = False
        self.steps += 1

        while not state.is_terminal():
            if state.is_chance_node():
                outcomes, probs = zip(*state.chance_outcomes())
                action = np.random.choice(outcomes, p=probs)
                tau_target = self._helper.tau_policy_deterministic(self._chance_player_id)
                # chance_target = self._helper.chance_policy(probs)
                chance_target = self._helper.chance_policy_deterministic(action)
                choice_target = self._helper.choice_policy_deterministic(self._pass_action)
                state_feature = self._helper.state_feature_extractor(state)
                action_image = self._helper.action_to_image(action)
                node = MuZeroNode(None, -1, 1)
            else:
                game_has_started = True
                tau_target = self._helper.tau_policy_deterministic(state.current_player())
                chance_target = self._helper.chance_policy_deterministic(self._pass_action)
                action, choice_target, node = bot.action_policy_and_node(state)
                state_feature = self._helper.state_feature_extractor(state)
                action_image = self._helper.action_to_image(action)
                if with_nodes:
                    if len(state.legal_actions()) > 1:
                        bellman_action = self._bellman_bot.step(state)
                        node.bellman_action = bellman_action
                    else:
                        node.bellman_action = action
                else:
                    node = MuZeroNode(None, -1, 1)
            if game_has_started:
                history_item = MuZeroHistoryItem(
                    action=action,
                    state_string=str(state),
                    tau_target=tau_target,
                    action_image=action_image,
                    choice_target=choice_target,
                    chance_target=chance_target,
                    node=node,
                    active_player=state.current_player(),
                    state_feature=state_feature,
                    value=0)
                game_history.store(history_item)
            state.apply_action(action)

        game_history.update_for_board_games(state.returns())
        return game_history

