from typing import List, Literal, Optional, Tuple, TypeAlias

from datasets import load_dataset

from llm_mcts.data_types import Action
from llm_mcts.llm_generation_interface import GenerationResult, Model
from llm_mcts.mcts_algo.eval_result import EvalResult, EvalResultWithScore
from llm_mcts.tasks.base import Task
from llm_mcts.tasks.omni_math.llm_judge import LLMJudge
from llm_mcts.tasks.omni_math.problem import OmniMathProblem


OMNI_MATH_DATASET_NAME: TypeAlias = Literal[
    "KbsdJames/Omni-MATH"
]


class OmniMathTask(Task):
    def __init__(
        self,
        problem: OmniMathProblem,
        judge_model: Model | str,
        reward_model_name: Optional[str] = None,
        only_reward_model: bool = False,
        is_sigmoid: bool = True,
    ) -> None:
        self.problem = problem
        self.llm_judge = LLMJudge(
            judge_model, reward_model_name, only_reward_model, is_sigmoid
        )

    @classmethod
    def load_record(
        cls,
        idx: int,
        judge_model: Model | str,
        dataset_name: OMNI_MATH_DATASET_NAME = "KbsdJames/Omni-MATH",
        reward_model_name: Optional[str] = None,
        only_reward_model: bool = False,
        is_sigmoid: bool = True,
    ) -> "OmniMathTask":
        dataset = load_dataset(dataset_name, split="test")

        if idx < 0 or idx > len(dataset):
            raise ValueError(
                f"Invalid idx {idx}; idx should be in the range 0 <= idx < {len(dataset)}"
            )
        return cls(
            OmniMathProblem(**dataset[idx]),
            judge_model=judge_model,
            reward_model_name=reward_model_name,
            only_reward_model=only_reward_model,
            is_sigmoid=is_sigmoid,
        )

    def generate_eval_results(
        self, llm_answer: GenerationResult, kind: Action
    ) -> Optional[List[EvalResult]]:
        if kind != "answer":
            raise NotImplementedError()

        reason, score = self.llm_judge.estimate_score(
            problem=self.problem, answer=llm_answer.generation
        )
        return [EvalResultWithScore(score=score, reason=reason)]

    def evaluate_on_test(
        self, llm_answer: GenerationResult
    ) -> Tuple[List[EvalResult], float]:
        reason, score = self.llm_judge.check_test_score(
            problem=self.problem, answer=llm_answer.generation
        )

        return [EvalResultWithScore(score=score, reason=reason)], score
