"""Cursory audit of winning generated kernels per (method, benchmark) cell.

For every (method, benchmark) cell reported in trajectory_data.json we:
  1. Locate the best (fastest-correct) kernel across all samples for that cell.
  2. Read its `code.py` and scan for patterns typical of kernels that
     exploit the "baseline outputs zeros" correctness-gate bug:
       - Pure passthrough: `return jnp.zeros_like(...)` or similar.
       - pl.pallas_call with shape(0,) or empty grid.
       - `return x` (no computation).
       - Output produced entirely from an uninitialized VMEM scratch.
  3. Flag cells whose speedup+baseline-MXU implies >95% of TPU v6e peak
     (918 TFLOPS bf16), which is physically implausible for non-matmul-heavy
     workloads.

This is a CURSORY scan -- it does not re-execute anything on the TPU. It
gives us a short list of cells that are most likely to need re-generation
after the baseline fix.
"""
from __future__ import annotations

import json
import pathlib
import re
from collections import defaultdict


REPO = pathlib.Path("/path/to/tpu")
BASELINES_FLASH = REPO / "autocomp/output/baselines-flash"
AUTOCOMP_SWEEP = REPO / "autocomp/output/jaxbench-sweep-flash"
TRAJ = REPO / "jaxbench-overleaf/Figures/trajectory_data.json"
FLOPS_CACHE = REPO / "JAXBench/benchmark/flops_cache.json"

# TPU v6e peak bf16.
PEAK_TFLOPS = 918.0

METHODS = {
    "Iterative": BASELINES_FLASH / "iterative",
    "Iterative+ctx": BASELINES_FLASH / "iterative_context",
    "Best-of-N": BASELINES_FLASH / "best_of_n",
    "Autocomp": AUTOCOMP_SWEEP,
}


# ---------------------------------------------------------------------------
# Winning-kernel lookup
# ---------------------------------------------------------------------------

def _read_json(p: pathlib.Path) -> dict | None:
    try:
        return json.loads(p.read_text())
    except Exception:
        return None


def _is_correct(res: dict) -> bool:
    return bool(res.get("correct")) and res.get("latency") is not None


def _best_iterative_style(bench_dir: pathlib.Path) -> tuple[pathlib.Path, float] | None:
    """Iterative / Iterative+ctx / Best-of-N (when under baselines-flash/<m>):
    chain_*/turn_*/{code.py,result.json}."""
    best: tuple[pathlib.Path, float] | None = None
    for result_path in bench_dir.glob("chain_*/turn_*/result.json"):
        res = _read_json(result_path)
        if not res or not _is_correct(res):
            continue
        lat = float(res["latency"])
        code = result_path.parent / "code.py"
        if best is None or lat < best[1]:
            best = (code, lat)
    return best


def _best_best_of_n(bench_dir: pathlib.Path) -> tuple[pathlib.Path, float] | None:
    """Best-of-N runs have eval/code_*_result.txt + code_*.py."""
    best: tuple[pathlib.Path, float] | None = None
    eval_dir = bench_dir / "eval"
    if not eval_dir.is_dir():
        return _best_iterative_style(bench_dir)  # fall back if structured differently
    for rp in eval_dir.glob("code_*_result.txt"):
        if "_full" in rp.name:
            continue
        res = _read_json(rp)
        if not res or not _is_correct(res):
            continue
        m = re.search(r"code_(\d+)_", rp.name)
        if not m:
            continue
        code = eval_dir / f"code_{m.group(1)}.py"
        lat = float(res["latency"])
        if best is None or lat < best[1]:
            best = (code, lat)
    return best


def _best_autocomp(bench: str) -> tuple[pathlib.Path, float] | None:
    """Autocomp has two phases per bench:
        <bench>_baseline            (optimize phase)
        <bench>_baseline_translate  (translate phase)
    Winning kernel is the fastest correct across all eval-results-iter-*.
    """
    best: tuple[pathlib.Path, float] | None = None
    for suffix in ("_baseline", "_baseline_translate"):
        run_dir = AUTOCOMP_SWEEP / f"{bench}{suffix}"
        if not run_dir.is_dir():
            continue
        for it_dir in run_dir.glob("eval-results-iter-*"):
            for rp in it_dir.glob("code_*_result.txt"):
                if "_full" in rp.name:
                    continue
                res = _read_json(rp)
                if not res or not _is_correct(res):
                    continue
                m = re.search(r"code_(\d+)_", rp.name)
                if not m:
                    continue
                code = it_dir / f"code_{m.group(1)}.py"
                lat = float(res["latency"])
                if best is None or lat < best[1]:
                    best = (code, lat)
    return best


def winning_kernel(method: str, bench: str) -> tuple[pathlib.Path | None, float | None]:
    if method == "Autocomp":
        r = _best_autocomp(bench)
    elif method == "Best-of-N":
        bench_dir = METHODS[method] / bench
        r = _best_best_of_n(bench_dir) if bench_dir.is_dir() else None
    else:
        bench_dir = METHODS[method] / bench
        r = _best_iterative_style(bench_dir) if bench_dir.is_dir() else None
    return (r[0], r[1]) if r else (None, None)


# ---------------------------------------------------------------------------
# Code-pattern heuristics
# ---------------------------------------------------------------------------

# Very cheap syntactic matches. Each returns a short reason string if the
# pattern is triggered; else None.

_NOOP_PATTERNS: list[tuple[re.Pattern, str]] = [
    (re.compile(r"\breturn\s+jnp\.zeros_like\(", re.MULTILINE),
     "returns jnp.zeros_like(...)"),
    (re.compile(r"\breturn\s+jnp\.zeros\(", re.MULTILINE),
     "returns jnp.zeros(...)"),
    (re.compile(r"\breturn\s+np\.zeros\(", re.MULTILINE),
     "returns np.zeros(...)"),
    # Very short workload body: a return statement only.
    (re.compile(r"def\s+workload\([^)]*\)[^:]*:\s*\n\s*return\s+\w+\s*$",
                re.MULTILINE),
     "workload is a bare `return x` passthrough"),
    # Writes a literal all-zeros block into the output ref inside a kernel.
    (re.compile(r"out_ref\[\.\.\.\]\s*=\s*jnp\.zeros\(", re.MULTILINE),
     "kernel writes jnp.zeros(...) into output ref"),
    (re.compile(r"out_ref\[\.\.\.\]\s*=\s*jnp\.zeros_like\(", re.MULTILINE),
     "kernel writes jnp.zeros_like(...) into output ref"),
    (re.compile(r"o_ref\[\.\.\.\]\s*=\s*jnp\.zeros\(", re.MULTILINE),
     "kernel writes jnp.zeros(...) into output ref"),
]


def _scan_code_patterns(code_src: str) -> list[str]:
    hits: list[str] = []
    for pat, reason in _NOOP_PATTERNS:
        if pat.search(code_src):
            hits.append(reason)
    # pl.pallas_call with empty grid or zero-dim out_shape.
    if re.search(r"grid\s*=\s*\(\s*\)", code_src):
        hits.append("pallas_call has empty grid=()")
    if re.search(r"grid\s*=\s*\(\s*0\b", code_src):
        hits.append("pallas_call has leading-zero grid dim")
    # ShapeDtypeStruct((0,...)) -> zero-sized output.
    if re.search(r"ShapeDtypeStruct\(\s*\(\s*0\b", code_src):
        hits.append("out_shape has leading 0 dim")
    return hits


# ---------------------------------------------------------------------------
# Implausible-perf heuristic
# ---------------------------------------------------------------------------

def load_mxu_floor() -> dict[str, float]:
    """Return per-benchmark achieved MXU% of the *XLA baseline* so we can
    compute speedup * baseline_MXU upper-bounding."""
    data = _read_json(TRAJ) or {}
    baselines_ms = data.get("baselines", {})
    flops_cache = _read_json(FLOPS_CACHE) or {}
    mxu_floor: dict[str, float] = {}
    for bench, lat_ms in baselines_ms.items():
        fl = flops_cache.get(bench, {})
        f = fl.get("xla_flops", 0) or 0
        if lat_ms and f > 0:
            tflops = f / (lat_ms * 1e-3) / 1e12
            mxu_floor[bench] = tflops / PEAK_TFLOPS
    return mxu_floor


# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------

def main() -> None:
    traj = _read_json(TRAJ) or {}
    per_bench = traj.get("per_bench_speedups", {})
    baselines_ms = traj.get("baselines", {})
    mxu_floor = load_mxu_floor()

    # Only audit the 33 KernelBench L2 benchmarks (others have hand-authored
    # non-degenerate baselines).
    benches = [b for b in sorted(baselines_ms) if re.match(r"\d+k_", b)]

    results: dict[tuple[str, str], dict] = {}
    for bench in benches:
        for method in METHODS:
            speedup = per_bench.get(method, {}).get(bench)
            code_path, latency = winning_kernel(method, bench)
            cell: dict = {
                "speedup": speedup,
                "code_path": str(code_path) if code_path else None,
                "latency_ms": latency,
            }
            flags: list[str] = []

            # Implausible MXU.
            base_mxu = mxu_floor.get(bench)
            if speedup and base_mxu is not None and speedup > 0:
                proj_mxu = speedup * base_mxu
                cell["projected_mxu_pct"] = proj_mxu * 100
                if proj_mxu >= 0.95:
                    flags.append(f"implausible MXU {proj_mxu*100:.0f}%")

            # Static code scan.
            if code_path and code_path.is_file():
                try:
                    src = code_path.read_text()
                except Exception:
                    src = ""
                for h in _scan_code_patterns(src):
                    flags.append(f"code:{h}")

            cell["flags"] = flags
            results[(method, bench)] = cell

    # Pretty-print grouped by benchmark.
    print(f"{'benchmark':<50s} {'method':<16s} {'speedup':>10s} "
          f"{'proj_MXU':>9s}  flags")
    print("-" * 130)
    flagged_cells: list[tuple[str, str, dict]] = []
    for bench in benches:
        for method in METHODS:
            c = results[(method, bench)]
            sp = c.get("speedup")
            mxu = c.get("projected_mxu_pct")
            sp_s = f"{sp:.2f}x" if sp else "   -"
            mxu_s = f"{mxu:>6.1f}%" if mxu is not None else "     -"
            flag_s = ", ".join(c.get("flags", [])) or ""
            if c.get("flags"):
                flagged_cells.append((method, bench, c))
            print(f"{bench:<50s} {method:<16s} {sp_s:>10s} {mxu_s:>9s}  {flag_s}")

    print("\n" + "=" * 80)
    print(f"FLAGGED CELLS: {len(flagged_cells)}")
    print("=" * 80)
    for method, bench, c in flagged_cells:
        print(f"  {method:<16s} {bench:<50s}")
        for f in c["flags"]:
            print(f"      - {f}")
        if c.get("code_path"):
            print(f"      code: {c['code_path']}")


if __name__ == "__main__":
    main()
