from __future__ import annotations

import json
import math
import re
import sys
from pathlib import Path
from typing import Any, Callable, Sequence

from joblib import Parallel, delayed

from .base import ProbResult, Task

_CODE_BLOCK_RE = re.compile(r"```(?:python)?\s*([\s\S]*?)\s*```", re.IGNORECASE)
_CODE_EVAL_ROOT = Path(__file__).resolve().parent / "code_eval"


def _extract_code(text: str) -> str:
    blocks = _CODE_BLOCK_RE.findall(text)
    if blocks:
        return max((b.strip() for b in blocks), key=len, default="")
    return text.strip()


def _extract_expression(text: str) -> str:
    code = _extract_code(text)
    for line in code.splitlines():
        line = line.strip()
        if line:
            return line
    return code.strip()


def _extract_lcb_execution_expression(text: str) -> str:
    code = _extract_code(text)
    if "[ANSWER]" in code:
        code = code.split("[ANSWER]", 1)[1].strip()
    if "==" in code:
        code = code.split("==", 1)[1].strip()
    if "[/ANSWER]" in code:
        code = code.split("[/ANSWER]", 1)[0].strip()
    else:
        code = code.splitlines()[0].strip()
    return code.strip()


def _normalize_lcb_generation_result(result: Any) -> bool:
    try:
        import numpy as np
    except Exception:
        np = None
    if np is not None and isinstance(result, np.ndarray):
        try:
            result = result.item(0)
        except Exception:
            return False
    if np is not None and isinstance(result, np.bool_):
        return bool(result)
    if isinstance(result, bool):
        return result
    try:
        return bool(result == True)
    except Exception:
        return False


def _run_evalplus_untrusted_chunk(
    untrusted_check: Callable[..., tuple[str, Any]],
    dataset: str,
    solution: str,
    entry_point: str,
    inputs: list[Any],
    expected: list[Any],
    ref_time: list[float],
    atol: float,
    fast_check: bool,
    min_time_limit: float | None,
    gt_time_limit_factor: float | None,
) -> tuple[str, list[bool]]:
    stat, details = untrusted_check(
        dataset,
        solution,
        inputs,
        entry_point,
        expected=expected,
        atol=atol,
        ref_time=ref_time,
        fast_check=fast_check,
        min_time_limit=min_time_limit,
        gt_time_limit_factor=gt_time_limit_factor,
    )
    normalized = EvalPlusTask._normalize_details(details, len(inputs))
    return stat, normalized


def _run_lcb_generation_chunk(
    check_correctness: Callable[..., tuple[list, dict]],
    chunk_sample: dict[str, str],
    answer: str,
    timeout: int,
    debug: bool,
    expected_len: int,
) -> tuple[list, dict]:
    chunk_results, chunk_metadata = check_correctness(
        chunk_sample, answer, timeout=timeout, debug=debug
    )
    chunk_results = list(chunk_results)
    if len(chunk_results) < expected_len:
        chunk_results.extend([False] * (expected_len - len(chunk_results)))
    return chunk_results, chunk_metadata


def _ensure_on_path(root: Path) -> None:
    root_str = str(root)
    if root_str not in sys.path:
        sys.path.insert(0, root_str)


def _maybe_force_evalplus_fork(
    untrusted_check: Callable[..., tuple[str, Any]]
) -> None:
    try:
        import multiprocessing as mp
    except Exception:
        return
    try:
        default_method = mp.get_context().get_start_method()
    except Exception:
        default_method = None
    if default_method == "fork":
        return
    if "fork" not in mp.get_all_start_methods():
        return

    module = sys.modules.get(getattr(untrusted_check, "__module__", ""))
    if module is None or getattr(module, "_bd_mcts_force_fork", False):
        return
    try:
        ctx = mp.get_context("fork")
    except Exception:
        return

    # evalplus stores multiprocessing objects at module scope.
    module.multiprocessing = ctx
    if hasattr(module, "Array"):
        module.Array = ctx.Array
    if hasattr(module, "Value"):
        module.Value = ctx.Value
    module._bd_mcts_force_fork = True


class _PromptTokenizerStub:
    chat_template = None


def _looks_like_full_solution(answer: str, entry_point: str) -> bool:
    needle = f"def {entry_point}"
    if needle in answer:
        return True
    return f"class {entry_point}" in answer


class EvalPlusTask(Task):
    def __init__(
        self,
        dataset: str = "humaneval",
        *,
        base_only: bool = False,
        fast_check: bool = True,
        min_time_limit: float | None = None,
        gt_time_limit_factor: float | None = None,
        mini: bool = False,
        noextreme: bool = False,
        version: str = "default",
        sort_task_ids: bool = False,
        sanitize: bool = False,
        public_test_workers: int = 1,
        private_test_workers: int = 1,
    ) -> None:
        if public_test_workers < 1 or private_test_workers < 1:
            raise ValueError("test workers must be >= 1")
        try:
            from evalplus.config import (
                DEFAULT_GT_TIME_LIMIT_FACTOR,
                DEFAULT_MIN_TIME_LIMIT,
            )
            from evalplus.data import get_human_eval_plus, get_mbpp_plus
            from evalplus.eval import FAIL, PASS, TIMEOUT, untrusted_check
            from evalplus.eval._special_oracle import MBPP_OUTPUT_NOT_NONE_TASKS
            from evalplus.gen.util import trusted_exec
        except Exception:
            for name in list(sys.modules):
                if name == "evalplus" or name.startswith("evalplus."):
                    sys.modules.pop(name, None)
            _ensure_on_path(_CODE_EVAL_ROOT / "evalplus")
            try:
                from evalplus.config import (
                    DEFAULT_GT_TIME_LIMIT_FACTOR,
                    DEFAULT_MIN_TIME_LIMIT,
                )
                from evalplus.data import get_human_eval_plus, get_mbpp_plus
                from evalplus.eval import FAIL, PASS, TIMEOUT, untrusted_check
                from evalplus.eval._special_oracle import MBPP_OUTPUT_NOT_NONE_TASKS
                from evalplus.gen.util import trusted_exec
            except Exception as exc:  # pragma: no cover - import guard
                raise RuntimeError(
                    "evalplus is not available. Install it (e.g. `pip install "
                    "\"evalplus\" --upgrade`) or ensure "
                    "`src/bd_mcts/tasks/code_eval/evalplus` is on PYTHONPATH."
                ) from exc

        self._dataset = dataset
        self._base_only = base_only
        self._fast_check = fast_check
        self._min_time_limit = (
            DEFAULT_MIN_TIME_LIMIT if min_time_limit is None else min_time_limit
        )
        self._gt_time_limit_factor = (
            DEFAULT_GT_TIME_LIMIT_FACTOR
            if gt_time_limit_factor is None
            else gt_time_limit_factor
        )
        self._pass_token = PASS
        self._fail_token = FAIL
        self._timeout_token = TIMEOUT
        self._untrusted_check = untrusted_check
        _maybe_force_evalplus_fork(self._untrusted_check)
        self._trusted_exec = trusted_exec
        self._mbpp_output_not_none = MBPP_OUTPUT_NOT_NONE_TASKS
        self._public_test_workers = public_test_workers
        self._private_test_workers = private_test_workers

        if dataset == "humaneval":
            problems = get_human_eval_plus(
                mini=mini, noextreme=noextreme, version=version
            )
        elif dataset == "mbpp":
            problems = get_mbpp_plus(mini=mini, noextreme=noextreme, version=version)
        else:
            raise ValueError(f"Unsupported EvalPlus dataset: {dataset}")

        task_ids = list(problems.keys())
        if sort_task_ids:
            task_ids.sort()

        self._problems = problems
        self._task_ids = task_ids
        self._expected_cache: dict[str, dict[str, Any]] = {}
        self._sanitize_fn: Callable[[str, str | None], str] | None = None
        if sanitize:
            try:
                from evalplus.sanitize import sanitize as sanitize_fn
            except Exception as exc:  # pragma: no cover - import guard
                raise RuntimeError(
                    "evalplus.sanitize is not available. Install tree-sitter "
                    "dependencies or disable sanitize."
                ) from exc
            self._sanitize_fn = sanitize_fn

    def parse_answer(self, sample_id: int, lm_response: str) -> str:
        code = _extract_code(lm_response)
        if self._sanitize_fn is not None:
            task_id = self._task_ids[sample_id]
            entry_point = self._problems[task_id]["entry_point"]
            code = self._sanitize_fn(code, entry_point)
        return code

    def evaluate(self, sample_id: int, answer: str) -> ProbResult:
        task_id = self._task_ids[sample_id]
        problem = self._problems[task_id]
        solution = self._compose_solution(problem, answer)
        base_stat, base_details = self._check_public(
            problem, solution, self._public_test_workers
        )
        passed = base_stat == self._pass_token
        num_tests = len(base_details)
        num_passed = sum(1 for passed_test in base_details if passed_test)
        if num_tests:
            metric = num_passed / num_tests
        else:
            metric = 1.0 if passed else 0.0
        detail = {
            "dataset": self._dataset,
            "task_id": task_id,
            "split": "public",
            "status": base_stat,
            "details": base_details,
        }
        return ProbResult(metric=metric, sample_detail=detail)

    def submit(self, sample_id: int, answer: str) -> ProbResult:
        task_id = self._task_ids[sample_id]
        problem = self._problems[task_id]
        solution = self._compose_solution(problem, answer)
        public_stat, public_details = self._check_public(
            problem, solution, self._public_test_workers
        )
        private_stat = None
        private_details = None
        used_private = False
        if not self._base_only and "plus_input" in problem:
            private_stat, private_details, used_private = self._check_private(
                problem, solution, self._private_test_workers
            )
        passed = public_stat == self._pass_token and (
            not used_private or private_stat == self._pass_token
        )
        dataset = getattr(self, "_dataset", None)
        detail = {
            "dataset": dataset,
            "task_id": task_id,
            "split": "private",
            "public_status": public_stat,
            "public_details": public_details,
            "private_status": private_stat,
            "private_details": private_details,
            "used_private": used_private,
        }
        return ProbResult(metric=1.0 if passed else 0.0, sample_detail=detail)

    def _compose_solution(self, problem: dict[str, Any], answer: str) -> str:
        entry_point = problem["entry_point"]
        if _looks_like_full_solution(answer, entry_point):
            return answer
        return problem["prompt"] + answer

    def _expected_for(
        self, problem: dict[str, Any], *, include_plus: bool
    ) -> dict[str, Any]:
        task_id = problem["task_id"]
        cached = self._expected_cache.get(task_id, {})

        output_not_none = False
        if self._dataset == "mbpp":
            output_not_none = problem["entry_point"] in self._mbpp_output_not_none

        if "base" not in cached:
            base, base_time = self._trusted_exec(
                problem["prompt"] + problem["canonical_solution"],
                problem["base_input"],
                problem["entry_point"],
                record_time=True,
                output_not_none=output_not_none,
            )
            cached["base"] = base
            cached["base_time"] = base_time

        if include_plus and not self._base_only and "plus" not in cached:
            plus, plus_time = self._trusted_exec(
                problem["prompt"] + problem["canonical_solution"],
                problem["plus_input"],
                problem["entry_point"],
                record_time=True,
                output_not_none=output_not_none,
            )
            cached["plus"] = plus
            cached["plus_time"] = plus_time

        self._expected_cache[task_id] = cached
        return cached

    def _check_public(
        self, problem: dict[str, Any], solution: str, test_workers: int
    ) -> tuple[str, list[bool]]:
        expected = self._expected_for(problem, include_plus=False)
        return self._run_untrusted_checks(
            problem,
            solution,
            problem["base_input"],
            expected["base"],
            expected["base_time"],
            test_workers,
        )

    def _check_private(
        self, problem: dict[str, Any], solution: str, test_workers: int
    ) -> tuple[str, list[bool], bool]:
        if self._base_only or "plus_input" not in problem:
            stat, details = self._check_public(problem, solution, test_workers)
            return stat, details, False

        expected = self._expected_for(problem, include_plus=True)
        if "plus" not in expected:
            stat, details = self._check_public(problem, solution, test_workers)
            return stat, details, False

        stat, details = self._run_untrusted_checks(
            problem,
            solution,
            problem["plus_input"],
            expected["plus"],
            expected["plus_time"],
            test_workers,
        )
        return stat, details, True

    def _run_untrusted_checks(
        self,
        problem: dict[str, Any],
        solution: str,
        inputs: list[Any],
        expected: list[Any],
        ref_time: list[float],
        test_workers: int,
    ) -> tuple[str, list[bool]]:
        if test_workers <= 1 or len(inputs) <= 1:
            stat, details = self._untrusted_check(
                self._dataset,
                solution,
                inputs,
                problem["entry_point"],
                expected=expected,
                atol=problem["atol"],
                ref_time=ref_time,
                fast_check=self._fast_check,
                min_time_limit=self._min_time_limit,
                gt_time_limit_factor=self._gt_time_limit_factor,
            )
            return stat, self._normalize_details(details, len(inputs))

        chunks = self._chunk_slices(len(inputs), test_workers)
        backend = "loky"
        qualname = getattr(self._untrusted_check, "__qualname__", "")
        if "<locals>" in qualname:
            backend = "threading"
        results = Parallel(n_jobs=len(chunks), backend=backend)(
            delayed(_run_evalplus_untrusted_chunk)(
                self._untrusted_check,
                self._dataset,
                solution,
                problem["entry_point"],
                inputs[sl],
                expected[sl],
                ref_time[sl],
                problem["atol"],
                self._fast_check,
                self._min_time_limit,
                self._gt_time_limit_factor,
            )
            for sl in chunks
        )

        overall = self._pass_token
        for stat, _ in results:
            if stat == self._timeout_token:
                overall = self._timeout_token
                break
            if stat == self._fail_token:
                overall = self._fail_token

        merged_details: list[bool] = []
        for _, details in results:
            merged_details.extend(details)
        return overall, merged_details

    @staticmethod
    def _normalize_details(details: Any, expected_len: int) -> list[bool]:
        if hasattr(details, "tolist"):
            normalized = details.tolist()
        else:
            normalized = list(details)
        if len(normalized) < expected_len:
            normalized.extend([False] * (expected_len - len(normalized)))
        return normalized

    @staticmethod
    def _chunk_slices(total: int, workers: int) -> list[slice]:
        if total <= 0:
            return []
        workers = min(workers, total)
        chunk_size = max(1, math.ceil(total / workers))
        return [
            slice(start, min(start + chunk_size, total))
            for start in range(0, total, chunk_size)
        ]


class BigCodeBenchTask(Task):
    def __init__(
        self,
        *,
        subset: str = "full",
        split: str | None = None,
        min_time_limit: float = 1.0,
        gt_time_limit: float = 20.0,
        max_as_limit: float = 30 * 1024,
        max_data_limit: float = 30 * 1024,
        max_stack_limit: float = 10,
        use_gt_time: bool = False,
        calibrated: bool = False,
        sort_task_ids: bool = False,
    ) -> None:
        _ensure_on_path(_CODE_EVAL_ROOT / "bigcodebench")
        try:
            from bigcodebench.data import get_bigcodebench
            from bigcodebench.eval import PASS, untrusted_check
            from bigcodebench.gen.util import trusted_check
        except Exception as exc:  # pragma: no cover - import guard
            raise RuntimeError(
                "bigcodebench is not available. Install it or ensure "
                "`src/bd_mcts/tasks/code_eval/bigcodebench` is on PYTHONPATH."
            ) from exc

        try:
            if split is None:
                problems = get_bigcodebench(subset=subset)
            else:
                problems = get_bigcodebench(subset=subset, split=split)
        except TypeError:
            problems = get_bigcodebench(subset=subset)

        task_ids = list(problems.keys())
        if sort_task_ids:
            task_ids.sort()

        self._problems = problems
        self._task_ids = task_ids
        self._pass_token = PASS
        self._untrusted_check = untrusted_check
        self._trusted_check = trusted_check
        self._min_time_limit = min_time_limit
        self._gt_time_limit = gt_time_limit
        self._max_as_limit = max_as_limit
        self._max_data_limit = max_data_limit
        self._max_stack_limit = max_stack_limit
        self._use_gt_time = use_gt_time
        self._calibrated = calibrated
        self._expected_time_cache: dict[str, float | None] = {}

    def parse_answer(self, sample_id: int, lm_response: str) -> str:
        return _extract_code(lm_response)

    def evaluate(self, sample_id: int, answer: str) -> ProbResult:
        return self._run_check(sample_id, answer, split="public")

    def submit(self, sample_id: int, answer: str) -> ProbResult:
        return self._run_check(
            sample_id, answer, split="private", public_equivalent=True
        )

    def _run_check(
        self,
        sample_id: int,
        answer: str,
        *,
        split: str,
        public_equivalent: bool = False,
    ) -> ProbResult:
        task_id = self._task_ids[sample_id]
        problem = self._problems[task_id]
        solution = self._compose_solution(problem, answer)
        gt_time = self._expected_time_for(problem)
        stat, details = self._untrusted_check(
            solution,
            problem["test"],
            problem["entry_point"],
            self._max_as_limit,
            self._max_data_limit,
            self._max_stack_limit,
            self._min_time_limit,
            gt_time,
        )
        passed = stat == self._pass_token
        detail = {
            "task_id": task_id,
            "split": split,
            "status": stat,
            "details": details,
        }
        if public_equivalent:
            detail["public_equivalent"] = True
        return ProbResult(metric=1.0 if passed else 0.0, sample_detail=detail)

    def _compose_solution(self, problem: dict[str, Any], answer: str) -> str:
        entry_point = problem["entry_point"]
        if _looks_like_full_solution(answer, entry_point):
            solution = answer
        else:
            solution = problem["complete_prompt"] + answer
        if self._calibrated and "code_prompt" in problem:
            solution = problem["code_prompt"] + "\n    pass\n" + solution
        return solution

    def _expected_time_for(self, problem: dict[str, Any]) -> float:
        if not self._use_gt_time:
            return self._gt_time_limit

        task_id = problem["task_id"]
        if task_id in self._expected_time_cache:
            cached = self._expected_time_cache[task_id]
            return cached if cached is not None else self._gt_time_limit

        if not problem.get("canonical_solution"):
            self._expected_time_cache[task_id] = None
            return self._gt_time_limit

        result = self._trusted_check(
            problem["complete_prompt"] + "\n" + problem["canonical_solution"],
            problem["test"],
            problem["task_id"],
            self._max_as_limit,
            self._max_data_limit,
            self._max_stack_limit,
            self._min_time_limit,
        )
        expected_time = result["time"]
        self._expected_time_cache[task_id] = expected_time
        return expected_time if expected_time is not None else self._gt_time_limit


class CruxEvalTask(Task):
    def __init__(
        self,
        mode: str,
        *,
        dataset_path: str | None = None,
        samples: Sequence[dict[str, Any]] | None = None,
    ) -> None:
        if mode not in ("input", "output"):
            raise ValueError("CruxEval mode must be 'input' or 'output'.")

        _ensure_on_path(_CODE_EVAL_ROOT / "cruxeval")
        try:
            from evaluation.utils_execute import check_correctness
        except Exception as exc:  # pragma: no cover - import guard
            raise RuntimeError(
                "cruxeval evaluation utilities are not available. Ensure "
                "`src/bd_mcts/tasks/code_eval/cruxeval` is on PYTHONPATH."
            ) from exc

        self._mode = mode
        self._check_correctness = check_correctness
        self._samples = (
            list(samples) if samples is not None else self._load_samples(dataset_path)
        )

    def parse_answer(self, sample_id: int, lm_response: str) -> str:
        return _extract_expression(lm_response)

    def evaluate(self, sample_id: int, answer: str) -> ProbResult:
        return self._run_check(sample_id, answer, split="public")

    def submit(self, sample_id: int, answer: str) -> ProbResult:
        return self._run_check(
            sample_id, answer, split="private", public_equivalent=True
        )

    def _run_check(
        self,
        sample_id: int,
        answer: str,
        *,
        split: str,
        public_equivalent: bool = False,
    ) -> ProbResult:
        sample = self._samples[sample_id]
        code = sample["code"]
        inp = sample["input"]
        out = sample["output"]
        passed = self._evaluate_one(code, inp, out, answer)
        detail = {
            "mode": self._mode,
            "sample_id": sample_id,
            "split": split,
        }
        if public_equivalent:
            detail["public_equivalent"] = True
        return ProbResult(metric=1.0 if passed else 0.0, sample_detail=detail)

    def _load_samples(self, dataset_path: str | None) -> list[dict[str, Any]]:
        if dataset_path is not None:
            with open(dataset_path, "r", encoding="utf-8") as f:
                return [json.loads(line) for line in f]

        try:
            from datasets import load_dataset
        except Exception as exc:  # pragma: no cover - import guard
            raise RuntimeError(
                "datasets is required to load cruxeval from HF. "
                "Provide dataset_path or install datasets."
            ) from exc
        dataset = load_dataset("cruxeval-org/cruxeval", split="test")
        return list(dataset)

    def _evaluate_one(self, code: str, inp: str, out: str, gen: str) -> bool:
        if self._mode == "input" and "f(" not in gen:
            return False
        if self._mode == "output" and f"f({inp})" in gen:
            return False
        code_to_execute = f"{code}\nassert {out} == {gen}"
        return self._check_correctness(code_to_execute, 3)


class LiveCodeBenchExecutionTask(Task):
    def __init__(
        self,
        *,
        release_version: str = "release_v1",
        samples: Sequence[dict[str, Any]] | None = None,
    ) -> None:
        _ensure_on_path(_CODE_EVAL_ROOT / "lcb")
        try:
            from lcb_runner.benchmarks.code_execution import (
                load_code_execution_dataset,
            )
            from lcb_runner.evaluation.utils_execute import (
                BASE_IMPORTS,
                check_correctness,
            )
        except Exception as exc:  # pragma: no cover - import guard
            raise RuntimeError(
                "LiveCodeBench utilities are not available. Ensure "
                "`src/bd_mcts/tasks/code_eval/lcb` is on PYTHONPATH."
            ) from exc

        self._base_imports = BASE_IMPORTS
        self._check_correctness = check_correctness
        if samples is not None:
            self._samples = list(samples)
        else:
            dataset = load_code_execution_dataset(release_version=release_version)
            self._samples = [item.get_evaluation_sample() for item in dataset]

    def parse_answer(self, sample_id: int, lm_response: str) -> str:
        return _extract_lcb_execution_expression(lm_response)

    def evaluate(self, sample_id: int, answer: str) -> ProbResult:
        return self._run_check(sample_id, answer, split="public")

    def submit(self, sample_id: int, answer: str) -> ProbResult:
        return self._run_check(
            sample_id, answer, split="private", public_equivalent=True
        )

    def _run_check(
        self,
        sample_id: int,
        answer: str,
        *,
        split: str,
        public_equivalent: bool = False,
    ) -> ProbResult:
        sample = self._samples[sample_id]
        code = sample["code"]
        inp = sample["input"]
        out = sample["output"]
        if inp in answer:
            passed = False
        else:
            program = f"{self._base_imports}\n{code}\nassert {out} == {answer}"
            passed = self._check_correctness(program, 3)
        detail = {
            "sample_id": sample_id,
            "split": split,
        }
        if public_equivalent:
            detail["public_equivalent"] = True
        return ProbResult(metric=1.0 if passed else 0.0, sample_detail=detail)


class LiveCodeBenchGenerationTask(Task):
    def __init__(
        self,
        *,
        release_version: str = "release_v1",
        start_date: str | None = None,
        end_date: str | None = None,
        difficulty: str | None = None,
        samples: Sequence[Any] | None = None,
        timeout: int = 6,
        debug: bool = False,
        public_test_workers: int = 1,
        private_test_workers: int = 1,
        prompt_use_instruct: bool = True,
        prompt_add_prefix: bool = False,
        prompt_tokenizer_name: str | None = None,
        prompt_trust_remote_code: bool = False,
        prompt_use_fast: bool = True,
    ) -> None:
        if public_test_workers < 1 or private_test_workers < 1:
            raise ValueError("test workers must be >= 1")
        _ensure_on_path(_CODE_EVAL_ROOT / "lcb")
        try:
            from lcb_runner.benchmarks.code_generation import (
                load_code_generation_dataset,
            )
            from lcb_runner.evaluation.compute_code_generation_metrics import (
                check_correctness,
            )
        except Exception as exc:  # pragma: no cover - import guard
            raise RuntimeError(
                "LiveCodeBench code-generation utilities are not available. Ensure "
                "`src/bd_mcts/tasks/code_eval/lcb` is on PYTHONPATH."
            ) from exc

        if samples is None:
            dataset = load_code_generation_dataset(
                release_version=release_version,
                start_date=start_date,
                end_date=end_date,
                difficulty=difficulty,
            )
            samples = dataset
        self._problems = list(samples)

        self._check_correctness = check_correctness
        self._timeout = timeout
        self._debug = debug
        self._public_test_workers = public_test_workers
        self._private_test_workers = private_test_workers
        self._prompt_use_instruct = prompt_use_instruct
        self._prompt_add_prefix = prompt_add_prefix
        self._prompt_tokenizer_name = prompt_tokenizer_name
        self._prompt_trust_remote_code = prompt_trust_remote_code
        self._prompt_use_fast = prompt_use_fast
        self._prompt_tokenizer = None
        self._public_samples: list[dict[str, str]] = []
        self._private_samples: list[dict[str, str]] = []
        self._has_private: list[bool] = []

        for sample in samples:
            public_sample = self._build_sample(sample, include_private=False)
            has_private = bool(getattr(sample, "private_test_cases", None))
            if has_private:
                private_sample = self._build_sample(sample, include_private=True)
            else:
                private_sample = public_sample
            self._public_samples.append(public_sample)
            self._private_samples.append(private_sample)
            self._has_private.append(has_private)

    def parse_answer(self, sample_id: int, lm_response: str) -> str:
        blocks = _CODE_BLOCK_RE.findall(lm_response)
        if blocks:
            return blocks[-1].strip()
        return lm_response.strip()

    def get_prompt(self, sample_id: int) -> str:
        _ensure_on_path(_CODE_EVAL_ROOT / "lcb")
        from run_lcb import prompt_formatter
        from lcb_runner.prompts.code_generation import PromptConstants

        problem = self._problems[sample_id]
        if self._prompt_use_instruct:
            # Build raw instruction text; the runner applies chat templates later.
            prompt = (
                "You will be given a question (problem specification) and will generate a correct Python program that matches the specification and passes all tests. You will NOT return anything except for the program.\n\n"
            )
            prompt += f"Question:\n{problem.question_content}\n\n"
            if problem.starter_code:
                prompt += f"{PromptConstants.FORMATTING_MESSAGE_WITH_STARTER_CODE}\n"
                prompt += f"```python\n{problem.starter_code}\n```\n\n"
            else:
                prompt += f"{PromptConstants.FORMATTING_WITHOUT_STARTER_CODE}\n\n"
                prompt += "```python\n# YOUR CODE HERE\n```\n\n"
            if self._prompt_add_prefix:
                prompt += "```python\n"
            return prompt
        return prompt_formatter(
            problem,
            use_instruct_prompt=False,
        )

    def evaluate(self, sample_id: int, answer: str) -> ProbResult:
        return self._run_check(
            sample_id,
            answer,
            split="public",
            test_workers=self._public_test_workers,
            use_fraction=True,
        )

    def submit(self, sample_id: int, answer: str) -> ProbResult:
        return self._run_check(
            sample_id, answer, split="private", test_workers=self._private_test_workers
        )

    def _build_sample(self, sample: Any, *, include_private: bool) -> dict[str, str]:
        if isinstance(sample, dict) and "input_output" in sample:
            return {"input_output": sample["input_output"]}

        public_tests = getattr(sample, "public_test_cases", [])
        private_tests = getattr(sample, "private_test_cases", [])
        tests = list(public_tests)
        if include_private:
            tests += list(private_tests)
        fn_name = None
        metadata = getattr(sample, "metadata", None)
        if isinstance(metadata, dict):
            fn_name = metadata.get("func_name", None)

        return {
            "input_output": json.dumps(
                {
                    "inputs": [t.input for t in tests],
                    "outputs": [t.output for t in tests],
                    "fn_name": fn_name,
                }
            )
        }

    def _expected_test_count(self, sample: dict[str, str]) -> int | None:
        input_output = sample.get("input_output")
        if not input_output:
            return None
        if isinstance(input_output, str):
            try:
                input_output = json.loads(input_output)
            except Exception:
                return None
        if not isinstance(input_output, dict):
            return None
        inputs = input_output.get("inputs")
        if isinstance(inputs, list):
            return len(inputs)
        return None

    def _run_check(
        self,
        sample_id: int,
        answer: str,
        *,
        split: str,
        test_workers: int,
        use_fraction: bool = False,
    ) -> ProbResult:
        if split == "public":
            sample = self._public_samples[sample_id]
            used_private = False
        else:
            sample = self._private_samples[sample_id]
            used_private = self._has_private[sample_id]

        results, metadata = self._check_with_workers(sample, answer, test_workers)
        results = list(results)
        expected_tests = self._expected_test_count(sample)
        if expected_tests is not None:
            if expected_tests < len(results):
                expected_tests = len(results)
            elif expected_tests > len(results):
                results.extend([False] * (expected_tests - len(results)))
        normalized_results = [_normalize_lcb_generation_result(r) for r in results]
        passed = bool(normalized_results) and all(normalized_results)
        num_passed = sum(1 for r in normalized_results if r)
        num_tests = expected_tests if expected_tests is not None else len(results)
        if use_fraction:
            metric = num_passed / num_tests if num_tests else 0.0
        else:
            metric = 1.0 if passed else 0.0
        detail = {
            "split": split,
            "used_private": used_private,
            "num_tests": num_tests,
            "num_passed": num_passed,
            "metadata": metadata,
        }
        return ProbResult(metric=metric, sample_detail=detail)

    def _check_with_workers(
        self, sample: dict[str, str], answer: str, test_workers: int
    ) -> tuple[list, dict]:
        if test_workers <= 1:
            return self._check_correctness(
                sample, answer, timeout=self._timeout, debug=self._debug
            )

        in_outs = json.loads(sample["input_output"])
        inputs = list(in_outs.get("inputs", []))
        outputs = list(in_outs.get("outputs", []))
        fn_name = in_outs.get("fn_name", None)
        if len(inputs) <= 1:
            return self._check_correctness(
                sample, answer, timeout=self._timeout, debug=self._debug
            )

        chunks = EvalPlusTask._chunk_slices(len(inputs), test_workers)
        jobs = [
            delayed(_run_lcb_generation_chunk)(
                self._check_correctness,
                {
                    "input_output": json.dumps(
                        {
                            "inputs": inputs[sl],
                            "outputs": outputs[sl],
                            "fn_name": fn_name,
                        }
                    )
                },
                answer,
                self._timeout,
                self._debug,
                len(inputs[sl]),
            )
            for sl in chunks
        ]
        chunk_pairs = Parallel(n_jobs=len(chunks), backend="loky")(jobs)

        merged_results: list = []
        metadata_chunks: list[dict] = []
        for chunk_results, chunk_metadata in chunk_pairs:
            merged_results.extend(chunk_results)
            metadata_chunks.append(chunk_metadata)
        return merged_results, {"chunks": metadata_chunks}


def make_code_eval_task(name: str, **kwargs: Any) -> Task:
    name_l = name.lower()
    if name_l in ("humaneval", "humaneval_plus"):
        return EvalPlusTask(dataset="humaneval", **kwargs)
    if name_l in ("mbpp", "mbpp_plus"):
        return EvalPlusTask(dataset="mbpp", **kwargs)
    if name_l in ("math", "hendrycks_math", "hendrycks-math"):
        from .math import make_hendrycks_math_task

        return make_hendrycks_math_task(**kwargs)
    if name_l in ("bigcodebench", "bcb"):
        return BigCodeBenchTask(**kwargs)
    if name_l in ("cruxeval_input", "cruxeval-input"):
        return CruxEvalTask(mode="input", **kwargs)
    if name_l in ("cruxeval_output", "cruxeval-output"):
        return CruxEvalTask(mode="output", **kwargs)
    if name_l in ("lcb_gen", "livecodebench_gen", "livecodebench_generation"):
        return LiveCodeBenchGenerationTask(**kwargs)
    if name_l in ("lcb_exec", "livecodebench_exec", "livecodebench-exec"):
        return LiveCodeBenchExecutionTask(**kwargs)
    raise ValueError(f"Unknown code eval task name: {name}")
