"""
JAXBench runner — executed on the TPU VM by autocomp.

Usage:
    python jaxbench_runner.py <workload.py> <impl_0.py> [impl_1.py ...]

For each implementation file, imports its `workload()` function, checks correctness
against the reference workload from the workload file, benchmarks using host-side
timing, and prints delimited JSON results to stdout.
"""
import importlib.util
import json
import multiprocessing
import os
import sys
import time
import traceback

import jax
import jax.numpy as jnp
import numpy as np
from jax.experimental import pallas as pl  # noqa: F401 — available for implementations
from jax.experimental.pallas import tpu as pltpu  # noqa: F401

DELIM_START = "===JAXBENCH_IMPL_START==="
DELIM_END = "===JAXBENCH_IMPL_END==="

NUM_WARMUP = int(os.getenv("AUTOCOMP_TPU_NUM_WARMUP", "5"))
NUM_TRIALS = int(os.getenv("AUTOCOMP_TPU_NUM_TRIALS", "20"))
ATOL = float(os.getenv("AUTOCOMP_JAXBENCH_ATOL", "3.125e-2"))
RTOL = float(os.getenv("AUTOCOMP_JAXBENCH_RTOL", "1e-2"))
IMPL_TIMEOUT = int(os.getenv("AUTOCOMP_JAXBENCH_IMPL_TIMEOUT", "120"))
PROFILE = os.getenv("AUTOCOMP_JAXBENCH_PROFILE", "0") == "1"
PROFILE_STEPS = int(os.getenv("AUTOCOMP_JAXBENCH_PROFILE_STEPS", "10"))


def _load_module(path: str, name: str):
    spec = importlib.util.spec_from_file_location(name, path)
    mod = importlib.util.module_from_spec(spec)
    sys.modules[name] = mod
    spec.loader.exec_module(mod)
    return mod


def _summarize_profile(profile_dir: str) -> str:
    """Parse xprof overview_page + roofline_model for profile_dir and return a
    compact text summary suitable for LLM feedback. Returns "" on any failure.
    """
    try:
        import glob
        from xprof.convert import raw_to_tool_data
    except Exception as e:
        return f"[profile: xprof import failed: {e}]"

    xps = glob.glob(os.path.join(profile_dir, "**", "*.xplane.pb"), recursive=True)
    if not xps:
        return f"[profile: no .xplane.pb under {profile_dir}]"
    xp = xps[0]

    try:
        ov_bytes, _ = raw_to_tool_data.xspace_to_tool_data([xp], "overview_page", {})
        ov = json.loads(ov_bytes.decode() if isinstance(ov_bytes, (bytes, bytearray)) else ov_bytes)
    except Exception as e:
        ov = []
        ov_err = str(e)
    else:
        ov_err = None
    try:
        rf_bytes, _ = raw_to_tool_data.xspace_to_tool_data([xp], "roofline_model", {})
        rf = json.loads(rf_bytes.decode() if isinstance(rf_bytes, (bytes, bytearray)) else rf_bytes)
    except Exception as e:
        rf = []
        rf_err = str(e)
    else:
        rf_err = None

    perf = {}
    for block in ov if isinstance(ov, list) else []:
        p = block.get("p", {}) if isinstance(block, dict) else {}
        for k in ("mxu_utilization_percent",
                  "flop_rate_utilization_relative_to_roofline",
                  "memory_bw_utilization_relative_to_hw_limit",
                  "hbm_utilization_percent",
                  "device_duty_cycle_percent"):
            if k in p and k not in perf:
                perf[k] = p[k]

    rf_rows = []
    peak_flop = None
    peak_hbm_bw = None
    ridge = None
    for block in rf if isinstance(rf, list) else []:
        if not isinstance(block, dict):
            continue
        p = block.get("p", {})
        if "peak_flop_rate" in p:
            peak_flop = p.get("peak_flop_rate")
            peak_hbm_bw = p.get("peak_hbm_bw")
            ridge = p.get("hbm_ridge_point")
        cols = block.get("cols") or []
        col_ids = [c.get("id") for c in cols]
        for row in block.get("rows", []) or []:
            vals = [cell.get("v") for cell in (row.get("c") or [])]
            rf_rows.append(dict(zip(col_ids, vals)))

    lines = []
    if peak_flop:
        try:
            lines.append(
                f"TPU v6e peaks: {float(peak_flop)/1000:.0f} TFLOPS (bf16), "
                f"{float(peak_hbm_bw):.0f} GB/s HBM, ridge {float(ridge):.0f} FLOP/B"
            )
        except Exception:
            pass
    if perf:
        lines.append(
            "Utilization: "
            + ", ".join(f"{k.replace('_percent','').replace('_',' ')}={v}" for k, v in perf.items())
        )

    if rf_rows:
        # Deduplicate by operation name, prefer rows marked 'Total' in step field
        seen = set()
        uniq = []
        for r in rf_rows:
            name = r.get("operation")
            if not name or name in seen:
                continue
            seen.add(name)
            uniq.append(r)
        def _num(x, default=0.0):
            try:
                return float(x)
            except Exception:
                return default
        top = sorted(uniq, key=lambda r: -_num(r.get("total_self_time")))[:10]
        if top:
            lines.append("Top HLO ops by self-time:")
            lines.append(
                f"  {'op':<36s} {'calls':>5s} {'time_ms':>8s} {'time%':>6s} "
                f"{'GF/s':>8s} {'compute%':>9s} {'mem_bw%':>8s} {'bound_by':>10s}"
            )
            for r in top:
                name = (r.get("operation") or "?")[:36]
                calls = int(_num(r.get("occurrences")))
                t_ms = _num(r.get("total_self_time"))
                t_pct = _num(r.get("total_self_time_percent"))
                rate = _num(r.get("measured_flop_rate"))
                comp_eff = _num(r.get("compute_efficiency")) * 100.0
                mem_eff = _num(r.get("max_mem_bw_utilization")) * 100.0
                bn = str(r.get("bound_by") or "")
                lines.append(
                    f"  {name:<36s} {calls:>5d} {t_ms:>8.3f} {t_pct:>5.1f}% "
                    f"{rate:>8.0f} {comp_eff:>8.1f}% {mem_eff:>7.1f}% {bn:>10s}"
                )

    if not lines:
        return f"[profile: empty; ov_err={ov_err} rf_err={rf_err} ov_len={len(ov) if isinstance(ov,list) else '?'} rf_len={len(rf) if isinstance(rf,list) else '?'}]"
    return "\n".join(lines).strip()


def _jit_kwargs(cfg: dict) -> dict:
    """Extract jit kwargs (static_argnums/static_argnames) from a workload CONFIG."""
    kw = {}
    if isinstance(cfg, dict):
        for k in ("static_argnums", "static_argnames"):
            if k in cfg:
                v = cfg[k]
                if isinstance(v, list):
                    v = tuple(v)
                kw[k] = v
    return kw


def _eval_impl(impl_path: str, inputs, ref_out, atol=ATOL, rtol=RTOL, jit_kwargs=None):
    """Evaluate a single implementation. Returns a result dict.

    `failure_type` is one of: None (success), "import_error", "missing_workload",
    "runtime_error" (covers compilation errors and runtime exceptions),
    "correctness_error".
    """
    result = {"correct": False, "latency": None, "error": "", "failure_type": None}

    try:
        impl_mod = _load_module(impl_path, "impl")
    except Exception:
        result["error"] = traceback.format_exc()
        result["failure_type"] = "import_error"
        return result

    impl_workload = getattr(impl_mod, "workload", None)
    if impl_workload is None:
        result["error"] = "implementation does not define a workload() function"
        result["failure_type"] = "missing_workload"
        return result

    try:
        impl_fn = jax.jit(impl_workload, **(jit_kwargs or {}))
        for _ in range(NUM_WARMUP):
            impl_out = impl_fn(*inputs)
            jax.block_until_ready(impl_out)

        abs_diff = jnp.abs(ref_out - impl_out)
        max_diff = float(jnp.max(abs_diff))
        ref_max = jnp.maximum(jnp.max(jnp.abs(ref_out)), 1e-6)
        global_rel_diff = float(max_diff / ref_max)
        elem_rel_diff = float(jnp.max(abs_diff / jnp.maximum(jnp.abs(ref_out), 1e-6)))

        if not jnp.allclose(ref_out, impl_out, atol=atol, rtol=rtol):
            result["error"] = (
                f"correctness check failed (max_diff={max_diff:.6f}, "
                f"global_rel_diff={global_rel_diff:.6f}, "
                f"elem_rel_diff={elem_rel_diff:.6f})"
            )
            result["failure_type"] = "correctness_error"
            result["max_diff"] = round(max_diff, 6)
            result["global_rel_diff"] = round(global_rel_diff, 6)
            result["elem_rel_diff"] = round(elem_rel_diff, 6)
            return result

        # Host-side timing with block_until_ready
        times_ms = []
        for _ in range(NUM_TRIALS):
            t0 = time.perf_counter()
            out = impl_fn(*inputs)
            jax.block_until_ready(out)
            times_ms.append((time.perf_counter() - t0) * 1000.0)
        timing_method = "wall_clock"

        times_arr = np.array(times_ms)
        median_ms = float(np.median(times_arr))

        result["correct"] = True
        result["latency"] = round(median_ms, 3)
        result["timing_method"] = timing_method
        result["all_times_ms"] = [round(t, 3) for t in times_ms]
        result["max_diff"] = round(max_diff, 6)
        result["global_rel_diff"] = round(global_rel_diff, 6)
        result["elem_rel_diff"] = round(elem_rel_diff, 6)

        if PROFILE:
            try:
                import tempfile
                with tempfile.TemporaryDirectory(prefix="autocomp_prof_") as prof_dir:
                    with jax.profiler.trace(prof_dir):
                        for step in range(PROFILE_STEPS):
                            with jax.profiler.StepTraceAnnotation("step", step_num=step):
                                out = impl_fn(*inputs)
                                jax.block_until_ready(out)
                    result["profile"] = _summarize_profile(prof_dir)
            except Exception:
                result["profile"] = ""

    except SystemExit:
        result["error"] = "SystemExit raised"
        result["failure_type"] = "runtime_error"
    except Exception:
        result["error"] = traceback.format_exc()
        result["failure_type"] = "runtime_error"

    return result


def _eval_worker(impl_path, workload_path, result_queue):
    """Run in a subprocess so we can hard-kill on timeout."""
    try:
        ref_mod = _load_module(workload_path, "jaxbench_ref")
        cfg = getattr(ref_mod, "CONFIG", {})
        jit_kwargs = _jit_kwargs(cfg)
        if hasattr(ref_mod, "create_inputs") and hasattr(ref_mod, "workload"):
            inputs = ref_mod.create_inputs()
            ref_fn = jax.jit(ref_mod.workload, **jit_kwargs)
        elif hasattr(ref_mod, "Model") and hasattr(ref_mod, "get_inputs"):
            init_args = ref_mod.get_init_inputs() if hasattr(ref_mod, "get_init_inputs") else []
            model = ref_mod.Model(*init_args)
            inputs = ref_mod.get_inputs()
            ref_fn = jax.jit(model.forward, **jit_kwargs)
        else:
            result_queue.put({"correct": False, "latency": None, "error": "bad workload"})
            return
        ref_out = ref_fn(*inputs)
        jax.block_until_ready(ref_out)
        atol = cfg.get("atol", ATOL)
        rtol = cfg.get("rtol", RTOL)
        result = _eval_impl(impl_path, inputs, ref_out, atol=atol, rtol=rtol, jit_kwargs=jit_kwargs)
    except Exception:
        result = {"correct": False, "latency": None, "error": traceback.format_exc()}
    result_queue.put(result)


def main():
    if len(sys.argv) < 3:
        print(f"Usage: {sys.argv[0]} <workload.py> <impl_0.py> [impl_1.py ...]",
              file=sys.stderr)
        sys.exit(1)

    workload_path = sys.argv[1]
    impl_paths = sys.argv[2:]

    for idx, impl_path in enumerate(impl_paths):
        print(DELIM_START, flush=True)

        ctx = multiprocessing.get_context("spawn")
        result_queue = ctx.Queue()
        proc = ctx.Process(target=_eval_worker,
                           args=(impl_path, workload_path, result_queue))
        proc.start()
        proc.join(timeout=IMPL_TIMEOUT)

        if proc.is_alive():
            proc.kill()
            proc.join(timeout=10)
            result = {"correct": False, "latency": None,
                      "error": f"timed out after {IMPL_TIMEOUT}s"}
        elif not result_queue.empty():
            result = result_queue.get_nowait()
        else:
            result = {"correct": False, "latency": None,
                      "error": f"worker exited with code {proc.exitcode}"}

        if result["correct"]:
            print(f"Latency: {result['latency']:.3f} ms")
        elif result["error"]:
            print(f"FAIL: {result['error']}")

        print(json.dumps(result))
        print(DELIM_END, flush=True)


if __name__ == "__main__":
    main()
