from __future__ import annotations
import os, csv, json, math, time, functools
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple
from contextlib import contextmanager
from datetime import datetime

import torch
import torch.nn as nn
import torch.distributed as dist

_ACTIVE_DECODE_PROFILER = None

def register_active_decode_profiler(prof):
    global _ACTIVE_DECODE_PROFILER
    _ACTIVE_DECODE_PROFILER = prof

@contextmanager
def attention_compute_timer():
    prof = _ACTIVE_DECODE_PROFILER
    if (prof is None) or prof.disabled or (not prof._active_measure) or (not prof.cfg.detailed_profiling):
        yield; return
    s = torch.cuda.Event(enable_timing=True); e = torch.cuda.Event(enable_timing=True)
    s.record()
    try:
        yield
    finally:
        e.record()
        prof._iter_events.append((s, e, "Attn.compute"))

@contextmanager
def rope_compute_timer():
    prof = _ACTIVE_DECODE_PROFILER
    if (prof is None) or prof.disabled or (not prof._active_measure) or (not prof.cfg.detailed_profiling):
        yield; return
    s = torch.cuda.Event(enable_timing=True); e = torch.cuda.Event(enable_timing=True)
    s.record()
    try:
        yield
    finally:
        e.record()
        prof._iter_events.append((s, e, "RoPE"))

def _dist_ready() -> bool:
    return dist.is_available() and dist.is_initialized()

def _rank_world() -> tuple[int, int]:
    if _dist_ready(): return dist.get_rank(), dist.get_world_size()
    return 0, 1

def _mkdir(p: str): os.makedirs(p, exist_ok=True)

def _pct(sorted_vals: List[float], q: float) -> float:
    if not sorted_vals: return float("nan")
    n = len(sorted_vals); k = max(0, min(n-1, math.ceil(q/100.0*n)-1))
    return sorted_vals[k] if n>1 else sorted_vals[0]

def _fmt(x: Any) -> Any:
    try: return f"{float(x):.3f}"
    except: return x if x is not None else ""

def _sort_key(run: Dict[str, Any]) -> Tuple[int, int, int]:
    m = run["meta"]
    return int(m.get("bsz", 0)), int(m.get("seqlen", 0)), int(m.get("declen", 0))

@dataclass
class ProfilerConfig:
    output_dir: str = "profiler_out"
    collect_on_rank0_only: bool = True
    strict_sync: bool = True
    detailed_profiling: bool = False
    print_per_run: bool = True

class Profiler:
    def __init__(self, cfg: Optional[ProfilerConfig] = None):
        self.cfg = cfg or ProfilerConfig()
        _mkdir(self.cfg.output_dir)
        self.rank, self.world = _rank_world()
        self.disabled = (self.cfg.collect_on_rank0_only and self.rank != 0)

        self._active_measure: bool = False
        self._current_meta: Dict[str, Any] = {}
        self._iters_elapsed_ms: List[float] = []
        self._iter_events: List[Tuple[torch.cuda.Event, torch.cuda.Event, str]] = []

        self._accum: Dict[str, float] = {
            "embed": 0.0,
            "Attn.qkv_proj": 0.0,
            "Attn.compute": 0.0,
            "Attn.out_proj": 0.0,
            "mlp.gate_up_proj": 0.0,
            "mlp.down_proj": 0.0,
            "Norm": 0.0,
            "RoPE": 0.0,
            "lm_head": 0.0,
            "communication": 0.0,
        }
        self._runs: List[Dict[str, Any]] = []

    def attach_model(self, model: nn.Module) -> None:
        if self.disabled or not self.cfg.detailed_profiling:
            if self.rank == 0 and not self.disabled:
                print("[Profiler] model attach skipped (detailed_profiling=False).")
            return

        for name, m in model.named_modules():
            lname = name.lower()
            cls = m.__class__.__name__.lower()

            if lname.endswith("embed_tokens") or "embedding" in cls:
                self._wrap_module_forward(m, "embed");         continue
            if lname.endswith("lm_head") or "lm_head" in lname:
                self._wrap_module_forward(m, "lm_head");       continue
            if ("o_proj" in lname) or ("out_proj" in lname):
                self._wrap_module_forward(m, "Attn.out_proj"); continue
            if "down_proj" in lname:
                self._wrap_module_forward(m, "mlp.down_proj"); continue
            if any(s in lname for s in ["q_proj", "k_proj", "v_proj"]):
                self._wrap_module_forward(m, "Attn.qkv_proj"); continue
            if ("rmsnorm" in cls) or isinstance(m, nn.LayerNorm):
                self._wrap_module_forward(m, "Norm");          continue

            if "wqkv" in lname:
                self._wrap_module_forward(m, "Attn.qkv_proj"); continue
            if "wo" in lname:
                self._wrap_module_forward(m, "Attn.out_proj"); continue
            if "w13" in lname:
                self._wrap_module_forward(m, "mlp.gate_up_proj"); continue
            if "w2" in lname:
                self._wrap_module_forward(m, "mlp.down_proj"); continue
            if "output" in lname:
                self._wrap_module_forward(m, "lm_head"); continue

        for mod in model.modules():
            if hasattr(mod, "_fused_qkv_proj") and callable(getattr(mod, "_fused_qkv_proj")):
                self._wrap_bound_method(mod, "_fused_qkv_proj", "Attn.qkv_proj")
            for meth in ("_fused_gate_up_silu_mul", "_fused_gate_up_proj"):
                if hasattr(mod, meth) and callable(getattr(mod, meth)):
                    self._wrap_bound_method(mod, meth, "mlp.gate_up_proj")

        self._patch_communication_ops()
        if self.rank == 0:
            print("[Profiler] model attached (detailed).")

    def begin_run(self, *, bsz: int, declen: int, seqlen: int, label: str = "decode") -> None:
        if self.disabled: return
        self._current_meta = {
            "label": label, "bsz": int(bsz), "declen": int(declen), "seqlen": int(seqlen),
            "rank": self.rank, "world": self.world,
            "started_at": datetime.now().isoformat(timespec="seconds"),
        }
        self._iters_elapsed_ms.clear()
        self._iter_events.clear()
        for k in self._accum: self._accum[k] = 0.0

    @contextmanager
    def time_decode(self):
        if self.disabled:
            yield; return
        s = torch.cuda.Event(enable_timing=True); e = torch.cuda.Event(enable_timing=True)
        torch.cuda.synchronize()  # align to a clean start for the iteration
        self._active_measure = True
        s.record()
        try:
            yield
        finally:
            e.record()
            if self.cfg.strict_sync:
                torch.cuda.synchronize()  # ensure all streams complete if needed
            e.synchronize()
            self._active_measure = False

            self._iters_elapsed_ms.append(float(s.elapsed_time(e)))

            if self.cfg.detailed_profiling:
                for s_ev, e_ev, bucket in self._iter_events:
                    dt = float(s_ev.elapsed_time(e_ev))
                    self._accum[bucket] = self._accum.get(bucket, 0.0) + dt
            self._iter_events.clear()

    def end_run(self) -> None:
        if self.disabled: return
        vals = sorted(self._iters_elapsed_ms); n = len(vals)
        stats = {"count": n}
        mean = sum(vals)/n if n>0 else 0.0
        if n > 0:
            stats.update({
                "mean_ms": mean, "min_ms": vals[0], "max_ms": vals[-1],
                "p50_ms": _pct(vals,50), "p90_ms": _pct(vals,90), "p99_ms": _pct(vals,99),
                "throughput_tok_s": (self._current_meta["bsz"]*self._current_meta["declen"])/(mean/1000.0) if mean>0 else float("nan"),
            })

        if self.cfg.detailed_profiling:
            total_sum = n * mean
            known_sum = sum(self._accum[k] for k in [
                "embed","Attn.qkv_proj","Attn.compute","Attn.out_proj",
                "mlp.gate_up_proj","mlp.down_proj","Norm","RoPE","lm_head","communication"
            ])
            others_sum = max(0.0, total_sum - known_sum)
            buckets_avg = {k: (self._accum[k]/n if n>0 else 0.0) for k in self._accum}
            buckets_avg["Others"] = others_sum / n if n>0 else 0.0
            sum_breakdown = sum(buckets_avg[k] for k in [
                "embed","Attn.qkv_proj","Attn.compute","Attn.out_proj",
                "mlp.gate_up_proj","mlp.down_proj","Norm","RoPE","lm_head","communication","Others"
            ])
            stats["recon_error_ms"] = mean - sum_breakdown
        else:
            buckets_avg = {k: 0.0 for k in self._accum}
            buckets_avg["Others"] = mean
            stats["recon_error_ms"] = 0.0

        pack = {"meta": dict(self._current_meta), "stats": stats, "buckets_avg_ms": buckets_avg}
        self._runs.append(pack)

        if self.cfg.print_per_run:
            self._print_run_summary(pack)

    def save_all(self) -> None:
        if self.disabled or not self._runs:
            if self.rank == 0: print("[Profiler] No results to save."); return
        out = self.cfg.output_dir

        runs_sorted = sorted(self._runs, key=_sort_key)

        jpath = os.path.join(out, "summary.json")
        with open(jpath, "w") as f: json.dump(runs_sorted, f, indent=2)

        cpath = os.path.join(out, "summary.csv")
        with open(cpath, "w", newline="") as f:
            w = csv.writer(f)
            w.writerow(["bsz","seqlen","declen","iters","mean_ms","p50_ms","p90_ms","p99_ms","tok_per_s","recon_error_ms",
                        "embed","Attn.qkv_proj","Attn.compute","Attn.out_proj",
                        "mlp.gate_up_proj","mlp.down_proj","Norm","RoPE","lm_head","communication","Others"])
            for r in runs_sorted:
                m, s, b = r["meta"], r["stats"], r["buckets_avg_ms"]
                w.writerow([m["bsz"], m["seqlen"], m["declen"], s.get("count",0),
                            _fmt(s.get("mean_ms")),_fmt(s.get("p50_ms")),_fmt(s.get("p90_ms")),_fmt(s.get("p99_ms")),
                            _fmt(s.get("throughput_tok_s")),_fmt(s.get("recon_error_ms")),
                            _fmt(b.get("embed")),_fmt(b.get("Attn.qkv_proj")),_fmt(b.get("Attn.compute")),_fmt(b.get("Attn.out_proj")),
                            _fmt(b.get("mlp.gate_up_proj")),_fmt(b.get("mlp.down_proj")),_fmt(b.get("Norm")),_fmt(b.get("RoPE")),
                            _fmt(b.get("lm_head")),_fmt(b.get("communication")),_fmt(b.get("Others"))])

        mpath = os.path.join(out, "report.md")
        with open(mpath, "w") as f:
            f.write("# Decode Profiling Report (avg ms/iter)\n\n")
            f.write(f"- Generated at: {datetime.now().isoformat(timespec='seconds')}\n\n")
            for r in runs_sorted:
                m, s, b = r["meta"], r["stats"], r["buckets_avg_ms"]
                f.write(f"## bsz={m['bsz']}  seqlen={m['seqlen']}  declen={m['declen']}\n\n")
                f.write(f"- iters: **{s.get('count',0)}**, mean: **{_fmt(s.get('mean_ms'))} ms**, "
                        f"p50: {_fmt(s.get('p50_ms'))}, p90: {_fmt(s.get('p90_ms'))}, "
                        f"p99: {_fmt(s.get('p99_ms'))}, tok/s: **{_fmt(s.get('throughput_tok_s'))}**, "
                        f"recon_error: **{_fmt(s.get('recon_error_ms'))} ms**\n\n")
                if self.cfg.detailed_profiling:
                    f.write("| bucket | avg ms/iter | share |\n|--:|--:|--:|\n")
                    total_avg = float(s.get("mean_ms") or 0.0)
                    for key in ["embed","Attn.qkv_proj","Attn.compute","Attn.out_proj",
                                "mlp.gate_up_proj","mlp.down_proj","Norm","RoPE","lm_head","communication","Others"]:
                        v = float(b.get(key, 0.0)); pct = (v/total_avg*100.0) if total_avg>0 else 0.0
                        f.write(f"| `{key}` | {v:.3f} | {pct:.1f}% |\n")
                    f.write("\n")

        if self.rank == 0:
            print(f"[Profiler] Saved:\n  {jpath}\n  {cpath}\n  {mpath}")
            print("[Profiler] Sorted overview (bsz, seqlen, declen):")
            for r in runs_sorted:
                m, s = r["meta"], r["stats"]
                print(f"  bsz={m['bsz']:>3}  seqlen={m['seqlen']:>6}  declen={m['declen']:>4}  "
                      f"mean={_fmt(s.get('mean_ms'))} ms  p90={_fmt(s.get('p90_ms'))} ms  p99={_fmt(s.get('p99_ms'))} ms")

    def _wrap_module_forward(self, module: nn.Module, bucket: str):
        if not hasattr(module, "forward") or getattr(module, "__prof_wrapped_forward__", False):
            return
        orig = module.forward
        def wrapped(*args, **kwargs):
            if not self._active_measure:
                return orig(*args, **kwargs)
            s = torch.cuda.Event(enable_timing=True); e = torch.cuda.Event(enable_timing=True)
            s.record()
            out = orig(*args, **kwargs)
            e.record()
            if self.cfg.detailed_profiling:
                self._iter_events.append((s, e, bucket))
            return out
        module.forward = wrapped  # type: ignore
        module.__prof_wrapped_forward__ = True

    def _wrap_bound_method(self, obj: Any, method_name: str, bucket: str):
        orig = getattr(obj, method_name, None)
        if not callable(orig) or getattr(orig, "__prof_wrapped__", False):
            return
        @functools.wraps(orig)
        def wrapped(*args, **kwargs):
            if not self._active_measure:
                return orig(*args, **kwargs)
            s = torch.cuda.Event(enable_timing=True); e = torch.cuda.Event(enable_timing=True)
            s.record()
            out = orig(*args, **kwargs)
            e.record()
            if self.cfg.detailed_profiling:
                self._iter_events.append((s, e, bucket))
            return out
        wrapped.__prof_wrapped__ = True
        setattr(obj, method_name, wrapped)

    def _patch_communication_ops(self):
        if not self.cfg.detailed_profiling:
            return
        for name in ["all_reduce","all_gather","reduce_scatter","reduce_scatter_tensor",
                     "all_to_all","all_to_all_single","broadcast"]:
            if hasattr(dist, name) and callable(getattr(dist, name)):
                self._wrap_comm_fn(name)

    def _wrap_comm_fn(self, name: str):
        orig = getattr(dist, name)
        def wrapped(*args, **kwargs):
            if not self._active_measure:
                return orig(*args, **kwargs)
            torch.cuda.synchronize()
            t0 = time.perf_counter()
            out = orig(*args, **kwargs)
            torch.cuda.synchronize()
            self._accum["communication"] += (time.perf_counter() - t0) * 1000.0
            return out
        setattr(dist, name, wrapped)

    def _print_run_summary(self, pack: Dict[str, Any]) -> None:
        if self.rank != 0: return
        m, s, b = pack["meta"], pack["stats"], pack["buckets_avg_ms"]
        title = f"[Decode] bsz={m['bsz']}  seqlen={m['seqlen']}  declen={m['declen']}"
        line = "-" * max(20, len(title))
        print(line)
        print(title)
        print(f"iters={s.get('count',0)} | mean={_fmt(s.get('mean_ms'))} ms | p50={_fmt(s.get('p50_ms'))} | "
              f"p90={_fmt(s.get('p90_ms'))} | p99={_fmt(s.get('p99_ms'))} | tok/s={_fmt(s.get('throughput_tok_s'))} "
              f"| recon_error={_fmt(s.get('recon_error_ms'))} ms")
        if self.cfg.detailed_profiling:
            total_avg = float(s.get("mean_ms") or 0.0)
            print("breakdown (avg per iter):")
            for key in ["embed","Attn.qkv_proj","Attn.compute","Attn.out_proj",
                        "mlp.gate_up_proj","mlp.down_proj","Norm","RoPE","lm_head","communication","Others"]:
                v = float(b.get(key, 0.0)); pct = (v/total_avg*100.0) if total_avg>0 else 0.0
                print(f"  {key.ljust(22)} {v:8.3f} ms   {pct:6.1f}%")
        print(line)
