"""Evaluate a single solution.py against a JAXBench workload on TPU.

Called by eval.sh in the Gemini CLI workspace. Prints human-readable output
(Correct: true/false, Latency: X ms) and appends a JSON result line to
a trajectory file.

Usage:
    python eval_single.py <prob_id> <solution_py> <trajectory_jsonl> <counter_file> <budget>
"""
from __future__ import annotations

import json
import sys
import time

from autocomp.baselines.common import evaluate_many, load_baseline_code
from autocomp.backend.jaxbench.jaxbench_eval import JaxBenchEvalBackend


def main():
    if len(sys.argv) != 6:
        print(f"Usage: {sys.argv[0]} <prob_id> <solution.py> <trajectory.jsonl> <counter_file> <budget>",
              file=sys.stderr)
        sys.exit(1)

    prob_id, solution_path, trajectory_path, counter_path, budget_str = sys.argv[1:6]
    budget = int(budget_str)

    # Read and bump counter
    try:
        count = int(open(counter_path).read().strip())
    except (FileNotFoundError, ValueError):
        count = 0

    if count >= budget:
        print(f"BUDGET EXHAUSTED ({count}/{budget} evaluations used)")
        sys.exit(1)

    count += 1
    with open(counter_path, "w") as f:
        f.write(str(count))

    code = open(solution_path).read()
    prob, baseline_code = load_baseline_code(prob_id)

    backend = JaxBenchEvalBackend()
    t0 = time.perf_counter()
    results = evaluate_many(prob, [code], backend=backend)
    elapsed = time.perf_counter() - t0
    result = results[0]

    correct = result.get("correct", False)
    latency = result.get("latency")
    failure_type = result.get("failure_type")
    stdout_text = result.get("stdout", "")

    print(f"Correct: {'true' if correct else 'false'}")
    if correct and latency is not None:
        print(f"Latency: {latency:.3f} ms")
    else:
        error_lines = stdout_text.strip().splitlines()[-5:] if stdout_text else []
        print(f"Error: {failure_type or 'unknown'}")
        for line in error_lines:
            print(f"  {line}")

    profile = result.get("profile", "")
    if profile:
        print(f"\nProfiler summary:\n{profile}")

    print(f"\nEval {count}/{budget}")

    entry = {
        "eval_idx": count,
        "correct": correct,
        "latency": latency,
        "failure_type": failure_type,
        "elapsed_s": round(elapsed, 1),
    }
    with open(trajectory_path, "a") as f:
        f.write(json.dumps(entry) + "\n")


if __name__ == "__main__":
    main()
