import math

from algorithms.abstract.node import ZeroNode


class AlphaZeroNode(ZeroNode):
    """A node in the search tree.

    A SearchNode represents a state and possible continuations from it. Each child
    represents a possible action, and the expected result from doing so.

    Attributes:
      action: The action from the parent node's perspective. Not important for the
        root node, as the actions that lead to it are in the past.
      player: Which player made this action.
      prior: A prior probability for how likely this action will be selected.
      explore_count: How many times this node was explored.
      total_value: The sum of rewards of rollouts through this node, from the
        parent node's perspective. The average reward of this node is
        `total_reward / explore_count`
      outcome: The rewards for all players if this is a terminal node or the
        subtree has been proven, otherwise None.
      children: A list of SearchNodes representing the possible actions from this
        node, along with their expected rewards.
    """
    __slots__ = [
        "action",
        "player",
        "prior",
        "explore_count",
        "total_value",
        "outcome",
        "children",
        'bellman_action'
    ]

    def __init__(self, action, player, prior):
        ZeroNode.__init__(self, action, prior)
        self.player = player
        self.outcome = None
        self.children = []

    def uct_value(self, parent_explore_count, uct_c):
        """Returns the UCT value of child."""
        if self.outcome is not None:
            return self.outcome[self.player]

        if self.explore_count == 0:
            return float("inf")

        return self.total_value / self.explore_count + uct_c * math.sqrt(
            math.log(parent_explore_count) / self.explore_count)

    def puct_value(self, parent_explore_count, uct_c):
        """Returns the PUCT value of child."""
        if self.outcome is not None:
            return self.outcome[self.player]

        return ((self.explore_count and self.total_value / self.explore_count) +
                uct_c * self.prior * math.sqrt(parent_explore_count) /
                (self.explore_count + 1))

    def sort_key(self):
        """Returns the best action from this node, either proven or most visited.

        This ordering leads to choosing:
        - Highest proven score > 0 over anything else, including a promising but
          unproven action.
        - A proven draw only if it has higher exploration than others that are
          uncertain, or the others are losses.
        - Uncertain action with most exploration over loss of any difficulty
        - Hardest loss if everything is a loss
        - Highest expected reward if explore counts are equal (unlikely).
        - Longest win, if multiple are proven (unlikely due to early stopping).
        """
        return (0 if self.outcome is None else self.outcome[self.player],
                self.explore_count, self.total_value)

    def best_child(self):
        """Returns the best child in order of the sort key."""
        return max(self.children, key=AlphaZeroNode.sort_key)

    def children_str(self, state=None):
        """Returns the string rep of this node's children.

        They are ordered based on the sort key, so order of being chosen to play.

        Args:
          state: A `pyspiel.State` object, to be used to convert the action id into
            a human readable format. If None, the action integer id is used.
        """
        return "\n".join([
            c.to_str(state)
            for c in reversed(sorted(self.children, key=AlphaZeroNode.sort_key))
        ])

    def to_str(self, state=None):
        """Returns the string rep of this node.

        Args:
          state: A `pyspiel.State` object, to be used to convert the action id into
            a human readable format. If None, the action integer id is used.
        """
        action = (
            state.action_to_string(state.current_player(), self.action)
            if state and self.action is not None else str(self.action))
        return ("{:>6}: player: {}, prior: {:5.3f}, value: {:6.3f}, sims: {:5d}, "
                "outcome: {}, {:3d} children").format(
                        action,
                        self.player,
                        self.prior,
                        self.explore_count and self.total_value / self.explore_count,
                        self.explore_count,
                        ("{:4.1f}".format(self.outcome[self.player])
                         if self.outcome else "none"), len(self.children))

    def __str__(self):
        return self.to_str(None)
