import warnings
from typing import Optional, Union

import torch
import triton
import triton.language as tl
from einops import rearrange

from fla.ops.common.utils import prepare_chunk_indices, prepare_chunk_offsets, prepare_lens, prepare_token_indices
from fla.ops.utils import mean_pooling
from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous
from nsa_lib.ops.utils import _bitonic_merge
import pdb
try:
    from flash_attn import flash_attn_func, flash_attn_varlen_func
except ImportError:
    warnings.warn(
        "Flash Attention is not installed. Please install it via `pip install flash-attn --no-build-isolation`",
        category=ImportWarning
    )
    flash_attn_func = None

@triton.jit
def parallel_nsa_fwd_kernel(
    q,
    k,
    v,
    o,
    lse,
    scale,
    block_indices,
    block_counts,
    offsets,
    token_indices,
    T,
    H: tl.constexpr,
    HQ: tl.constexpr,
    G: tl.constexpr,
    K: tl.constexpr,
    V: tl.constexpr,
    S: tl.constexpr,
    BS: tl.constexpr,
    BK: tl.constexpr,
    BV: tl.constexpr,
    USE_OFFSETS: tl.constexpr,
    USE_BLOCK_COUNTS: tl.constexpr
):
    i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
    i_b, i_h = i_bh // H, i_bh % H
    pdb.set_trace()
    if USE_OFFSETS:
        i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + 1).to(tl.int32)
        bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
        T = eos - bos
    else:
        bos, eos = i_b * T, i_b * T + T

    k += (bos * H + i_h) * K
    v += (bos * H + i_h) * V
    block_indices += (bos + i_t) * H*S + i_h * S

    if USE_BLOCK_COUNTS:
        NS = tl.load(block_counts + (bos + i_t) * H + i_h)
    else:
        NS = S

    p_q = tl.make_block_ptr(q + (bos + i_t) * HQ*K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0))
    # the Q block is kept in the shared memory throughout the whole kernel
    # [G, BK]
    b_q = tl.load(p_q, boundary_check=(0, 1))
    b_q = (b_q * scale).to(b_q.dtype)

    p_o = tl.make_block_ptr(o + (bos + i_t) * HQ*V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0))
    p_lse = lse + (bos + i_t) * HQ + i_h * G + tl.arange(0, G)
    # [G, BV]
    b_o = tl.zeros([G, BV], dtype=tl.float32)

    b_m = tl.full([G], float('-inf'), dtype=tl.float32)
    b_acc = tl.zeros([G], dtype=tl.float32)
    for i in range(NS):
        i_s = tl.load(block_indices + i).to(tl.int32) * BS
        if i_s <= i_t and i_s >= 0:
            p_k = tl.make_block_ptr(k, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1))
            p_v = tl.make_block_ptr(v, (T, V), (H*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
            # [BK, BS]
            b_k = tl.load(p_k, boundary_check=(0, 1))                                               
            # [BS, BV]
            b_v = tl.load(p_v, boundary_check=(0, 1))
            # [G, BS]
            
            b_s = tl.dot(b_q, b_k)  # b_q [4, 16] [head, dim] // b_k [16, 32]  [dim, block_size]
            b_s = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s, float('-inf'))  # only for block

            # [G]
            b_m, b_mp = tl.maximum(b_m, tl.max(b_s, 1)), b_m
            b_r = tl.exp(b_mp - b_m)
            # [G, BS]
            b_p = tl.exp(b_s - b_m[:, None])
            # [G]
            b_acc = b_acc * b_r + tl.sum(b_p, 1)
            # [G, BV]
            b_o = b_o * b_r[:, None] + tl.dot(b_p.to(b_q.dtype), b_v)

            b_mp = b_m
    b_o = b_o / b_acc[:, None]
    b_m += tl.log(b_acc)
    tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
    tl.store(p_lse, b_m.to(p_lse.dtype.element_ty))
    
def run_debug_kernel():
    # Shapes: small values for debug
    B, T, H, HQ, K, V, S, BS = 1, 64, 2, 8, 16, 16, 4, 32
    G = HQ // H
    BK = K
    BV = V

    q = torch.randn(B, T, HQ, K, dtype=torch.float32, device='cuda')
    k = torch.randn(B, T, H, K, dtype=torch.float32, device='cuda')
    v = torch.randn(B, T, H, V, dtype=torch.float32, device='cuda')
    block_indices = torch.zeros(B, T, H, S, dtype=torch.int32, device='cuda')
    for t in range(T):
        for s in range(S):
            block_indices[0, t, 0, s] = min(t // BS, (T // BS) - 1 - s)  # dummy

    o = torch.zeros(B, T, HQ, V, dtype=torch.float32, device='cuda')
    lse = torch.zeros(B, T, HQ, dtype=torch.float32, device='cuda')
    scale = 1.0

    grid = (T, 1, 1)
    parallel_nsa_fwd_kernel[grid](
        q, k, v, o, lse, scale, block_indices,  # q: [1,64,4,16] k: [1,64,1,16] v: [1,64,1,16]
        None, None, None, T,
        H, HQ, G, K, V, S, BS, BK, BV,
        False, False
    )
    print("Output o:", o.cpu().numpy())
    print("Logsumexp lse:", lse.cpu().numpy())

if __name__ == "__main__":
    run_debug_kernel()