from typing import Optional

from llm_mcts.mcts_algo.algo_base import MCTSAlgo
from llm_mcts.mcts_algo.best_first_search.algo import BestFirstSearchAlgo
from llm_mcts.mcts_algo.hierarchical_thompson.algo import HierarchicalThompsonAlgo
from llm_mcts.mcts_algo.hierarchical_thompson.thompson_state import (
    GaussianPrior,
    PriorConfig,
)
from llm_mcts.mcts_algo.linucb.algo import LinUCBWithGenNodeAlgo
from llm_mcts.mcts_algo.mcts_config import MCTSConfig
from llm_mcts.mcts_algo.multi_armed_bandit_ucb.algo import MultiArmedBanditUCBAlgo
from llm_mcts.mcts_algo.pymc_mixed.algo import PyMCMixedAlgo, PyMCMixedAlgoConfig
from llm_mcts.mcts_algo.score_funcs import ScoreFunc, UCTScore
from llm_mcts.mcts_algo.standard.algo import StandardAlgo
from llm_mcts.mcts_algo.tree_of_thoughts_bfs.algo import TreeOfThoughtsBFSAlgo


def build_algo(
    algo_name: str, config: MCTSConfig, score_func: Optional[ScoreFunc] = None
) -> MCTSAlgo:
    if algo_name == "standard":
        if score_func is None:
            score_func = UCTScore()
            print("score_func not specified, will use UCT score...")
        return StandardAlgo(config, score_func=score_func)
    elif algo_name.lower() == "linucb":
        return LinUCBWithGenNodeAlgo(config)
    elif algo_name.lower() == "thompson":
        return HierarchicalThompsonAlgo(config)
    elif algo_name.lower() == "thompson-gaussian-tau-square-01":
        return HierarchicalThompsonAlgo(
            config,
            prior_config=PriorConfig(
                dist_type="gaussian", prior=GaussianPrior(tau_square=0.1)
            ),
        )
    elif algo_name.lower() == "thompson-beta":
        return HierarchicalThompsonAlgo(
            config, prior_config=PriorConfig(dist_type="beta")
        )
    elif algo_name.lower() == "best-first-search":
        return BestFirstSearchAlgo(config)
    elif algo_name.lower() == "tot-bfs":
        return TreeOfThoughtsBFSAlgo(config)
    elif algo_name.lower() == "mab-ucb":
        return MultiArmedBanditUCBAlgo(config)
    elif algo_name.lower() == "pymc-thompson":
        return PyMCMixedAlgo(
            config,
            pymc_config=PyMCMixedAlgoConfig(algo="thompson", enable_pruning=True),
        )
    raise NotImplementedError(f"algo_name {algo_name} not supported.")
