import torch
from torch import nn
import triton
import triton.language as tl

from jetengine_ext.utils.context import get_context
from jetengine_ext.engine.sequence import RunType
from jetengine_ext.kernels.triton.attention import sparse_attn_varlen
from jetengine_ext.kernels.triton.attention import fused_kv_cache_attention



@triton.jit
def store_kvcache_kernel(
    key_ptr,
    key_stride,
    value_ptr,
    value_stride,
    k_cache_ptr,
    v_cache_ptr,
    slot_mapping_ptr,
    D: tl.constexpr,
):
    idx = tl.program_id(0)
    key_offsets = idx * key_stride + tl.arange(0, D)
    value_offsets = idx * value_stride + tl.arange(0, D)
    key = tl.load(key_ptr + key_offsets)
    value = tl.load(value_ptr + value_offsets)
    slot = tl.load(slot_mapping_ptr + idx)
    cache_offsets = slot * D + tl.arange(0, D)
    tl.store(k_cache_ptr + cache_offsets, key)
    tl.store(v_cache_ptr + cache_offsets, value)


def store_kvcache(key: torch.Tensor, value: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, slot_mapping: torch.Tensor):
    N, num_heads, head_dim = key.shape
    D = num_heads * head_dim
    assert key.stride(-1) == 1 and value.stride(-1) == 1
    assert key.stride(1) == head_dim and value.stride(1) == head_dim
    assert k_cache.stride(1) == D and v_cache.stride(1) == D
    assert slot_mapping.numel() == N
    store_kvcache_kernel[(N,)](key, key.stride(0), value, value.stride(0), k_cache, v_cache, slot_mapping, D)
    
def store_part_kvcache(key: torch.Tensor, value: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, slot_mapping: torch.Tensor, cu_seqlens_q: torch.Tensor, is_last_denoise_step: torch.Tensor):
    N, num_heads, head_dim = key.shape
    D = num_heads * head_dim
    assert key.stride(-1) == 1 and value.stride(-1) == 1
    assert key.stride(1) == head_dim and value.stride(1) == head_dim
    assert k_cache.stride(1) == D and v_cache.stride(1) == D
    assert slot_mapping.numel() == N
    store_mask = torch.repeat_interleave(is_last_denoise_step, cu_seqlens_q[1:] - cu_seqlens_q[:-1]).unsqueeze(1).unsqueeze(2)
    key = store_mask * key
    value = store_mask * value
    # The store_kvcache_kernel expects the key and value to be masked already
    # so we can just pass them directly
    store_kvcache_kernel[(N,)](key, key.stride(0), value, value.stride(0), k_cache, v_cache, slot_mapping, D)


class Attention(nn.Module):

    def __init__(
        self,
        num_heads,
        head_dim,
        scale,
        num_kv_heads,
    ):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.scale = scale
        self.num_kv_heads = num_kv_heads
        self.k_cache = self.v_cache = torch.tensor([])

    def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
        pass

class BlockAttention(Attention):
    def __init__(
        self,
        num_heads,
        head_dim,
        scale,
        num_kv_heads,
    ):
        super().__init__(num_heads, head_dim, scale, num_kv_heads)
    
    def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
        o: torch.Tensor
        q = q.view(-1, self.num_heads, self.head_dim)
        k = k.view(-1, self.num_kv_heads, self.head_dim)
        v = v.view(-1, self.num_kv_heads, self.head_dim)
        context = get_context()
        k_cache, v_cache = self.k_cache, self.v_cache

        should_store_whole = (context.run_type == RunType.PREFILL)
        if should_store_whole and k_cache.numel() and v_cache.numel():
            store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)

        should_store_part = any(context.is_last_denoise_step)
        if should_store_part and k_cache.numel() and v_cache.numel():
            store_part_kvcache(k, v, k_cache, v_cache, context.slot_mapping, context.cu_seqlens_q, torch.tensor(context.is_last_denoise_step).int().to(k.device))
        
        if context.run_type == RunType.PREFILL:

            o = sparse_attn_varlen(q, k, v,
                                cu_seqlens_q=context.cu_seqlens_q,
                                cu_seqlens_k=context.cu_seqlens_k,
                                staircase_size=context.block_length)
        else:
            o = fused_kv_cache_attention(q, k, v, k_cache, v_cache, 
                                        block_tables=context.block_tables,
                                        cu_seqlens_q=context.cu_seqlens_q, 
                                        cu_seqlens_k=context.cu_seqlens_k,
                                        block_len=context.block_length)           
        o = o.view(-1, self.num_heads * self.head_dim)
        return o

        
