from dataclasses import dataclass, field
from typing import Dict

from llm_mcts.data_types import Action, ScorerType
from llm_mcts.mcts_scorer.base import MCTSScorer
from llm_mcts.mcts_scorer.default import DefaultScorer, DefaultScorerConfig
from llm_mcts.mcts_scorer.arc.verifier import VerifierScorer, VerifierScorerConfig
from llm_mcts.tasks.base import Task


@dataclass
class ScorerConfig:
    score_factor: Dict[Action, float] = field(
        default_factory=lambda: {
            "transform": 1.0,
            "question": 0.3,
            "multi_questions": 0.05,
        }
    )
    scorer_type: ScorerType = "default"


def build_scorer(config: ScorerConfig, task: Task) -> MCTSScorer:
    if config.scorer_type == "default":
        return DefaultScorer(
            config=DefaultScorerConfig(score_factor=config.score_factor)
        )
    elif config.scorer_type == "verifier":
        return VerifierScorer(
            config=VerifierScorerConfig(score_factor=config.score_factor), task=task
        )
    else:
        raise NotImplementedError()
