#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
Comprehensive Attention Benchmark Script

Compares performance between:
1. Your custom Triton implementation with learnable bias
2. PyTorch's native SDPA (Scaled Dot-Product Attention)
3. Flash Attention v2 implementation (reference from provided code)

This script measures forward and backward pass performance across various configurations.
"""

import argparse
import sys
import time
from dataclasses import dataclass
from typing import Dict, List, Tuple

import matplotlib.pyplot as plt

import torch
import torch.nn.functional as F

# Import your triton implementation
sys.path.append("/home/ubuntu/projects/torchtitan")
from torchtitan.models.llama3.model.reference_flash_attention import reference_attention
from torchtitan.models.llama3.model.triton_attention import triton_attention_with_bias

# Try to import flash attention
try:
    from flash_attn.flash_attn_interface import (
        flash_attn_qkvpacked_func as flash_attn_func,
    )

    HAS_FLASH = True
except ImportError:
    HAS_FLASH = False
    print(
        "Warning: Flash Attention not available. Install with: pip install flash-attn"
    )

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
torch.manual_seed(42)


@dataclass
class BenchmarkConfig:
    """Configuration for benchmark runs"""

    batch_size: int
    seq_len: int
    num_heads: int
    head_dim: int
    causal: bool = True
    dtype: torch.dtype = torch.float16
    device: str = DEVICE


class AttentionBenchmark:
    """Benchmark suite for attention implementations"""

    def __init__(self, warmup_steps: int = 10, benchmark_steps: int = 100):
        self.warmup_steps = warmup_steps
        self.benchmark_steps = benchmark_steps
        self.results = {}

    def generate_inputs(self, config: BenchmarkConfig) -> Tuple[torch.Tensor, ...]:
        """Generate random inputs for attention"""
        torch.manual_seed(42)  # For reproducible results

        q = torch.randn(
            config.batch_size,
            config.seq_len,
            config.num_heads,
            config.head_dim,
            dtype=config.dtype,
            device=config.device,
            requires_grad=True,
        )
        k = torch.randn(
            config.batch_size,
            config.seq_len,
            config.num_heads,
            config.head_dim,
            dtype=config.dtype,
            device=config.device,
            requires_grad=True,
        )
        v = torch.randn(
            config.batch_size,
            config.seq_len,
            config.num_heads,
            config.head_dim,
            dtype=config.dtype,
            device=config.device,
            requires_grad=True,
        )

        return q, k, v

    def benchmark_triton_custom(
        self, config: BenchmarkConfig, mode: str = "fwd"
    ) -> Dict:
        """Benchmark your custom triton implementation"""
        q, k, v = self.generate_inputs(config)
        sm_scale = 1.0 / (config.head_dim**0.5)

        # Optional: test with learnable bias
        bias_params = torch.randn(
            config.num_heads, dtype=torch.float32, device=config.device
        )

        def forward_fn():
            return triton_attention_with_bias(
                q, k, v, bias_params=bias_params, sm_scale=sm_scale
            )

        def backward_fn():
            out = forward_fn()
            grad_out = torch.randn_like(out)
            out.backward(grad_out, retain_graph=True)

        # Warmup
        for _ in range(self.warmup_steps):
            if mode == "fwd":
                _ = forward_fn()
            else:
                backward_fn()
            torch.cuda.synchronize()

        # Benchmark
        if config.device == "cuda":
            torch.cuda.synchronize()
            start_time = time.perf_counter()

            for _ in range(self.benchmark_steps):
                if mode == "fwd":
                    _ = forward_fn()
                else:
                    backward_fn()

            torch.cuda.synchronize()
            end_time = time.perf_counter()
        else:
            start_time = time.perf_counter()
            for _ in range(self.benchmark_steps):
                if mode == "fwd":
                    _ = forward_fn()
                else:
                    backward_fn()
            end_time = time.perf_counter()

        avg_time = (end_time - start_time) / self.benchmark_steps * 1000  # ms
        return {"time_ms": avg_time, "implementation": "triton_custom"}

    def benchmark_pytorch_sdpa(
        self, config: BenchmarkConfig, mode: str = "fwd"
    ) -> Dict:
        """Benchmark PyTorch's native SDPA"""
        q, k, v = self.generate_inputs(config)
        sm_scale = 1.0 / (config.head_dim**0.5)

        # Reshape for SDPA (expects batch, num_heads, seq_len, head_dim)
        q_sdpa = q.transpose(1, 2)  # (batch, num_heads, seq_len, head_dim)
        k_sdpa = k.transpose(1, 2)
        v_sdpa = v.transpose(1, 2)

        def forward_fn():
            return F.scaled_dot_product_attention(
                q_sdpa, k_sdpa, v_sdpa, is_causal=config.causal, scale=sm_scale
            )

        def backward_fn():
            out = forward_fn()
            grad_out = torch.randn_like(out)
            out.backward(grad_out, retain_graph=True)

        # Warmup
        for _ in range(self.warmup_steps):
            if mode == "fwd":
                _ = forward_fn()
            else:
                backward_fn()
            if config.device == "cuda":
                torch.cuda.synchronize()

        # Benchmark
        if config.device == "cuda":
            torch.cuda.synchronize()
            start_time = time.perf_counter()

            for _ in range(self.benchmark_steps):
                if mode == "fwd":
                    _ = forward_fn()
                else:
                    backward_fn()

            torch.cuda.synchronize()
            end_time = time.perf_counter()
        else:
            start_time = time.perf_counter()
            for _ in range(self.benchmark_steps):
                if mode == "fwd":
                    _ = forward_fn()
                else:
                    backward_fn()
            end_time = time.perf_counter()

        avg_time = (end_time - start_time) / self.benchmark_steps * 1000  # ms
        return {"time_ms": avg_time, "implementation": "pytorch_sdpa"}

    def benchmark_flash_attention(
        self, config: BenchmarkConfig, mode: str = "fwd"
    ) -> Dict:
        """Benchmark Flash Attention v2 implementation"""
        if not HAS_FLASH:
            return {
                "time_ms": float("inf"),
                "implementation": "flash_attn",
                "error": "Not available",
            }

        # Flash attention expects qkv packed: (batch, seq_len, 3, num_heads, head_dim)
        qkv = torch.randn(
            config.batch_size,
            config.seq_len,
            3,
            config.num_heads,
            config.head_dim,
            dtype=config.dtype,
            device=config.device,
            requires_grad=True,
        )

        def forward_fn():
            return flash_attn_func(qkv, causal=config.causal)

        def backward_fn():
            out = forward_fn()
            grad_out = torch.randn_like(out)
            out.backward(grad_out, retain_graph=True)

        # Warmup
        for _ in range(self.warmup_steps):
            if mode == "fwd":
                _ = forward_fn()
            else:
                backward_fn()
            if config.device == "cuda":
                torch.cuda.synchronize()

        # Benchmark
        if config.device == "cuda":
            torch.cuda.synchronize()
            start_time = time.perf_counter()

            for _ in range(self.benchmark_steps):
                if mode == "fwd":
                    _ = forward_fn()
                else:
                    backward_fn()

            torch.cuda.synchronize()
            end_time = time.perf_counter()
        else:
            start_time = time.perf_counter()
            for _ in range(self.benchmark_steps):
                if mode == "fwd":
                    _ = forward_fn()
                else:
                    backward_fn()
            end_time = time.perf_counter()

        avg_time = (end_time - start_time) / self.benchmark_steps * 1000  # ms
        return {"time_ms": avg_time, "implementation": "flash_attn"}

    def benchmark_reference_triton(
        self, config: BenchmarkConfig, mode: str = "fwd"
    ) -> Dict:
        """Benchmark the reference Flash Attention Triton implementation"""
        # Implement the reference triton attention from the provided code
        q, k, v = self.generate_inputs(config)
        sm_scale = 1.0 / (config.head_dim**0.5)

        # Transpose to match expected format (batch, num_heads, seq_len, head_dim)
        q_ref = q.transpose(1, 2)
        k_ref = k.transpose(1, 2)
        v_ref = v.transpose(1, 2)

        def forward_fn():
            return reference_attention(q_ref, k_ref, v_ref, config.causal, sm_scale)

        def backward_fn():
            out = forward_fn()
            grad_out = torch.randn_like(out)
            out.backward(grad_out, retain_graph=True)

        # Warmup
        for _ in range(self.warmup_steps):
            if mode == "fwd":
                _ = forward_fn()
            else:
                backward_fn()
            if config.device == "cuda":
                torch.cuda.synchronize()

        # Benchmark
        if config.device == "cuda":
            torch.cuda.synchronize()
            start_time = time.perf_counter()

            for _ in range(self.benchmark_steps):
                if mode == "fwd":
                    _ = forward_fn()
                else:
                    backward_fn()

            torch.cuda.synchronize()
            end_time = time.perf_counter()
        else:
            start_time = time.perf_counter()
            for _ in range(self.benchmark_steps):
                if mode == "fwd":
                    _ = forward_fn()
                else:
                    backward_fn()
            end_time = time.perf_counter()

        avg_time = (end_time - start_time) / self.benchmark_steps * 1000  # ms
        return {"time_ms": avg_time, "implementation": "reference_triton"}

    def calculate_flops(self, config: BenchmarkConfig, mode: str = "fwd") -> float:
        """Calculate theoretical FLOPs for attention"""
        B, S, H, D = (
            config.batch_size,
            config.seq_len,
            config.num_heads,
            config.head_dim,
        )

        # Forward pass FLOPs
        # Q @ K^T: B * H * S * S * D
        # Softmax: ~B * H * S * S (simplified)
        # P @ V: B * H * S * S * D
        flops_fwd = 2 * B * H * S * S * D + B * H * S * S

        if config.causal:
            flops_fwd *= 0.5  # Roughly half due to causal masking

        if mode == "bwd":
            flops_fwd *= 2.5  # Approximation for backward pass

        return flops_fwd

    def run_comprehensive_benchmark(self, configs: List[BenchmarkConfig]) -> Dict:
        """Run benchmark across all configurations and implementations"""
        results = {
            "configs": [],
            "triton_custom": [],
            "pytorch_sdpa": [],
            "flash_attn": [],
            "reference_triton": [],
            "flops": [],
        }

        for config in configs:
            print(
                f"Benchmarking: B={config.batch_size}, S={config.seq_len}, "
                f"H={config.num_heads}, D={config.head_dim}"
            )

            results["configs"].append(config)

            # Calculate theoretical FLOPs
            flops = self.calculate_flops(config, "fwd")
            results["flops"].append(flops)

            # Benchmark each implementation
            try:
                triton_result = self.benchmark_triton_custom(config, "fwd")
                results["triton_custom"].append(triton_result["time_ms"])
            except Exception as e:
                print(f"Triton custom failed: {e}")
                results["triton_custom"].append(float("inf"))

            try:
                pytorch_result = self.benchmark_pytorch_sdpa(config, "fwd")
                results["pytorch_sdpa"].append(pytorch_result["time_ms"])
            except Exception as e:
                print(f"PyTorch SDPA failed: {e}")
                results["pytorch_sdpa"].append(float("inf"))

            try:
                flash_result = self.benchmark_flash_attention(config, "fwd")
                results["flash_attn"].append(flash_result["time_ms"])
            except Exception as e:
                print(f"Flash Attention failed: {e}")
                results["flash_attn"].append(float("inf"))

            try:
                ref_result = self.benchmark_reference_triton(config, "fwd")
                results["reference_triton"].append(ref_result["time_ms"])
            except Exception as e:
                print(f"Reference Triton failed: {e}")
                results["reference_triton"].append(float("inf"))

        return results


# Reference implementation is now imported from reference_flash_attention.py


def create_benchmark_configs() -> List[BenchmarkConfig]:
    """Create a comprehensive set of benchmark configurations"""
    configs = []

    # Standard configurations
    batch_sizes = [1, 4, 8]
    seq_lens = [128, 512, 1024, 2048, 4096]
    head_dims = [64, 128]
    num_heads_options = [8, 16, 32]

    for batch_size in batch_sizes:
        for seq_len in seq_lens:
            for head_dim in head_dims:
                for num_heads in num_heads_options:
                    # Skip very large configurations to avoid OOM
                    if batch_size * seq_len * num_heads * head_dim > 2**28:
                        continue

                    configs.append(
                        BenchmarkConfig(
                            batch_size=batch_size,
                            seq_len=seq_len,
                            num_heads=num_heads,
                            head_dim=head_dim,
                            causal=True,
                        )
                    )

    return configs[:20]  # Limit to first 20 for reasonable runtime


def plot_results(results: Dict, save_path: str = "attention_benchmark_results.png"):
    """Create visualization of benchmark results"""
    configs = results["configs"]
    seq_lens = [config.seq_len for config in configs]

    # Calculate TFLOPS for each implementation
    triton_tflops = [
        flops / (time_ms * 1e-3) / 1e12 if time_ms != float("inf") else 0
        for flops, time_ms in zip(results["flops"], results["triton_custom"])
    ]
    pytorch_tflops = [
        flops / (time_ms * 1e-3) / 1e12 if time_ms != float("inf") else 0
        for flops, time_ms in zip(results["flops"], results["pytorch_sdpa"])
    ]
    flash_tflops = [
        flops / (time_ms * 1e-3) / 1e12 if time_ms != float("inf") else 0
        for flops, time_ms in zip(results["flops"], results["flash_attn"])
    ]
    ref_tflops = [
        flops / (time_ms * 1e-3) / 1e12 if time_ms != float("inf") else 0
        for flops, time_ms in zip(results["flops"], results["reference_triton"])
    ]

    plt.figure(figsize=(12, 8))
    plt.plot(
        seq_lens, triton_tflops, "ro-", label="Triton Custom (with bias)", linewidth=2
    )
    plt.plot(seq_lens, pytorch_tflops, "bo-", label="PyTorch SDPA", linewidth=2)
    if HAS_FLASH:
        plt.plot(seq_lens, flash_tflops, "go-", label="Flash Attention v2", linewidth=2)
    plt.plot(seq_lens, ref_tflops, "mo-", label="Reference Triton", linewidth=2)

    plt.xlabel("Sequence Length")
    plt.ylabel("Performance (TFLOPS)")
    plt.title("Attention Implementation Performance Comparison")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.xscale("log", base=2)
    plt.yscale("log")

    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches="tight")
    print(f"Results saved to {save_path}")


def print_summary_table(results: Dict):
    """Print a summary table of results"""
    print("\n" + "=" * 80)
    print("PERFORMANCE SUMMARY")
    print("=" * 80)
    print(
        f"{'Config':<25} {'Triton Custom':<15} {'PyTorch SDPA':<15} {'Flash Attn':<15} {'Ref Triton':<15}"
    )
    print("-" * 80)

    for i, config in enumerate(results["configs"]):
        config_str = f"B{config.batch_size}xS{config.seq_len}xH{config.num_heads}xD{config.head_dim}"

        triton_time = results["triton_custom"][i]
        pytorch_time = results["pytorch_sdpa"][i]
        flash_time = results["flash_attn"][i]
        ref_time = results["reference_triton"][i]

        def format_time(t):
            return f"{t:.2f}ms" if t != float("inf") else "FAIL"

        print(
            f"{config_str:<25} {format_time(triton_time):<15} {format_time(pytorch_time):<15} "
            f"{format_time(flash_time):<15} {format_time(ref_time):<15}"
        )


def main():
    parser = argparse.ArgumentParser(description="Benchmark attention implementations")
    parser.add_argument(
        "--warmup", type=int, default=10, help="Number of warmup iterations"
    )
    parser.add_argument(
        "--benchmark", type=int, default=50, help="Number of benchmark iterations"
    )
    parser.add_argument(
        "--output", type=str, default="attention_benchmark", help="Output file prefix"
    )
    parser.add_argument("--device", type=str, default="cuda", help="Device to run on")

    args = parser.parse_args()

    if args.device == "cuda" and not torch.cuda.is_available():
        print("CUDA not available, falling back to CPU")
        args.device = "cpu"

    print("Starting Attention Implementation Benchmark")
    print(f"Device: {args.device}")
    print(f"Warmup steps: {args.warmup}")
    print(f"Benchmark steps: {args.benchmark}")
    print(f"Flash Attention available: {HAS_FLASH}")

    # Create benchmark suite
    benchmark = AttentionBenchmark(
        warmup_steps=args.warmup, benchmark_steps=args.benchmark
    )

    # Generate configurations
    configs = create_benchmark_configs()
    print(f"Testing {len(configs)} configurations")

    # Run benchmark
    results = benchmark.run_comprehensive_benchmark(configs)

    # Print summary
    print_summary_table(results)

    # Save and plot results
    plot_results(results, f"{args.output}_plot.png")

    # Save raw results
    import json

    with open(f"{args.output}_results.json", "w") as f:
        # Convert configs to dict for JSON serialization
        results_copy = results.copy()
        results_copy["configs"] = [
            {
                "batch_size": c.batch_size,
                "seq_len": c.seq_len,
                "num_heads": c.num_heads,
                "head_dim": c.head_dim,
                "causal": c.causal,
            }
            for c in results["configs"]
        ]
        json.dump(results_copy, f, indent=2)

    print(f"\nBenchmark complete! Results saved to {args.output}_*")


if __name__ == "__main__":
    main()
