from __future__ import annotations

import json
import os
import subprocess
from pathlib import Path
from typing import Any

from .base import ProbResult, Task


def _find_repo_root() -> Path:
    current = Path(__file__).resolve()
    for parent in current.parents:
        if (parent / "pyproject.toml").exists():
            return parent
    return current.parents[3]


class SubprocessTask(Task):
    def __init__(
        self,
        *,
        python_path: str,
        task_name: str,
        task_kwargs: dict[str, Any] | None = None,
        repo_root: str | None = None,
        worker_module: str = "bd_mcts.tasks.code_eval_worker",
    ) -> None:
        self._python_path = python_path
        self._task_name = task_name
        self._task_kwargs = task_kwargs or {}
        self._repo_root = Path(repo_root) if repo_root else _find_repo_root()
        self._worker_module = worker_module

    def evaluate(self, sample_id: int, answer: str) -> ProbResult:
        data = self._run_worker("evaluate", sample_id, answer)
        return ProbResult(metric=data["metric"], sample_detail=data["sample_detail"])

    def submit(self, sample_id: int, answer: str) -> ProbResult:
        data = self._run_worker("submit", sample_id, answer)
        return ProbResult(metric=data["metric"], sample_detail=data["sample_detail"])

    def parse_answer(self, sample_id: int, lm_response: str) -> str:
        data = self._run_worker("parse_answer", sample_id, lm_response)
        return data["parsed"]

    def _run_worker(self, method: str, sample_id: int, answer: str) -> dict[str, Any]:
        payload = {
            "task_name": self._task_name,
            "task_kwargs": self._task_kwargs,
            "method": method,
            "sample_id": sample_id,
            "answer": answer,
        }
        env = os.environ.copy()
        existing = env.get("PYTHONPATH", "")
        env["PYTHONPATH"] = (
            f"{self._repo_root}{os.pathsep}{existing}"
            if existing
            else str(self._repo_root)
        )
        cmd = [
            self._python_path,
            "-m",
            self._worker_module,
        ]
        result = subprocess.run(
            cmd,
            input=json.dumps(payload),
            text=True,
            capture_output=True,
            check=False,
            env=env,
        )
        if result.returncode != 0:
            raise RuntimeError(
                f"worker failed (code {result.returncode}): {result.stderr}"
            )
        return json.loads(result.stdout)
