import itertools
import math

import torch
import triton
import triton.language as tl
from sgl_kernel import lightning_attention_decode


def next_power_of_2(n):
    return 2 ** (int(math.ceil(math.log(n, 2))))


@triton.jit
def _decode_kernel(
    Q,
    K,
    V,
    KV,
    Out,
    S,
    b: tl.constexpr,
    h: tl.constexpr,
    n: tl.constexpr,
    d: tl.constexpr,
    d_original: tl.constexpr,
    e: tl.constexpr,
    e_original: tl.constexpr,
):
    off_bh = tl.program_id(0)
    off_h = off_bh % h

    qk_offset = off_bh * n * d
    v_offset = off_bh * n * e
    o_offset = off_bh * n * e
    kv_offset = off_bh * d * e

    s = tl.load(S + off_h)
    ratio = tl.exp(-s)

    d_idx = tl.arange(0, d)
    e_idx = tl.arange(0, e)

    # Create masks for original dimensions
    d_mask = d_idx < d_original
    e_mask = e_idx < e_original

    # Load with masking
    q = tl.load(Q + qk_offset + d_idx, mask=d_mask, other=0.0)
    k = tl.load(K + qk_offset + d_idx, mask=d_mask, other=0.0)
    v = tl.load(V + v_offset + e_idx, mask=e_mask, other=0.0)

    # Load KV with 2D masking
    kv = tl.load(
        KV + kv_offset + d_idx[:, None] * e + e_idx[None, :],
        mask=(d_mask[:, None] & e_mask[None, :]),
        other=0.0,
    )

    # Compute outer product using element-wise operations
    k_v_prod = k[:, None] * v[None, :]
    kv = ratio * kv + k_v_prod

    # Store KV with 2D masking
    tl.store(
        KV + kv_offset + d_idx[:, None] * e + e_idx[None, :],
        kv.to(KV.dtype.element_ty),
        mask=(d_mask[:, None] & e_mask[None, :]),
    )

    # Compute matrix-vector multiplication using element-wise operations and reduction
    o = tl.sum(q[:, None] * kv, axis=0)

    # Store output with masking
    tl.store(Out + o_offset + e_idx, o.to(Out.dtype.element_ty), mask=e_mask)


def triton_lightning_attn_decode(q, k, v, kv, s):
    """Triton implementation of Lightning Attention decode operation"""
    b, h, n, d = q.shape
    e = v.shape[-1]
    assert n == 1, "Sequence length must be 1 in decode mode"

    # Get padded dimensions (power of 2)
    d_padded = next_power_of_2(d)
    e_padded = next_power_of_2(e)

    # Create output tensor (padded)
    o_padded = torch.empty(b, h, n, e_padded, dtype=v.dtype, device=v.device)

    # Create padded tensors without actually padding the data
    q_padded = torch.empty(b, h, n, d_padded, dtype=q.dtype, device=q.device)
    k_padded = torch.empty(b, h, n, d_padded, dtype=k.dtype, device=k.device)
    v_padded = torch.empty(b, h, n, e_padded, dtype=v.dtype, device=v.device)
    kv_padded = torch.empty(
        b, h, d_padded, e_padded, dtype=torch.float32, device=kv.device
    )

    # Copy data to padded tensors
    q_padded[..., :d] = q
    k_padded[..., :d] = k
    v_padded[..., :e] = v
    kv_padded[..., :d, :e] = kv

    # Launch kernel
    grid = (b * h, 1)
    _decode_kernel[grid](
        q_padded,
        k_padded,
        v_padded,
        kv_padded,
        o_padded,
        s,
        b=b,
        h=h,
        n=n,
        d=d_padded,
        d_original=d,
        e=e_padded,
        e_original=e,
    )

    # Get unpadded outputs
    o = o_padded[..., :e]
    kv_out = kv_padded[..., :d, :e]

    return o, kv_out


def lightning_attention_decode_naive(q, k, v, past_kv, slope):
    """Naive implementation of lightning attention decode"""
    original_dtype = q.dtype
    ratio = torch.exp(-slope)  # [h, 1, 1]

    kv = past_kv
    b, h, n, d = q.shape

    output = []
    for i in range(n):
        kv = ratio * kv.to(torch.float32) + torch.einsum(
            "... n d, ... n e -> ... d e",
            k[:, :, i : i + 1],
            v[:, :, i : i + 1],
        )
        qkv = torch.einsum(
            "... n e, ... e d -> ... n d",
            q[:, :, i : i + 1].to(torch.float32),
            kv.to(torch.float32),
        )
        output.append(qkv)
    output = torch.cat(output, dim=-2)

    return output.to(original_dtype), kv


def lightning_attention_decode_kernel(q, k, v, past_kv, slope, output, new_kv):
    return lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv)


def calculate_diff(batch_size):
    dtype = torch.bfloat16
    device = torch.device("cuda")
    num_heads = 64
    head_dim = 96
    seq_len = 1

    q = torch.randn(
        batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
    )
    k = torch.randn(
        batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
    )
    v = torch.randn(
        batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
    )
    past_kv = torch.randn(batch_size, num_heads, head_dim, head_dim, device=device)
    slope = torch.randn(num_heads, 1, 1, device=device)

    output_naive, new_kv_naive = lightning_attention_decode_naive(
        q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone()
    )

    output_kernel = torch.empty_like(output_naive)
    new_kv_kernel = torch.empty_like(new_kv_naive)
    lightning_attention_decode_kernel(
        q.clone(),
        k.clone(),
        v.clone(),
        past_kv.clone(),
        slope.clone(),
        output_kernel,
        new_kv_kernel,
    )

    output_triton, new_kv_triton = triton_lightning_attn_decode(
        q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone()
    )

    if (
        torch.allclose(output_naive, output_kernel, atol=1e-2, rtol=1e-2)
        and torch.allclose(output_naive, output_triton, atol=1e-2, rtol=1e-2)
        and torch.allclose(new_kv_naive, new_kv_kernel, atol=1e-2, rtol=1e-2)
        and torch.allclose(new_kv_naive, new_kv_triton, atol=1e-2, rtol=1e-2)
    ):
        print("✅ All implementations match")
    else:
        print("❌ Implementations differ")


batch_size_range = [i for i in range(1, 65)]  # 1 to 128
configs = [(bs,) for bs in batch_size_range]


@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=["batch_size"],
        x_vals=[list(_) for _ in configs],
        line_arg="provider",
        line_vals=["naive", "kernel", "triton"],
        line_names=["PyTorch Naive", "SGL Kernel", "Triton"],
        styles=[("blue", "-"), ("red", "-"), ("green", "-")],
        ylabel="us",
        plot_name="lightning-attention-decode-performance",
        args={},
    )
)
def benchmark(batch_size, provider):
    dtype = torch.bfloat16
    device = torch.device("cuda")
    num_heads = 64
    head_dim = 96
    seq_len = 1

    q = torch.randn(
        batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
    )
    k = torch.randn(
        batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
    )
    v = torch.randn(
        batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
    )
    past_kv = torch.randn(batch_size, num_heads, head_dim, head_dim, device=device)
    slope = torch.randn(num_heads, 1, 1, device=device)

    quantiles = [0.5, 0.2, 0.8]

    if provider == "naive":
        ms, min_ms, max_ms = triton.testing.do_bench(
            lambda: lightning_attention_decode_naive(
                q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone()
            ),
            quantiles=quantiles,
        )
    elif provider == "kernel":
        output = torch.empty(
            batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
        )
        new_kv = torch.empty(batch_size, num_heads, head_dim, head_dim, device=device)
        ms, min_ms, max_ms = triton.testing.do_bench(
            lambda: lightning_attention_decode_kernel(
                q.clone(),
                k.clone(),
                v.clone(),
                past_kv.clone(),
                slope.clone(),
                output,
                new_kv,
            ),
            quantiles=quantiles,
        )
    elif provider == "triton":
        ms, min_ms, max_ms = triton.testing.do_bench(
            lambda: triton_lightning_attn_decode(
                q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone()
            ),
            quantiles=quantiles,
        )

    return 1000 * ms, 1000 * max_ms, 1000 * min_ms


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--save_path",
        type=str,
        default="./configs/benchmark_ops/lightning_attention_decode_sgl/",
        help="Path to save lightning attention decode benchmark results",
    )
    args = parser.parse_args()

    # Run correctness test
    calculate_diff(batch_size=4)

    # Run performance benchmark
    benchmark.run(print_data=True)
