from typing import Callable, Dict, List, Tuple

from llm_mcts.data_types import Action
from llm_mcts.llm_generation_interface import (
    GenerationRequest,
    GenerationResult,
    Message,
)
from llm_mcts.mcts_algo.node import Node
from llm_mcts.mcts_algo.solver.base import AggregatedSolver
from llm_mcts.mcts_scorer.base import MCTSScorer
from llm_mcts.tasks.base import Task


class RandomScoreSolver(AggregatedSolver):
    """
    A toy solver which randomly picks scores at each step.
    """

    def __init__(
        self, random_number_generators: Dict[str, Callable[[], float]], task: Task
    ) -> None:
        self.solvers = list(random_number_generators.keys())
        self.rngs = random_number_generators
        self.task = task

        self.current_solver = self.solvers[0]

    def generate_child_nodes(
        self,
        node: Node,
        kind: Action,
        num_samples: int,
        scorer: MCTSScorer,
        next_serial_number: int,
    ) -> Tuple[List[Node], int]:
        children: List[Node] = []
        for i in range(num_samples):
            gen_result = self.generate_gen_result(next_prompt=node.next_prompt)
            eval_results = self.task.generate_eval_results(gen_result, kind=kind)

            child = Node(
                serial_number=next_serial_number,
                next_prompt=self.build_next_prompt(gen_result),
                llm_name=self.current_solver,
                parent=node,
                last_action=kind,
                eval_results=eval_results,
            )
            node.children.append(child)
            children.append(child)
            next_serial_number += 1

        return children, next_serial_number

    def generate_gen_result(self, next_prompt: GenerationRequest) -> GenerationResult:
        """
        Generate only the score using the current rng
        """
        generation = f"```score\n{self.rngs[self.current_solver]()}\n```"

        return GenerationResult(request=next_prompt, generation=generation)

    def build_next_prompt(self, result: GenerationResult) -> GenerationRequest:
        return GenerationRequest(
            messages=result.request.messages
            + Message(role="assistant", content=result.generation)
        )

    def get_solvers(self) -> List[str]:
        return self.solvers

    def set_solver(self, solver_name: str) -> None:
        self.current_solver = solver_name
