import argparse
import sys
import contextlib
from pathlib import Path

import torch
import matplotlib.pyplot as plt
import numpy as np


ROOT = Path(__file__).parent.parent
FLA_ROOT = ROOT / "flash-linear-attention"
if str(FLA_ROOT) not in sys.path:
    sys.path.insert(0, str(FLA_ROOT))

from fla.modules.rotary import SelectiveRoPE, SelectiveRoPEFast  # noqa: E402

# Import OGSelectiveRoPE from original_selective_rope.py
sys.path.insert(0, str(ROOT / "tests"))
from original_selective_rope import OGSelectiveRoPE  # noqa: E402


def measure_runtime(module, B, T, H, D, iters, device):
    """Measure runtime in seconds for fwd+bwd passes."""
    torch.manual_seed(0)
    module = module.to(device=device)

    # Apply torch.compile if requested
    module = torch.compile(module)
    # Keep master params in float32; use autocast for compute
    q = torch.randn(B, T, H, D, device=device, dtype=torch.bfloat16, requires_grad=True)
    k = torch.randn(B, T, H, D, device=device, dtype=torch.bfloat16, requires_grad=True)

    with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
        # Warmup
        for _ in range(10):
            q.grad = k.grad = None
            out = module(q, k)
            loss = out[0].sum() + out[1].sum()
            loss.backward()

        if device.type == "cuda":
            torch.cuda.synchronize()
            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)
            start.record()
        else:
            import time

            start_time = time.perf_counter()

        for _ in range(iters):
            q.grad = k.grad = None
            out = module(q, k)
            loss = out[0].sum() + out[1].sum()
            loss.backward()

        if device.type == "cuda":
            end.record()
            torch.cuda.synchronize()
            elapsed_ms = start.elapsed_time(end)
            elapsed_s = elapsed_ms / 1000.0
        else:
            import time

            elapsed_s = time.perf_counter() - start_time

    # Return average time per iteration
    return elapsed_s / iters


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--B", type=int, default=4, help="Batch size")
    parser.add_argument("--H", type=int, default=8, help="Number of heads")
    parser.add_argument("--D", type=int, default=128, help="Head dimension")
    parser.add_argument(
        "--iters", type=int, default=20, help="Number of iterations per benchmark"
    )
    parser.add_argument(
        "--output",
        type=str,
        default="selective_rope_benchmark.png",
        help="Output plot filename",
    )
    args = parser.parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    dtype = torch.bfloat16

    print(f"Device: {device}, dtype: {dtype}")
    print(f"Config: B={args.B}, H={args.H}, D={args.D}, iters={args.iters}")

    # Sequence lengths to benchmark
    seq_lengths = [512, 1024, 2048, 4096, 8192, 16384, 32768]

    # Results storage
    results_og = []
    results_py = []
    results_tri = []

    # Initialize modules
    d_state = args.D  # Using head_dim as d_state for OGSelectiveRoPE
    mod_og = OGSelectiveRoPE(
        num_heads=args.H,
        d_state=d_state,
        d_conv=4,
        device=device,
        dtype=dtype,
        skip_conv_cumsum=True,
    )
    mod_py = SelectiveRoPE(
        head_dim=args.D,
        num_heads=args.H,
        d_conv=4,
        device=device,
        skip_conv_cumsum=True,
    )
    mod_tri = SelectiveRoPEFast(
        head_dim=args.D,
        num_heads=args.H,
        d_conv=4,
        device=device,
        skip_conv_cumsum=True,
    )

    # Sync weights between implementations (where possible)
    # Copy phi_bias from SelectiveRoPE to OGSelectiveRoPE
    with torch.no_grad():
        mod_og.phi_bias.data.copy_(mod_py.phi_bias.data)
        mod_tri.phi_bias.data.copy_(mod_py.phi_bias.data)
        # Copy phi_proj weights
        if hasattr(mod_py.phi_proj, "weight"):
            mod_og.phi_proj.weight.data.copy_(mod_py.phi_proj.weight.data)
            mod_tri.phi_proj.weight.data.copy_(mod_py.phi_proj.weight.data)
        # Copy conv weights
        if not mod_py.skip_conv_cumsum:
            # Note: mod_py uses ShortConvolution while mod_og uses Conv1d
            # They might have different weight shapes, so we need to be careful
            pass  # Skip weight copying for conv layers due to different implementations

    print("Benchmarking across sequence lengths...")
    print("-" * 60)

    for T in seq_lengths:
        print(f"\nSequence length: {T}")

        # Adjust iterations for longer sequences to keep runtime reasonable
        iters = max(5, min(args.iters, int(args.iters * 1024 / T)))

        # Benchmark OGSelectiveRoPE
        runtime_og = measure_runtime(
            mod_og, args.B, T, args.H, args.D, iters, device
        )
        results_og.append(runtime_og * 1000)  # Convert to milliseconds
        print(f"  Original:      {runtime_og * 1000:.3f} ms/iter")

        # Benchmark SelectiveRoPE (Python)
        runtime_py = measure_runtime(
            mod_py, args.B, T, args.H, args.D, iters, device
        )
        results_py.append(runtime_py * 1000)  # Convert to milliseconds
        print(f"  Python:        {runtime_py * 1000:.3f} ms/iter")

        # Benchmark SelectiveRoPEFast (Triton)
        runtime_tri = measure_runtime(
            mod_tri, args.B, T, args.H, args.D, iters, device
        )
        results_tri.append(runtime_tri * 1000)  # Convert to milliseconds
        print(f"  Triton:        {runtime_tri * 1000:.3f} ms/iter")

        # Print speedups if both values are available
        speedup = results_py[-1] / results_tri[-1]
        print(f"  Speedup (Triton vs Python): {speedup:.2f}x")
        speedup = results_og[-1] / results_tri[-1]
        print(f"  Speedup (Triton vs Original): {speedup:.2f}x")

    # Create the plot
    print("\n" + "=" * 60)
    print("Creating plot...")

    # Set font to Times New Roman
    plt.rcParams["font.family"] = "serif"
    plt.rcParams["font.serif"] = ["Times New Roman"] + plt.rcParams["font.serif"]
    plt.rcParams["font.size"] = 12

    fig, ax = plt.subplots(figsize=(10, 6))

    # Plot each implementation
    valid_og = [
        (seq_lengths[i], results_og[i])
        for i in range(len(seq_lengths))
        if results_og[i] is not None
    ]
    valid_py = [
        (seq_lengths[i], results_py[i])
        for i in range(len(seq_lengths))
        if results_py[i] is not None
    ]
    valid_tri = [
        (seq_lengths[i], results_tri[i])
        for i in range(len(seq_lengths))
        if results_tri[i] is not None
    ]

    if valid_og:
        ax.plot(
            [x[0] for x in valid_og],
            [x[1] for x in valid_og],
            "o-",
            label="Original (OGSelectiveRoPE)",
            linewidth=2,
            markersize=8,
        )
    if valid_py:
        ax.plot(
            [x[0] for x in valid_py],
            [x[1] for x in valid_py],
            "s-",
            label="Python (SelectiveRoPE)",
            linewidth=2,
            markersize=8,
        )
    if valid_tri:
        ax.plot(
            [x[0] for x in valid_tri],
            [x[1] for x in valid_tri],
            "^-",
            label="Triton (SelectiveRoPEFast)",
            linewidth=2,
            markersize=8,
        )

    ax.set_xlabel("Sequence Length", fontsize=14)
    ax.set_ylabel("Runtime per Iteration (ms)", fontsize=14)
    title = f"Selective RoPE Benchmark (B={args.B}, H={args.H}, D={args.D})"
    ax.set_title(title, fontsize=16)

    # Use log scale for x-axis since sequence lengths grow exponentially
    ax.set_xscale("log", base=2)
    ax.set_yscale("log")

    # Set x-axis ticks to show actual sequence lengths
    ax.set_xticks(seq_lengths)
    ax.set_xticklabels([str(s) for s in seq_lengths])

    ax.grid(True, alpha=0.3, linestyle="--")
    ax.legend(loc="best", fontsize=11)

    plt.tight_layout()
    plt.savefig(args.output, dpi=150, bbox_inches="tight")
    print(f"Plot saved to {args.output}")

    # Also display the plot if in interactive mode
    plt.show()


if __name__ == "__main__":
    main()
