from abc import ABC
from math import sqrt
from typing import Optional

from pydantic import BaseModel

from llm_mcts.mcts_algo.node import Node


class UCTConfig(BaseModel):
    ucb_c: float = 1.25


class ScoreFunc(ABC):
    def score(self, parent: Node, child: Node) -> float:
        raise NotImplementedError()


class UCTScore(ScoreFunc):
    """
    UCT Score. The idea is to balance the exploitation and exploration.
    For the former we use the value of each node. For the latter, we use the
    upper bound of the confidence interval of the reward mean value times some constant and prior.
    The constant factor was introduced in UCT paper (which stands for UCB for Tree search), and
    the prior factor can be found, e.g. in the alphago pseudocode implementation.
    """

    def __init__(self, config: Optional[UCTConfig] = None):
        self.config = config if config is not None else UCTConfig()

    def score(
        self,
        parent: Node,
        child: Node,
    ) -> float:
        score = child.value()
        score += (
            self.config.ucb_c
            * child.prior
            * sqrt(parent.visit_count)
            / (1 if child is None else 1 + child.visit_count)
        )
        return score
