from __future__ import annotations

import json
import multiprocessing
import sys
import os
import signal
import concurrent.futures
import urllib.request
import urllib.parse
import urllib.error
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Tuple
import re

from op_eval.dataset import dataset
from op_eval.config import (
    OP_EVAL_REMOTE_HTTP_TIMEOUT_S,
    OP_EVAL_WORKER_TIMEOUT_S,
    prepare_ascend_workdir,
)
from op_eval.models import EvalRequest
from op_eval.utils.evaluation_utils import eval_single, extract_first_code
import tempfile
import shutil

ASCEND_LANGUAGES = {"ascendc", "tilelang_ascend"}


@dataclass(frozen=True)
class _RemoteWorkItem:
    req: EvalRequest
    run: int


def _infer_run_from_run_name(run_name: str, fallback: int) -> int:
    if run_name:
        match = re.search(r"_run(\d+)$", run_name)
        if match:
            return int(match.group(1))
    return fallback


def _evaluate_op_process(req: EvalRequest, device_queue):
    """
    Evaluate a single operator using the same pattern as server.
    Parent acquires slot from device_queue, then spawns worker process.
    """
    ctx = multiprocessing.get_context("spawn")
    
    # Load code first (before acquiring slot)
    code_text = req.load_code()
    
    # Acquire device slot (parent-side, like server)
    device_id = device_queue.get() if device_queue else 0
    
    try:
        # Create temp workspace
        tmp_dir = tempfile.mkdtemp(prefix=f"eval_local_{req.op}_")
        result_queue = ctx.Queue()
        
        timeout_s = int(
            os.environ.get("OP_EVAL_LOCAL_TIMEOUT_S", str(OP_EVAL_WORKER_TIMEOUT_S))
        )
        
        p = ctx.Process(
            target=_run_local_worker,
            args=(req.op, code_text, req.language, device_id, result_queue, tmp_dir),
        )
        p.start()
        p.join(timeout=timeout_s)
        
        if p.is_alive():
            p.terminate()
            p.join(timeout=5)
            if p.is_alive():
                p.kill()
            return req.op, {
                "compiled": False,
                "correctness": None,
                "performance": None,
                "error": f"Evaluation timeout after {timeout_s}s",
            }
        
        if not result_queue.empty():
            return req.op, result_queue.get()
        
        return req.op, {
            "compiled": False,
            "correctness": None,
            "performance": None,
            "error": "Worker returned no result",
        }
    finally:
        # Release device slot back to queue
        if device_queue:
            device_queue.put(device_id)
        # Cleanup temp dir
        try:
            if 'tmp_dir' in dir() and os.path.exists(tmp_dir):
                shutil.rmtree(tmp_dir)
        except Exception:
            pass


def _run_local_worker(op, code_content, language, device_id, result_queue, workspace_dir):
    """
    Worker process for local batch evaluation.
    Mirrors server's run_worker but without Flask dependencies.
    """
    import traceback
    try:
        # Setup workspace environment
        os.environ["ASCEND_OP_ROOT"] = workspace_dir
        os.environ["ASCEND_OP_RUNS_ROOT"] = workspace_dir
        os.environ["ASCEND_DEVICE_ID"] = str(device_id)
        os.environ["DEVICE_ID"] = str(device_id)
        os.environ.setdefault("OP_EVAL_SKIP_NPU_CLEANUP", "1")
        
        tmp_scratch = os.path.join(workspace_dir, "_tmp")
        os.makedirs(tmp_scratch, exist_ok=True)
        os.environ.setdefault("TMPDIR", tmp_scratch)
        os.environ.setdefault("TMP", tmp_scratch)
        os.environ.setdefault("TEMP", tmp_scratch)
        
        # Set NPU device
        try:
            import torch_npu
            torch_npu.npu.set_device(device_id)
        except Exception as e:
            print(f"[WARN] Failed to set NPU device {device_id}: {e}", file=sys.stderr, flush=True)
        
        # Run evaluation
        result = eval_single(code_content, op, language, device_id=device_id)
        result_queue.put(result)
    except Exception as e:
        print(f"[ERROR] _run_local_worker: Exception for {op}: {e}", file=sys.stderr, flush=True)
        traceback.print_exc()
        result_queue.put({
            "compiled": False,
            "correctness": None,
            "performance": None,
            "error": f"Worker Process Failure: {str(e)}"
        })


def _run_parallel(
    requests: Iterable[EvalRequest],
    device_queue,
    max_workers: int,
) -> Dict[str, Dict]:
    results: Dict[str, Dict] = {}
    ctx = multiprocessing.get_context("spawn")
    with ctx.Pool(processes=max_workers, maxtasksperchild=1) as pool:
        tasks = []
        for req in requests:
            async_result = pool.apply_async(
                _evaluate_op_process,
                (req, device_queue),
            )
            tasks.append((req.op, async_result))
        for op, async_result in tasks:
            try:
                _, result = async_result.get()
            except Exception as exc:
                result = {
                    "compiled": False,
                    "correctness": None,
                    "performance": None,
                    "error": str(exc),
                }
            results[op] = result
    return results


def _single_worker(r, dq, rq):
    # CRITICAL: Redirect stdout to stderr to prevent breaking MCP JSON-RPC protocol
    # This catches both Python print() and C-level printf() from the backend
    try:
        sys.stdout.flush()
        os.dup2(sys.stderr.fileno(), sys.stdout.fileno())
    except Exception as e:
        # Fallback if dup2 fails (unlikely on POSIX)
        sys.stderr.write(f"Failed to redirect stdout: {e}\n")

    try:
        op, res = _evaluate_op_process(r, dq)
        rq.put((op, res))
    except Exception as e:
        rq.put((r.op, {
            "compiled": False,
            "correctness": None,
            "performance": None,
            "error": f"Critical Process Failure: {str(e)}"
        }))


def _run_single_isolated(req: EvalRequest, device_id: int) -> Dict[str, Dict]:
    """Run a single request in an isolated process without Manager/Pool overhead."""
    ctx = multiprocessing.get_context("spawn")
    # Use a simple Queue to mimic the device allocator queue
    device_q = ctx.Queue()
    device_q.put(device_id)
    
    # We need a return queue to get the result back
    result_q = ctx.Queue()
    
    p = ctx.Process(target=_single_worker, args=(req, device_q, result_q))
    p.start()
    p.join()
    
    if not result_q.empty():
        op, res = result_q.get()
        return {op: res}
    return {req.op: {"error": "Process died without returning result", "compiled": False}}


def _run_sequential(requests: Iterable[EvalRequest]) -> Dict[str, Dict]:
    results: Dict[str, Dict] = {}
    for req in requests:
        try:
            code_text = req.load_code()
            result = eval_single(code_text, req.op, req.language)
        except Exception as exc:
            result = {
                "compiled": False,
                "correctness": None,
                "performance": None,
                "error": str(exc),
            }
        results[req.op] = result
    return results


def _run_remote(
    items: Iterable[_RemoteWorkItem],
    remote_url: str,
    max_workers: int,
    desc: str = "Remote Eval",
    on_result: Optional[Callable[[_RemoteWorkItem, Dict], None]] = None,
    on_drop: Optional[Callable[[_RemoteWorkItem, Dict], None]] = None,
    desc_fn: Optional[Callable[[], str]] = None,
    last_width: int = 40,
    periodic_every_s: Optional[float] = None,
    periodic_fn: Optional[Callable[[], None]] = None,
) -> List[Tuple[_RemoteWorkItem, Dict]]:
    results: List[Tuple[_RemoteWorkItem, Dict]] = []

    def _send_request(item: _RemoteWorkItem):
        req = item.req
        try:
            code_text = req.load_code()
            endpoint = remote_url.rstrip("/") + "/evaluate"
            params = urllib.parse.urlencode({"op": req.op, "language": req.language})
            url = f"{endpoint}?{params}"

            cleaned_code = extract_first_code(code_text, ["python", "cpp", req.language])
            if cleaned_code is None:
                cleaned_code = code_text

            payload = json.dumps({"code": cleaned_code}).encode("utf-8")
            http_req = urllib.request.Request(
                url,
                data=payload,
                headers={"Content-Type": "application/json"},
            )

            http_timeout_s = float(
                os.environ.get(
                    "OP_EVAL_REMOTE_HTTP_TIMEOUT_S",
                    str(OP_EVAL_REMOTE_HTTP_TIMEOUT_S),
                )
            )
            with urllib.request.urlopen(http_req, timeout=http_timeout_s) as response:
                if response.status == 200:
                    return item, json.loads(response.read())
                raise RuntimeError(f"HTTP {response.status}: {response.read().decode()}")
        except Exception as e:
            # No per-request retry here; rely on batch-level retry in evaluate_requests().
            return item, {
                "compiled": False,
                "correctness": None,
                "performance": None,
                "error": f"Remote Error: {str(e)}",
            }

    from tqdm import tqdm
    def _format_last(op_name: str, status: str) -> str:
        if len(op_name) > last_width:
            trimmed = op_name[: max(0, last_width - 3)] + "..."
        else:
            trimmed = op_name.ljust(last_width)
        return f"Last: {trimmed} [{status}]"

    with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = {executor.submit(_send_request, item): item for item in items}
        pending_futures = set(futures.keys())
        initial_desc = desc_fn() if desc_fn is not None else desc
        with tqdm(total=len(futures), desc=initial_desc) as pbar:
            import time
            last_periodic = time.monotonic()
            while pending_futures:
                done, pending_futures = concurrent.futures.wait(
                    pending_futures,
                    timeout=1.0,
                    return_when=concurrent.futures.FIRST_COMPLETED,
                )

                if not done:
                    if (
                        periodic_every_s is not None
                        and periodic_fn is not None
                        and (time.monotonic() - last_periodic) >= float(periodic_every_s)
                    ):
                        try:
                            periodic_fn()
                        except Exception as e:
                            print(f"[WARN] Periodic save failed: {e}")
                        last_periodic = time.monotonic()
                    continue

                for future in done:
                    item, res = future.result()
                    req = item.req

                    # Check for infrastructure/network failure
                    if str(res.get("error", "")).startswith("Remote Error:"):
                        if desc_fn is not None:
                            pbar.set_description(desc_fn())
                        pbar.set_postfix_str(_format_last(req.op, "DROP"))
                        if on_drop is not None:
                            on_drop(item, res)
                        pbar.update(1)
                        continue

                    if on_result is not None:
                        on_result(item, res)
                    results.append((item, res))

                    # Update progress bar with status
                    # ERR = infrastructure failure (has 'error' field), FAIL = kernel issues
                    if res.get("error"):
                        status = "ERR"
                    elif res.get("correctness"):
                        status = "OK"
                    else:
                        status = "FAIL"  # Compilation or correctness failure
                    if desc_fn is not None:
                        pbar.set_description(desc_fn())
                    pbar.set_postfix_str(_format_last(req.op, status))
                    pbar.update(1)

                if (
                    periodic_every_s is not None
                    and periodic_fn is not None
                    and (time.monotonic() - last_periodic) >= float(periodic_every_s)
                ):
                    try:
                        periodic_fn()
                    except Exception as e:
                        print(f"[WARN] Periodic save failed: {e}")
                    last_periodic = time.monotonic()
            
    return results


def _write_results(out_file: Path, results: Dict) -> None:
    out_file.parent.mkdir(parents=True, exist_ok=True)
    with open(out_file, "w") as f:
        json.dump(results, f, indent=2)


def build_requests_from_dir(
    *,
    input_dir: Path,
    language: str,
    ops: Optional[Sequence[str]] = None,
    run_name: str = "default",
    skip_dataset_check: bool = False,
) -> List[EvalRequest]:
    """Scan an output dir and build EvalRequest objects for available ops."""
    base_dir = Path(input_dir)
    requested = set(ops) if ops else None
    requests: List[EvalRequest] = []
    
    def _add_req(path: Path, r_name: str):
        op = path.stem
        if requested and op not in requested:
            return
        has_reference = op in dataset
        if not has_reference and not skip_dataset_check:
            print(f"[WARN] {op}: unknown operator (not in dataset); skipping build and evaluation")
            return
        if not has_reference:
            print(f"[INFO] {op}: unknown operator (not in dataset); proceeding with --skip-dataset-check")
        requests.append(
            EvalRequest(
                op=op,
                code_path=path,
                language=language,
                run_name=r_name,
                metadata={"has_reference": has_reference},
            )
        )

    # 1. Scan root directory
    root_files = sorted(base_dir.glob("*.txt"))
    for path in root_files:
        _add_req(path, run_name)

    # 2. If no root files found, scan subdirectories starting with 'run'
    if not requests:
        dirs = [d for d in base_dir.glob("run*") if d.is_dir()]
        
        # Sort numerically: run2 < run10
        def _run_idx(p: Path):
            # Extract digits from 'run123' -> 123
            # Fallback to string sort if weird name
            try:
                base = p.name
                # Assuming 'run' prefix, get suffix
                val = base.replace("run", "")
                # Handle 'run_1' case if exists, or just 'run1'
                # Simple extraction of all digits? 
                # Let's stick to strict suffix to match generation logic or robust natural sort
                import re
                match = re.search(r'run(\d+)', base)
                return int(match.group(1)) if match else float('inf')
            except:
                return float('inf')

        # secondary sort by name for stability
        dirs.sort(key=lambda p: (_run_idx(p), p.name))

        for run_subdir in dirs:
            # Derive isolated run_name for this subdir
            # e.g., "gpt-5.2" + "_run0" -> "gpt-5.2_run0"
            sub_run_name = f"{run_name}_{run_subdir.name}"
            for path in sorted(run_subdir.glob("*.txt")):
                _add_req(path, sub_run_name)
                    
    missing = requested.difference({r.op for r in requests}) if requested else set()
    for op in sorted(missing):
        print(f"[WARN] Requested op {op} not found under {base_dir}")
    return requests


def evaluate_requests(
    requests: Sequence[EvalRequest],
    *,
    runs: int = 1,
    build_workers: int = 4,
    device_ids: Optional[Sequence[int]] = None,
    device_offset: int = 0,
    result_path: Optional[Path] = None,
    enable_npu_parallelism: bool = False,
    progress_callback: Optional[Callable[[], None]] = None,
    remote_url: Optional[str] = None,
) -> Dict:
    """Evaluate a batch of requests; returns results keyed by op."""
    if not requests:
        return {}
    aggregated: Dict[str, List[Dict]] = {}
    
    # 1. Resume capability: Load existing results
    if result_path and result_path.exists():
        try:
            with open(result_path, "r") as f:
                existing_data = json.load(f)
            # Normalize to list internal representation
            for op, val in existing_data.items():
                if isinstance(val, dict):
                    if "error" in val:
                        continue
                    aggregated[op] = [val]
                elif isinstance(val, list):
                    filtered = [entry for entry in val if isinstance(entry, dict) and "error" not in entry]
                    if filtered:
                        aggregated[op] = filtered
            print(f"[INFO] Resuming from {result_path} with {len(aggregated)} ops already recorded.")
        except Exception as e:
            print(f"[WARN] Failed to load existing results from {result_path}, starting fresh. Error: {e}")

    # helper to check if requests span multiple run_names
    run_name_groups = {}
    for req in requests:
        run_name_groups.setdefault(req.run_name, []).append(req)
        
    is_multi_sample = len(run_name_groups) > 1 or runs > 1
    try:
        error_retry_rounds = max(0, int(os.environ.get("OP_EVAL_ERROR_RETRY_ROUNDS", "1")))
    except ValueError:
        error_retry_rounds = 1
    max_attempts = 1 + error_retry_rounds
        
    for current_run_name, group_requests in run_name_groups.items():
        # Validate language consistency within group
        language = group_requests[0].language
        if any(req.language != language for req in group_requests):
            raise ValueError("All requests must target the same language/backend")

    if remote_url:
        pending_items: List[_RemoteWorkItem] = []
        pending_batches: Dict[Tuple[str, int], int] = {}
        for current_run_name, group_requests in run_name_groups.items():
            for run in range(runs):
                # Filtering: Determine which requests need to be run
                needed_requests = []
                for req in group_requests:
                    already_done = False
                    expected_run = _infer_run_from_run_name(req.run_name, run)
                    if req.op in aggregated:
                        for res in aggregated[req.op]:
                            # Check match. 
                            # Legacy results might lack 'run_name', assuming default if missing?
                            # But we are in a strict resume mode now.
                            r_match = res.get("run") == expected_run
                            rn_match = res.get("run_name") == current_run_name
                            if r_match and rn_match:
                                already_done = True
                                break
                    if not already_done:
                        needed_requests.append(req)

                if not needed_requests:
                    continue

                pending_batches[(current_run_name, run)] = len(needed_requests)
                for req in needed_requests:
                    pending_items.append(_RemoteWorkItem(req=req, run=run))

        if pending_items:
            total_requests = len(pending_items)
            total_batches = len(pending_batches)
            print(
                f"[INFO] Remote eval: {total_requests} requests across {total_batches} batches "
                f"(max {build_workers} concurrent) via {remote_url}"
            )
            completed_batches = {"count": 0}
            periodic_save_s = float(os.environ.get("OP_EVAL_PERIODIC_SAVE_S", "300"))

            def _mark_remote_complete(item: _RemoteWorkItem) -> None:
                batch_key = (item.req.run_name, item.run)
                remaining = pending_batches.get(batch_key)
                if remaining is None:
                    return
                remaining -= 1
                pending_batches[batch_key] = remaining
                if remaining == 0:
                    completed_batches["count"] += 1
                    if result_path:
                        final_state: Dict = aggregated if is_multi_sample else {op: vals[0] for op, vals in aggregated.items() if vals}
                        _write_results(Path(result_path), final_state)
                        if progress_callback:
                            try:
                                progress_callback()
                            except Exception as e:
                                print(f"[WARN] Progress callback failed: {e}")

            def _periodic_flush() -> None:
                if not result_path:
                    return
                final_state: Dict = aggregated if is_multi_sample else {op: vals[0] for op, vals in aggregated.items() if vals}
                _write_results(Path(result_path), final_state)
                print(f"[INFO] Periodic save: {len(aggregated)} ops -> {result_path}")
                if progress_callback:
                    try:
                        progress_callback()
                    except Exception as e:
                        print(f"[WARN] Progress callback failed: {e}")

            def _desc_fn() -> str:
                return f"Batches {completed_batches['count']}/{total_batches}"

            remaining_items = list(pending_items)
            for attempt in range(max_attempts):
                retry_items: List[_RemoteWorkItem] = []

                def _on_remote_result(item: _RemoteWorkItem, result: Dict) -> None:
                    if "error" in result:
                        if attempt < max_attempts - 1:
                            retry_items.append(item)
                        else:
                            _mark_remote_complete(item)
                        return
                    req = item.req
                    effective_run = _infer_run_from_run_name(req.run_name, item.run)

                    result["run"] = effective_run
                    result["run_name"] = req.run_name
                    result["code_path"] = str(req.code_path)
                    aggregated.setdefault(req.op, []).append(result)
                    _mark_remote_complete(item)

                def _on_remote_drop(item: _RemoteWorkItem, result: Dict) -> None:
                    if attempt < max_attempts - 1:
                        retry_items.append(item)
                    else:
                        _mark_remote_complete(item)

                _run_remote(
                    remaining_items,
                    remote_url,
                    max_workers=build_workers,
                    desc="Remote",
                    on_result=_on_remote_result,
                    on_drop=_on_remote_drop,
                    desc_fn=_desc_fn,
                    periodic_every_s=periodic_save_s,
                    periodic_fn=_periodic_flush,
                )

                if not retry_items:
                    break
                if attempt < max_attempts - 1:
                    sample = ", ".join(wi.req.op for wi in retry_items[:5])
                    print(
                        f"[WARN] {len(retry_items)} ops returned 'error' (attempt {attempt + 1}/{max_attempts}); retrying. "
                        f"Examples: {sample}"
                    )
                    remaining_items = retry_items
                else:
                    sample = ", ".join(wi.req.op for wi in retry_items[:5])
                    print(
                        f"[WARN] {len(retry_items)} ops still returned 'error' after {max_attempts} attempts; leaving for future retry. "
                        f"Examples: {sample}"
                    )
    else:
        # Calculate total sequential batches
        total_batches = len(run_name_groups) * runs
        batch_idx = 0

        for current_run_name, group_requests in run_name_groups.items():
            for run in range(runs):
                batch_idx += 1
                # Filtering: Determine which requests need to be run
                needed_requests = []
                for req in group_requests:
                    already_done = False
                    expected_run = _infer_run_from_run_name(req.run_name, run)
                    if req.op in aggregated:
                        for res in aggregated[req.op]:
                            # Check match. 
                            # Legacy results might lack 'run_name', assuming default if missing?
                            # But we are in a strict resume mode now.
                            r_match = res.get("run") == expected_run
                            rn_match = res.get("run_name") == current_run_name
                            if r_match and rn_match:
                                already_done = True
                                break
                    if not already_done:
                        needed_requests.append(req)

                if not needed_requests:
                    # All done for this group/run
                    continue

                if language in ASCEND_LANGUAGES:
                    prepare_ascend_workdir(current_run_name, language, run)

                remaining_requests = list(needed_requests)
                for attempt in range(max_attempts):
                    device_queue = None
                    manager = None

                    if language in ASCEND_LANGUAGES:
                        # Optimization for single-op evaluation (MCP/interactive)
                        # Avoids Manager overhead and potential deadlocks in nested process environments
                        if len(remaining_requests) == 1:
                            dev_id = (device_ids[0] if device_ids else 0) + device_offset
                            results = _run_single_isolated(remaining_requests[0], dev_id)
                        else:
                            manager = multiprocessing.Manager()
                            try:
                                device_queue = manager.Queue()
                                ids = device_ids if device_ids is not None else list(range(build_workers))

                                num_devices = len(ids)
                                if enable_npu_parallelism:
                                    concurrency = max(1, build_workers // num_devices)
                                else:
                                    concurrency = 1

                                for device_id in ids:
                                    for _ in range(concurrency):
                                        device_queue.put(device_id + device_offset)
                                results = _run_parallel(remaining_requests, device_queue, max_workers=build_workers)
                            finally:
                                if manager:
                                    manager.shutdown()
                    else:
                        results = _run_sequential(remaining_requests)

                    req_map = {req.op: req for req in remaining_requests}
                    retry_requests: List[EvalRequest] = []

                    for op, result in results.items():
                        if "error" in result:
                            if attempt < max_attempts - 1:
                                req = req_map.get(op)
                                if req is not None:
                                    retry_requests.append(req)
                            continue

                        req = req_map.get(op)
                        effective_run = _infer_run_from_run_name(req.run_name, run) if req else run
                        result["run"] = effective_run

                        if req:
                            result["run_name"] = req.run_name
                            result["code_path"] = str(req.code_path)

                        aggregated.setdefault(op, []).append(result)

                    if not retry_requests:
                        break
                    if attempt < max_attempts - 1:
                        sample = ", ".join(r.op for r in retry_requests[:5])
                        print(
                            f"[WARN] {len(retry_requests)} ops returned 'error' (attempt {attempt + 1}/{max_attempts}); retrying. "
                            f"Examples: {sample}"
                        )
                        remaining_requests = retry_requests
                    else:
                        sample = ", ".join(r.op for r in retry_requests[:5])
                        print(
                            f"[WARN] {len(retry_requests)} ops still returned 'error' after {max_attempts} attempts; leaving for future retry. "
                            f"Examples: {sample}"
                        )

                # Incremental Save
                if result_path:
                    final_state: Dict = aggregated if is_multi_sample else {op: vals[0] for op, vals in aggregated.items() if vals}
                    _write_results(Path(result_path), final_state)
                    if progress_callback:
                        try:
                            progress_callback()
                        except Exception as e:
                            print(f"[WARN] Progress callback failed: {e}")

    # Simplified return: if runs=1 (cli default) but we have multiple run_names (Pass@k), we have a list of results per op.
    # The original logic flattened it if runs=1 for repeated trials.
    # But for Pass@k (implicit runs), we WANT the list.
    # We should detect if this is a "multi-sample" scenario.
    is_multi_sample = len(run_name_groups) > 1 or runs > 1
    
    final_results: Dict = aggregated if is_multi_sample else {op: vals[0] for op, vals in aggregated.items()}
    if result_path:
        _write_results(Path(result_path), final_results)
    return final_results


def evaluate_directory(
    *,
    input_dir: Path,
    language: str,
    ops: Optional[Sequence[str]] = None,
    run_name: str = "default",
    runs: int = 1,
    build_workers: int = 4,
    device_ids: Optional[Sequence[int]] = None,
    device_offset: int = 0,
    result_path: Optional[Path] = None,
    enable_npu_parallelism: bool = False,
) -> Dict:
    """Convenience helper: build requests from a directory and evaluate."""
    requests = build_requests_from_dir(
        input_dir=input_dir,
        language=language,
        ops=ops,
        run_name=run_name,
    )
    return evaluate_requests(
        requests,
        runs=runs,
        build_workers=build_workers,
        device_ids=device_ids,
        device_offset=device_offset,
        result_path=result_path,
        enable_npu_parallelism=enable_npu_parallelism,
    )


def evaluate_code(
    *,
    op: str,
    language: str,
    code: str,
    run_name: str = "inline",
    runs: int = 1,
    device_id: Optional[int] = None,
    result_path: Optional[Path] = None,
) -> Dict:
    """Evaluate a single code snippet without touching the filesystem."""
    has_reference = op in dataset
    if not has_reference:
        print(f"[WARN] {op}: no reference implementation found; will build only and skip evaluation")
    device_ids = [device_id] if device_id is not None else None
    request = EvalRequest(
        op=op,
        language=language,
        code=code,
        run_name=run_name,
        metadata={"has_reference": has_reference},
    )
    return evaluate_requests(
        [request],
        runs=runs,
        build_workers=1,
        device_ids=device_ids,
        device_offset=0,
        result_path=result_path,
    )
