"""Compare Perfetto device profiler vs wall-clock timing on TPU.

Run on the TPU VM:
    python3.11 compare_timing.py <workload.py>

Measures the same workload with both methods and reports:
- Measured kernel latency (median ms) for each method
- Total evaluation time (how long the benchmarking itself takes)
"""
import glob
import gzip
import importlib.util
import json
import os
import shutil
import sys
import time

import jax
import jax.numpy as jnp
import numpy as np

NUM_WARMUP = 5
NUM_TRIALS = 20


def _load_module(path, name):
    spec = importlib.util.spec_from_file_location(name, path)
    mod = importlib.util.module_from_spec(spec)
    sys.modules[name] = mod
    spec.loader.exec_module(mod)
    return mod


def _extract_device_times(trace_dir):
    perfetto_files = glob.glob(f"{trace_dir}/**/perfetto_trace.json.gz", recursive=True)
    if not perfetto_files:
        return None
    with gzip.open(perfetto_files[0], 'rt') as f:
        data = json.load(f)
    events = data.get('traceEvents', data) if isinstance(data, dict) else data
    kernel_times = []
    for e in events:
        if not isinstance(e, dict) or e.get('dur', 0) <= 0:
            continue
        name = e.get('name', '')
        if name.startswith('jit_') and '(' in name:
            kernel_times.append(e['dur'])
    return kernel_times if kernel_times else None


def main():
    if len(sys.argv) < 2:
        print(f"Usage: {sys.argv[0]} <workload.py>")
        sys.exit(1)

    workload_path = sys.argv[1]

    # === Full pipeline: Wall-clock ===
    full_t0 = time.perf_counter()
    mod = _load_module(workload_path, "bench_mod")
    inputs = mod.create_inputs()
    fn = jax.jit(mod.workload)

    for _ in range(NUM_WARMUP):
        out = fn(*inputs)
        jax.block_until_ready(out)

    wall_times = []
    for _ in range(NUM_TRIALS):
        t0 = time.perf_counter()
        out = fn(*inputs)
        jax.block_until_ready(out)
        wall_times.append((time.perf_counter() - t0) * 1000.0)
    wall_total = time.perf_counter() - full_t0

    wall_arr = np.array(wall_times)
    print(f"Workload: {workload_path}")
    print(f"Output shape: {out.shape}, dtype: {out.dtype}")
    print(f"Warmup: {NUM_WARMUP}, Trials: {NUM_TRIALS}")
    print()

    print("=== Wall-clock timing ===")
    print(f"  Kernel median: {np.median(wall_arr):.3f} ms")
    print(f"  Kernel mean:   {np.mean(wall_arr):.3f} ms")
    print(f"  TOTAL time (load+jit+warmup+bench): {wall_total:.3f} s")
    print()

    # === Full pipeline: Perfetto ===
    # Re-do the full pipeline to get comparable total time
    jax.clear_caches()
    full_t0 = time.perf_counter()
    mod2 = _load_module(workload_path, "bench_mod2")
    inputs2 = mod2.create_inputs()
    fn2 = jax.jit(mod2.workload)

    for _ in range(NUM_WARMUP):
        out2 = fn2(*inputs2)
        jax.block_until_ready(out2)

    trace_dir = f"/tmp/compare_timing_trace_{os.getpid()}"
    if os.path.exists(trace_dir):
        shutil.rmtree(trace_dir)
    os.makedirs(trace_dir, exist_ok=True)

    with jax.profiler.trace(trace_dir, create_perfetto_link=False, create_perfetto_trace=True):
        for _ in range(NUM_TRIALS):
            with jax.named_scope('bench_kernel'):
                out2 = fn2(*inputs2)
            jax.block_until_ready(out2)

    kernel_times_us = _extract_device_times(trace_dir)
    shutil.rmtree(trace_dir, ignore_errors=True)
    perfetto_total = time.perf_counter() - full_t0

    if kernel_times_us:
        times_ms = np.array([t / 1000.0 for t in kernel_times_us[:NUM_TRIALS]])
        print(f"=== Perfetto device profiler ({len(kernel_times_us)} events) ===")
        print(f"  Kernel median: {np.median(times_ms):.3f} ms")
        print(f"  Kernel mean:   {np.mean(times_ms):.3f} ms")
    else:
        print("=== Perfetto device profiler: NO EVENTS FOUND ===")
    print(f"  TOTAL time (load+jit+warmup+bench+parse): {perfetto_total:.3f} s")
    print()

    # === Comparison ===
    print("=== Comparison ===")
    print(f"  Wall-clock total:  {wall_total:.3f} s")
    print(f"  Perfetto total:    {perfetto_total:.3f} s")
    print(f"  Overhead: {perfetto_total - wall_total:.3f} s ({(perfetto_total/wall_total - 1)*100:.0f}% slower)")
    if kernel_times_us:
        perfetto_median = np.median([t / 1000.0 for t in kernel_times_us[:NUM_TRIALS]])
        wall_median = np.median(wall_arr)
        print(f"  Kernel latency ratio (perfetto/wall): {perfetto_median/wall_median:.3f}x")


if __name__ == "__main__":
    main()
