"""Cache XLA cost_analysis FLOPs for every JAXBench workload.

For each benchmark under $JAXBENCH/benchmark/*/baseline.py:
  1. Load its CONFIG and create_inputs (bf16).
  2. Lower + compile workload via jax.jit and pull 'flops' from cost_analysis.
  3. Prefer the module's get_flops() if defined (matches run_all.py's precedence).
  4. Also read any `flops = ...` value defined in the module (priority kernels)
     for cross-checking.

Writes JSON: { bench_name: {"xla_flops": int, "manual_flops": float|None,
                             "has_get_flops": bool, "error": str|None} }.
"""
from __future__ import annotations
import importlib.util
import json
import pathlib
import sys
import traceback

import jax
import jax.numpy as jnp


BENCHMARK_ROOT = pathlib.Path("/path/to/JAXBench/benchmark")
OUT_PATH = pathlib.Path("/path/to/JAXBench/benchmark/flops_cache.json")


def _load_module(bench_dir: pathlib.Path):
    spec = importlib.util.spec_from_file_location(
        f"jaxbench_{bench_dir.name}", bench_dir / "baseline.py"
    )
    mod = importlib.util.module_from_spec(spec)
    sys.modules[spec.name] = mod
    spec.loader.exec_module(mod)
    return mod


def _manual_flops(mod) -> float | None:
    # Module-level `flops = ...` (priority kernels).
    v = getattr(mod, "flops", None)
    if v is None:
        return None
    try:
        return float(v)
    except Exception:
        return None


def _xla_flops(mod) -> int:
    """Follow run_all.py: prefer get_flops(), else cost_analysis()."""
    get_flops = getattr(mod, "get_flops", None)
    if get_flops is not None:
        return int(get_flops())

    inputs = mod.create_inputs(dtype=jnp.bfloat16)
    if getattr(mod, "_skip_jit", False):
        return 0

    compiled = jax.jit(mod.workload).lower(*inputs).compile()
    cost = compiled.cost_analysis()
    if isinstance(cost, list):
        cost = cost[0] if cost else {}
    return int(cost.get("flops", 0))


def main():
    results: dict[str, dict] = {}
    benches = sorted(p.name for p in BENCHMARK_ROOT.iterdir() if p.is_dir())
    for name in benches:
        bench_dir = BENCHMARK_ROOT / name
        if not (bench_dir / "baseline.py").exists():
            continue
        entry: dict = {"xla_flops": 0, "manual_flops": None,
                       "has_get_flops": False, "error": None}
        try:
            mod = _load_module(bench_dir)
            entry["manual_flops"] = _manual_flops(mod)
            entry["has_get_flops"] = getattr(mod, "get_flops", None) is not None
            entry["xla_flops"] = _xla_flops(mod)
        except Exception as e:
            entry["error"] = f"{type(e).__name__}: {e}"
            traceback.print_exc()
        results[name] = entry
        x = entry["xla_flops"]
        m = entry["manual_flops"]
        err = entry["error"] or ""
        print(f"{name:55s} xla={x:>14d} manual={m!s:>14s} {err}")

    OUT_PATH.write_text(json.dumps(results, indent=2))
    print(f"\nWrote {OUT_PATH}")


if __name__ == "__main__":
    main()
