from nsa_lib.ops.parallel import parallel_nsa
import torch 
import triton

S=16
B, T, H, HQ, D = 4, 2048, 4, 64, 64
block_size = 64
window_size = 64
dtype = torch.float16
device = 'cuda'
q = torch.randn((B, T, HQ, D), dtype=dtype, device='cuda').requires_grad_(True)
k = torch.randn((B, T, H, D), dtype=dtype, device='cuda').requires_grad_(True)
v = torch.randn((B, T, H, D), dtype=dtype, device='cuda').requires_grad_(True)
g_slc = torch.rand((B, T, HQ), dtype=dtype, device='cuda').requires_grad_(True)
g_swa = torch.rand((B, T, HQ), dtype=dtype, device='cuda').requires_grad_(True)
g_cmp = torch.rand((B, T, HQ), dtype=dtype, device='cuda').requires_grad_(True)
# randomly generated block indices
block_indices = torch.full((B, T, H, S), T, dtype=torch.long, device=device)
for b in range(B):
    for t in range(T):
        for h in range(H):
            i_i = torch.randperm(max(1, triton.cdiv(t, block_size)))[:S]
            block_indices[b, t, h, :len(i_i)] = i_i
block_indices = block_indices.sort(-1)[0]
block_counts = torch.randint(1, S + 1, (B, T, H), device=device)

parallel_nsa(
    q=q,  # [4, 2048, 64, 64]
    k=k,  # [4, 2048, 4, 64]
    v=v,  # [4, 2048, 4, 64]
    g_cmp=g_cmp,
    g_slc=g_slc,  # [4, 2048, 64]
    g_swa=g_swa,
    block_indices=block_indices,  # [4, 2048, 4, 16]
    block_counts=block_counts,  # [4, 2048, 4]
    block_size=block_size,      # 64
    window_size=window_size,    # 64
)

"""
# variable-length inputs are supported as well
# randomly split the sequence into N segments
N, T = 4, 2048
offsets = torch.cat([
    torch.tensor([0], dtype=torch.long),
    torch.arange(16, T)[torch.randperm(T - 1)[:N-1]],
    torch.tensor([T], dtype=torch.long)
], 0).cuda().sort()[0]
# seq-first required for inputs with variable lengths
q = torch.rand((1, T, HQ, D), dtype=dtype, device='cuda').requires_grad_(True)
k = torch.rand((1, T, H, D), dtype=dtype, device='cuda').requires_grad_(True)
v = torch.rand((1, T, H, D), dtype=dtype, device='cuda').requires_grad_(True)
g_slc = torch.rand((B, T, HQ), dtype=dtype, device='cuda').requires_grad_(True)
g_swa = torch.rand((B, T, HQ), dtype=dtype, device='cuda').requires_grad_(True)

# randomly generated block indices
block_indices = torch.full((B, T, H, S), T, dtype=torch.long, device=device)
for b in range(B):
    for t in range(T):
        for h in range(H):
            i_i = torch.randperm(max(1, triton.cdiv(t, block_size)))[:S]
            block_indices[b, t, h, :len(i_i)] = i_i
block_indices = block_indices.sort(-1)[0]
block_counts = torch.randint(1, S + 1, (B, T, H), device=device)

parallel_nsa(
    q=q,
    k=k,
    v=v,
    g_cmp=g_cmp,
    g_slc=g_slc,
    g_swa=g_swa,
    block_indices=block_indices,
    block_counts=block_counts,
    block_size=block_size,
    window_size=window_size,
    cu_seqlens=offsets
)
"""