from __future__ import annotations
import numpy as np
from algorithms.abstract import ZeroBot
from algorithms.mu_zero import MuZeroEvaluator
from algorithms.mu_zero.node import MuZeroNode
from algorithms.utils.types import NodeType, ChoicePolicy, SpielAction
from typing import TYPE_CHECKING
if TYPE_CHECKING:
    from algorithms.utils.types import SpielGame, SpielState
    from algorithms.utils.params import Params
    from typing import List, Tuple


class MuZeroBot(ZeroBot):
    """Bot that uses NDMZ variant of the Monte-Carlo Tree Search algorithm."""

    CHANCE_PLAYER_ID = None
    TERMINAL_PLAYER_ID = None

    def __init__(self,
                 game: SpielGame,
                 evaluator: MuZeroEvaluator,
                 params: Params,
                 uct_c: float = 1.0,
                 random_state=None,
                 child_selection_fn=MuZeroNode.puct_value,
                 dirichlet_noise: Tuple[float, float] = (0.25, 1.),
                 verbose=False):
        """Initializes the NDMZ variant of the MCTS Search algorithm in the form of a bot.

        In multiplayer games, or non-zero-sum games, the players will play the
        greedy strategy.

        Args:
          game: A pyspiel.Game to play.
          uct_c: The exploration constant for UCT.
          evaluator: A `Evaluator` object to use to evaluate a leaf node.
          random_state: An optional numpy RandomState to make it deterministic.
          child_selection_fn: A function to select the child in the descent phase.
            The default is UCT.
          dirichlet_noise: A tuple of (epsilon, alpha) for adding dirichlet noise to
            the policy at the root. This is from the alpha-zero paper.
          verbose: Whether to print information about the search tree before
            returning the action. Useful for confirming the search is working
            sensibly.
        """
        ZeroBot.__init__(self, game, params, verbose)
        self.evaluator = evaluator
        self.uct_c = uct_c
        self.max_simulations = params.num_simulations
        self.verbose = verbose
        self.max_utility = game.max_utility()
        self._dirichlet_noise = dirichlet_noise
        self._random_state = random_state or np.random.RandomState()
        self._child_selection_fn = child_selection_fn
        self._num_actions = params.num_actions
        MuZeroBot.CHANCE_PLAYER_ID = params.chance_player_id
        MuZeroBot.TERMINAL_PLAYER_ID = params.terminal_player_id

    def search(self, state: SpielState) -> MuZeroNode:
        """
        The primary entry point of the search. Initializes root, then loops until `max_simulations` are
        reached. The _tree_select method returns a list of nodes, which are then backed up with the
        value resulting from the search.
        A simulation is only considered complete when the leaf node is a terminal node or a choice node.
        """
        root = MuZeroNode(None, None, 1)
        self.expand_root(root, state)
        i = self.max_simulations - 1
        while i > 0:
            visit_path = self._tree_select(root, state)
            self._backup(visit_path)
            last_node = visit_path[-1]
            if last_node.type is NodeType.CHOICE or last_node.type is NodeType.TERMINAL:
                i -= 1
        return root

    def expand_root(self, root: MuZeroNode, state: SpielState) -> None:
        """
        Special consideration is needed for the root node, which is a choice node by default. It has no tau node parent;
        instead we give it an "action" that is the current player in the state, corresponding to the action
        of the absent tau node. The possible child actions are masked to include only legal moves, as well.
        We apply noise and increment the visit count before finishing.
        """
        root.type = NodeType.CHOICE
        root.hidden_state = self.evaluator.root_representation(state)
        root.action = state.current_player()
        legal_actions = state.legal_actions()
        mask = np.array(state.legal_actions_mask())
        root_prediction_output = self.evaluator.root_prediction(root.hidden_state, legal_actions, mask)
        root.value, root.child_actions, root.child_priors = root_prediction_output
        root.state_string = str(state)
        action_priors = list(zip(root.child_actions, root.child_priors))
        if self._dirichlet_noise:
            action_priors = self.add_noise(action_priors)
        root.children = [MuZeroNode(root, action, prior) for action, prior in action_priors]
        root.total_value += root.value
        root.explore_count = 1

    @staticmethod
    def _backup(visit_path: List[MuZeroNode]) -> None:
        """
        The backup for NDMZ differs from MuZero in a few respects. We want to back up the value
        of the choice nodes, since those are what are trained against in terms of game outcome.

        """
        last_node = visit_path[-1]
        if last_node.type is NodeType.CHOICE:
            value = last_node.value
            player = last_node.action
            explore = 1
        elif last_node.type is NodeType.TERMINAL:
            value = last_node.value
            player = last_node.parent.parent.action
            explore = 1
        else:
            value = 0
            player = -1
            explore = 0
        for node in reversed(visit_path):
            # If the node is an identity node, and the parent is a choice node,
            # we want to change the value in order to influence PUCT choice of the node
            # in future simulations. If the player is the same, we increase the value,
            # otherwise we decrease the value
            if node.type is NodeType.TAU and node.parent.type is NodeType.CHOICE:
                if node.parent.action is player:
                    node.total_value += value
                else:
                    node.total_value -= value
            # This does not affect seach, but is useful for logging the search's overall
            # appraisal of the choice node
            elif node.type is NodeType.CHOICE:
                if node.action is player:
                    node.total_value += value
                else:
                    node.total_value -= value
            node.explore_count += explore

    def _tree_select(self, root: MuZeroNode, state: SpielState) -> List[MuZeroNode]:
        state = state.clone()  # type: SpielState
        visit_path = [root]
        current_node = root
        is_valid = True
        while current_node.type is not NodeType.TERMINAL:
            if current_node.type is NodeType.UNKNOWN:
                # If we do not know the type of the node, that means that it is unexpanded
                self.expand(current_node, state)
                return visit_path
            elif current_node.type is NodeType.CHANCE or current_node.type is NodeType.TAU:
                # If it is a chance node or an identity node, then we take roulette wheel selection
                # of either the chance action or the player identity based on the priors taken
                # from the prediction network
                action = np.random.choice(current_node.child_actions, p=current_node.child_priors)
                chosen_child = next(c for c in current_node.children if c.action == action)
                if is_valid and current_node.type is NodeType.CHANCE: #
                    try:
                        state.apply_action(chosen_child.action)
                    except:
                        is_valid = False
                        state = 'INVALID STATE'
            elif current_node.type is NodeType.CHOICE:
                # We select the child of a choice node based on PUCT
                max_key = lambda c: self._child_selection_fn(c, current_node.explore_count, self.uct_c)
                chosen_child = max(current_node.children, key=max_key)  # type: MuZeroNode
                if is_valid:
                    try:
                        state.apply_action(chosen_child.action)
                    except:
                        is_valid = False
                        state = 'INVALID STATE'
            else:
                raise Exception('Invalid node type!')
            current_node = chosen_child
            visit_path.append(current_node)

        return visit_path

    def expand(self, node: MuZeroNode, state: SpielState) -> None:
        parent = node.parent
        node.state_string = str(state)
        if parent.type is NodeType.CHOICE or parent.type is NodeType.CHANCE:
            # This will always be an identity node
            node.type = NodeType.TAU
            # An identity node will always have 4 children, player 1, player 2, chance, or terminal
            possible_actions = list(range(4))
            node.child_actions = possible_actions
            # Input to the dynamics network is the hidden state of the parent, prior to the chance or choice move,
            # and the action performed by the chance or choice node
            # We perform a dynamics evaluation to produce the next hidden state and identity priors
            dynamics_output = self.evaluator.dynamics(parent.hidden_state, node.action)
            node.hidden_state, node.child_priors = dynamics_output
            # We perfom a prediction evaluation on the hidden state produced by the dynamics network
            # This will produce the priors for the chance child and choice children,
            # as well as the value of the node for PUCT
            node.chance_priors, node.choice_priors, node.value = self.evaluator.prediction(node.hidden_state)
        elif parent.type is NodeType.TAU:
            #
            node.hidden_state = parent.hidden_state
            node.child_actions = list(range(self._num_actions))
            if node.action is MuZeroBot.CHANCE_PLAYER_ID:
                node.type = NodeType.CHANCE
                node.child_priors = parent.chance_priors
                node.value = parent.value
            elif node.action is MuZeroBot.TERMINAL_PLAYER_ID:
                # A terminal node has no children or child priors
                # It will always have a value of 1
                node.child_actions, node.child_priors = [], []
                node.type = NodeType.TERMINAL
                node.value = 1
            else:
                node.type = NodeType.CHOICE
                node.child_priors = parent.choice_priors
                node.value = parent.value

        action_priors = [(action, node.child_priors[action]) for action in node.child_actions]
        node.children = [MuZeroNode(node, action, prior) for action, prior in action_priors]

    def add_noise(self, action_priors):
        epsilon, alpha = self._dirichlet_noise
        noise = self._random_state.dirichlet([alpha] * len(action_priors))
        action_priors = [(a, (1 - epsilon) * p + epsilon * n) for (a, p), n in zip(action_priors, noise)]
        return action_priors

    def step(self, state: SpielState):
        root = self.search(state)
        best = root.best_child()
        mcts_action = best.action
        return mcts_action

    def action_and_policy(self, state: SpielState):
        action, policy, node = self.action_policy_and_node(state)
        return action, policy

    def action_policy_and_node(self, state: SpielState) -> Tuple[SpielAction, ChoicePolicy, MuZeroNode]:
        legal_actions = state.legal_actions()
        if len(legal_actions) == 1:
            player_action = legal_actions[0]
            choice_policy_target = self._extractor.choice_policy_deterministic(player_action)
            root_node = MuZeroNode(None, None, 1)
        else:
            root_node = self.search(state)
            probs_list = self.get_probs_from_root_node(root_node)
            player_action = np.random.choice(range(len(probs_list)), p=probs_list)
            choice_policy_target = self._extractor.choice_policy(probs_list)
        return player_action, choice_policy_target, root_node

    def get_probs_from_root_node(self, root_node: MuZeroNode) -> ChoicePolicy:
        explore_counts = [0 for _ in range(self._game.num_distinct_actions())]
        for child in root_node.children:
            explore_counts[child.action] = child.explore_count
        sum_explore = sum(explore_counts)
        probs = []
        for explore_count in explore_counts:
            prob = explore_count / sum_explore
            probs.append(prob)
        return probs
