# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""AlphaZero Bot implemented in TensorFlow."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
from algorithms.abstract.bot import ZeroBot
from algorithms.abstract.gymnasium import Gymnasium
from algorithms.alpha_zero.game_history import AlphaZeroGameHistory
from algorithms.alpha_zero.utils import np_softmax
from algorithms.utils.types import AlphaZeroResult, SpielGame
from algorithms.utils.params import Params


class AlphaZeroGymnasium(Gymnasium):
    """AlphaZero implementation.

  Follows the pseudocode AlphaZero implementation given in the paper
  DOI:10.1126/science.aar6404.
  """

    def __init__(self, game: SpielGame, bellman_bot: ZeroBot, params: Params):
        """AlphaZero constructor.

    Args:
      game: a pyspiel.Game object
      bot: an MCTSBot object.
      replay_buffer_capacity: the size of the replay buffer in which the results
        of self-play games are stored.
      action_selection_transition: an integer representing the move number in a
        game of self-play when greedy action selection is used. Before this,
        actions are sampled from the MCTS policy.

    Raises:
      ValueError: if incorrect inputs are supplied.
    """
        Gymnasium.__init__(self, game, bellman_bot, params)
        self._feature_extractor = params.extractor.state_feature_extractor

    def self_play_single(self, bot: 'ZeroBot', with_nodes: bool = False):
        """Play a single game and add it to the replay buffer."""
        state = self._game.new_initial_state()
        policy_targets, state_features, state_strings, player_to_act, nodes = [], [], [], [], []
        game_history = AlphaZeroGameHistory()

        while not state.is_terminal():
            if state.is_chance_node():
                outcomes, probs = zip(*state.chance_outcomes())
                action = np.random.choice(outcomes, p=probs)
                state.apply_action(action)
            else:
                root_node = bot.search(state)
                player_to_act.append(state.current_player())
                state_features.append(self._feature_extractor(state))
                state_strings.append(str(state))
                target_policy = np.zeros(self._game.num_distinct_actions(), dtype=np.float32)
                for child in root_node.children:
                    target_policy[child.action] = child.explore_count
                target_policy /= sum(target_policy)
                policy_targets.append(target_policy)
                node = root_node if with_nodes else None
                nodes.append(node)
                action = self._select_action(root_node.children)
                state.apply_action(action)
                self.steps += 1

        terminal_rewards = np.array(state.rewards(), dtype=np.float32)
        for i, (feature, pol, player, node) in enumerate(zip(state_features, policy_targets, player_to_act, nodes)):
            value = terminal_rewards[player]
            result = AlphaZeroResult(state_feature=feature, target_policy=pol, target_value=value, node=node)
            game_history.store(result)
        return game_history

    @staticmethod
    def _select_action(children):
        explore_counts = [(child.explore_count, child.action) for child in children]
        probs = np_softmax(np.array([i[0] for i in explore_counts]))
        action_index = np.random.choice(range(len(probs)), p=probs)
        action = explore_counts[action_index][1]
        return action

