import math
import random
import numpy as np


class SearchDomain:
    def get_initial_state(self):
        raise NotImplementedError

    def get_valid_actions(self, state):
        raise NotImplementedError

    def apply_action(self, state, action):
        raise NotImplementedError

    def is_terminal(self, state):
        raise NotImplementedError

    def evaluate(self, state):
        raise NotImplementedError

    def rollout_action(self, state):
        raise NotImplementedError

    def log(self, node, final_state, reward):
        pass


class MCTSNode:
    def __init__(self, state, parent=None):
        self.state = state
        self.parent = parent
        self.children = []
        self.N = 0
        self.Q = 0  # -float('inf')
        self.untried_actions = []
        self.reward = 0

    def __repr__(self):
        return f'{self.state}'


class MCTS:
    def __init__(self, domain: SearchDomain, exploration=1 / np.sqrt(2)):
        self.domain = domain
        self.root = MCTSNode(self.domain.get_initial_state())
        self.root.untried_actions = self.domain.get_valid_actions(self.root.state)
        self.nodes = {self.root.state: self.root}

        self.current_iter = 0
        self.num_eval = 0
        self.exploration = exploration

        self.scale = 1

        self.best_state = None
        self.best_reward = -float('inf')

    def _update_scale(self):
        # return
        for state, node in self.nodes.items():
            node.Q *= (self.scale / self.best_reward)
        self.scale = self.best_reward

    def search(self):

        for it in range(self.iterations):
            self.current_iter = it
            node = self._select(self.root)
            final_state, reward = self._simulate(node)
            self._backpropagate(node, reward)
            self.domain.log(node, final_state, reward)

        return self.best_state

    def step(self):
        self.current_iter += 1
        node = self._select(self.root)
        final_state, reward = self._simulate(node)
        self._backpropagate(node, reward)
        if reward > self.best_reward:
            self.best_reward = reward
            self.best_state = final_state
            self._update_scale()
        self.domain.log(node, final_state, reward)
        return node, final_state, reward

    def _select(self, node):
        while not self.domain.is_terminal(node.state):
            if node.untried_actions:
                return self._expand(node)
            else:
                if random.random() < 0.2:
                    node = random.choice([child for child in node.children])
                else:
                    node = max(node.children, key=self._ucb_score)
        return node

    def _expand(self, node):
        action = node.untried_actions.pop()
        new_state = self.domain.apply_action(node.state, action)
        child = MCTSNode(new_state, parent=node)
        child.untried_actions = self.domain.get_valid_actions(new_state)
        node.children.append(child)

        self.nodes[child.state] = child
        return child

    def _simulate(self, node):
        if self.domain.is_terminal(node.state):
            return node.state, self._evaluate(node.state)

        best_reward = -float('inf')
        best_state = None
        for _ in range(1):
            final_state = node.state
            while not self.domain.is_terminal(final_state):
                action = self.domain.rollout_action(final_state)
                final_state = self.domain.apply_action(final_state, action)

            reward = self._evaluate(final_state)
            if reward > best_reward:
                best_state = final_state
                best_reward = reward

        return best_state, best_reward

    def _backpropagate(self, node, reward):
        while node:
            node.N += 1
            node.Q = max(node.Q, reward / self.scale)
            node.reward = max(node.reward, reward)
            # node.Q += reward
            node = node.parent

    def _ucb_score(self, node):
        if node.N == 0:
            return float('inf')
        # return (node.Q / node.N) + self.exploration * math.sqrt(
        #     math.log(node.parent.N) / node.N
        # )
        return (node.Q / node.N) + self.exploration * math.sqrt(
            math.log(node.parent.N) / node.N
        )

    def _evaluate(self, state):  # state should be a final state
        node = self.nodes.get(state)
        if node and node.N > 0:
            return node.reward
        else:
            self.num_eval += 1
            return self.domain.evaluate(state)
