from typing import List, Literal, Optional, Tuple, TypeAlias

from datasets import load_dataset

from llm_mcts.data_types import Action
from llm_mcts.llm_generation_interface import GenerationResult, Model
from llm_mcts.mcts_algo.eval_result import EvalResult, EvalResultWithScore
from llm_mcts.tasks.base import Task
from llm_mcts.tasks.swe_bench.llm_judge import SWEBenchLLMJudge
from llm_mcts.tasks.swe_bench.problem import SWEBenchProblem


SWE_BENCH_DATASET_NAME: TypeAlias = Literal[
    "princeton-nlp/SWE-bench_Lite",
    "princeton-nlp/SWE-bench_Lite_bm25_13K",
    "princeton-nlp/SWE-bench_Lite_bm25_27K",
    "princeton-nlp/SWE-bench_Lite_bm25_40K",
    "princeton-nlp/SWE-bench_Lite_bm25_50K_llama",
    "princeton-nlp/SWE-bench_Lite_oracle",
]
SWE_BENCH_SPLIT: TypeAlias = Literal["dev", "test"]


class SWEBenchTask(Task):
    def __init__(self, problem: SWEBenchProblem, judge_model: Model) -> None:
        self.problem = problem
        self.llm_judge = SWEBenchLLMJudge(judge_model)

    @classmethod
    def load_record(
        cls,
        idx: int,
        judge_model: Optional[Model] = None,
        dataset_name: SWE_BENCH_DATASET_NAME = "princeton-nlp/SWE-bench_Lite_bm25_13K",
        split: SWE_BENCH_SPLIT = "dev",
    ) -> "SWEBenchTask":
        dataset = load_dataset(dataset_name, split=split)

        if idx < 0 or idx > len(dataset):
            raise ValueError(
                f"Invalid idx {idx}; idx should be in the range 0 <= idx < {len(dataset)}"
            )
        return cls(SWEBenchProblem(**dataset[idx]), judge_model=judge_model)

    def generate_eval_results(
        self, llm_answer: GenerationResult, kind: Action
    ) -> Optional[List[EvalResult]]:
        if kind != "answer":
            raise NotImplementedError()

        reason, score = self.llm_judge.estimate_score(
            problem=self.problem, answer=llm_answer.generation
        )
        return [EvalResultWithScore(score=score, reason=reason)]

    def evaluate_on_test(
        self, llm_answer: GenerationResult
    ) -> Tuple[List[EvalResult], float]:
        reason, score = self.llm_judge.check_test_score(
            problem=self.problem, answer=llm_answer.generation
        )

        return [EvalResultWithScore(score=score, reason=reason)], score
