# SPDX-License-Identifier: Apache-2.0
# fmt: off
# ruff: noqa: E501
import time

# Import DeepGEMM functions
import deep_gemm
import torch
import triton
from deep_gemm import calc_diff, ceil_div, get_col_major_tma_aligned_tensor

# Import vLLM functions
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
    per_token_group_quant_fp8, w8a8_block_fp8_matmul)


# Copied from
# https://github.com/deepseek-ai/DeepGEMM/blob/78cacf70d41d15d688bd493ebc85845f7f2a3d5d/tests/test_core.py#L9
def per_token_cast_to_fp8(
        x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    """Convert tensor to FP8 format with per-token scaling."""
    assert x.dim() == 2 and x.size(1) % 128 == 0
    m, n = x.shape
    x_view = x.view(m, -1, 128)
    x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
    return (x_view * (448.0 / x_amax.unsqueeze(2))).to(
        torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1)


# Copied from
# https://github.com/deepseek-ai/DeepGEMM/blob/78cacf70d41d15d688bd493ebc85845f7f2a3d5d/tests/test_core.py#L17
def per_block_cast_to_fp8(
        x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    """Convert tensor to FP8 format with per-block scaling."""
    assert x.dim() == 2
    m, n = x.shape
    x_padded = torch.zeros((ceil_div(m, 128) * 128, ceil_div(n, 128) * 128),
                           dtype=x.dtype,
                           device=x.device)
    x_padded[:m, :n] = x
    x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
    x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
    x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
    return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (
        x_amax / 448.0).view(x_view.size(0), x_view.size(2))


def benchmark_shape(m: int,
                    n: int,
                    k: int,
                    warmup: int = 100,
                    repeat: int = 10000,
                    verbose: bool = False) -> dict:
    """Benchmark all implementations for a specific (m, n, k) shape."""
    if verbose:
        print(f"\n=== Benchmarking shape: m={m}, n={n}, k={k} ===")

    # Create test tensors
    A = torch.randn((m, k), device='cuda', dtype=torch.bfloat16)
    B = torch.randn((n, k), device='cuda', dtype=torch.bfloat16)

    # Reference result in BF16
    torch.cuda.synchronize()
    C_ref = A @ B.t()

    # Pre-quantize B for all implementations
    # (weights can be pre-quantized offline)
    B_deepgemm, B_scale_deepgemm = per_block_cast_to_fp8(B)
    B_vllm, B_scale_vllm = per_block_cast_to_fp8(B)

    # Block size configuration
    block_size = [128, 128]

    # Pre-quantize A for all implementations
    A_deepgemm, A_scale_deepgemm = per_token_cast_to_fp8(A)
    A_scale_deepgemm = get_col_major_tma_aligned_tensor(A_scale_deepgemm)
    C_deepgemm = torch.empty((m, n), device='cuda', dtype=torch.bfloat16)
    A_vllm, A_scale_vllm = per_token_group_quant_fp8(A, block_size[1])
    A_vllm_cutlass, A_scale_vllm_cutlass = per_token_group_quant_fp8(
        A, block_size[1], column_major_scales=True)

    # === DeepGEMM Implementation ===
    def deepgemm_gemm():
        # A quantization is inside the loop as it depends on activations
        # A_deepgemm, A_scale_deepgemm = per_token_cast_to_fp8(A)
        # A_deepgemm, A_scale_deepgemm = per_token_group_quant_fp8(
        #     A, block_size[1])
        # A_scale_aligned = get_col_major_tma_aligned_tensor(A_scale_deepgemm)
        # C_deepgemm = torch.empty((m, n), device='cuda', dtype=torch.bfloat16)
        deep_gemm.gemm_fp8_fp8_bf16_nt((A_deepgemm, A_scale_deepgemm),
                                       (B_deepgemm, B_scale_deepgemm),
                                       C_deepgemm)
        return C_deepgemm

    # === vLLM Triton Implementation ===
    def vllm_triton_gemm():
        # A quantization is inside the loop as it depends on activations
        # A_vllm, A_scale_vllm = per_token_group_quant_fp8(A, block_size[1])
        return w8a8_block_fp8_matmul(A_vllm,
                                     B_vllm,
                                     A_scale_vllm,
                                     B_scale_vllm,
                                     block_size,
                                     output_dtype=torch.bfloat16)

    # === vLLM CUTLASS Implementation ===
    def vllm_cutlass_gemm():
        # A quantization is inside the loop as it depends on activations
        # A_vllm_cutlass, A_scale_vllm_cutlass = per_token_group_quant_fp8(
        #     A, block_size[1], column_major_scales=True)
        return ops.cutlass_scaled_mm(A_vllm_cutlass,
                                     B_vllm.T,
                                     scale_a=A_scale_vllm_cutlass,
                                     scale_b=B_scale_vllm.T,
                                     out_dtype=torch.bfloat16)

    # Run correctness check first
    if verbose:
        print("Running correctness check...")
    C_deepgemm = deepgemm_gemm()
    C_vllm_triton = vllm_triton_gemm()
    C_vllm_cutlass = vllm_cutlass_gemm()

    deepgemm_diff = calc_diff(C_deepgemm, C_ref)
    vllm_triton_diff = calc_diff(C_vllm_triton, C_ref)
    vllm_cutlass_diff = calc_diff(C_vllm_cutlass, C_ref)

    if verbose:
        print(f"DeepGEMM vs Reference difference: {deepgemm_diff:.6f}")
        print(f"vLLM Triton vs Reference difference: {vllm_triton_diff:.6f}")
        print(f"vLLM CUTLASS vs Reference difference: {vllm_cutlass_diff:.6f}")
        print("vLLM Triton vs DeepGEMM difference: "
              f"{calc_diff(C_vllm_triton, C_deepgemm):.6f}")
        print("vLLM CUTLASS vs DeepGEMM difference: "
              f"{calc_diff(C_vllm_cutlass, C_deepgemm):.6f}")

    # Benchmark implementations
    implementations = {
        "DeepGEMM": deepgemm_gemm,
        "vLLM Triton": vllm_triton_gemm,
        "vLLM CUTLASS": vllm_cutlass_gemm
    }

    benchmark_results = {
        "shape": {
            "m": m,
            "n": n,
            "k": k
        },
        "implementations": {}
    }

    for name, func in implementations.items():
        # Warmup
        for _ in range(warmup):
            func()
            torch.cuda.synchronize()

        # Timing loop
        torch.cuda.synchronize()
        start = time.time()
        for _ in range(repeat):
            func()
        torch.cuda.synchronize()
        end = time.time()

        # Calculate timing and TFLOPS
        avg_time_ms = (end - start) / repeat * 1000
        avg_time_us = avg_time_ms * 1000
        tflops = 2 * m * n * k / (avg_time_ms * 1e-3) / 1e12
        gb_s = (m * k + k * n + m * n * 2) / 1e9 / (avg_time_ms * 1e-3)

        benchmark_results["implementations"][name] = {
            "time_ms": avg_time_ms,
            "time_us": avg_time_us,
            "tflops": tflops,
            "gb_s": gb_s,
            "diff": {
                "DeepGEMM":
                0.0 if name == "DeepGEMM" else calc_diff(func(), C_deepgemm),
                "Reference":
                deepgemm_diff if name == "DeepGEMM" else
                (vllm_triton_diff
                 if name == "vLLM Triton" else vllm_cutlass_diff)
            }
        }

        if verbose:
            print(
                f"{name}: {avg_time_ms:.3f} ms, {tflops:.2f} TFLOPS, {gb_s:.2f} GB/s"
            )

    # Calculate speedups
    baseline = benchmark_results["implementations"]["DeepGEMM"]["time_ms"]
    for name, data in benchmark_results["implementations"].items():
        if name != "DeepGEMM":
            speedup = baseline / data["time_ms"]
            benchmark_results["implementations"][name][
                "speedup_vs_deepgemm"] = speedup
            if verbose:
                print(f"DeepGEMM is {1/speedup:.2f}x "
                      f"{'faster' if 1/speedup > 1 else 'slower'} than {name}")

    vllm_triton_time = benchmark_results["implementations"]["vLLM Triton"][
        "time_ms"]
    vllm_cutlass_time = benchmark_results["implementations"]["vLLM CUTLASS"][
        "time_ms"]
    cutlass_vs_triton = vllm_triton_time / vllm_cutlass_time
    benchmark_results["implementations"]["vLLM CUTLASS"][
        "speedup_vs_triton"] = cutlass_vs_triton
    if verbose:
        print(
            f"vLLM CUTLASS is {cutlass_vs_triton:.2f}x "
            f"{'faster' if cutlass_vs_triton > 1 else 'slower'} than vLLM Triton"
        )

    return benchmark_results


def format_table_row(values, widths):
    """Format a row with specified column widths."""
    return "| " + " | ".join(f"{val:{w}}"
                             for val, w in zip(values, widths)) + " |"


def print_table(headers, rows, title=None):
    """Print a table with headers and rows."""
    if title:
        print(f"\n{title}")

    # Calculate column widths based on headers and data
    widths = [
        max(len(str(h)), max(len(str(row[i])) for row in rows))
        for i, h in enumerate(headers)
    ]

    # Create separator line
    separator = "+-" + "-+-".join("-" * w for w in widths) + "-+"

    # Print table
    print(separator)
    print(format_table_row(headers, widths))
    print(separator)
    for row in rows:
        print(format_table_row(row, widths))
    print(separator)


def format_speedup(value):
    """Format speedup value with indicator if it's faster or slower."""
    return f"{value:.2f}x {'faster' if value > 1.0 else 'slower'}"


def run_benchmarks(verbose: bool = False):
    """Run benchmarks for a set of common shapes."""
    print("===== STARTING FP8 GEMM BENCHMARK =====")

    # Make sure we're using the GPU
    if not torch.cuda.is_available():
        print("CUDA not available! Tests require GPU.")
        return

    # Print system information
    print(f"PyTorch version: {torch.__version__}")
    print(f"CUDA version: {torch.version.cuda}")
    print(f"Triton version: {triton.__version__}")
    print(f"Using device: {torch.cuda.get_device_name()}")

    # Enable TF32 for better performance
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

    # Set seeds for reproducibility
    torch.manual_seed(42)
    torch.cuda.manual_seed(42)

    # Define benchmark shapes (m, n, k)
    shapes = [
        (8, 4096, 7168),
        (8, 7168, 18432),
        (8, 18432, 7168),
        (64, 4096, 7168),
        (64, 7168, 18432),
        (64, 18432, 7168),
        (64, 24576, 1536),
        (64, 32768, 512),
        (64, 7168, 16384),
        (128, 4096, 7168),
        (128, 7168, 18432),
        (128, 18432, 7168),
        (1024, 4096, 7168),
        (1024, 18432, 7168),
        (2048, 4096, 7168),
        (4096, 4096, 7168),
    ]
    shapes = [
        # (64, 2112, 7168),
        (64, 24576, 1536),
        (64, 32768, 512),
        (64, 7168, 16384),
        (64, 4096, 7168),
        (64, 7168, 2048),
        # (128, 2112, 7168),
        (128, 24576, 1536),
        (128, 32768, 512),
        (128, 7168, 16384),
        (128, 4096, 7168),
        (128, 7168, 2048),
        # (4096, 2112, 7168),
        (4096, 24576, 1536),
        (4096, 32768, 512),
        (4096, 7168, 16384),
        (4096, 4096, 7168),
        (4096, 7168, 2048),
    ]

    all_results = []
    for m, n, k in shapes:
        result = benchmark_shape(m, n, k, verbose=verbose)
        all_results.append(result)

    # Print results in a nicely formatted table
    print("\n===== PERFORMANCE COMPARISON =====")

    # Print DeepGEMM table
    deepgemm_headers = ["m", "n", "k", "Time (μs)", "TFLOPS", "GB/s"]
    deepgemm_rows = []
    for result in all_results:
        shape = result["shape"]
        impl_data = result["implementations"]["DeepGEMM"]
        deepgemm_rows.append([
            shape["m"], shape["n"], shape["k"], f"{impl_data['time_us']:.1f}",
            f"{impl_data['tflops']:.1f}", f"{impl_data['gb_s']:.1f}"
        ])

    print_table(deepgemm_headers,
                deepgemm_rows,
                title="DeepGEMM Implementation:")

    # Print vLLM Triton table
    triton_headers = [
        "m", "n", "k", "Time (μs)", "TFLOPS", "GB/s", "vs DeepGEMM"
    ]
    triton_rows = []
    for result in all_results:
        shape = result["shape"]
        impl_data = result["implementations"]["vLLM Triton"]
        speedup = impl_data.get("speedup_vs_deepgemm", 1.0)
        triton_rows.append([
            shape["m"], shape["n"], shape["k"], f"{impl_data['time_us']:.1f}",
            f"{impl_data['tflops']:.1f}", f"{impl_data['gb_s']:.1f}",
            format_speedup(speedup)
        ])

    print_table(triton_headers,
                triton_rows,
                title="vLLM Triton Implementation:")

    # Print vLLM CUTLASS table
    cutlass_headers = [
        "m", "n", "k", "Time (μs)", "TFLOPS", "GB/s", "vs DeepGEMM",
        "vs Triton"
    ]
    cutlass_rows = []
    for result in all_results:
        shape = result["shape"]
        impl_data = result["implementations"]["vLLM CUTLASS"]
        vs_deepgemm = impl_data.get("speedup_vs_deepgemm", 1.0)
        vs_triton = impl_data.get("speedup_vs_triton", 1.0)
        cutlass_rows.append([
            shape["m"], shape["n"], shape["k"], f"{impl_data['time_us']:.1f}",
            f"{impl_data['tflops']:.1f}", f"{impl_data['gb_s']:.1f}",
            format_speedup(vs_deepgemm),
            format_speedup(vs_triton)
        ])

    print_table(cutlass_headers,
                cutlass_rows,
                title="vLLM CUTLASS Implementation:")

    # Calculate and print averages
    print("\n===== AVERAGE PERFORMANCE =====")

    implementations = ["DeepGEMM", "vLLM Triton", "vLLM CUTLASS"]
    avg_metrics = {
        impl: {
            "tflops": 0,
            "gb_s": 0,
            "time_ms": 0
        }
        for impl in implementations
    }

    for result in all_results:
        for impl in implementations:
            impl_data = result["implementations"][impl]
            avg_metrics[impl]["tflops"] += impl_data["tflops"]
            avg_metrics[impl]["gb_s"] += impl_data["gb_s"]
            avg_metrics[impl]["time_ms"] += impl_data["time_ms"]

    num_shapes = len(all_results)
    avg_headers = ["Implementation", "Avg TFLOPS", "Avg GB/s", "Avg Time (ms)"]
    avg_rows = []

    for impl in implementations:
        avg_tflops = avg_metrics[impl]["tflops"] / num_shapes
        avg_mem_bw = avg_metrics[impl]["gb_s"] / num_shapes
        avg_time = avg_metrics[impl]["time_ms"] / num_shapes
        avg_rows.append([
            impl, f"{avg_tflops:.2f}", f"{avg_mem_bw:.2f}", f"{avg_time:.2f}"
        ])

    print_table(avg_headers, avg_rows)

    # Calculate average speedups
    avg_speedups = {
        "DeepGEMM vs vLLM Triton": 0,
        "DeepGEMM vs vLLM CUTLASS": 0,
        "vLLM CUTLASS vs vLLM Triton": 0
    }

    for result in all_results:
        deepgemm_time = result["implementations"]["DeepGEMM"]["time_ms"]
        vllm_triton_time = result["implementations"]["vLLM Triton"]["time_ms"]
        vllm_cutlass_time = result["implementations"]["vLLM CUTLASS"][
            "time_ms"]

        avg_speedups[
            "DeepGEMM vs vLLM Triton"] += vllm_triton_time / deepgemm_time
        avg_speedups[
            "DeepGEMM vs vLLM CUTLASS"] += vllm_cutlass_time / deepgemm_time
        avg_speedups[
            "vLLM CUTLASS vs vLLM Triton"] += vllm_triton_time / vllm_cutlass_time

    print("\n===== AVERAGE SPEEDUPS =====")
    speedup_headers = ["Comparison", "Speedup"]
    speedup_rows = []
    for comparison, total in avg_speedups.items():
        avg_speedup = total / num_shapes
        status = "faster" if avg_speedup > 1 else "slower"
        speedup_rows.append([comparison, f"{avg_speedup:.2f}x {status}"])

    print_table(speedup_headers, speedup_rows)

    # Average accuracy comparison
    print("\n===== ACCURACY COMPARISON =====")
    avg_diff = {impl: 0 for impl in implementations}

    for result in all_results:
        for impl in implementations:
            avg_diff[impl] += result["implementations"][impl]["diff"][
                "Reference"]

    diff_headers = ["Implementation", "Avg Diff vs Reference"]
    diff_rows = []
    for impl in implementations:
        diff_rows.append([impl, f"{avg_diff[impl] / num_shapes:.6f}"])

    print_table(diff_headers, diff_rows)


if __name__ == "__main__":
    run_benchmarks(verbose=False)
