"""For each JAXBench workload, load it, run create_inputs+workload, and
report output statistics. Used to verify the baseline fix produced
numerically non-trivial reference outputs.
"""
import importlib.util
import pathlib
import sys
import traceback

import jax
import jax.numpy as jnp
import numpy as np


BENCH_ROOT = pathlib.Path("$HOME/benchmark")


def run_one(name: str) -> dict:
    bench_dir = BENCH_ROOT / name
    path = bench_dir / "baseline.py"
    spec = importlib.util.spec_from_file_location(f"bench_{name}", path)
    mod = importlib.util.module_from_spec(spec)
    sys.modules[spec.name] = mod
    try:
        spec.loader.exec_module(mod)
        inputs = mod.create_inputs(dtype=jnp.bfloat16)
        if getattr(mod, "_skip_jit", False):
            out = mod.workload(*inputs)
        else:
            out = jax.jit(mod.workload)(*inputs)
        if hasattr(out, "block_until_ready"):
            out.block_until_ready()
        arr = np.asarray(out.astype(jnp.float32))
        nonzero_frac = float((arr != 0).mean())
        return {
            "ok": True,
            "shape": tuple(arr.shape),
            "mean": float(arr.mean()),
            "std": float(arr.std()),
            "min": float(arr.min()),
            "max": float(arr.max()),
            "nonzero_frac": nonzero_frac,
        }
    except Exception as e:
        return {"ok": False, "error": f"{type(e).__name__}: {e}",
                "trace": traceback.format_exc()}


def main():
    benches = sorted(p.name for p in BENCH_ROOT.iterdir() if p.is_dir() and (p / "baseline.py").is_file())
    print(f"{'benchmark':<56s} {'std':>12s} {'nonzero%':>9s} {'mean':>12s} {'min':>12s} {'max':>12s}")
    print("-" * 120)
    for name in benches:
        r = run_one(name)
        if not r["ok"]:
            print(f"{name:<56s} ERROR: {r['error']}")
            continue
        print(
            f"{name:<56s} "
            f"{r['std']:>12.4g} "
            f"{r['nonzero_frac']*100:>8.1f}% "
            f"{r['mean']:>12.4g} "
            f"{r['min']:>12.4g} "
            f"{r['max']:>12.4g}"
        )


if __name__ == "__main__":
    main()
