from typing import List, Optional

from llm_mcts.data_types import Action
from llm_mcts.llm_generation_interface import GenerationRequest, GenerationResult
from llm_mcts.mcts_algo.eval_result import EvalResult


class Node:
    def __init__(
        self,
        serial_number: int,
        next_prompt: GenerationRequest,
        llm_name: Optional[str] = None,
        completion: Optional[GenerationResult] = None,
        parent: Optional["Node"] = None,
        last_action: Optional[Action] = None,
        eval_results: Optional[List[EvalResult]] = None,
    ):
        self.parent = parent
        self.serial_number = serial_number
        self.next_prompt = next_prompt
        self.last_action = last_action
        self.eval_results = eval_results
        self.llm_name = llm_name
        self.completion = completion
        self.visit_count = 0
        self.value_sum = 0.0
        self.children: List["Node"] = list()

        self._prior: Optional[float] = None

    def is_expanded(self) -> bool:
        return len(self.children) > 0

    def value(self) -> float:
        """Exploitation term of UCT score. Should be in the range from 0 to 1, according to UCT normalization."""
        if self.visit_count == 0:
            return 0.0
        return self.value_sum / self.visit_count

    def backpropagate(self, value: float) -> None:
        self.value_sum += value
        self.visit_count += 1

        if self.parent is not None:
            self.parent.backpropagate(value)

    def is_root(self) -> bool:
        return self.parent is None

    @property
    def prior(self) -> float:
        assert self._prior is not None
        return self._prior

    @prior.setter
    def prior(self, prior: float) -> None:
        self._prior = prior
