from typing import List, Optional

from tqdm import tqdm

from llm_mcts.llm_generation_interface import GenerationRequest, Message, Model
from llm_mcts.mcts_algo.algo_base import MCTSAlgo
from llm_mcts.mcts_algo.algo_builder import build_algo
from llm_mcts.mcts_algo.mcts_config import MCTSConfig
from llm_mcts.mcts_algo.mcts_result import MCTSResult
from llm_mcts.mcts_algo.node import Node
from llm_mcts.mcts_algo.score_funcs import UCTScore
from llm_mcts.mcts_algo.solver.llm_solver import LLMSolver
from llm_mcts.mcts_scorer.base import MCTSScorer
from llm_mcts.prompt_configs import PromptConfig
from llm_mcts.prompts.base import PromptTemplate
from llm_mcts.prompts.prompt_builder import build_prompt_template
from llm_mcts.tasks.base import Task


def run_mcts(
    task: Task,
    model: Model,
    scorer: MCTSScorer,
    mcts_config: Optional[MCTSConfig] = None,
    prompt_config: Optional[PromptConfig] = None,
    fewshot_prompts: Optional[List[Message]] = None,
    mcts_algo: Optional[MCTSAlgo] = None,
    prompt_template: Optional[PromptTemplate] = None,
) -> MCTSResult:
    if mcts_config is None:
        mcts_config = MCTSConfig()

    if prompt_config is None:
        prompt_config = PromptConfig(is_o1=False)

    if mcts_algo is None:
        mcts_algo = build_algo("standard", config=mcts_config, score_func=UCTScore())

    if prompt_template is None:
        print("Building prompt template from prompt_config...")
        prompt_template = build_prompt_template(prompt_config=prompt_config, task=task)
    else:
        print(
            "prompt_template is given as an argument to run_mcts, skipping building prompt from prompt_config..."
        )

    # ===============================
    # Build prompts
    # ===============================
    messages = []

    initial_prompt = prompt_template.initial_prompt()
    messages.append(Message(role="user", content=initial_prompt))
    if fewshot_prompts is not None:
        messages = fewshot_prompts + messages

    request = GenerationRequest(messages=messages)
    root = Node(serial_number=0, next_prompt=request)
    solver = LLMSolver(model, prompt_template, task=task)

    num_simulations = (
        1 if mcts_config.num_simulations == 0 else mcts_config.num_simulations
    )
    for _ in tqdm(range(num_simulations)):
        node = root
        mcts_algo.run_mcts_step(node, solver, scorer=scorer)

    return MCTSResult(root)
