from typing import Tuple

import deep_gemm
import tilelang
import tilelang.language as T
import torch
import triton
from deep_gemm import ceil_div, get_col_major_tma_aligned_tensor
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
    w8a8_block_fp8_matmul as vllm_w8a8_block_fp8_matmul,
)

from sglang.srt.layers.quantization.fp8_kernel import w8a8_block_fp8_matmul


# Adapted from https://github.com/tile-ai/tilelang/blob/a8cfdce92795cb861c9033573534653ee040b5ed/examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py#L1
def tl_gemm(
    M,
    N,
    K,
    in_dtype,
    out_dtype,
    accum_dtype,
):
    assert in_dtype in [
        "e4m3_float8",
    ], "Currently only e4m3_float8 is supported"
    assert out_dtype in [
        "bfloat16",
        "float16",
    ], "Currently only bfloat16 and float16 are supported"

    TILE_SIZE = (128, 128, 128)
    block_M = TILE_SIZE[0]
    block_N = TILE_SIZE[1]
    block_K = TILE_SIZE[2]

    A_shape = (M, K)
    Scales_A_shape = (M, T.ceildiv(K, block_K))
    B_shape = (N, K)
    Scales_B_shape = (T.ceildiv(N, block_N), T.ceildiv(K, block_K))
    A_shared_shape = (block_M, block_K)
    B_shared_shape = (block_N, block_K)
    C_shared_shape = (block_M, block_N)

    @T.prim_func
    def main(
        A: T.Buffer(A_shape, in_dtype),
        scales_a: T.Buffer(Scales_A_shape, "float32"),
        B: T.Buffer(B_shape, in_dtype),
        scales_b: T.Buffer(Scales_B_shape, "float32"),
        C: T.Buffer((M, N), out_dtype),
    ):
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (
            bx,
            by,
        ):

            A_shared = T.alloc_shared(A_shared_shape, in_dtype)
            B_shared = T.alloc_shared(B_shared_shape, in_dtype)
            C_shared = T.alloc_shared(C_shared_shape, out_dtype)
            Scale_C_shared = T.alloc_shared((block_M), "float32")
            C_local = T.alloc_fragment(C_shared_shape, accum_dtype)
            C_local_accum = T.alloc_fragment(C_shared_shape, accum_dtype)

            # Improve L2 Cache
            T.use_swizzle(panel_size=10)

            T.clear(C_local)
            T.clear(C_local_accum)
            K_iters = T.ceildiv(K, block_K)
            for k in T.Pipelined(K_iters, num_stages=4):
                # Load A into shared memory
                T.copy(A[by * block_M, k * block_K], A_shared)
                # Load B into shared memory
                T.copy(B[bx * block_N, k * block_K], B_shared)
                # Load scale into shared memory
                Scale_B = scales_b[bx, k]
                for i in T.Parallel(block_M):
                    Scale_C_shared[i] = scales_a[by * block_M + i, k] * Scale_B

                T.gemm(A_shared, B_shared, C_local, transpose_B=True)
                # Promote to enable 2xAcc
                for i, j in T.Parallel(block_M, block_N):
                    C_local_accum[i, j] += C_local[i, j] * Scale_C_shared[i]
                T.clear(C_local)
            # TMA store
            T.copy(C_local_accum, C_shared)
            T.copy(C_shared, C[by * block_M, bx * block_N])

    return main


def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    assert x.dim() == 2 and x.size(1) % 128 == 0
    m, n = x.shape
    x_view = x.view(m, -1, 128)
    x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
    return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(
        m, n
    ), (x_amax / 448.0).view(m, -1)


def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    assert x.dim() == 2
    m, n = x.shape
    x_padded = torch.zeros(
        (ceil_div(m, 128) * 128, ceil_div(n, 128) * 128), dtype=x.dtype, device=x.device
    )
    x_padded[:m, :n] = x
    x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
    x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
    x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
    return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(
        x_view.size(0), x_view.size(2)
    )


def fp8_gemm_deepgemm(
    x_fp8: torch.Tensor,
    x_scale: torch.Tensor,
    y_fp8: torch.Tensor,
    y_scale: torch.Tensor,
    m: int,
    n: int,
    k: int,
):
    """DeepGEMM implementation of FP8 GEMM"""
    out = torch.empty((m, n), device="cuda", dtype=torch.bfloat16)

    # Run DeepGEMM kernel
    deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale), (y_fp8, y_scale), out)
    return out


def fp8_gemm_sglang(
    x_fp8: torch.Tensor,
    x_scale: torch.Tensor,
    y_fp8: torch.Tensor,
    y_scale: torch.Tensor,
    m: int,
    n: int,
    k: int,
):
    """SGLang implementation of FP8 GEMM"""
    block_size = [128, 128]  # Matches the block size in per_block_cast_to_fp8

    # Run SGLang kernel
    out = w8a8_block_fp8_matmul(
        x_fp8, y_fp8, x_scale, y_scale, block_size, torch.bfloat16
    )
    return out


def fp8_gemm_vllm(
    x_fp8: torch.Tensor,
    x_scale: torch.Tensor,
    y_fp8: torch.Tensor,
    y_scale: torch.Tensor,
    m: int,
    n: int,
    k: int,
):
    """vLLM implementation of FP8 GEMM"""
    block_size = [128, 128]  # Matches the block size in per_block_cast_to_fp8

    # Run vLLM kernel
    out = vllm_w8a8_block_fp8_matmul(
        x_fp8, y_fp8, x_scale, y_scale, block_size, torch.bfloat16
    )
    return out


def calculate_diff(m: int, n: int, k: int):
    x = torch.randn((m, k), device="cuda", dtype=torch.bfloat16)
    y = torch.randn((n, k), device="cuda", dtype=torch.bfloat16)

    x_fp8, x_scale = per_token_cast_to_fp8(x.clone())
    y_fp8, y_scale = per_block_cast_to_fp8(y.clone())
    x_scale_col_major = get_col_major_tma_aligned_tensor(x_scale.clone())

    out_deepgemm = fp8_gemm_deepgemm(
        x_fp8.clone(),
        x_scale_col_major.clone(),
        y_fp8.clone(),
        y_scale.clone(),
        m,
        n,
        k,
    )
    out_sglang = fp8_gemm_sglang(
        x_fp8.clone(), x_scale.clone(), y_fp8.clone(), y_scale.clone(), m, n, k
    )

    tilelang_func = tl_gemm(m, n, k, "e4m3_float8", "bfloat16", "float32")
    tilelang_kernel = tilelang.compile(tilelang_func, out_idx=[-1])
    out_tilelang = tilelang_kernel(
        x_fp8.clone(), x_scale.clone(), y_fp8.clone(), y_scale.clone()
    )

    diff_sglang_deepgemm = torch.abs(out_deepgemm - out_sglang).mean().item()
    diff_tilelang_deepgemm = torch.abs(out_deepgemm - out_tilelang).mean().item()
    diff_tilelang_sglang = torch.abs(out_tilelang - out_sglang).mean().item()

    print(f"Shape m={m}, n={n}, k={k}:")
    print(f"DeepGEMM output: {out_deepgemm[0, 0:5]}")
    print(f"SGLang output: {out_sglang[0, 0:5]}")
    print(f"TileLang output: {out_tilelang[0, 0:5]}")
    print(f"Mean absolute difference (SGLang-DeepGEMM): {diff_sglang_deepgemm}")
    print(f"Mean absolute difference (TileLang-DeepGEMM): {diff_tilelang_deepgemm}")
    print(f"Mean absolute difference (TileLang-SGLang): {diff_tilelang_sglang}")

    sglang_deepgemm_match = torch.allclose(
        out_deepgemm, out_sglang, atol=1e-2, rtol=1e-2
    )
    tilelang_deepgemm_match = torch.allclose(
        out_deepgemm, out_tilelang, atol=1e-2, rtol=1e-2
    )
    tilelang_sglang_match = torch.allclose(
        out_tilelang, out_sglang, atol=1e-2, rtol=1e-2
    )

    if sglang_deepgemm_match and tilelang_deepgemm_match and tilelang_sglang_match:
        print("✅ All implementations match\n")
    else:
        print("❌ Some implementations differ:")
        print(f"  - SGLang vs DeepGEMM: {'✅' if sglang_deepgemm_match else '❌'}")
        print(f"  - TileLang vs DeepGEMM: {'✅' if tilelang_deepgemm_match else '❌'}")
        print(f"  - TileLang vs SGLang: {'✅' if tilelang_sglang_match else '❌'}\n")


def get_weight_shapes(tp_size):
    # cannot TP
    total = [
        (512 + 64, 7168),
        ((128 + 64) * 128, 7168),
        (128 * (128 + 128), 512),
        (7168, 16384),
        (7168, 18432),
    ]
    # N can TP
    n_tp = [
        (18432 * 2, 7168),
        ((128 + 64) * 128, 7168),
        (128 * (128 + 128), 512),
        (24576, 1536),
        (4096, 7168),
    ]
    # K can TP
    k_tp = [(7168, 18432), (7168, 16384), (7168, 2048)]

    weight_shapes = []
    for t in total:
        weight_shapes.append(t)
    for n_t in n_tp:
        new_t = (n_t[0] // tp_size, n_t[1])
        weight_shapes.append(new_t)
    for k_t in k_tp:
        new_t = (k_t[0], k_t[1] // tp_size)
        weight_shapes.append(new_t)

    return weight_shapes


def create_benchmark_configs(tp_size):
    configs = []
    weight_shapes = get_weight_shapes(tp_size)
    batch_sizes = [8, 16, 32, 64, 128, 256, 1024, 2048, 4096]

    for n, k in weight_shapes:
        for m in batch_sizes:
            configs.append((m, n, k, tp_size))

    return configs


def get_benchmark(tp_size):
    all_configs = create_benchmark_configs(tp_size)

    @triton.testing.perf_report(
        triton.testing.Benchmark(
            x_names=["m", "n", "k", "tp_size"],
            x_vals=[list(config) for config in all_configs],
            line_arg="provider",
            line_vals=["deepgemm", "sglang", "tilelang"],
            line_names=["DeepGEMM", "SGLang", "TileLang"],
            styles=[("blue", "-"), ("red", "-"), ("green", "-")],
            ylabel="ms",
            plot_name=f"fp8-gemm-performance-comparison-tp{tp_size}",
            args={},
        )
    )
    def benchmark(m, n, k, tp_size, provider):
        print(f"Shape (m={m}, n={n}, k={k}, tp={tp_size}), Provider: {provider}")
        x = torch.randn((m, k), device="cuda", dtype=torch.bfloat16)
        y = torch.randn((n, k), device="cuda", dtype=torch.bfloat16)

        # Preprocess data before benchmarking
        x_fp8, x_scale = per_token_cast_to_fp8(x)
        y_fp8, y_scale = per_block_cast_to_fp8(y)
        x_scale_col_major = get_col_major_tma_aligned_tensor(x_scale.clone())

        quantiles = [0.5, 0.2, 0.8]

        if provider == "deepgemm":
            ms, min_ms, max_ms = triton.testing.do_bench(
                lambda: fp8_gemm_deepgemm(
                    x_fp8.clone(),
                    x_scale_col_major.clone(),
                    y_fp8.clone(),
                    y_scale.clone(),
                    m,
                    n,
                    k,
                ),
                quantiles=quantiles,
            )
        elif provider == "sglang":
            ms, min_ms, max_ms = triton.testing.do_bench(
                lambda: fp8_gemm_sglang(
                    x_fp8.clone(),
                    x_scale.clone(),
                    y_fp8.clone(),
                    y_scale.clone(),
                    m,
                    n,
                    k,
                ),
                quantiles=quantiles,
            )
        else:  # tilelang
            tilelang_func = tl_gemm(m, n, k, "e4m3_float8", "bfloat16", "float32")
            tilelang_kernel = tilelang.compile(tilelang_func, out_idx=[-1])
            ms, min_ms, max_ms = triton.testing.do_bench(
                lambda: tilelang_kernel(
                    x_fp8.clone(),
                    x_scale.clone(),
                    y_fp8.clone(),
                    y_scale.clone(),
                ),
                quantiles=quantiles,
            )

        # Calculate TFLOPS
        flops = 2 * m * n * k  # multiply-adds
        tflops = flops / (ms * 1e-3) / 1e12

        # Print shape-specific results with TFLOPS
        print(f"Time: {ms*1000:.2f} ms, TFLOPS: {tflops:.2f}")
        return ms * 1000, max_ms * 1000, min_ms * 1000  # convert to ms

    return benchmark


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--save_path",
        type=str,
        default="./configs/benchmark_ops/fp8_gemm/",
        help="Path to save fp8 gemm benchmark results",
    )
    parser.add_argument(
        "--run_correctness",
        action="store_true",
        default=True,
        help="Whether to run correctness test",
    )
    parser.add_argument(
        "--tp_size",
        type=int,
        default=1,
        help="Tensor parallelism size to benchmark (default: 1)",
    )
    args = parser.parse_args()

    # Set random seed for reproducibility
    torch.manual_seed(0)
    torch.cuda.manual_seed(0)

    # Enable TF32, adapted from https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_core.py#L148
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

    # Run correctness tests on a few examples
    if args.run_correctness:
        print("Running correctness tests...")
        calculate_diff(64, 512, 7168)  # Small test
        calculate_diff(64, 7168, 16384)  # Medium test
        calculate_diff(64, 18432, 7168)  # Large test

    # Get the benchmark function with the specified tp_size
    benchmark = get_benchmark(args.tp_size)

    print(f"Running performance benchmark for TP size = {args.tp_size}...")
    benchmark.run(print_data=True, save_path=args.save_path)
