from typing import List, Literal, Optional, Tuple, TypeAlias

from datasets import load_dataset

from llm_mcts.data_types import Action
from llm_mcts.llm_generation_interface import GenerationResult, Model
from llm_mcts.mcts_algo.eval_result import EvalResult, EvalResultWithScore
from llm_mcts.tasks.base import Task
from llm_mcts.tasks.math_vista.llm_judge import MathVistaLLMJudge
from llm_mcts.tasks.math_vista.problem import MathVistaProblem


MATH_VISTA_DATASET_NAME: TypeAlias = Literal["AI4Math/MathVista"]
MATH_VISTA_SPLIT: TypeAlias = Literal["testmini", "test"]


class MathVistaTask(Task):
    def __init__(self, problem: MathVistaProblem, judge_model: Model) -> None:
        self.problem = problem
        self.llm_judge = MathVistaLLMJudge(judge_model)

    @classmethod
    def load_record(
        cls,
        idx: int,
        judge_model: Model,
        dataset_name: MATH_VISTA_DATASET_NAME = "AI4Math/MathVista",
        split: MATH_VISTA_SPLIT = "testmini",
    ) -> "MathVistaTask":
        dataset = load_dataset(dataset_name, split=split)

        if idx < 0 or idx > len(dataset):
            raise ValueError(
                f"Invalid idx {idx}; idx should be in the range 0 <= idx < {len(dataset)}"
            )
        return cls(MathVistaProblem(**dataset[idx]), judge_model=judge_model)

    def generate_eval_results(
        self, llm_answer: GenerationResult, kind: Action
    ) -> Optional[List[EvalResult]]:
        if kind != "answer":
            raise NotImplementedError()

        reason, score = self.llm_judge.estimate_score(
            problem=self.problem, answer=llm_answer.generation
        )
        return [EvalResultWithScore(score=score, reason=reason)]

    def evaluate_on_test(
        self, llm_answer: GenerationResult
    ) -> Tuple[List[EvalResult], float]:
        reason, score = self.llm_judge.check_test_score(
            problem=self.problem, answer=llm_answer.generation
        )

        return [EvalResultWithScore(score=score, reason=reason)], score
