import torch
from torch import nn
import triton
import triton.language as tl
import torch.nn.functional as F
from typing import Optional
from flash_attn import flash_attn_varlen_func
from wedlm.utils.context import get_context


@triton.jit
def store_kvcache_kernel(
    key_ptr,
    key_stride,
    value_ptr,
    value_stride,
    k_cache_ptr,
    v_cache_ptr,
    slot_mapping_ptr,
    D: tl.constexpr,
):
    idx = tl.program_id(0)
    slot = tl.load(slot_mapping_ptr + idx)
    if slot == -1:
        return
    key_offsets = idx * key_stride + tl.arange(0, D)
    value_offsets = idx * value_stride + tl.arange(0, D)
    key = tl.load(key_ptr + key_offsets)
    value = tl.load(value_ptr + value_offsets)
    cache_offsets = slot * D + tl.arange(0, D)
    tl.store(k_cache_ptr + cache_offsets, key)
    tl.store(v_cache_ptr + cache_offsets, value)


def store_kvcache(key, value, k_cache, v_cache, slot_mapping):
    N, num_heads, head_dim = key.shape
    D = num_heads * head_dim
    store_kvcache_kernel[(N,)](
        key,
        key.stride(0),
        value,
        value.stride(0),
        k_cache,
        v_cache,
        slot_mapping,
        D,
    )


class Attention(nn.Module):
    def __init__(
        self,
        num_heads: int,
        head_dim: int,
        scale: float,
        num_kv_heads: int,
        wedlm_window_size: Optional[int] = None,
        max_context_len: Optional[int] = None,
    ):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.scale = scale
        self.num_kv_heads = num_kv_heads
        self.k_cache = self.v_cache = torch.tensor([])
        self.wedlm_window_size = wedlm_window_size
        self.max_context_len = max_context_len if max_context_len is not None else 4096

    def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
        context = get_context()
        device = q.device

        # 1) Store KV Cache
        k_cache, v_cache = self.k_cache, self.v_cache
        if k_cache.numel() and v_cache.numel():
            store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)

        # 2) Prefill
        if context.is_prefill:
            k_src, v_src = (
                (k_cache, v_cache) if context.block_tables is not None else (k, v)
            )
            return flash_attn_varlen_func(
                q,
                k_src,
                v_src,
                max_seqlen_q=context.max_seqlen_q,
                cu_seqlens_q=context.cu_seqlens_q,
                max_seqlen_k=context.max_seqlen_k,
                cu_seqlens_k=context.cu_seqlens_k,
                softmax_scale=self.scale,
                causal=True,
                block_table=context.block_tables,
            )

        if context.per_seq_wedlm_sizes is None:
            raise RuntimeError(
                "context.per_seq_wedlm_sizes is None inside Attention.forward (Decode mode)."
            )

        per_seq_wedlm_sizes = context.per_seq_wedlm_sizes

        cu_seqlens_q = F.pad(per_seq_wedlm_sizes.cumsum(0), (1, 0)).to(dtype=torch.int32)

        if context.max_seqlen_q > 0:
            max_seqlen_q = context.max_seqlen_q
        else:
            max_seqlen_q = torch.max(per_seq_wedlm_sizes).item()

        prefix_lens = context.context_lens
        k_lens = (prefix_lens + per_seq_wedlm_sizes).to(torch.int32)
        cu_seqlens_k = F.pad(k_lens.cumsum(dim=0), (1, 0)).to(torch.int32)

        return flash_attn_varlen_func(
            q,
            k_cache,
            v_cache,
            cu_seqlens_q=cu_seqlens_q,
            max_seqlen_q=max_seqlen_q,
            cu_seqlens_k=cu_seqlens_k,
            max_seqlen_k=self.max_context_len,
            block_table=context.block_tables,
            softmax_scale=self.scale,
            causal=True,
        )
