from typing import Optional

from llm_mcts.mcts_algo.algo_base import MCTSAlgo
from llm_mcts.mcts_algo.hierarchical_thompson.algo import HierarchicalThompsonAlgo
from llm_mcts.mcts_algo.hierarchical_thompson.thompson_state import (
    GaussianPrior,
    PriorConfig,
)
from llm_mcts.mcts_algo.mcts_config import MCTSConfig
from llm_mcts.mcts_algo.pymc_mixed.algo import PyMCMixedAlgo, PyMCMixedAlgoConfig
from llm_mcts.mcts_algo.standard.algo import StandardAlgo


def build_reward_average_priors(
    answer_models_str: Optional[str], prior_str: Optional[str]
) -> Optional[float | dict]:
    if prior_str is None or answer_models_str is None:
        return None
    answer_models = answer_models_str.split(",")
    priors = list(map(float, prior_str.split(",")))
    if len(priors) == 1:
        return priors[0]
    assert len(answer_models) == len(
        priors
    ), f"The number of answer models {len(answer_models)} does not match the number of priors {len(priors)}"

    return {model: prior for model, prior in zip(answer_models, priors)}


def build_algo(algo_name: str, config: MCTSConfig, **kwargs) -> MCTSAlgo:
    if algo_name == "standard":
        return StandardAlgo(config)
    elif algo_name.lower() == "thompson":
        priors = kwargs.get("priors")
        answer_models = kwargs.get("answer_models")
        reward_average_priors = build_reward_average_priors(answer_models, priors)
        return HierarchicalThompsonAlgo(
            config, reward_average_priors=reward_average_priors
        )
    elif algo_name.lower() == "thompson-gaussian-tau-square-01":
        priors = kwargs.get("priors")
        answer_models = kwargs.get("answer_models")
        strategy = kwargs.get("multimodel_strategy")
        reward_average_priors = build_reward_average_priors(answer_models, priors)
        return HierarchicalThompsonAlgo(
            config,
            prior_config=PriorConfig(
                dist_type="gaussian", prior=GaussianPrior(tau_square=0.1)
            ),
            reward_average_priors=reward_average_priors,
            model_selection_strategy=strategy,
        )
    elif algo_name.lower() == "thompson-beta":
        priors = kwargs.get("priors")
        answer_models = kwargs.get("answer_models")
        strategy = kwargs.get("multimodel_strategy")
        reward_average_priors = build_reward_average_priors(answer_models, priors)
        return HierarchicalThompsonAlgo(
            config,
            prior_config=PriorConfig(dist_type="beta"),
            reward_average_priors=reward_average_priors,
            model_selection_strategy=strategy,
        )
    elif algo_name.lower() == "pymc-thompson":
        priors = kwargs.get("priors")
        answer_models = kwargs.get("answer_models")
        strategy = kwargs.get("multimodel_strategy")
        if priors is None:
            pymc_config = PyMCMixedAlgoConfig(algo="thompson", enable_pruning=True)
        else:
            reward_average_priors = build_reward_average_priors(answer_models, priors)
            pymc_config = PyMCMixedAlgoConfig(
                enable_pruning=True,
                algo="thompson",
                reward_average_priors=reward_average_priors,
                model_selection_strategy=strategy,
            )
        return PyMCMixedAlgo(config, pymc_config=pymc_config)
    raise NotImplementedError(f"algo_name {algo_name} not supported.")
