from __future__ import annotations

import os, csv, json, math, time
from dataclasses import dataclass, asdict, field
from typing import Any, Dict, List, Optional, Tuple

import torch
import torch.nn as nn
import torch.distributed as dist


# Bucket canonicalization & order (edit here to change ordering)
MODEL_BUCKET_ORDER = [
    "embed",
    "attn.qkv_proj",
    "attn.qkv_proj.lora",
    "attn.qk_norm",
    "attn.rope",
    "attn.compute",
    "attn.o_proj",
    "attn.o_proj.lora",
    "mlp.gate_up_proj",
    "mlp.gate_up_proj.lora",
    "mlp.silu_and_mul",
    "mlp.down_proj",
    "mlp.down_proj.lora",
    "norm",
    "lm_head",
    "sampler.linear.0",
    "sampler.linear.1",
    "sampler.norm",
    "communication",
]

BACKEND_BUCKET_ORDER = [
    "pre_decode",
    "decode",
    "decode.sample",
    "interleave_mask_tokens",
    "pre_draft",
    "draft",
    "pre_draft_and_verify",
    "draft_and_verify",
    "sampler_draft",
    "evaluate_posterior",
    "collate_kv",
    "insert_kv",
    "delete_kv",
]


# Global active handle
_ACTIVE_DECODE_PROFILER: Optional["Profiler"] = None
def register_active_decode_profiler(prof: "Profiler" | None):
    """Register an active profiler so helper timers can find it without plumbing."""
    global _ACTIVE_DECODE_PROFILER
    _ACTIVE_DECODE_PROFILER = prof


# timing contexts

class _NoopCtx:
    __slots__ = ()
    def __enter__(self): return None
    def __exit__(self, exc_type, exc, tb): return False

_NOOP = _NoopCtx()

class _CudaTimerCtx:
    __slots__ = ("prof", "bucket", "s", "e")
    def __init__(self, prof, bucket: str):
        self.prof = prof
        self.bucket = bucket
        self.s = torch.cuda.Event(enable_timing=True)
        self.e = torch.cuda.Event(enable_timing=True)
    def __enter__(self):
        self.s.record()
        return None
    def __exit__(self, exc_type, exc, tb):
        self.e.record()
        self.prof._iter_events.append(("cuda", self.s, self.e, self.bucket))
        return False

class _CpuTimerCtx:
    __slots__ = ("prof", "bucket", "t0")
    def __init__(self, prof, bucket: str):
        self.prof = prof
        self.bucket = bucket
        self.t0 = 0.0
    def __enter__(self):
        if self.prof.cfg.strict_sync: torch.cuda.synchronize()
        self.t0 = time.perf_counter()
        return None
    def __exit__(self, exc_type, exc, tb):
        if self.prof.cfg.strict_sync: torch.cuda.synchronize()
        dt_ms = (time.perf_counter() - self.t0) * 1e3
        self.prof._iter_events.append(("cpu", dt_ms, None, self.bucket))
        return False


def _maybe_timer(bucket: str, need_model: bool = True):
    """Return an active CUDA timer context for the bucket, or a no-op.

    need_model=True  -> gated by cfg.model_profiling
    need_model=False -> gated by cfg.backend_profiling
    need_model=None  -> ungated (rare)
    """
    prof = _ACTIVE_DECODE_PROFILER
    if (prof is None) or prof.disabled or (not prof._active_measure) or (not prof._run_ge_warmup):
        return _NOOP
    if need_model is True and not prof.cfg.model_profiling:
        return _NOOP
    elif need_model is False and not prof.cfg.backend_profiling:
        return _NOOP
    return _CudaTimerCtx(prof, bucket)

def _maybe_cpu_timer(bucket: str, need_backend: bool = True):
    prof = _ACTIVE_DECODE_PROFILER
    if (prof is None) or prof.disabled or (not prof._active_measure) or (not prof._run_ge_warmup):
        return _NOOP
    if need_backend and not prof.cfg.backend_profiling:
        return _NOOP
    return _CpuTimerCtx(prof, bucket)

# Public timers to be imported by model/backend code
def attention_compute_timer():
    return _maybe_timer("attn.compute", need_model=True)

def rope_compute_timer():
    return _maybe_timer("attn.rope", need_model=True)

def bucket_timer(bucket: str):
    return _maybe_timer(bucket, need_model=True)

def backend_bucket_timer(bucket: str):
    """Backend-only wall-clock timer; gated by cfg.backend_profiling."""
    return _maybe_cpu_timer(bucket, need_backend=True)


# Helper methods
def _mean_std(vals):
    vals = [float(v) for v in vals if v is not None]
    n = len(vals)
    if n == 0:
        return float("nan"), float("nan"), 0
    if n == 1:
        return vals[0], 0.0, 1
    m = sum(vals) / n
    var = sum((x - m) * (x - m) for x in vals) / (n - 1)
    return m, math.sqrt(var), n

def _dist_ready() -> bool:
    return dist.is_available() and dist.is_initialized()

def _rank_world() -> Tuple[int, int]:
    if _dist_ready(): return dist.get_rank(), dist.get_world_size()
    return 0, 1

def _mkdir(p: str): os.makedirs(p, exist_ok=True)

def _percentile(sorted_vals: List[float], q: float) -> float:
    if not sorted_vals: return float("nan")
    n = len(sorted_vals)
    k = max(0, min(n-1, math.ceil(q/100.0*n)-1))
    return sorted_vals[k] if n > 1 else sorted_vals[0]

def _fmt(x: Any) -> Any:
    try: return f"{float(x):.3f}"
    except Exception: return x if x is not None else ""

def _now_s() -> str:
    from datetime import datetime
    return datetime.now().isoformat(timespec="seconds")

def _canon_bucket(bucket: str) -> Tuple[str, str]:
    """
    Return (canonical_name, domain) where domain in {"model","backend"}.
    - Lowercases
    - backend.* -> strip "backend." prefix and mark domain="backend"
    - Normalize aliases (out_proj->o_proj, attn.qknorm->attn.qk_norm)
    """
    s = bucket.lower()
    if s.startswith("backend."):
        return s.split("backend.", 1)[1], "backend"
    # model domain normalization
    s = s.replace("attn.out_proj", "attn.o_proj")
    s = s.replace("attn.qknorm", "attn.qk_norm")
    return s, "model"

def _order_and_compact(avg_dict: Dict[str, float], order: List[str]) -> Dict[str, float]:
    """Return ordered dict (insertion-ordered) with only keys in 'order' first, others summed into 'others' at end."""
    ordered: Dict[str, float] = {}
    for k in order:
        if k in avg_dict:
            ordered[k] = avg_dict[k]
    
    # other keys are remaining unordered
    for k, v in avg_dict.items():
        if k not in ordered:
            ordered[k] = float(v)
    
    # others = 0.0
    # for k, v in avg_dict.items():
    #     if k not in ordered:
    #         others += float(v)
    # if others > 0:
    #     ordered["others"] = others
    return ordered

def generate_run_name(args: Any) -> str:
    model_name = args.model_name.split("/")[-1]
    run_name = f"{model_name}-{args.dataset}"
    if args.backend == "mtp":
        run_name += f"-mtp-r{args.lora_rank}-k{args.draft_length[0]}"
    elif args.backend == "mtp_rectangular":
        run_name += f"-mtp-r{args.lora_rank}-k{args.draft_length[0]}-{args.draft_length[1]}"
    else:
        run_name += f"-{args.backend}"
    run_name += f"-tp{len(args.rank_group)}-bsz{args.batch_size}-prefix{args.prefix_len}-gen{args.max_gen_len}"
    if args.temperature > 0:
        run_name += f"-temp{args.temperature:.1f}-topp{args.top_p:.2f}-topk{args.top_k}"
    else:
        run_name += f"-greedy"
    return run_name

# Config
@dataclass
class ProfilerConfig:
    output_dir: str = "profiler_out"
    collect_on_rank0_only: bool = True   # collect only on rank0
    strict_sync: bool = True             # cuda sync at step boundaries (accuracy↑)
    dist_barrier: bool = False           # barrier at step boundaries (outside timing)
    model_profiling: bool = False        # model/module breakdown
    backend_profiling: bool = False      # backend call breakdown
    num_total_runs: int = 10             # total number of runs (only for reporting)
    warmup_runs: int = 1                 # ignore first N runs
    print_per_run: bool = True           # print run summary
    run_name: Optional[str] = None       # folder name override
    kv_len_reduce: str = "mean"          # one of {"mean","max","p50","p90","sum"}
    kv_bins: List[int] = field(default_factory=lambda: [0, 512, 1024, 2048, 4096, 8192, 16384])


# Profiler
class Profiler:
    """
    Minimal-overhead CUDA-event profiler for speculative decoding.
    - Per-step latency and accepted-token throughput (tok/s).
    - Optional model/LoRA/backend/comms breakdown.
    - Rank0 aggregation by default; TP-aware.
    - Global aggregation across multiple _run_mtp_loop calls.
    - Clean JSON/CSV/Markdown outputs.
    """

    # Construction
    def __init__(self, runner_args: Any, cfg: Optional[ProfilerConfig] = None):
        self.cfg = cfg or ProfilerConfig()
        self.rank, self.world = _rank_world()
        self.disabled = (self.cfg.collect_on_rank0_only and self.rank != 0)
        self.out_dir = os.path.join(self.cfg.output_dir, self.cfg.run_name or generate_run_name(runner_args))
        if not self.disabled: _mkdir(self.out_dir)
        self.runner_args = vars(runner_args)

        # Per-run State
        self._active_measure: bool = False
        self._iter_idx: int = 0
        self._current_meta: Dict[str, Any] = {}
        self._iters_elapsed_ms: List[float] = []
        self._iter_tokens: List[int] = []           # accepted tokens per step (sum over batch)
        self._iter_kv_lens: List[int] = []          # KV length (sequence length) at the *start* of the step
        self._iter_events: List[Tuple[str, Any, Any, str]] = []  # (etype,start,end,bucket)
        self._iters_model_sums: List[Dict[str, float]] = []
        self._iters_backend_sums: List[Dict[str, float]] = []
        self._pending_step_tokens: int = 0
        self._pending_step_kvlen: Optional[int] = None

        # Per-run Results (list of dicts)
        self._runs: List[Dict[str, Any]] = []

        # Run-level warmup gating
        self._run_idx: int = 0
        self._run_ge_warmup: bool = False

        # Global Aggregation (across all runs in this process)
        self._g_lat_ms: List[float] = []
        self._g_tokens: int = 0
        self._g_steps: int = 0
        self._g_model_bucket_sum_ms: Dict[str, float] = {}
        self._g_backend_bucket_sum_ms: Dict[str, float] = {}
        # Global decode-length bucket aggregation
        self._g_len_bins_meta: List[Tuple[Optional[int], Optional[int], str]] = self._build_len_bins()
        self._g_len_bins_data: Dict[str, Dict[str, Any]] = {
            meta[2]: {"lat_ms": [], "tokens": 0, "time_ms": 0.0, "steps": 0} for meta in self._g_len_bins_meta
        }

        if self.rank == 0 and not self.disabled:
            print(f"[Profiler] Output dir: {self.out_dir}")

    # Attach / instrument
    def attach_model(self, model: nn.Module, use_gated_lora: bool = False) -> None:
        if self.disabled or not self.cfg.model_profiling:
            if self.rank == 0 and not self.disabled:
                print("[Profiler] model attach skipped (model_profiling=False).")
            return

        for name, m in model.named_modules():
            lname = name.lower()
            cls = m.__class__.__name__.lower()

            # Embedding / LM head
            if lname.endswith("tok_embeddings") or "embedding" in cls:
                self._wrap_module_forward(m, "embed");         continue
            if lname.endswith("output") or "output" in lname:
                self._wrap_module_forward(m, "lm_head");       continue

            # Attention projections (base vs LoRA split)
            if "wqkv" in lname and lname.endswith("wqkv"):
                if use_gated_lora: self._set_linear_buckets(m, "attn.qkv_proj", "attn.qkv_proj.lora"); continue
                else: self._wrap_module_forward(m, "attn.qkv_proj"); continue
            if "wo" in lname and lname.endswith("wo"):
                if use_gated_lora: self._set_linear_buckets(m, "attn.out_proj", "attn.out_proj.lora"); continue
                else: self._wrap_module_forward(m, "attn.out_proj"); continue

            # MLP projections (base vs LoRA)
            if "w13" in lname and lname.endswith("w13"):
                if use_gated_lora: self._set_linear_buckets(m, "mlp.gate_up_proj", "mlp.gate_up_proj.lora"); continue
                else: self._wrap_module_forward(m, "mlp.gate_up_proj"); continue
            if "w2" in lname and lname.endswith("w2"):
                if use_gated_lora: self._set_linear_buckets(m, "mlp.down_proj", "mlp.down_proj.lora"); continue
                else: self._wrap_module_forward(m, "mlp.down_proj"); continue

            # Norms
            if lname.endswith("q_norm") or lname.endswith("k_norm"):
                self._wrap_module_forward(m, "attn.qk_norm");        continue
            elif "sampler" not in lname and "rmsnorm" in cls:
                self._wrap_module_forward(m, "norm");          continue

            # Sampler
            if "sampler." in lname:
                if ".layers." in lname and lname.endswith(".linear"):
                    try:
                        idx = int(lname.split(".layers.")[1].split(".")[0])
                        self._wrap_module_forward(m, f"sampler.linear.{idx}"); continue
                    except Exception: pass
                if lname.endswith(".norm"):
                    self._wrap_module_forward(m, "sampler.norm"); continue

        self._patch_communication_ops()
        if self.rank == 0: print("[Profiler] model attached.")

    def _set_linear_buckets(self, mod: nn.Module, base_bucket: str, lora_bucket: str) -> None:
        try:
            setattr(mod, "_prof_base_bucket", base_bucket)
            setattr(mod, "_prof_lora_bucket", lora_bucket)
        except Exception: pass

    def attach_backend(self, backend_obj: Any) -> None:
        if self.disabled or not self.cfg.backend_profiling:
            if self.rank == 0 and not self.disabled:
                print("[Profiler] backend attach skipped (backend_profiling=False).")
            return
        name_map = {
            "interleave_mask_tokens": "interleave_mask_tokens",
            "decode": "decode",
            "draft": "draft",
            "draft_and_verify": "draft_and_verify",
            "sampler_draft": "sampler_draft",
            "evaluate_posterior": "evaluate_posterior",
            "collate_kv": "collate_kv",
            "collate_accepted_kv_cache": "collate_kv",
            "pre_decode": "pre_decode",
            "pre_draft": "pre_draft",
            "pre_draft_and_verify": "pre_draft_and_verify",
            "insert_kv": "insert_kv",
            "delete_kv": "delete_kv",
            "setup_caches": "setup_caches",
        }
        for method_name, short in name_map.items():
            if hasattr(backend_obj, method_name):
                self._wrap_backend_method(backend_obj, method_name, f"backend.{short}")
        if self.rank == 0: print("[Profiler] backend attached.")


    # Length-bucket helpers
    # Length reduce & bins
    def _reduce_kv_len(self, x: Union[int, List[int], torch.Tensor]) -> int:
        """
        Reduce a batch 1D tensor/list of lengths to a single scalar according to cfg.kv_len_reduce.
        Accepts int, list/tuple, or torch.Tensor (cpu/cuda).
        """
        if isinstance(x, int): return int(x)
        if isinstance(x, (list, tuple)): t = torch.as_tensor(x)
        elif torch.is_tensor(x): t = x
        else:
            try: return int(x)
            except Exception: return 0
        if t.numel() == 0: return 0
        if t.is_cuda: t = t.detach().to("cpu", non_blocking=True)
        t = t.to(torch.float32)

        mode = (self.cfg.kv_len_reduce or "mean").lower()
        if mode == "mean": v = t.mean()
        elif mode == "max": v = t.max()
        elif mode in ("p50", "median"): v = torch.quantile(t, 0.5)
        elif mode == "p90": v = torch.quantile(t, 0.9)
        elif mode == "sum": v = t.sum()
        else: v = t.mean()
        return int(v.item())


    def _build_len_bins(self) -> List[Tuple[Optional[int], Optional[int], str]]:
        edges = sorted(int(x) for x in self.cfg.kv_bins)
        metas: List[Tuple[Optional[int], Optional[int], str]] = []
        for i in range(len(edges) - 1):
            lo, hi = edges[i], edges[i + 1]
            metas.append((lo, hi, f"[{lo},{hi})"))
        # last bin: [last_edge, +inf)
        last = edges[-1] if edges else 0
        metas.append((last, None, f"[{last},∞)"))
        return metas

    def _bin_key_for_len(self, L: int) -> str:
        for lo, hi, key in self._g_len_bins_meta:
            if hi is None:
                if L >= int(lo or 0):
                    return key
            else:
                if int(lo or 0) <= L < int(hi):
                    return key
        # fallback (shouldn't happen)
        return "unbinned"


    # Run lifecycle
    def begin_run(self, *, bsz: int, label: str = "decode") -> None:
        if self.disabled: return
        self._current_meta = {
            "label": label, "bsz": int(bsz), "rank": self.rank, "world": self.world,
            "started_at": _now_s(), "cfg": asdict(self.cfg),
        }
        self._iters_elapsed_ms.clear()
        self._iter_tokens.clear()
        self._iter_events.clear()
        self._iters_model_sums.clear()
        self._iters_backend_sums.clear()
        self._iter_idx = 0
        self._iter_kv_lens.clear()
        self._run_idx += 1
        self._run_ge_warmup = (self._run_idx > int(self.cfg.warmup_runs))

    def time_decode(self):
        class _StepCtx:
            __slots__ = ("prof", "s", "e")
            def __init__(self, prof:"Profiler"):
                self.prof = prof
                self.s = torch.cuda.Event(enable_timing=True)
                self.e = torch.cuda.Event(enable_timing=True)
            def __enter__(self):
                if self.prof.cfg.strict_sync: torch.cuda.synchronize()
                if self.prof.cfg.dist_barrier and _dist_ready(): dist.barrier()  # outside timing
                self.prof._active_measure = True
                self.prof._iter_events.clear()
                self.prof._iter_idx += 1
                self.s.record()
                return None
            def __exit__(self, exc_type, exc, tb):
                self.e.record()
                self.prof._active_measure = False
                if self.prof.cfg.strict_sync: torch.cuda.synchronize()
                self.e.synchronize()
                if self.prof.cfg.dist_barrier and _dist_ready(): dist.barrier()  # outside timing

                if self.prof._run_ge_warmup:
                    step_ms = float(self.s.elapsed_time(self.e))
                    self.prof._iters_elapsed_ms.append(step_ms)
                    tok = getattr(self.prof, "_pending_step_tokens", 0)
                    self.prof._iter_tokens.append(int(tok))
                    self.prof._pending_step_tokens = 0

                    # record KV length seen at the *start* of this step; default to 0 if not provided
                    kvlen = getattr(self.prof, "_pending_step_kvlen", None)
                    self.prof._iter_kv_lens.append(int(kvlen if kvlen is not None else 0))
                    self.prof._pending_step_kvlen = None

                    if (self.prof.cfg.model_profiling or self.prof.cfg.backend_profiling) and self.prof._iter_events:
                        lm: Dict[str, float] = {}
                        lb: Dict[str, float] = {}
                        for etype, a, b, bucket in self.prof._iter_events:
                            dt = float(a.elapsed_time(b)) if etype == "cuda" else float(a)
                            key, domain = _canon_bucket(bucket)
                            if domain == "backend":
                                lb[key] = lb.get(key, 0.0) + dt
                            else:
                                lm[key] = lm.get(key, 0.0) + dt
                        if lm: self.prof._iters_model_sums.append(lm)
                        if lb: self.prof._iters_backend_sums.append(lb)

                self.prof._iter_events.clear()
                return False
        return _StepCtx(self)

    def set_step_tokens(self, n: int):
        if self.disabled: return
        try: self._pending_step_tokens = int(n)
        except Exception: self._pending_step_tokens = 0

    def set_step_seq_len(self, kv_len: Union[int, List[int], torch.Tensor]):
        """
        Record the KV-cache length (sequence length) *at the start of the step*.
        Call this once per decode step before the step finishes.
        """
        if self.disabled: return
        try: self._pending_step_kvlen = int(self._reduce_kv_len(kv_len))
        except Exception: self._pending_step_kvlen = 0

    def end_run(self) -> None:
        if self.disabled: return

        # --- skip warmup runs entirely ---
        if not self._run_ge_warmup:
            # clear pending state to be safe
            self._iters_elapsed_ms.clear()
            self._iter_tokens.clear()
            self._iter_events.clear()
            self._iters_model_sums.clear()
            self._iters_backend_sums.clear()
            if self.cfg.print_per_run and self.rank == 0:
                print(f"[Profiler] warmup run #{self._run_idx} skipped from aggregation.")
            return

        # Per-run stats
        vals = sorted(self._iters_elapsed_ms)
        n = len(vals); stats = {"count": n}
        mean = sum(vals)/n if n>0 else 0.0
        secs_total = sum(self._iters_elapsed_ms)/1000.0
        tok_total = int(sum(self._iter_tokens))
        tp = (tok_total/secs_total) if secs_total>0 else float("nan")
        if n>0:
            stats.update({
                "mean_ms": mean, "min_ms": vals[0], "max_ms": vals[-1],
                "p50_ms": _percentile(vals,50), "p90_ms": _percentile(vals,90), "p99_ms": _percentile(vals,99),
                "throughput_tok_s": tp, "tokens_total": tok_total, "time_total_ms": sum(self._iters_elapsed_ms),
            })

        # Per-run buckets (avg per step)
        buckets_model_avg: Dict[str, float] = {}
        buckets_backend_avg: Dict[str, float] = {}

        if n > 0 and self._iters_model_sums:
            agg_m: Dict[str, float] = {}
            for d in self._iters_model_sums:
                for k, v in d.items():
                    agg_m[k] = agg_m.get(k, 0.0) + v
            avg_m = {k: (v / n) for k, v in agg_m.items()}
            buckets_model_avg = _order_and_compact(avg_m, MODEL_BUCKET_ORDER)

        if n > 0 and self._iters_backend_sums:
            agg_b: Dict[str, float] = {}
            for d in self._iters_backend_sums:
                for k, v in d.items():
                    agg_b[k] = agg_b.get(k, 0.0) + v
            avg_b = {k: (v / n) for k, v in agg_b.items()}
            buckets_backend_avg = _order_and_compact(avg_b, BACKEND_BUCKET_ORDER)

        # Per-run decode-length buckets
        if n > 0 and len(self._iter_kv_lens) == n:
            # assemble per-bin aggregations
            per_bin_lat: Dict[str, List[float]] = {meta[2]: [] for meta in self._g_len_bins_meta}
            per_bin_tok: Dict[str, int] = {meta[2]: 0 for meta in self._g_len_bins_meta}
            per_bin_time_ms: Dict[str, float] = {meta[2]: 0.0 for meta in self._g_len_bins_meta}
            per_bin_steps: Dict[str, int] = {meta[2]: 0 for meta in self._g_len_bins_meta}
            for step_ms, step_tok, L in zip(self._iters_elapsed_ms, self._iter_tokens, self._iter_kv_lens):
                key = self._bin_key_for_len(int(L))
                per_bin_lat[key].append(float(step_ms))
                per_bin_tok[key] += int(step_tok)
                per_bin_time_ms[key] += float(step_ms)
                per_bin_steps[key] += 1

            def _mk_stat(ls: List[float], tok: int, tms: float, steps: int) -> Dict[str, Any]:
                ls_sorted = sorted(ls)
                m = (sum(ls_sorted)/steps) if steps>0 else float("nan")
                return {
                    "steps": steps,
                    "mean_ms": m,
                    "p50_ms": _percentile(ls_sorted, 50),
                    "p90_ms": _percentile(ls_sorted, 90),
                    "p99_ms": _percentile(ls_sorted, 99),
                    "time_total_ms": tms,
                    "tokens_total": tok,
                    "throughput_tok_s": (tok / (tms/1000.0)) if tms>0 else float("nan"),
                }

            len_bucket_stats: Dict[str, Dict[str, Any]] = {}
            for lo, hi, key in self._g_len_bins_meta:
                len_bucket_stats[key] = _mk_stat(per_bin_lat[key], per_bin_tok[key], per_bin_time_ms[key], per_bin_steps[key])
                # global accumulation
                gb = self._g_len_bins_data[key]
                gb["lat_ms"].extend(per_bin_lat[key])
                gb["tokens"] += per_bin_tok[key]
                gb["time_ms"] += per_bin_time_ms[key]
                gb["steps"] += per_bin_steps[key]

        pack = {
            "meta": dict(self._current_meta),
            "stats": stats,
            "buckets_model_avg_ms": buckets_model_avg,
            "buckets_backend_avg_ms": buckets_backend_avg,
            "decode_length_buckets": {
                "bins": [{"key": key, "range": [lo, hi] if hi is not None else [lo, None]} for lo, hi, key in self._g_len_bins_meta],
                "stats": len_bucket_stats,
            }
        }

        self._runs.append(pack)
        if self.cfg.print_per_run:
            self._print_run_summary(pack)

        # ---- Global accumulation (across runs) ----
        self._g_lat_ms.extend(self._iters_elapsed_ms)
        self._g_tokens += tok_total
        self._g_steps += n
        if self._iters_model_sums:
            for d in self._iters_model_sums:
                for k, v in d.items():
                    self._g_model_bucket_sum_ms[k] = self._g_model_bucket_sum_ms.get(k, 0.0) + float(v)
        if self._iters_backend_sums:
            for d in self._iters_backend_sums:
                for k, v in d.items():
                    self._g_backend_bucket_sum_ms[k] = self._g_backend_bucket_sum_ms.get(k, 0.0) + float(v)

    # Save
    def save_all(self) -> None:
        if self.disabled:
            if self.rank == 0: print("[Profiler] Disabled (rank gated). Nothing to save.")
            return

        # Global summary (single configuration across many runs)
        g_vals = sorted(self._g_lat_ms)
        g_n = len(g_vals)
        g_mean = sum(g_vals)/g_n if g_n>0 else 0.0
        g_secs_total = sum(self._g_lat_ms)/1000.0
        g_tp = (self._g_tokens/g_secs_total) if g_secs_total>0 else float("nan")

        # Global bucket averages per step
        g_buckets_model_avg: Dict[str, float] = {}
        g_buckets_backend_avg: Dict[str, float] = {}

        if g_n > 0 and self._g_model_bucket_sum_ms:
            avgm = {k: (v / g_n) for k, v in self._g_model_bucket_sum_ms.items()}
            g_buckets_model_avg = _order_and_compact(avgm, MODEL_BUCKET_ORDER)

        if g_n > 0 and self._g_backend_bucket_sum_ms:
            avgb = {k: (v / g_n) for k, v in self._g_backend_bucket_sum_ms.items()}
            g_buckets_backend_avg = _order_and_compact(avgb, BACKEND_BUCKET_ORDER)
        
        # Across-run mean±std (unweighted)
        run_lat_means = [r["stats"].get("mean_ms") for r in self._runs if "mean_ms" in r["stats"]]
        run_tp_means = [r["stats"].get("throughput_tok_s") for r in self._runs if "throughput_tok_s" in r["stats"]]

        lat_mean, lat_std, lat_n = _mean_std(run_lat_means)
        tp_mean, tp_std, tp_n = _mean_std(run_tp_means)

        # Buckets: collect per-run averages when present
        bucket_samples_model: Dict[str, List[float]] = {}
        bucket_samples_backend: Dict[str, List[float]] = {}

        for r in self._runs:
            bm = r.get("buckets_model_avg_ms", {}) or {}
            for k, v in bm.items():
                bucket_samples_model.setdefault(k, []).append(float(v))
            bb = r.get("buckets_backend_avg_ms", {}) or {}
            for k, v in bb.items():
                bucket_samples_backend.setdefault(k, []).append(float(v))

        def _stats_map(d: Dict[str, List[float]]) -> Dict[str, Dict[str, float]]:
            out = {}
            for k, lst in d.items():
                m, s, n = _mean_std(lst)
                out[k] = {"mean": m, "std": s, "n_runs": n}
            return out

        buckets_runs_stats_model = _stats_map(bucket_samples_model)
        buckets_runs_stats_backend = _stats_map(bucket_samples_backend)

        global_summary = {
            "meta": {
                "rank": self.rank, "world": self.world,
                "started_at": self._runs[0]["meta"]["started_at"] if self._runs else _now_s(),
                "cfg": asdict(self.cfg),
                "output_dir": self.out_dir,
            },
            "stats": {
                "steps_total": g_n,
                "time_total_ms": sum(self._g_lat_ms),
                "tokens_total": self._g_tokens,
                "mean_ms": g_mean,
                "p50_ms": _percentile(g_vals,50), "p90_ms": _percentile(g_vals,90), "p99_ms": _percentile(g_vals,99),
                "throughput_tok_s": g_tp,
            },
            "buckets_model_avg_ms": g_buckets_model_avg,
            "buckets_backend_avg_ms": g_buckets_backend_avg,
            "runs": self._runs,  # keep per-run packs for transparency
            "runs_aggregate": {
                "latency_ms": {"mean": lat_mean, "std": lat_std, "n_runs": lat_n},
                "throughput_tok_s": {"mean": tp_mean, "std": tp_std, "n_runs": tp_n},
                "buckets_model_avg_ms": buckets_runs_stats_model,
                "buckets_backend_avg_ms": buckets_runs_stats_backend,
            }
        }

        # Global decode-length buckets
        def _mk_global_len_stats() -> Dict[str, Any]:
            out = {}
            for lo, hi, key in self._g_len_bins_meta:
                d = self._g_len_bins_data[key]
                ls_sorted = sorted(d["lat_ms"])
                steps = int(d["steps"])
                tms = float(d["time_ms"])
                tok = int(d["tokens"])
                mean_ms = (sum(ls_sorted)/steps) if steps>0 else float("nan")
                out[key] = {
                    "steps": steps,
                    "mean_ms": mean_ms,
                    "p50_ms": _percentile(ls_sorted, 50),
                    "p90_ms": _percentile(ls_sorted, 90),
                    "p99_ms": _percentile(ls_sorted, 99),
                    "time_total_ms": tms,
                    "tokens_total": tok,
                    "throughput_tok_s": (tok / (tms/1000.0)) if tms>0 else float("nan"),
                    "range": [lo, hi] if hi is not None else [lo, None],
                }
            return out
        global_summary["decode_length_buckets"] = {
            "bins": [{"key": key, "range": [lo, hi] if hi is not None else [lo, None]} for lo, hi, key in self._g_len_bins_meta],
            "stats": _mk_global_len_stats(),
        }

        # Write files
        _mkdir(self.out_dir)
        jpath = os.path.join(self.out_dir, "summary.json")
        with open(jpath, "w") as f: json.dump(global_summary, f, indent=2)

        # Overview markdown
        mpath = os.path.join(self.out_dir, "report.md")
        with open(mpath, "w") as f:
            f.write("# Speculative Decoding — Global Profile\n\n")
            f.write(f"- Output dir: `{self.out_dir}`\n")
            cfg = asdict(self.cfg)
            f.write(f"- Config: `model_profiling={cfg['model_profiling']}`  `backend_profiling={cfg['backend_profiling']}`  "
                    f"`num_runs(warmup/total)={cfg['warmup_runs']}/{cfg['num_total_runs']}`  `strict_sync={cfg['strict_sync']}`  `dist_barrier={cfg['dist_barrier']}`  "
                    f"`kv_bins={cfg['kv_bins']}`  `kv_len_reduce={cfg['kv_len_reduce']}`\n\n")

            # Global summary table
            bsz = self._runs[0]["meta"]["bsz"]
            s = global_summary["stats"]
            f.write("## Global Summary\n\n")
            f.write("| num decoding steps | mean (ms) | p50 | p90 | p99 | tok/s | tokens | mean generated tokens | time (s) |\n|--:|--:|--:|--:|--:|--:|--:|--:|--:|\n")
            f.write("| {steps} | {mean} | {p50} | {p90} | {p99} | {tp} | {tok} | {mean_gen_tok} | {secs} |\n\n".format(
                steps=s["steps_total"],
                mean=_fmt(s["mean_ms"]), p50=_fmt(s["p50_ms"]), p90=_fmt(s["p90_ms"]), p99=_fmt(s["p99_ms"]),
                tp=_fmt(s["throughput_tok_s"]), tok=int(s["tokens_total"]), mean_gen_tok=_fmt(s['tokens_total'] / (s['steps_total'] * bsz)), secs=_fmt(s["time_total_ms"]/1000.0)
            ))

            # model buckets
            if g_buckets_model_avg:
                f.write("## Model Bucket Breakdown (avg (ms) / step)\n\n")
                f.write("| bucket | avg (ms) | share |\n|--:|--:|--:|\n")
                for k in MODEL_BUCKET_ORDER + (["others"] if "others" in g_buckets_model_avg else []):
                    if k in g_buckets_model_avg:
                        v = g_buckets_model_avg[k]
                        pct = (v / float(s["mean_ms"] or 1e-9)) * 100.0 if s["mean_ms"] else 0.0
                        f.write(f"| `{k}` | {float(v):.3f} | {pct:.1f}% |\n")
                f.write("\n")

            # Backend buckets
            if g_buckets_backend_avg:
                f.write("## Backend Bucket Breakdown (avg (ms) / step)\n\n")
                f.write("| bucket | avg (ms) | share |\n|--:|--:|--:|\n")
                for k in BACKEND_BUCKET_ORDER + (["others"] if "others" in g_buckets_backend_avg else []):
                    if k in g_buckets_backend_avg:
                        v = g_buckets_backend_avg[k]
                        pct = (v / float(s["mean_ms"] or 1e-9)) * 100.0 if s["mean_ms"] else 0.0
                        f.write(f"| `{k}` | {float(v):.3f} | {pct:.1f}% |\n")
                f.write("\n")

            # Per-run quick view
            if self._runs:
                f.write("## Runs\n\n")
                f.write("| run | bsz | num decoding steps | total generated tokens | mean (ms) | tok/s | mean generated tokens |\n|--:|--:|--:|--:|--:|--:|--:|\n")
                for i, r in enumerate(self._runs):
                    rm, rs = r["meta"], r["stats"]
                    f.write(
                        f"| {i} | {rm.get('bsz','?')} | {rs.get('count',0)} | {int(rs.get('tokens_total', 0))} | "
                        f"{_fmt(rs.get('mean_ms'))} | {_fmt(rs.get('throughput_tok_s'))} | {_fmt(rs.get('tokens_total', 0) / (rs.get('count', 1) * bsz))}|\n"
                    )
                f.write("\n")
            
            # Runs aggregate section
            ra = global_summary.get("runs_aggregate", {})
            if ra:
                f.write("## Across-run (unweighted) — mean ± std\n\n")
                lat = ra.get("latency_ms", {}); tp = ra.get("throughput_tok_s", {})
                f.write("| metric | mean | std | n_runs |\n|--|--:|--:|--:|\n")
                f.write(f"| latency (ms/num decoding step) | { _fmt(lat.get('mean')) } | { _fmt(lat.get('std')) } | { lat.get('n_runs', 0) } |\n")
                f.write(f"| throughput (tok/s) | { _fmt(tp.get('mean')) } | { _fmt(tp.get('std')) } | { tp.get('n_runs', 0) } |\n\n")

                bstats_m = ra.get("buckets_model_avg_ms", {})
                if bstats_m:
                    f.write("### Model buckets — mean ± std across runs\n\n")
                    f.write("| bucket | mean (ms) | std (ms) | n_runs |\n|--:|--:|--:|--:|\n")
                    for k in MODEL_BUCKET_ORDER + sorted([x for x in bstats_m.keys() if x not in MODEL_BUCKET_ORDER and x != "others"]):
                        if k in bstats_m:
                            d = bstats_m[k]
                            f.write(f"| `{k}` | { _fmt(d.get('mean')) } | { _fmt(d.get('std')) } | { d.get('n_runs', 0) } |\n")
                    if "others" in bstats_m:
                        d = bstats_m["others"]
                        f.write(f"| `others` | { _fmt(d.get('mean')) } | { _fmt(d.get('std')) } | { d.get('n_runs', 0) } |\n")
                    f.write("\n")

                bstats_b = ra.get("buckets_backend_avg_ms", {})
                if bstats_b:
                    f.write("### Backend buckets — mean ± std across runs\n\n")
                    f.write("| bucket | mean (ms) | std (ms) | n_runs |\n|--:|--:|--:|--:|\n")
                    for k in BACKEND_BUCKET_ORDER + sorted([x for x in bstats_b.keys() if x not in BACKEND_BUCKET_ORDER and x != "others"]):
                        if k in bstats_b:
                            d = bstats_b[k]
                            f.write(f"| `{k}` | { _fmt(d.get('mean')) } | { _fmt(d.get('std')) } | { d.get('n_runs', 0) } |\n")
                    if "others" in bstats_b:
                        d = bstats_b["others"]
                        f.write(f"| `others` | { _fmt(d.get('mean')) } | { _fmt(d.get('std')) } | { d.get('n_runs', 0) } |\n")
                    f.write("\n")

            # Decode-length buckets (global)
            if "decode_length_buckets" in global_summary:
                f.write("## Decode-length Buckets (KV length at step start)\n\n")
                f.write("| bin | range | steps | mean (ms) | p50 (ms) | p90 (ms) | p99 (ms) | tok/s | tokens | mean generated tokens | time (s) |\n|--:|--|--:|--:|--:|--:|--:|--:|--:|--:|--:|\n")
                gdb = global_summary["decode_length_buckets"]["stats"]
                # keep the declared order
                for lo, hi, key in self._g_len_bins_meta:
                    s = gdb.get(key, {})
                    rng = f"[{lo},{hi})" if hi is not None else f"[{lo},∞)"
                    f.write("| {key} | {rng} | {steps} | {mean} | {p50} | {p90} | {p99} | {tp} | {tok} | {mean_gen_tok} | {secs} |\n".format(
                        key=key, rng=rng, steps=int(s.get("steps",0)),
                        mean=_fmt(s.get("mean_ms")), p50=_fmt(s.get("p50_ms")),
                        p90=_fmt(s.get("p90_ms")), p99=_fmt(s.get("p99_ms")),
                        tp=_fmt(s.get("throughput_tok_s")), tok=int(s.get("tokens_total",0)),
                        mean_gen_tok=_fmt(s.get("tokens_total",0) / (1 if s.get("steps", 0) == 0 else s.get("steps",0) * bsz)),
                        secs=_fmt((s.get("time_total_ms",0.0))/1000.0)
                    ))
                f.write("\n")

        if self.rank == 0:
            print(f"[Profiler] Saved:\n  {jpath}\n  {mpath}")
            print(f"[Profiler] Global: steps={g_n} mean={_fmt(g_mean)}ms tok/s={_fmt(g_tp)}")

    def save_config(self, filename: str = "config.json", extra: Optional[Dict[str, Any]] = None) -> None:
        if self.disabled: return
        cfg = asdict(self.cfg); cfg["rank"] = self.rank; cfg["world"] = self.world
        def _to_jsonable(v):
            if isinstance(v, (str,int,float,bool)) or v is None: return v
            if isinstance(v, (list,tuple)): return [_to_jsonable(x) for x in v]
            return str(v)
        if hasattr(self.runner_args, "__dict__"):
            args_dict = {k:_to_jsonable(v) for k,v in vars(self.runner_args).items() if not callable(v) and not k.startswith("_")}
        elif isinstance(self.runner_args, dict):
            args_dict = {k:_to_jsonable(v) for k,v in self.runner_args.items() if not callable(v) and not k.startswith("_")}
        else:
            args_dict = {"__raw__": _to_jsonable(self.runner_args)}
        payload = {"runner_args": args_dict, "profiler_cfg": cfg}
        if extra: payload["extra"] = {k:_to_jsonable(v) for k,v in extra.items()}
        jpath = os.path.join(self.out_dir, filename)
        with open(jpath, "w") as f: json.dump(payload, f, indent=2)
        if self.rank == 0: print(f"[Profiler] Saved config: {jpath}")

    def _wrap_module_forward(self, module: nn.Module, bucket: str) -> None:
        # guard duplicate wraps
        if getattr(module, "_prof_wrapped", False): return
        setattr(module, "_prof_wrapped", True)
        orig_fwd = module.forward
        def wrapped(*args, **kwargs):
            prof = _ACTIVE_DECODE_PROFILER
            if (prof is None) or prof.disabled or (not prof._active_measure) or (not prof._run_ge_warmup) or (not prof.cfg.model_profiling):
                return orig_fwd(*args, **kwargs)
            s = torch.cuda.Event(enable_timing=True); e = torch.cuda.Event(enable_timing=True)
            s.record()
            out = orig_fwd(*args, **kwargs)
            e.record()
            prof._iter_events.append(("cuda", s, e, bucket))
            return out
        module.forward = wrapped  # type: ignore

    def _wrap_backend_method(self, obj: Any, method_name: str, bucket: str) -> None:
        orig = getattr(obj, method_name)
        def wrapped(*args, **kwargs):
            prof = _ACTIVE_DECODE_PROFILER
            if (prof is None) or prof.disabled or (not prof._active_measure) or (not prof._run_ge_warmup) or (not prof.cfg.backend_profiling):
                return orig(*args, **kwargs)
            t0 = time.perf_counter()
            try:
                return orig(*args, **kwargs)
            finally:
                dt_ms = (time.perf_counter() - t0) * 1e3
                prof._iter_events.append(("cpu", dt_ms, None, bucket))
        setattr(obj, method_name, wrapped)

    def _patch_communication_ops(self):
        if not _dist_ready(): return
        def _wrap(name: str):
            if not hasattr(dist, name): return
            orig = getattr(dist, name)
            def wrapped(*args, **kwargs):
                prof = _ACTIVE_DECODE_PROFILER
                if (prof is None) or prof.disabled or (not prof._active_measure) or (not prof._run_ge_warmup) or (not prof.cfg.model_profiling):
                    return orig(*args, **kwargs)
                s = torch.cuda.Event(enable_timing=True); e = torch.cuda.Event(enable_timing=True)
                s.record()
                out = orig(*args, **kwargs)
                e.record()
                prof._iter_events.append(("cuda", s, e, "communication"))
                return out
            setattr(dist, name, wrapped)
        for n in ("all_reduce","reduce_scatter_tensor","all_gather","broadcast"):
            _wrap(n)

    # Pretty printing
    def _print_run_summary(self, pack: Dict[str, Any]) -> None:
        m = pack.get("meta", {}) or {}
        s = pack.get("stats", {}) or {}
        bm = pack.get("buckets_model_avg_ms") or {}
        bb = pack.get("buckets_backend_avg_ms") or {}

        steps = int(s.get("count") or 0)
        mean = s.get("mean_ms", float("nan"))
        tok_s = s.get("throughput_tok_s", float("nan"))

        # single-line run summary
        print(f"[Profiler] run (bsz={m.get('bsz','?')} num decoding steps={steps}): mean={_fmt(mean)} ms, tok/s={_fmt(tok_s)}")

        # no breakdown case
        if not bm and not bb:
            print("[Profiler] (no bucket breakdown enabled)")
            return

        # shares: prefer mean_ms as denominator; fallback to domain sum
        denom = float(mean) if isinstance(mean, (int, float)) and math.isfinite(float(mean)) and float(mean) > 0 else None

        def _top_brief(d: Dict[str, float], label: str):
            if not d:
                return
            total = denom if denom is not None else (sum(float(v) for v in d.values()) or 1.0)
            top = sorted(d.items(), key=lambda kv: kv[1], reverse=True)[:5]
            brief = ", ".join([f"{k}:{float(v):.2f}ms({(float(v)/total*100.0):.0f}%)" for k, v in top])
            print(f"[Profiler] {label} top: {brief}")

        # print model/backend top
        _top_brief(bm, "model")
        _top_brief(bb, "backend")