from typing import Dict, Optional

from pydantic import BaseModel, Field

from llm_mcts.data_types import Action
from llm_mcts.mcts_algo.node import Node
from llm_mcts.mcts_scorer.base import MCTSScorer


class DefaultScorerConfig(BaseModel):
    score_factor: Dict[Action, float] = Field(
        default_factory=lambda: {
            "transform": 1.0,
            "question": 0.3,
            "multi_questions": 0.05,
        }
    )


class DefaultScorer(MCTSScorer):
    def __init__(self, config: Optional[DefaultScorerConfig] = None):
        if config is None:
            config = DefaultScorerConfig()

        self.scorer_config = config

    def get_score(self, node: Node) -> float:
        if node.eval_results is None:
            return 0.0

        score = sum(map(lambda x: x.get_score(), node.eval_results))
        score /= len(node.eval_results)
        return self.scorer_config.score_factor.get(node.last_action, 1.0) * score
