import math
from algorithms.abstract.node import ZeroNode
from algorithms.utils.types import SpielAction, NodeType, HiddenState
from typing import Optional, Tuple, List


class MuZeroNode(ZeroNode):
    """A node in the search tree.

    A SearchNode represents a state and possible continuations from it. Each child
    represents a possible action, and the expected result from doing so.

    Attributes:
      action: The action from the parent node's perspective. Not important for the
        root node, as the actions that lead to it are in the past.
      player: Which player made this action.
      prior: A prior probability for how likely this action will be selected.
      explore_count: How many times this node was explored.
      total_reward: The sum of rewards of rollouts through this node, from the
        parent node's perspective. The average reward of this node is
        `total_reward / explore_count`
      outcome: The rewards for all players if this is a terminal node or the
        subtree has been proven, otherwise None.
      children: A list of SearchNodes representing the possible actions from this
        node, along with their expected rewards.
    """
    __slots__ = [
        'parent',
        'action',
        'prior',
        'type',
        'explore_count',
        'total_value',
        'hidden_state',
        'value',
        'children',
        'child_actions',
        'child_priors',
        'prediction_output',
        'choice_priors',
        'chance_priors',
        'state_string',
        'bellman_action'
    ]

    def __init__(self, parent: 'Optional[MuZeroNode]', action: SpielAction, prior: float):
        ZeroNode.__init__(self, action, prior)
        self.parent = parent
        self.type = NodeType.UNKNOWN
        self.value = 0  # type: float
        self.hidden_state = None  # type: Optional[HiddenState]
        self.children = []  # type: List[MuZeroNode]
        self.child_actions = []  # type: List[int]
        self.child_priors = []  # type: List[float]
        self.prediction_output = None
        self.state_string = None

    def uct_value(self, parent_explore_count: int, uct_c: float) -> float:
        """Returns the UCT value of child."""
        if self.explore_count == 0:
            return float('inf')

        return self.total_value / self.explore_count + uct_c * math.sqrt(
            math.log(parent_explore_count) / self.explore_count)

    def puct_value(self, parent_explore_count: int, uct_c: float) -> float:
        """Returns the PUCT value of child."""
        return ((self.explore_count and self.total_value / self.explore_count) +
                uct_c * self.prior * math.sqrt(parent_explore_count) /
                (self.explore_count + 1))

    def sort_key(self) -> Tuple[int, float]:
        """Returns the best action from this node, either proven or most visited.

        This ordering leads to choosing:
        - Highest proven score > 0 over anything else, including a promising but
          unproven action.
        - A proven draw only if it has higher exploration than others that are
          uncertain, or the others are losses.
        - Uncertain action with most exploration over loss of any difficulty
        - Hardest loss if everything is a loss
        - Highest expected reward if explore counts are equal (unlikely).
        - Longest win, if multiple are proven (unlikely due to early stopping).
        """
        return self.explore_count, self.total_value

    def best_child(self):
        """Returns the best child in order of the sort key."""
        return max(self.children, key=MuZeroNode.sort_key)

    def children_str(self, state=None):
        """Returns the string rep of this node's children.

        They are ordered based on the sort key, so order of being chosen to play.

        Args:
          state: A `pyspiel.State` object, to be used to convert the action id into
            a human readable format. If None, the action integer id is used.
        """
        return "\n".join([
            c.to_str(state)
            for c in reversed(sorted(self.children, key=MuZeroNode.sort_key))
        ])

    def to_str(self, state=None):
        """Returns the string rep of this node.

        Args:
          state: A `pyspiel.State` object, to be used to convert the action id into
            a human readable format. If None, the action integer id is used.
        """
        action = (
            state.action_to_string(state.current_player(), self.action)
            if state and self.action is not None else str(self.action))
        t = 'Type: {}, Action: {:>6}, prior: {:5.3f}, value: {:6.3f}, mcts_value: {:6.3f}, sims: {:5d}, children: {' \
            ':3d}, pUCT: {:6.3f}, state: {} '

        return t.format(self.type.name,
                        action,
                        self.prior,
                        self.value,
                        self.explore_count and self.total_value / self.explore_count,
                        self.explore_count,
                        len(self.children),
                        self.puct_value(self.parent.explore_count, 1.0) if self.parent else 0,
                        self.state_string)

    def __str__(self):
        return self.to_str(None)
