# SPDX-License-Identifier: Apache-2.0

import pytest
import torch

from vllm.model_executor.layers.lightning_attn import (
    linear_decode_forward_triton)
from vllm.platforms import current_platform

NUM_HEADS = [4, 8]
HEAD_SIZES = [64]
BATCH_SIZES = [1, 2]
SEQ_LENGTHS = [16]
DTYPES = [torch.float32]


def reference_lightning_attention(q, k, v, ed, block_size, kv_history):
    """Reference implementation of lightning attention core algorithm
    
    The difference from the main implementation is that this processes 
    each step sequentially, instead of using parallelized triton kernels
    """
    B, H, S, D = q.shape
    E = v.shape[-1]
    dtype = q.dtype
    output = torch.zeros((B, H, S, E), dtype=dtype, device=q.device)

    # Use clone() to ensure an independent copy
    if kv_history is None:
        kv_cache = torch.zeros((B, H, D, E), dtype=dtype, device=q.device)
    else:
        kv_cache = kv_history.clone()

    # More efficient implementation
    # Convert decay factors to matrix form
    if ed.dim() == 1:
        decay = torch.exp(-ed).view(1, -1, 1, 1)
    else:
        decay = torch.exp(-ed)

    for b in range(B):
        for step in range(S):
            # Process all heads at once for this position
            q_bs = q[b, :, step]  # [H, D]
            k_bs = k[b, :, step]  # [H, D]
            v_bs = v[b, :, step]  # [H, E]

            # Calculate KV outer products for all heads
            for h in range(H):
                # Calculate KV outer product
                kv_outer = torch.outer(k_bs[h], v_bs[h])

                # Update KV cache with decay
                # Note: Using the same order as in the Triton kernel
                kv_cache[b, h] = decay[0, h, 0, 0] * kv_cache[b, h] + kv_outer

                # Calculate attention output
                output[b, h, step] = torch.matmul(q_bs[h], kv_cache[b, h])

    # Match the shape returned by the actual implementation
    # The actual implementation returns a tensor of shape [B, H, 2, D, E]
    # where dimension 2 contains both KV and KV history
    kv_reshaped = kv_cache.unsqueeze(2)  # [B, H, 1, D, E]
    final_kv_cache = torch.cat([kv_reshaped, kv_reshaped],
                               dim=2)  # [B, H, 2, D, E]

    return output, final_kv_cache


def reference_linear_decode(q, k, v, kv_caches, slope_rate, slot_idx):
    """Reference implementation: linear attention decode function"""
    B, H, _, D = q.shape
    output = torch.zeros(B, H * D, dtype=q.dtype, device=q.device)

    # Calculate decay factors once (more efficient)
    decay = torch.exp(-slope_rate).view(-1, 1, 1)  # [H, 1, 1]

    # Process each batch
    for b in range(B):
        slot_id = slot_idx[b].item()

        # Skip padding positions
        if slot_id == -1:
            continue

        # Process all heads at once for this batch
        q_b = q[b, :, 0]  # [H, D]
        k_b = k[b, :, 0]  # [H, D]
        v_b = v[b, :, 0]  # [H, D]

        # Process each attention head
        for h in range(H):
            # Get current query, key and value
            q_bh = q_b[h]
            k_bh = k_b[h]
            v_bh = v_b[h]

            # Get cache
            kv_cache_old = kv_caches[b, h]

            # Calculate new key-value outer product
            kv_outer = torch.outer(k_bh, v_bh)

            # Apply decay and update cache
            kv_new = kv_outer + decay[h, 0, 0] * kv_cache_old

            # Calculate output
            out_h = torch.matmul(q_bh, kv_new)

            # Update output and cache
            output[b, h * D:(h + 1) * D] = out_h
            kv_caches[b, h] = kv_new

    return output


@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@torch.inference_mode()
def test_linear_decode_forward_triton(
    batch_size: int,
    num_heads: int,
    head_size: int,
    dtype: torch.dtype,
):
    torch.set_default_device("cuda")
    torch.manual_seed(42)
    torch.cuda.manual_seed_all(42)
    current_platform.seed_everything(42)
    base = 0.01
    q = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
    k = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
    v = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)

    kv_caches = base * torch.randn(batch_size,
                                   num_heads,
                                   head_size,
                                   head_size,
                                   dtype=dtype,
                                   device="cuda")

    kv_caches_copy = kv_caches.clone()

    slope_rate = torch.zeros(num_heads, device="cuda")
    for h in range(num_heads):
        slope_rate[h] = 0.1 * (h + 1)

    slot_idx = torch.arange(batch_size, device="cuda")

    triton_output = linear_decode_forward_triton(q, k, v, kv_caches,
                                                 slope_rate, slot_idx)

    reference_output = reference_linear_decode(q, k, v, kv_caches_copy,
                                               slope_rate, slot_idx)
    torch.testing.assert_close(triton_output,
                               reference_output,
                               rtol=1e-1,
                               atol=1e-1)
    torch.testing.assert_close(kv_caches, kv_caches_copy, rtol=1e-1, atol=1e-1)

    assert triton_output.shape == (batch_size, num_heads * head_size)


@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@torch.inference_mode()
def test_linear_decode_forward_triton_with_padding(
    num_heads: int,
    head_size: int,
    dtype: torch.dtype,
):
    torch.set_default_device("cuda")
    torch.manual_seed(42)
    torch.cuda.manual_seed_all(42)
    current_platform.seed_everything(42)

    batch_size = 4
    base = 0.01
    q = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
    k = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
    v = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)

    kv_caches = base * torch.randn(batch_size,
                                   num_heads,
                                   head_size,
                                   head_size,
                                   dtype=dtype,
                                   device="cuda")

    kv_caches_copy = kv_caches.clone()

    slope_rate = torch.zeros(num_heads, device="cuda")
    for h in range(num_heads):
        slope_rate[h] = 0.1 * (h + 1)

    slot_idx = torch.tensor([0, 1, -1, 2], device="cuda")

    triton_output = linear_decode_forward_triton(q, k, v, kv_caches,
                                                 slope_rate, slot_idx)

    reference_output = reference_linear_decode(q, k, v, kv_caches_copy,
                                               slope_rate, slot_idx)

    padding_mask = (slot_idx
                    != -1).unsqueeze(1).expand(-1, num_heads * head_size)

    triton_masked = triton_output[padding_mask]
    reference_masked = reference_output[padding_mask]

    atol, rtol = 1.5e-1, 1.5e-1

    valid_indices = slot_idx != -1

    for i in range(batch_size):
        if valid_indices[i] > 0:
            torch.testing.assert_close(kv_caches[i],
                                       kv_caches_copy[i],
                                       rtol=rtol,
                                       atol=atol)

    torch.testing.assert_close(triton_masked,
                               reference_masked,
                               rtol=rtol,
                               atol=atol)

    assert triton_output.shape == (batch_size, num_heads * head_size)


@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("seq_len", SEQ_LENGTHS)
@pytest.mark.parametrize("dtype", DTYPES)
@torch.inference_mode()
def test_lightning_attention_reference(
    batch_size: int,
    num_heads: int,
    head_size: int,
    seq_len: int,
    dtype: torch.dtype,
):
    torch.set_default_device("cuda")
    torch.manual_seed(42)
    torch.cuda.manual_seed_all(42)
    current_platform.seed_everything(42)

    base = 0.01
    q = base * torch.randn(
        batch_size, num_heads, seq_len, head_size, dtype=dtype)
    k = base * torch.randn(
        batch_size, num_heads, seq_len, head_size, dtype=dtype)
    v = base * torch.randn(
        batch_size, num_heads, seq_len, head_size, dtype=dtype)

    ed = torch.zeros(num_heads, device="cuda")
    for h in range(num_heads):
        ed[h] = 0.1 * (h + 1)

    kv_history = base * torch.randn(batch_size,
                                    num_heads,
                                    head_size,
                                    head_size,
                                    dtype=dtype,
                                    device="cuda")

    kv_history_clone = kv_history.clone()

    ref_output, ref_kv_cache = reference_lightning_attention(
        q, k, v, ed, 256, kv_history)

    from vllm.model_executor.layers.lightning_attn import lightning_attention
    actual_output, actual_kv_cache = lightning_attention(
        q, k, v, ed, 256, kv_history_clone)

    atol, rtol = 1.5e-1, 1.5e-1
    torch.testing.assert_close(ref_output, actual_output, rtol=rtol, atol=atol)
    torch.testing.assert_close(ref_kv_cache,
                               actual_kv_cache,
                               rtol=rtol,
                               atol=atol)

    assert ref_output.shape == (batch_size, num_heads, seq_len, head_size)
    assert ref_kv_cache.shape == actual_kv_cache.shape
