from __future__ import annotations

import json
import os
import subprocess
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from pathlib import Path
from typing import Sequence

from .base import ProbResult, Task


@dataclass(frozen=True)
class TaskItem:
    sample_id: int
    answer: str


def _find_repo_root() -> Path:
    current = Path(__file__).resolve()
    for parent in current.parents:
        if (parent / "pyproject.toml").exists():
            return parent
    return current.parents[3]


class InProcessTaskRunner:
    def __init__(
        self, task: Task, *, evaluate_workers: int = 1, submit_workers: int = 1
    ) -> None:
        if evaluate_workers < 1 or submit_workers < 1:
            raise ValueError("workers must be >= 1")
        self._task = task
        self._evaluate_workers = evaluate_workers
        self._submit_workers = submit_workers

    def evaluate_many(self, items: Sequence[TaskItem]) -> list[ProbResult]:
        return self._run_many(
            items, method="evaluate", max_workers=self._evaluate_workers
        )

    def submit_many(self, items: Sequence[TaskItem]) -> list[ProbResult]:
        return self._run_many(items, method="submit", max_workers=self._submit_workers)

    def _run_many(
        self, items: Sequence[TaskItem], *, method: str, max_workers: int
    ) -> list[ProbResult]:
        if not items:
            return []
        fn = getattr(self._task, method)
        with ThreadPoolExecutor(max_workers=max_workers) as executor:
            futures = [
                executor.submit(fn, item.sample_id, item.answer) for item in items
            ]
            return [future.result() for future in futures]


class SubprocessTaskRunner:
    def __init__(
        self,
        *,
        python_path: str,
        task_name: str,
        task_kwargs: dict,
        evaluate_workers: int = 1,
        submit_workers: int = 1,
        repo_root: str | None = None,
    ) -> None:
        if evaluate_workers < 1 or submit_workers < 1:
            raise ValueError("workers must be >= 1")
        self._python_path = python_path
        self._task_name = task_name
        self._task_kwargs = task_kwargs
        self._evaluate_workers = evaluate_workers
        self._submit_workers = submit_workers
        self._repo_root = Path(repo_root) if repo_root else _find_repo_root()

    def evaluate_many(self, items: Sequence[TaskItem]) -> list[ProbResult]:
        return self._run_many(
            items, method="evaluate", max_workers=self._evaluate_workers
        )

    def submit_many(self, items: Sequence[TaskItem]) -> list[ProbResult]:
        return self._run_many(items, method="submit", max_workers=self._submit_workers)

    def _run_many(
        self, items: Sequence[TaskItem], *, method: str, max_workers: int
    ) -> list[ProbResult]:
        if not items:
            return []
        with ThreadPoolExecutor(max_workers=max_workers) as executor:
            futures = [executor.submit(self._run_one, method, item) for item in items]
            return [future.result() for future in futures]

    def _run_one(self, method: str, item: TaskItem) -> ProbResult:
        payload = {
            "task_name": self._task_name,
            "task_kwargs": self._task_kwargs,
            "method": method,
            "sample_id": item.sample_id,
            "answer": item.answer,
        }
        env = os.environ.copy()
        existing = env.get("PYTHONPATH", "")
        env["PYTHONPATH"] = (
            f"{self._repo_root}{os.pathsep}{existing}"
            if existing
            else str(self._repo_root)
        )
        cmd = [
            self._python_path,
            "-m",
            "bd_mcts.tasks.code_eval_worker",
        ]
        result = subprocess.run(
            cmd,
            input=json.dumps(payload),
            text=True,
            capture_output=True,
            check=False,
            env=env,
        )
        if result.returncode != 0:
            raise RuntimeError(
                f"worker failed (code {result.returncode}): {result.stderr}"
            )
        data = json.loads(result.stdout)
        return ProbResult(metric=data["metric"], sample_detail=data["sample_detail"])
