from hip.models.hip_attention.attention2_draft_prefetch import (
    block_sparse_attention, 
    HiPAttentionArgs, 
    load_checkouts
)
import torch

seq_len = 131072
head_dim = 128
num_heads = 32
num_heads_kv = 8
bsz = 2

_, _, _, _, cos, sin = load_checkouts(
    idx=0, 
    window=40, 
    seq_len=seq_len, 
    return_cos_sin=True, 
    derope=True,
    dtype=torch.bfloat16
)

def test_index(target_idx):
    device = 0
    dtype = torch.bfloat16
    q = torch.zeros((bsz, 1, num_heads, head_dim), device=device, dtype=dtype)
    k = torch.zeros((bsz, seq_len, num_heads, head_dim), device=device, dtype=dtype)
    v = torch.zeros((bsz, seq_len, num_heads, head_dim), device=device, dtype=dtype)
    seq_lens = torch.tensor([[seq_len], [seq_len]], dtype=torch.int64, device=device)
    
    args = HiPAttentionArgs(
        mask_k=512,
        block_size_q=1,
        block_stride_q=1,
        block_size_k=8,
        block_stride_k=1,
        sink_token_size=1024,
        sliding_window_size=1024,
        rope_cos=cos,
        rope_sin=sin,
    )
    
    indices = torch.zeros((bsz * num_heads, 1, 1), device=device, dtype=torch.int32)
    indices.fill_(target_idx // 8 * 8)
    ks = torch.zeros((bsz * num_heads, 1), device=device, dtype=torch.int32)
    ks.fill_(1)
    ks_count = ks.unsqueeze(-1)
    ks_start_end = torch.cat([torch.zeros_like(ks.unsqueeze(-1)), ks.unsqueeze(-1)], dim=-1)
    
    q[:, :, :, :] = (torch.arange(0, head_dim, device=device, dtype=dtype) / head_dim)[None, None, None, :]
    k[:, target_idx:target_idx+1, :, :] = (torch.arange(0, head_dim, device=device, dtype=dtype) / head_dim)[None, None, None, :]
    v[:, :, :, 0] = torch.arange(0, seq_len, device=device, dtype=dtype)[None, :, None]
    v[:, target_idx, :, -1] = 1
    
    out = block_sparse_attention(
        q=q, k=k, v=v, 
        seq_lens=seq_lens,
        indices=indices,
        ks=ks,
        ks_count=ks_count,
        ks_start_end=ks_start_end,
        args=args, 
        EXTEND_BACKEND='streaming', 
        model_context_length=131072,
    )
    
    lookup_idx = out[0, 0, 0, 0].item()
    lookup_canary = out[0, 0, 0, -1].item()
    print(target_idx, lookup_idx, lookup_canary)

for i in range(0, seq_len, 371):
    test_index(i)