"""Shared utilities for baseline comparisons.

Thin wrappers over existing Autocomp infra so baseline scripts stay small.
"""
from __future__ import annotations

import json
import pathlib
import time
from typing import Iterable

from autocomp.agents.llm_agent import extract
from autocomp.backend.jaxbench.jaxbench_eval import JaxBenchEvalBackend
from autocomp.common import LLMClient, logger
from autocomp.search.prob import Prob
from autocomp.search.search import load_initial_code


def load_baseline_code(prob_id: str, prob_type: str = "jaxbench-baseline") -> tuple[Prob, str]:
    """Return (Prob, source code of the starting implementation).

    prob_type="jaxbench-baseline" loads baseline.py (JAX/XLA).
    prob_type="jaxbench-pallas"  loads optimized.py (hand-written Pallas).
    """
    prob = Prob(prob_type, prob_id)
    code = load_initial_code("jaxbench", prob)
    return prob, code


def build_prompt(prob_id: str, baseline_code: str, prob_type: str = "jaxbench-baseline") -> str:
    """Minimal prompt for one-shot generation."""
    if prob_type == "jaxbench-pallas":
        return (
            f"You are optimizing a Pallas TPU kernel for TPU v6e using the Pallas "
            f"programming model (`jax.experimental.pallas`).\n\n"
            f"Benchmark: {prob_id}\n"
            f"Target hardware: TPU v6e (Trillium). VMEM = 128 MiB.\n\n"
            f"Below is an existing Pallas implementation. Optimize it to run faster "
            f"while producing identical outputs.\n\n"
            f"Constraints:\n"
            f"- Keep the public `workload(*inputs)` signature unchanged.\n"
            f"- Produce a complete, self-contained Python file.\n"
            f"- Output only Python code inside a single ```python ... ``` block.\n\n"
            f"```python\n{baseline_code}\n```\n"
        )
    return (
        f"You are optimizing a JAX kernel for TPU v6e using the Pallas "
        f"programming model.\n\n"
        f"Benchmark: {prob_id}\n"
        f"Target hardware: TPU v6e (Trillium). VMEM = 128 MiB.\n\n"
        f"Below is the XLA/JAX baseline. Rewrite it as a Pallas kernel "
        f"(`jax.experimental.pallas`) that produces identical outputs and "
        f"runs faster than the baseline.\n\n"
        f"Constraints:\n"
        f"- Keep the public `workload(*inputs)` signature unchanged.\n"
        f"- Produce a complete, self-contained Python file.\n"
        f"- Output only Python code inside a single ```python ... ``` block.\n\n"
        f"```python\n{baseline_code}\n```\n"
    )


def evaluate_many(
    prob: Prob,
    code_strs: list[str],
    backend: JaxBenchEvalBackend | None = None,
) -> list[dict]:
    """Evaluate a batch of candidate codes on TPU; returns one dict per candidate.

    Each dict: {correct: bool, latency: float|None, stdout: str, stderr: str}.
    """
    backend = backend or JaxBenchEvalBackend()
    return backend.evaluate_code(prob, code_strs, simulator=None)


def save_candidates(
    output_dir: pathlib.Path,
    responses: list[str],
    eval_results: list[dict],
) -> None:
    """Persist generated code and eval results in per-candidate files."""
    cand_dir = output_dir / "candidates"
    eval_dir = output_dir / "eval"
    cand_dir.mkdir(parents=True, exist_ok=True)
    eval_dir.mkdir(parents=True, exist_ok=True)
    for i, (resp, res) in enumerate(zip(responses, eval_results)):
        (cand_dir / f"candidate_{i}_full.txt").write_text(resp)
        (cand_dir / f"candidate_{i}.py").write_text(extract(resp))
        (eval_dir / f"code_{i}_result.txt").write_text(json.dumps(res, indent=2))


def write_summary(
    output_dir: pathlib.Path,
    prob_id: str,
    n: int,
    baseline_latency: float | None,
    eval_results: list[dict],
    runtime_s: float,
    extra: dict | None = None,
) -> dict:
    """Write summary.json aggregating the run."""
    correct = [r for r in eval_results if r.get("correct") and r.get("latency") is not None]
    best = min((r["latency"] for r in correct), default=None)

    # Aggregate failure types across all eval_results.
    failure_counts: dict[str, int] = {}
    for r in eval_results:
        if r.get("correct"):
            continue
        ft = r.get("failure_type") or "unknown"
        failure_counts[ft] = failure_counts.get(ft, 0) + 1

    summary = {
        "prob_id": prob_id,
        "n": n,
        "n_correct": len(correct),
        "baseline_latency_ms": baseline_latency,
        "best_latency_ms": best,
        "speedup_vs_xla": (baseline_latency / best) if (best and baseline_latency) else None,
        "runtime_s": round(runtime_s, 1),
        "failure_counts": failure_counts,
    }
    if extra:
        summary.update(extra)
    (output_dir / "summary.json").write_text(json.dumps(summary, indent=2))
    return summary
