from concurrent.futures import ThreadPoolExecutor
import os
import subprocess
import tempfile
import time
from typing import List, Literal, Optional, Tuple, TypeAlias

from datasets import load_dataset

from llm_mcts.data_types import Action
from llm_mcts.llm_generation_interface import GenerationResult
from llm_mcts.mcts_algo.eval_result import EvalResult, EvalResultWithAns
from llm_mcts.tasks.base import Task
from llm_mcts.tasks.code_contest.problem import CodeContestProblem


CODE_CONTEST_DATASET_NAME: TypeAlias = Literal["deepmind/code_contest",]
CODE_CONTEST_SPLIT: TypeAlias = Literal["train", "validation", "test"]


class CodeContestTask(Task):
    def __init__(self, problem: CodeContestProblem, num_workers: int = 1) -> None:
        assert num_workers > 0, "num_workers should be greater than 0"
        self.problem = problem
        self.public_tests = self.problem.public_tests
        self.private_tests = self.problem.private_tests
        self.generated_tests = self.problem.generated_tests
        self.num_workers = num_workers

    @classmethod
    def load_record(
        cls,
        idx: int,
        dataset_name: CODE_CONTEST_DATASET_NAME = "deepmind/code_contests",
        split: CODE_CONTEST_SPLIT = "valid",
        num_workers: int = 1,
    ) -> "CodeContestTask":
        dataset = load_dataset(dataset_name, split=split)

        if idx < 0 or idx > len(dataset):
            raise ValueError(
                f"Invalid idx {idx}; idx should be in the range 0 <= idx < {len(dataset)}"
            )
        return cls(CodeContestProblem(**dataset[idx]), num_workers=num_workers)

    def generate_eval_results(
        self, llm_answer: GenerationResult, kind: Action
    ) -> Optional[List[EvalResult]]:
        llm_answer_code = llm_answer.parse_python_code()
        if llm_answer_code is None:
            return None

        public_results = self.execute_python_code_on_tests(llm_answer_code, "public")

        return public_results

    def evaluate_on_test(
        self, llm_answer: GenerationResult
    ) -> Tuple[List[EvalResult], float]:
        py_code = llm_answer.parse_python_code()
        if py_code is None:
            return [], 0.0

        # Execute the code on all tests and combine the results
        public_results = self.execute_python_code_on_tests(py_code, "public")
        private_results = self.execute_python_code_on_tests(py_code, "private")
        generated_results = self.execute_python_code_on_tests(py_code, "generated")
        eval_results: List[EvalResult] = (
            public_results + private_results + generated_results
        )
        score = sum(eval_result.get_score() for eval_result in eval_results)
        return eval_results, score

    def execute_python_code_on_tests(
        self, code: str, test_type: Literal["public", "private", "generated"]
    ) -> List[EvalResult]:
        match test_type:
            case "public":
                input_list = self.public_tests["input"]
                expected_output_list = self.public_tests["output"]
            case "private":
                input_list = self.private_tests["input"]
                expected_output_list = self.private_tests["output"]
            case "generated":
                input_list = self.generated_tests["input"]
                expected_output_list = self.generated_tests["output"]

        timeout = (
            self.problem.time_limit["seconds"] + self.problem.time_limit["nanos"] / 1e9
        )
        tasks = []
        with ThreadPoolExecutor(max_workers=self.num_workers) as executor:
            for idx, (input_str, expected_output_str) in enumerate(
                zip(input_list, expected_output_list)
            ):
                future = executor.submit(
                    self.execute_python_code, code, input_str, timeout
                )
                tasks.append((idx, expected_output_str, future))
            eval_results: List[EvalResultWithAns] = [None] * len(tasks)
            for idx, expected_output, future in tasks:
                output, error, _exec_time = future.result()
                eval_results[idx] = EvalResultWithAns(
                    answer=output.strip() if output != "" else error.strip(),
                    groundtruth=expected_output.strip(),
                )

        return eval_results

    def execute_python_code(
        self,
        code: str,
        input_values: Optional[str] = None,
        timeout: float = 2.0,
    ) -> Tuple[str, str, float]:
        """
        Executes Python code as if it were a standalone Python file and measures the execution time.
        Optionally enforces a memory limit.

        Args:
            code (str): The Python code to execute, provided as a string.
            input_values (Optional[str]): The input values to simulate standard input, if required.
            timeout (float): The maximum time allowed for execution in seconds. Defaults to 2.0 seconds.

        Returns:
            Tuple[str, str, float]: A tuple containing the standard output, the standard error,
                                    and the time taken to execute the code in seconds.
        """
        # Create a temporary Python file to store the code
        with tempfile.NamedTemporaryFile(
            mode="w", suffix=".py", delete=False
        ) as temp_file:
            temp_file.write(code)
            temp_filename = temp_file.name

        try:
            start_time = time.perf_counter()

            try:
                # Run the Python file using a subprocess and capture the output
                result = subprocess.run(
                    ["python", temp_filename],
                    input=input_values,
                    text=True,
                    capture_output=True,
                    timeout=timeout,
                )

                end_time = time.perf_counter()
                exec_time = end_time - start_time

                return result.stdout, result.stderr, exec_time
            except subprocess.TimeoutExpired:
                return (
                    "",
                    f"Error: Execution timed out after {timeout} seconds.\n",
                    timeout,
                )
            except subprocess.CalledProcessError as e:
                return "", f"Error: {e}\n", -1.0
        except Exception as e:
            return "", str(e), -1.0
        finally:
            # Ensure the temporary file is removed after execution
            os.remove(temp_filename)
