"""Measure XLA baseline vs hand-tuned Pallas (optimized.py) for priority kernels.

Runs both variants through the full-50 JAXBench harness on the local TPU and
writes a JSON summary with apples-to-apples speedup = XLA / Pallas.

Usage:
    python -m autocomp.baselines.measure_handtuned \
        --output_dir /path/to/autocomp/output/handtuned \
        [--kernels 1p_Flash_Attention 2p_GQA_Attention ...]
"""
from __future__ import annotations

import argparse
import json
import pathlib
import time

from autocomp.backend.jaxbench.jaxbench_eval import JaxBenchEvalBackend
from autocomp.baselines.common import evaluate_many, load_baseline_code
from autocomp.common import logger


# Default: the 8 priority kernels that ship with hand-tuned optimized.py.
DEFAULT_KERNELS = [
    "1p_Flash_Attention",
    "2p_GQA_Attention",
    "3p_MLA_Attention",
    "4p_Sparse_Attention",
    "6p_Paged_Attention",
    "7p_Ragged_Paged_Attention",
    "8p_GEMM",
    "11p_Megablox_GMM",
]


def measure(prob_id: str, backend: JaxBenchEvalBackend) -> dict:
    """Measure one kernel: XLA baseline.py and Pallas optimized.py.

    Returns a dict with xla_latency_ms, pallas_latency_ms, speedup, correctness.
    """
    # XLA baseline
    prob_xla, code_xla = load_baseline_code(prob_id, prob_type="jaxbench-baseline")
    res_xla = evaluate_many(prob_xla, [code_xla], backend=backend)[0]
    xla_ms = res_xla.get("latency")
    xla_correct = res_xla.get("correct", False)

    # Pallas hand-tuned
    prob_pl, code_pl = load_baseline_code(prob_id, prob_type="jaxbench-pallas")
    res_pl = evaluate_many(prob_pl, [code_pl], backend=backend)[0]
    pl_ms = res_pl.get("latency")
    pl_correct = res_pl.get("correct", False)

    speedup = (xla_ms / pl_ms) if (xla_ms and pl_ms) else None

    return {
        "prob_id": prob_id,
        "xla_latency_ms": xla_ms,
        "xla_correct": xla_correct,
        "pallas_latency_ms": pl_ms,
        "pallas_correct": pl_correct,
        "pallas_speedup_vs_xla": speedup,
        "pallas_failure": None if pl_correct else (res_pl.get("failure_type") or "unknown"),
        "xla_failure": None if xla_correct else (res_xla.get("failure_type") or "unknown"),
    }


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--output_dir", required=True)
    p.add_argument("--kernels", nargs="+", default=DEFAULT_KERNELS)
    args = p.parse_args()

    outdir = pathlib.Path(args.output_dir)
    outdir.mkdir(parents=True, exist_ok=True)

    backend = JaxBenchEvalBackend()
    results = []
    for k in args.kernels:
        logger.info("Measuring %s ...", k)
        t0 = time.perf_counter()
        try:
            r = measure(k, backend)
        except Exception as e:
            logger.exception("Failed on %s", k)
            r = {"prob_id": k, "error": str(e)}
        r["elapsed_s"] = round(time.perf_counter() - t0, 1)
        results.append(r)
        (outdir / f"{k}.json").write_text(json.dumps(r, indent=2))
        logger.info("  XLA=%.3fms  Pallas=%.3fms  speedup=%sx",
                    r.get("xla_latency_ms") or -1, r.get("pallas_latency_ms") or -1,
                    f"{r['pallas_speedup_vs_xla']:.2f}" if r.get("pallas_speedup_vs_xla") else "N/A")

    (outdir / "summary.json").write_text(json.dumps({"results": results}, indent=2))
    logger.info("Wrote %s", outdir / "summary.json")


if __name__ == "__main__":
    main()
