import argparse
import os
import sys
sys.path.append("/home")
import time
from dataclasses import dataclass
from typing import Iterable, List, Optional, Tuple, Callable

import torch
import torch.nn as nn


def parse_int_list(csv_or_space: str) -> List[int]:
    parts = csv_or_space.replace(",", " ").split()
    return [int(p) for p in parts]


class SimpleRMSNorm(nn.Module):
    def __init__(self, hidden_dim: int, eps: float = 1e-6) -> None:
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_dim))
        self.eps = eps

    def forward(self, x: torch.Tensor) -> torch.Tensor:  # type: ignore[override]
        variance = x.pow(2).mean(dim=-1, keepdim=True)
        x_normed = x * torch.rsqrt(variance + self.eps)
        return x_normed * self.weight


def build_rmsnorm(hidden_dim: int, eps: float, device: torch.device, dtype: torch.dtype) -> nn.Module:
    # Prefer native PyTorch RMSNorm if available
    rmsnorm_mod: Optional[nn.Module] = None
    try:
        # PyTorch >= 2.4
        RMSNorm = getattr(nn, "RMSNorm")
        rmsnorm_mod = RMSNorm(hidden_dim, eps=eps, elementwise_affine=True)
    except AttributeError:
        rmsnorm_mod = SimpleRMSNorm(hidden_dim, eps=eps)

    return rmsnorm_mod.to(device=device, dtype=dtype)


def build_tiled_rmsnorm(hidden_dim: int, eps: float, device: torch.device, dtype: torch.dtype) -> nn.Module:
    try:
        from torchtitan.models.llama3.model import TiledRMSNorm as ModelTiledRMSNorm  # type: ignore
    except Exception as e:
        raise RuntimeError("TiledRMSNorm is not available. Ensure torchtitan.models.llama3.model is importable.") from e
    module = ModelTiledRMSNorm(hidden_dim, eps=eps)
    return module.to(device=device, dtype=dtype)


def maybe_compile(module: nn.Module, enable: bool, dynamic: bool) -> nn.Module:
    if not enable:
        return module
    if not hasattr(torch, "compile"):
        raise RuntimeError("torch.compile not available in this PyTorch build.")
    # Use a mode that prioritizes performance; allow dynamic shapes if requested
    compiled = torch.compile(module, fullgraph=False, dynamic=dynamic, mode="max-autotune")
    return compiled


@dataclass
class RunResult:
    seqlen: int
    fwd_peak_allocated_mb: float
    fwd_peak_reserved_mb: float
    fwd_ms: float
    bwd_peak_allocated_mb: float
    bwd_peak_reserved_mb: float
    bwd_ms: float


def format_mb(bytes_val: int) -> float:
    return float(bytes_val) / (1024.0 ** 2)


def measure_gpu_peak_memory(func, sync) -> Tuple[float, float, float]:
    # Returns (elapsed_ms, peak_allocated_mb, peak_reserved_mb)
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    sync()
    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)
    start_event.record()
    func()
    end_event.record()
    sync()
    elapsed_ms = float(start_event.elapsed_time(end_event))
    peak_allocated_mb = format_mb(torch.cuda.max_memory_allocated())
    peak_reserved_mb = format_mb(torch.cuda.max_memory_reserved())
    return elapsed_ms, peak_allocated_mb, peak_reserved_mb


def measure_forward_gpu(call_norm: Callable[[torch.Tensor], torch.Tensor], x: torch.Tensor, sync) -> Tuple[float, float, float]:
    def run():
        with torch.no_grad():
            _ = call_norm(x)
    return measure_gpu_peak_memory(run, sync)


def measure_backward_gpu(call_norm: Callable[[torch.Tensor], torch.Tensor], x: torch.Tensor, sync) -> Tuple[float, float, float]:
    # Build graph first
    out = call_norm(x)
    # Clone to avoid cudagraph output being overwritten by subsequent runs
    out = out.clone()
    loss = out.sum() / out.numel()
    # Reset peak stats to count backward-only absolute peak (including current baseline)
    torch.cuda.reset_peak_memory_stats()
    sync()
    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)
    start_event.record()
    loss.backward()
    end_event.record()
    sync()
    elapsed_ms = float(start_event.elapsed_time(end_event))
    peak_allocated_mb = format_mb(torch.cuda.max_memory_allocated())
    peak_reserved_mb = format_mb(torch.cuda.max_memory_reserved())
    return elapsed_ms, peak_allocated_mb, peak_reserved_mb


def main() -> None:
    parser = argparse.ArgumentParser(description="Profile RMSNorm peak memory vs sequence length")
    parser.add_argument("--seqlens", type=str, default="131072 65536 32768 16384 8192 4096 2048 1024", help="Space- or comma-separated sequence lengths")
    parser.add_argument("--hidden-dim", type=int, default=4096)
    parser.add_argument("--batch-size", type=int, default=1)
    parser.add_argument("--eps", type=float, default=1e-6)
    parser.add_argument("--dtype", type=str, default=None, choices=[None, "bf16", "fp16", "fp32"], help="Computation dtype. Default: bf16 on CUDA else fp32")
    parser.add_argument("--device", type=str, default=None, choices=[None, "cuda", "cpu"], help="Target device. Default: cuda if available else cpu")
    parser.add_argument("--norm", type=str, default="rmsnorm", choices=["rmsnorm", "tiled"], help="Norm implementation to profile")
    parser.add_argument("--compile", action="store_true", help="Enable torch.compile for RMSNorm/TiledRMSNorm")
    parser.add_argument("--dynamic", action="store_true", help="Enable dynamic shapes for torch.compile")
    parser.add_argument("--backward", action="store_true", help="Include backward pass in measurement")
    parser.add_argument("--warmup", type=int, default=1, help="Warmup iterations per seqlen (not measured)")
    parser.add_argument("--repeats", type=int, default=1, help="Measured repeats per seqlen; we report peak across repeats")
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--csv", type=str, default=None, help="Optional path to write CSV results")
    args = parser.parse_args()

    torch.manual_seed(args.seed)

    if args.device is None:
        device_str = "cuda" if torch.cuda.is_available() else "cpu"
    else:
        device_str = args.device
    device = torch.device(device_str)

    if args.dtype is None:
        if device_str == "cuda":
            dtype = torch.bfloat16
        else:
            dtype = torch.float32
    else:
        dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[args.dtype]

    if device_str == "cuda":
        sync = torch.cuda.synchronize
    else:
        sync = lambda: None  # no-op on CPU

    seqlens: List[int] = parse_int_list(args.seqlens)

    # Build module and optionally compile
    if args.norm == "tiled":
        norm = build_tiled_rmsnorm(args.hidden_dim, args.eps, device, dtype)
    else:
        norm = build_rmsnorm(args.hidden_dim, args.eps, device, dtype)
    norm = maybe_compile(norm, args.compile, args.dynamic)
    norm.train(args.backward)

    def cudagraph_mark_step_if_needed():
        if args.compile and device_str == "cuda" and hasattr(torch, "compiler") and hasattr(torch.compiler, "cudagraph_mark_step_begin"):
            try:
                torch.compiler.cudagraph_mark_step_begin()  # type: ignore[attr-defined]
            except Exception:
                pass

    def call_norm(x: torch.Tensor) -> torch.Tensor:
        cudagraph_mark_step_if_needed()
        return norm(x)

    def run_once(x: torch.Tensor, do_backward: bool) -> None:
        if do_backward:
            out = call_norm(x)
            loss = out.sum() / out.numel()
            loss.backward()
        else:
            with torch.no_grad():
                _ = call_norm(x)

    results: List[RunResult] = []

    # Warmup compile once with the largest expected shape to avoid multiple recompiles if dynamic is False
    if args.compile and device_str == "cuda" and len(seqlens) > 0:
        max_len = max(seqlens)
        warmup_x = torch.randn(args.batch_size, max_len, args.hidden_dim, device=device, dtype=dtype, requires_grad=args.backward)
        for _ in range(max(1, args.warmup)):
            run_once(warmup_x, args.backward)
            if device_str == "cuda":
                sync()
        if args.backward:
            # Clear grads from warmup
            norm.zero_grad(set_to_none=True)

    # Iterate over requested sequence lengths
    for L in seqlens:
        x = torch.randn(args.batch_size, L, args.hidden_dim, device=device, dtype=dtype, requires_grad=args.backward)

        # Per-length warmup
        for _ in range(max(0, args.warmup)):
            run_once(x, args.backward)
            if device_str == "cuda":
                sync()
        if args.backward:
            norm.zero_grad(set_to_none=True)

        fwd_peak_allocated_mb = 0.0
        fwd_peak_reserved_mb = 0.0
        fwd_ms = 0.0
        bwd_peak_allocated_mb = 0.0
        bwd_peak_reserved_mb = 0.0
        bwd_ms = 0.0

        if device_str == "cuda":
            # Measure forward peak
            for _ in range(max(1, args.repeats)):
                e_ms, p_alloc, p_rsrv = measure_forward_gpu(call_norm, x, sync)
                fwd_ms = max(fwd_ms, e_ms)
                fwd_peak_allocated_mb = max(fwd_peak_allocated_mb, p_alloc)
                fwd_peak_reserved_mb = max(fwd_peak_reserved_mb, p_rsrv)

            # Measure backward peak if requested
            if args.backward:
                # Recreate input to avoid reusing freed graph; ensure requires_grad=True
                x_bwd = torch.randn(args.batch_size, L, args.hidden_dim, device=device, dtype=dtype, requires_grad=True)
                for _ in range(max(1, args.repeats)):
                    e_ms, p_alloc, p_rsrv = measure_backward_gpu(call_norm, x_bwd, sync)
                    bwd_ms = max(bwd_ms, e_ms)
                    bwd_peak_allocated_mb = max(bwd_peak_allocated_mb, p_alloc)
                    bwd_peak_reserved_mb = max(bwd_peak_reserved_mb, p_rsrv)
                # Clear grads
                norm.zero_grad(set_to_none=True)
        else:
            # CPU fallback: approximate via process RSS delta (peak not easily resettable per-iter)
            try:
                import psutil  # type: ignore
                proc = psutil.Process(os.getpid())
                # Forward measurement (builds graph if backward is True)
                rss_before = proc.memory_info().rss
                t0 = time.perf_counter()
                _ = call_norm(x)
                t1 = time.perf_counter()
                rss_after = proc.memory_info().rss
                fwd_ms = max(fwd_ms, (t1 - t0) * 1000.0)
                fwd_peak_allocated_mb = max(fwd_peak_allocated_mb, max(0.0, (rss_after - rss_before) / (1024.0 ** 2)))
                fwd_peak_reserved_mb = 0.0

                if args.backward:
                    # Backward measurement: run fresh forward, then time/backward delta
                    x_bwd = torch.randn(args.batch_size, L, args.hidden_dim, device=device, dtype=dtype, requires_grad=True)
                    out_bwd = call_norm(x_bwd)
                    rss_before = proc.memory_info().rss
                    t0 = time.perf_counter()
                    (out_bwd.sum() / out_bwd.numel()).backward()
                    t1 = time.perf_counter()
                    rss_after = proc.memory_info().rss
                    bwd_ms = max(bwd_ms, (t1 - t0) * 1000.0)
                    bwd_peak_allocated_mb = max(bwd_peak_allocated_mb, max(0.0, (rss_after - rss_before) / (1024.0 ** 2)))
                    bwd_peak_reserved_mb = 0.0
                    norm.zero_grad(set_to_none=True)
            except Exception:
                # Fallback timings without memory deltas
                t0 = time.perf_counter(); _ = call_norm(x); t1 = time.perf_counter()
                fwd_ms = max(fwd_ms, (t1 - t0) * 1000.0)
                if args.backward:
                    x_bwd = torch.randn(args.batch_size, L, args.hidden_dim, device=device, dtype=dtype, requires_grad=True)
                    out_bwd = call_norm(x_bwd)
                    t0 = time.perf_counter(); (out_bwd.sum() / out_bwd.numel()).backward(); t1 = time.perf_counter()
                    bwd_ms = max(bwd_ms, (t1 - t0) * 1000.0)

        # Clear grads and inputs between sizes to avoid accumulation
        if args.backward:
            norm.zero_grad(set_to_none=True)
        del x
        if device_str == "cuda":
            torch.cuda.empty_cache()
            sync()

        results.append(
            RunResult(
                L,
                fwd_peak_allocated_mb,
                fwd_peak_reserved_mb,
                fwd_ms,
                bwd_peak_allocated_mb,
                bwd_peak_reserved_mb,
                bwd_ms,
            )
        )

    # Print results
    header = f"device={device_str} dtype={str(dtype).replace('torch.', '')} hidden_dim={args.hidden_dim} batch_size={args.batch_size} norm={args.norm} compile={args.compile} dynamic={args.dynamic} backward={args.backward}"
    print(header)
    print("seqlen, fwd_peak_allocated_MB, fwd_peak_reserved_MB, fwd_ms, bwd_peak_allocated_MB, bwd_peak_reserved_MB, bwd_ms")
    for r in results:
        print(
            f"{r.seqlen}, "
            f"{r.fwd_peak_allocated_mb:.2f}, {r.fwd_peak_reserved_mb:.2f}, {r.fwd_ms:.3f}, "
            f"{r.bwd_peak_allocated_mb:.2f}, {r.bwd_peak_reserved_mb:.2f}, {r.bwd_ms:.3f}"
        )

    # Optional CSV
    if args.csv:
        import csv
        with open(args.csv, "w", newline="") as f:
            writer = csv.writer(f)
            writer.writerow(["device", "dtype", "hidden_dim", "batch_size", "compile", "dynamic", "backward"]) 
            writer.writerow([device_str, str(dtype).replace('torch.', ''), args.hidden_dim, args.batch_size, args.compile, args.dynamic, args.backward])
            writer.writerow([])
            writer.writerow(["seqlen", "fwd_peak_allocated_MB", "fwd_peak_reserved_MB", "fwd_ms", "bwd_peak_allocated_MB", "bwd_peak_reserved_MB", "bwd_ms"])
            for r in results:
                writer.writerow([
                    r.seqlen,
                    f"{r.fwd_peak_allocated_mb:.2f}",
                    f"{r.fwd_peak_reserved_mb:.2f}",
                    f"{r.fwd_ms:.3f}",
                    f"{r.bwd_peak_allocated_mb:.2f}",
                    f"{r.bwd_peak_reserved_mb:.2f}",
                    f"{r.bwd_ms:.3f}",
                ])


if __name__ == "__main__":
    main()


