import argparse
import os
import sys
import time
from dataclasses import dataclass
from typing import Callable, List, Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F


def parse_int_list(csv_or_space: str) -> List[int]:
    parts = csv_or_space.replace(",", " ").split()
    return [int(p) for p in parts]


class SimpleFFN(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int, activation: str = "silu") -> None:
        super().__init__()
        self.w1 = nn.Linear(input_dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, input_dim, bias=False)
        if activation not in ("silu", "relu", "gelu"):
            raise ValueError("Unsupported activation: %s" % activation)
        self.activation = activation

    def forward(self, x: torch.Tensor) -> torch.Tensor:  # type: ignore[override]
        if self.activation == "silu":
            h = F.silu(self.w1(x))
        elif self.activation == "relu":
            h = F.relu(self.w1(x))
        else:
            h = F.gelu(self.w1(x))
        return self.w2(h)


class SwiGLUFFN(nn.Module):
    """Matches LLaMA-style FFN: w2( silu(w1(x)) * w3(x) ) but keeps exact hidden_dim"""
    def __init__(self, input_dim: int, hidden_dim: int) -> None:
        super().__init__()
        self.w1 = nn.Linear(input_dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, input_dim, bias=False)
        self.w3 = nn.Linear(input_dim, hidden_dim, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:  # type: ignore[override]
        return self.w2(F.silu(self.w1(x)) * self.w3(x))


def build_simple_ffn(input_dim: int, hidden_dim: int, device: torch.device, dtype: torch.dtype, gated: bool) -> nn.Module:
    mod: nn.Module = SwiGLUFFN(input_dim, hidden_dim) if gated else SimpleFFN(input_dim, hidden_dim, activation="silu")
    return mod.to(device=device, dtype=dtype)


def build_tiled_ffn(input_dim: int, hidden_dim: int, device: torch.device, dtype: torch.dtype, multiple_of: int = 256, ffn_dim_multiplier: Optional[float] = None) -> nn.Module:
    try:
        from torchtitan.models.llama3.model import TiledFFN as ModelTiledFFN  # type: ignore
    except Exception as e:
        # Try adding repo root to sys.path if not importable
        repo_root = "/home"
        if repo_root not in sys.path:
            sys.path.append(repo_root)
        try:
            from torchtitan.models.llama3.model import TiledFFN as ModelTiledFFN  # type: ignore
        except Exception as e2:
            raise RuntimeError("TiledFFN is not available. Ensure torchtitan.models.llama3.model is importable.") from e2

    module = ModelTiledFFN(input_dim, hidden_dim, multiple_of, ffn_dim_multiplier)
    return module.to(device=device, dtype=dtype)


@dataclass
class RunResult:
    seqlen: int
    fwd_peak_allocated_mb: float
    fwd_peak_reserved_mb: float
    fwd_ms: float
    bwd_peak_allocated_mb: float
    bwd_peak_reserved_mb: float
    bwd_ms: float


def format_mb(bytes_val: int) -> float:
    return float(bytes_val) / (1024.0 ** 2)


def measure_gpu_peak_memory(func, sync) -> Tuple[float, float, float]:
    # Returns (elapsed_ms, peak_allocated_mb, peak_reserved_mb)
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    sync()
    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)
    start_event.record()
    func()
    end_event.record()
    sync()
    elapsed_ms = float(start_event.elapsed_time(end_event))
    peak_allocated_mb = format_mb(torch.cuda.max_memory_allocated())
    peak_reserved_mb = format_mb(torch.cuda.max_memory_reserved())
    return elapsed_ms, peak_allocated_mb, peak_reserved_mb


def measure_forward_gpu(call_mod: Callable[[torch.Tensor], torch.Tensor], x: torch.Tensor, sync) -> Tuple[float, float, float]:
    def run():
        with torch.no_grad():
            _ = call_mod(x)
    return measure_gpu_peak_memory(run, sync)


def measure_backward_gpu(call_mod: Callable[[torch.Tensor], torch.Tensor], x: torch.Tensor, sync) -> Tuple[float, float, float]:
    # Build graph first
    out = call_mod(x)
    # Clone to avoid cudagraph output being overwritten by subsequent runs
    out = out.clone()
    loss = out.sum() / out.numel()
    # Reset peak stats to count backward-only absolute peak (including current baseline)
    torch.cuda.reset_peak_memory_stats()
    sync()
    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)
    start_event.record()
    loss.backward()
    end_event.record()
    sync()
    elapsed_ms = float(start_event.elapsed_time(end_event))
    peak_allocated_mb = format_mb(torch.cuda.max_memory_allocated())
    peak_reserved_mb = format_mb(torch.cuda.max_memory_reserved())
    return elapsed_ms, peak_allocated_mb, peak_reserved_mb


def maybe_compile(module: nn.Module, enable: bool, dynamic: bool) -> nn.Module:
    if not enable:
        return module
    if not hasattr(torch, "compile"):
        raise RuntimeError("torch.compile not available in this PyTorch build.")
    compiled = torch.compile(module, fullgraph=False, dynamic=dynamic, mode="max-autotune")
    return compiled


def main() -> None:
    parser = argparse.ArgumentParser(description="Profile FFN peak memory vs sequence length")
    parser.add_argument("--seqlens", type=str, default="131072 65536 32768 16384 8192 4096 2048 1024", help="Space- or comma-separated sequence lengths")
    parser.add_argument("--in-dim", type=int, default=4096)
    parser.add_argument("--hidden-dim", type=int, default=14336)
    parser.add_argument("--batch-size", type=int, default=1)
    parser.add_argument("--dtype", type=str, default=None, choices=[None, "bf16", "fp16", "fp32"], help="Computation dtype. Default: bf16 on CUDA else fp32")
    parser.add_argument("--device", type=str, default=None, choices=[None, "cuda", "cpu"], help="Target device. Default: cuda if available else cpu")
    parser.add_argument("--impl", type=str, default="simple", choices=["simple", "gated", "tiled"], help="FFN implementation: simple 2-layer, gated SwiGLU-like, or TiledFFN")
    parser.add_argument("--compile", action="store_true", help="Enable torch.compile")
    parser.add_argument("--dynamic", action="store_true", help="Enable dynamic shapes for torch.compile")
    parser.add_argument("--backward", action="store_true", help="Include backward pass in measurement")
    parser.add_argument("--warmup", type=int, default=1, help="Warmup iterations per seqlen (not measured)")
    parser.add_argument("--repeats", type=int, default=1, help="Measured repeats per seqlen; we report peak across repeats")
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--csv", type=str, default=None, help="Optional path to write CSV results")
    args = parser.parse_args()

    torch.manual_seed(args.seed)

    if args.device is None:
        device_str = "cuda" if torch.cuda.is_available() else "cpu"
    else:
        device_str = args.device
    device = torch.device(device_str)

    if args.dtype is None:
        if device_str == "cuda":
            dtype = torch.bfloat16
        else:
            dtype = torch.float32
    else:
        dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[args.dtype]

    if device_str == "cuda":
        sync = torch.cuda.synchronize
    else:
        sync = lambda: None  # no-op on CPU

    seqlens: List[int] = parse_int_list(args.seqlens)

    # Build module and optionally compile
    if args.impl == "tiled":
        mod = build_tiled_ffn(args.in_dim, args.hidden_dim, device, dtype)
    else:
        mod = build_simple_ffn(args.in_dim, args.hidden_dim, device, dtype, gated=(args.impl == "gated"))
    mod = maybe_compile(mod, args.compile, args.dynamic)
    mod.train(args.backward)

    def cudagraph_mark_step_if_needed():
        if args.compile and device_str == "cuda" and hasattr(torch, "compiler") and hasattr(torch.compiler, "cudagraph_mark_step_begin"):
            try:
                torch.compiler.cudagraph_mark_step_begin()  # type: ignore[attr-defined]
            except Exception:
                pass

    def call_mod(x: torch.Tensor) -> torch.Tensor:
        cudagraph_mark_step_if_needed()
        return mod(x)

    def run_once(x: torch.Tensor, do_backward: bool) -> None:
        if do_backward:
            out = call_mod(x)
            loss = out.sum() / out.numel()
            loss.backward()
        else:
            with torch.no_grad():
                _ = call_mod(x)

    results: List[RunResult] = []

    # Warmup compile once with the largest expected shape to avoid multiple recompiles if dynamic is False
    if args.compile and device_str == "cuda" and len(seqlens) > 0:
        max_len = max(seqlens)
        warmup_x = torch.randn(args.batch_size, max_len, args.in_dim, device=device, dtype=dtype, requires_grad=args.backward)
        for _ in range(max(1, args.warmup)):
            run_once(warmup_x, args.backward)
            if device_str == "cuda":
                sync()
        if args.backward:
            mod.zero_grad(set_to_none=True)

    # Iterate over requested sequence lengths
    for L in seqlens:
        x = torch.randn(args.batch_size, L, args.in_dim, device=device, dtype=dtype, requires_grad=args.backward)

        # Per-length warmup
        for _ in range(max(0, args.warmup)):
            run_once(x, args.backward)
            if device_str == "cuda":
                sync()
        if args.backward:
            mod.zero_grad(set_to_none=True)

        fwd_peak_allocated_mb = 0.0
        fwd_peak_reserved_mb = 0.0
        fwd_ms = 0.0
        bwd_peak_allocated_mb = 0.0
        bwd_peak_reserved_mb = 0.0
        bwd_ms = 0.0

        if device_str == "cuda":
            # Measure forward peak
            for _ in range(max(1, args.repeats)):
                e_ms, p_alloc, p_rsrv = measure_forward_gpu(call_mod, x, sync)
                fwd_ms = max(fwd_ms, e_ms)
                fwd_peak_allocated_mb = max(fwd_peak_allocated_mb, p_alloc)
                fwd_peak_reserved_mb = max(fwd_peak_reserved_mb, p_rsrv)

            # Measure backward peak if requested
            if args.backward:
                x_bwd = torch.randn(args.batch_size, L, args.in_dim, device=device, dtype=dtype, requires_grad=True)
                for _ in range(max(1, args.repeats)):
                    e_ms, p_alloc, p_rsrv = measure_backward_gpu(call_mod, x_bwd, sync)
                    bwd_ms = max(bwd_ms, e_ms)
                    bwd_peak_allocated_mb = max(bwd_peak_allocated_mb, p_alloc)
                    bwd_peak_reserved_mb = max(bwd_peak_reserved_mb, p_rsrv)
                mod.zero_grad(set_to_none=True)
        else:
            # CPU fallback: approximate via process RSS delta (peak not easily resettable per-iter)
            try:
                import psutil  # type: ignore
                proc = psutil.Process(os.getpid())
                # Forward measurement
                rss_before = proc.memory_info().rss
                t0 = time.perf_counter()
                _ = call_mod(x)
                t1 = time.perf_counter()
                rss_after = proc.memory_info().rss
                fwd_ms = max(fwd_ms, (t1 - t0) * 1000.0)
                fwd_peak_allocated_mb = max(fwd_peak_allocated_mb, max(0.0, (rss_after - rss_before) / (1024.0 ** 2)))
                fwd_peak_reserved_mb = 0.0

                if args.backward:
                    # Backward measurement: run fresh forward, then time/backward delta
                    x_bwd = torch.randn(args.batch_size, L, args.in_dim, device=device, dtype=dtype, requires_grad=True)
                    out_bwd = call_mod(x_bwd)
                    rss_before = proc.memory_info().rss
                    t0 = time.perf_counter()
                    (out_bwd.sum() / out_bwd.numel()).backward()
                    t1 = time.perf_counter()
                    rss_after = proc.memory_info().rss
                    bwd_ms = max(bwd_ms, (t1 - t0) * 1000.0)
                    bwd_peak_allocated_mb = max(bwd_peak_allocated_mb, max(0.0, (rss_after - rss_before) / (1024.0 ** 2)))
                    bwd_peak_reserved_mb = 0.0
                    mod.zero_grad(set_to_none=True)
            except Exception:
                # Fallback timings without memory deltas
                t0 = time.perf_counter(); _ = call_mod(x); t1 = time.perf_counter()
                fwd_ms = max(fwd_ms, (t1 - t0) * 1000.0)
                if args.backward:
                    x_bwd = torch.randn(args.batch_size, L, args.in_dim, device=device, dtype=dtype, requires_grad=True)
                    out_bwd = call_mod(x_bwd)
                    t0 = time.perf_counter(); (out_bwd.sum() / out_bwd.numel()).backward(); t1 = time.perf_counter()
                    bwd_ms = max(bwd_ms, (t1 - t0) * 1000.0)

        # Clear grads and inputs between sizes to avoid accumulation
        if args.backward:
            mod.zero_grad(set_to_none=True)
        del x
        if device_str == "cuda":
            torch.cuda.empty_cache()
            sync()

        results.append(
            RunResult(
                L,
                fwd_peak_allocated_mb,
                fwd_peak_reserved_mb,
                fwd_ms,
                bwd_peak_allocated_mb,
                bwd_peak_reserved_mb,
                bwd_ms,
            )
        )

    # Print results
    impl_eff = args.impl
    eff_hidden = None
    try:
        if hasattr(mod, "w1") and isinstance(mod.w1, nn.Linear):
            eff_hidden = mod.w1.out_features
    except Exception:
        pass
    header = (
        f"device={device_str} dtype={str(dtype).replace('torch.', '')} in_dim={args.in_dim} hidden_dim={args.hidden_dim}"
        f" eff_hidden_dim={eff_hidden} batch_size={args.batch_size} impl={impl_eff} compile={args.compile} dynamic={args.dynamic} backward={args.backward}"
    )
    print(header)
    print("seqlen, fwd_peak_allocated_MB, fwd_peak_reserved_MB, fwd_ms, bwd_peak_allocated_MB, bwd_peak_reserved_MB, bwd_ms")
    for r in results:
        print(
            f"{r.seqlen}, "
            f"{r.fwd_peak_allocated_mb:.2f}, {r.fwd_peak_reserved_mb:.2f}, {r.fwd_ms:.3f}, "
            f"{r.bwd_peak_allocated_mb:.2f}, {r.bwd_peak_reserved_mb:.2f}, {r.bwd_ms:.3f}"
        )

    # Optional CSV
    if args.csv:
        import csv
        with open(args.csv, "w", newline="") as f:
            writer = csv.writer(f)
            writer.writerow(["device", "dtype", "in_dim", "hidden_dim", "eff_hidden_dim", "batch_size", "impl", "compile", "dynamic", "backward"]) 
            writer.writerow([device_str, str(dtype).replace('torch.', ''), args.in_dim, args.hidden_dim, eff_hidden, args.batch_size, impl_eff, args.compile, args.dynamic, args.backward])
            writer.writerow([])
            writer.writerow(["seqlen", "fwd_peak_allocated_MB", "fwd_peak_reserved_MB", "fwd_ms", "bwd_peak_allocated_MB", "bwd_peak_reserved_MB", "bwd_ms"])
            for r in results:
                writer.writerow([
                    r.seqlen,
                    f"{r.fwd_peak_allocated_mb:.2f}",
                    f"{r.fwd_peak_reserved_mb:.2f}",
                    f"{r.fwd_ms:.3f}",
                    f"{r.bwd_peak_allocated_mb:.2f}",
                    f"{r.bwd_peak_reserved_mb:.2f}",
                    f"{r.bwd_ms:.3f}",
                ])


if __name__ == "__main__":
    main()


