import json
import os
import subprocess
import tempfile
import time
from typing import List, Literal, Optional, Tuple, TypeAlias

from datasets import load_dataset

from llm_mcts.data_types import Action
from llm_mcts.llm_generation_interface import GenerationResult
from llm_mcts.mcts_algo.eval_result import EvalResult, EvalResultLiveCodeBench
from llm_mcts.tasks.base import Task
from llm_mcts.tasks.live_code_bench_code_generation.problem import CodeGenerationProblem
from llm_mcts.tasks.live_code_bench_code_generation.testing_utils import run_test

LIVE_CODE_BENCH_CODE_GENERATION_VERSION: TypeAlias = Literal[
    "release_v1", "release_v2", "release_v3", "release_v4"
]


class LiveCodeBenchCodeGenerationTask(Task):
    def __init__(self, problem: CodeGenerationProblem) -> None:
        self.problem = problem
        self.public_tests = self.problem.public_test_cases
        self.private_tests = self.problem.private_test_cases

    @classmethod
    def load_record(
        cls,
        idx: int,
        release_version: LIVE_CODE_BENCH_CODE_GENERATION_VERSION = "release_v4",
    ) -> "LiveCodeBenchCodeGenerationTask":
        dataset = load_dataset(
            "livecodebench/code_generation_lite",
            split="test",
            version_tag=release_version,
            trust_remote_code=True,
        )

        if idx < 0 or idx > len(dataset):
            raise ValueError(
                f"Invalid idx {idx}; idx should be in the range 0 <= idx < {len(dataset)}"
            )
        return cls(CodeGenerationProblem(**dataset[idx]))

    def generate_eval_results(
        self, llm_answer: GenerationResult, kind: Action
    ) -> Optional[List[EvalResultLiveCodeBench]]:
        llm_answer_code = llm_answer.parse_python_code()
        if llm_answer_code is None:
            return None

        eval_results: List[EvalResultLiveCodeBench] = []

        (
            public_expected_outputs,
            public_generated_outputs,
            public_results,
            additional_info,
        ) = run_test(self.get_public_evaluation_sample(), llm_answer_code)
        for expected_output, generated_output, score in zip(
            public_expected_outputs, public_generated_outputs, public_results
        ):
            eval_results.append(
                EvalResultLiveCodeBench(
                    answer=generated_output,
                    groundtruth=expected_output,
                    additional_info=additional_info,
                    score=score,
                )
            )

        return eval_results

    def evaluate_on_test(
        self, llm_answer: GenerationResult
    ) -> Tuple[List[EvalResultLiveCodeBench], float]:
        py_code = llm_answer.parse_python_code()
        if py_code is None:
            return [], 0.0

        eval_results: List[EvalResultLiveCodeBench] = []

        # Execute the code on all tests and combine the results
        (
            public_expected_outputs,
            public_generated_outputs,
            public_results,
            additional_info,
        ) = run_test(self.get_public_evaluation_sample(), py_code)
        for expected_output, generated_output, score in zip(
            public_expected_outputs, public_generated_outputs, public_results
        ):
            eval_results.append(
                EvalResultLiveCodeBench(
                    answer=generated_output,
                    groundtruth=expected_output,
                    additional_info=additional_info,
                    score=score,
                )
            )

        (
            private_expected_outputs,
            private_generated_outputs,
            private_results,
            additional_info,
        ) = run_test(self.get_private_evaluation_sample(), py_code)
        for expected_output, generated_output, score in zip(
            private_expected_outputs, private_generated_outputs, private_results
        ):
            eval_results.append(
                EvalResultLiveCodeBench(
                    answer=generated_output,
                    groundtruth=expected_output,
                    additional_info=additional_info,
                    score=score,
                )
            )

        score = sum(eval_result.get_score() for eval_result in eval_results)
        return eval_results, score

    # this function is based on the original code generation task
    # https://github.com/LiveCodeBench/LiveCodeBench/blob/f05cda286956b0a976df08afe2e2a323358d32d1/lcb_runner/benchmarks/code_generation.py#L106-L121
    def get_private_evaluation_sample(self):
        return {
            "input_output": json.dumps(
                {
                    "inputs": [t.input for t in self.private_tests],
                    "outputs": [t.output for t in self.private_tests],
                    "fn_name": self.problem.metadata.get("func_name", None),
                }
            ),
        }

    def get_public_evaluation_sample(self):
        return {
            "input_output": json.dumps(
                {
                    "inputs": [t.input for t in self.public_tests],
                    "outputs": [t.output for t in self.public_tests],
                    "fn_name": self.problem.metadata.get("func_name", None),
                }
            ),
        }
