import torch
import triton
import triton.language as tl
# from .utils import _strides, get_padded_headsize
from typing import Optional, Sequence, Tuple, Union
from HybridTensor.triton.triton_flashattn_decode import maybe_contiguous

def get_padded_headsize(size):
    # Get closest power of 2 over or equal to 32.
    padded_d_model = 1 << (size - 1).bit_length()
    # Smallest head_dim supported is 16. If smaller, the tile in the
    # kernel is padded - there is no padding in memory for any dims.
    padded_d_model = max(padded_d_model, 16)
    return padded_d_model


def _strides(x: torch.Tensor, *stride_names: str):
    if x is None:
        return {f"stride_{s}": 0 for i, s in enumerate(stride_names)}

    assert x.ndim == len(stride_names)
    return {f"stride_{s}": x.stride(i) for i, s in enumerate(stride_names)}


@triton.jit
def _fwd_kernel_splitK(
    Q,
    K,
    V,
    sm_scale,
    Out_splitK,  # [B, H, split_k, Mq, K]
    Metadata,  # [B, H, 2, split_k, M_ceil] contains [mi, li]
    batch_group_index, # new, [B, top_k_groups]
    K_new,
    V_new,
    Cache_seqlens,
    Cache_batch_idx,
    Alibi_slopes,
    stride_qz, stride_qm, stride_qg, stride_qh, stride_qd,
    stride_kz, stride_kn, stride_kg, stride_kh, stride_kd,
    stride_vz, stride_vn, stride_vg, stride_vh, stride_vd,
    stride_osk_zhg, stride_osk_s, stride_osk_m,  stride_osk_d,
    stride_mzhg, stride_m2, stride_ms, stride_mm,
    # add strides for batch_group_index
    stride_bgz, stride_bgk, # new
    stride_kn_z, stride_kn_n, stride_kn_g, stride_kn_h, stride_kn_d,
    stride_vn_z, stride_vn_n, stride_vn_g, stride_vn_h, stride_vn_d,
    stride_az, stride_ah,
    Z,
    N_CTX_Q,
    N_CTX_K,
    N_CTX_NEW,
    BLOCK_N_PER_SPLIT,
    H_q: tl.constexpr,  # Total Heads 
    H_kv: tl.constexpr, # Total Heads for K/V
    G_q: tl.constexpr,  # Number of groups 
    TOP_K_GROUPS: tl.constexpr, # new, Number of groups to be selected for each batch
    BLOCK_M: tl.constexpr,
    BLOCK_DMODEL: tl.constexpr,
    ACTUAL_BLOCK_DMODEL: tl.constexpr,
    BLOCK_N: tl.constexpr,
    BOUNDS_CHECKS_N: tl.constexpr,
    USE_CACHE_SEQLENs: tl.constexpr,
    USE_CACHE_BATCH_IDX: tl.constexpr,
    NEW_KV: tl.constexpr,
    IS_GQA: tl.constexpr,
    IS_CAUSAL: tl.constexpr,
    USE_ALIBI: tl.constexpr,
):
    # Padding
    PADDED_HEAD: tl.constexpr = (ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL)
    if PADDED_HEAD:
        d_mask = tl.arange(0, BLOCK_DMODEL) < ACTUAL_BLOCK_DMODEL

    start_m = tl.program_id(0).to(tl.int32)
    off_zhg = tl.program_id(1).to(tl.int64)
    splitk_idx = tl.program_id(2).to(tl.int64)
    HEAD_RATIO: tl.constexpr = H_q // H_kv  # num of actual heads per group
    
    # off_z = off_zhg // (H_q * G_q)
    # off_h_q = (off_zhg // G_q) % H_q
    
    # off_z = off_zhg // (H_q * TOP_K_GROUPS) # index of current batch
    # off_k = (off_zhg //H_q) % TOP_K_GROUPS  # index of selected group
    # off_h_q = off_zhg  % H_q # index of current head within group (0 to H_q-1)
    
    off_z = off_zhg // (HEAD_RATIO * TOP_K_GROUPS) # index of current batch
    off_k = (off_zhg //HEAD_RATIO) % TOP_K_GROUPS  # index of selected group
    off_g_q = off_zhg % G_q # G_q always 1, off_g_q = 0
    
    # load the batch group index
    batch_group_index_ptr = batch_group_index + off_z * stride_bgz + off_k * stride_bgk
    group_idx = tl.load(batch_group_index_ptr, mask=True, other=0).to(tl.int64)
    
    # off_h_q = off_zhg  % HEAD_RATIO #  index of current head within group (0 to H_q-1)
    off_h_q = group_idx * HEAD_RATIO + off_zhg % HEAD_RATIO # index of current head within group (0 to H_q-1)
    
    # logical_off_zhg = off_z * (G_q * HEAD_RATIO) + off_g_q * HEAD_RATIO + off_h_q
    
    

    # pick batch index
    if USE_CACHE_BATCH_IDX:
        cache_batch_idx = tl.load(Cache_batch_idx + off_z)
    else:
        cache_batch_idx = off_z

    # Load ALiBi slope if enabled
    if USE_ALIBI:
        a_offset = off_z * stride_az + off_h_q * stride_ah
        alibi_slope = tl.load(Alibi_slopes + a_offset)
    else:
        alibi_slope = None

    lo = splitk_idx * BLOCK_N_PER_SPLIT
    if USE_CACHE_SEQLENs:
        cache_seqlen_last_idx = tl.load(Cache_seqlens + off_z)
        if NEW_KV:
            kv_len = cache_seqlen_last_idx + N_CTX_NEW
        else:
            kv_len = cache_seqlen_last_idx
    else:
        kv_len = N_CTX_K
    hi = tl.minimum((splitk_idx + 1) * BLOCK_N_PER_SPLIT, kv_len)
    
    if IS_GQA:
        k_head_idx = off_h_q // HEAD_RATIO
        v_head_idx = k_head_idx
    else:
        k_head_idx = off_h_q
        v_head_idx = off_h_q

    # calculate base offset
    k_base = K + k_head_idx * stride_kh + cache_batch_idx * stride_kz + off_g_q * stride_kg
    v_base = V + v_head_idx * stride_vh + cache_batch_idx * stride_vz + off_g_q * stride_vg

    # Copy new Keys and Values into Cache
    if NEW_KV:
        knew_base = K_new + k_head_idx * stride_kn_h + off_z * stride_kn_z + off_g_q * stride_kn_g
        
        # Determine the starting position for new data in the cache
        if USE_CACHE_SEQLENs:
            start_idx = tl.load(Cache_seqlens + off_z)
        else:
            start_idx = N_CTX_K - N_CTX_NEW

        # Copy new Keys
        for i in range(0, N_CTX_NEW, BLOCK_N):
            # Load from K_new
            k_new_block = tl.load(
                knew_base +
                tl.arange(0, BLOCK_DMODEL)[:, None] * stride_kn_d +
                (tl.arange(0, BLOCK_N) + i)[None, :] * stride_kn_n,
                 mask=(tl.arange(0, BLOCK_N)[None, :] + i < N_CTX_NEW) &
                     (tl.arange(0, BLOCK_DMODEL)[:, None] < ACTUAL_BLOCK_DMODEL),
                other=0
            )
            
            # Store to K
            tl.store(
                k_base +
                tl.arange(0, BLOCK_DMODEL)[:, None] * stride_kd +
                (tl.arange(0, BLOCK_N) + i + start_idx)[None, :] * stride_kn,
                k_new_block,
                 mask=(tl.arange(0, BLOCK_N)[None, :] + i < N_CTX_NEW) &
                     (tl.arange(0, BLOCK_DMODEL)[:, None] < ACTUAL_BLOCK_DMODEL),
            )

        # Copy new Values
        vnew_base = V_new + v_head_idx * stride_vn_h + off_z * stride_vn_z + off_g_q * stride_vn_g
        for i in range(0, N_CTX_NEW, BLOCK_N):
            # Load from V_new
            v_new_block = tl.load(
                vnew_base +
                (tl.arange(0, BLOCK_N) + i)[:, None] * stride_vn_n +
                tl.arange(0, BLOCK_DMODEL)[None, :] * stride_vn_d,
                mask=(tl.arange(0, BLOCK_N)[:, None] + i < N_CTX_NEW) &
                     (tl.arange(0, BLOCK_DMODEL)[None, :] < ACTUAL_BLOCK_DMODEL),
                other=0
            )
            
            # Store to V
            tl.store(
                v_base + 
                (tl.arange(0, BLOCK_N) + i + start_idx)[:, None] * stride_vn +
                tl.arange(0, BLOCK_DMODEL)[None, :] * stride_vd,
                v_new_block,
                 mask=(tl.arange(0, BLOCK_N)[:, None] + i < N_CTX_NEW) &
                     (tl.arange(0, BLOCK_DMODEL)[None, :] < ACTUAL_BLOCK_DMODEL),
            )

    # Q_block_ptr = tl.make_block_ptr(
    #     base=Q + off_h_q * stride_qh + off_z * stride_qz + off_g_q * stride_qg,
    #     shape=(N_CTX_Q, ACTUAL_BLOCK_DMODEL),
    #     strides=(stride_qm, stride_qd),
    #     offsets=(start_m * BLOCK_M, 0),
    #     block_shape=(BLOCK_M, BLOCK_DMODEL),
    #     order=(1, 0),
    # )

    # K_block_ptr = tl.make_block_ptr(
    #     base=k_base,
    #     shape=(ACTUAL_BLOCK_DMODEL, hi),
    #     strides=(stride_kd, stride_kn),
    #     offsets=(0, lo),
    #     block_shape=(BLOCK_DMODEL, BLOCK_N),
    #     order=(0, 1),
    # )
    # V_block_ptr = tl.make_block_ptr(
    #     base=v_base,
    #     shape=(hi, ACTUAL_BLOCK_DMODEL),
    #     strides=(stride_vn, stride_vd),
    #     offsets=(lo, 0),
    #     block_shape=(BLOCK_N, BLOCK_DMODEL),
    #     order=(1, 0),
    # )

    Q_block_ptr = tl.make_block_ptr(
        base=Q + off_h_q.to(tl.int64) * stride_qh + off_z.to(tl.int64) * stride_qz + off_g_q.to(tl.int64) * stride_qg, # base calc uses int64
        shape=(N_CTX_Q, ACTUAL_BLOCK_DMODEL),
        strides=(stride_qm, stride_qd),
        offsets=((start_m * BLOCK_M).to(tl.int32), 0), # <<< Cast offset to int32
        block_shape=(BLOCK_M, BLOCK_DMODEL),
        order=(1, 0),
    )

    # Ensure shapes/bounds used in K/V pointers are int32 where appropriate
    K_block_ptr = tl.make_block_ptr(
        base=k_base,
        shape=(ACTUAL_BLOCK_DMODEL, hi.to(tl.int32)), # <<< Ensure hi is int32 for shape
        strides=(stride_kd, stride_kn),
        offsets=(0, lo.to(tl.int32)), # <<< Ensure lo is int32 for offset
        block_shape=(BLOCK_DMODEL, BLOCK_N),
        order=(0, 1),
    )
    V_block_ptr = tl.make_block_ptr(
        base=v_base,
        shape=(hi.to(tl.int32), ACTUAL_BLOCK_DMODEL), # <<< Ensure hi is int32 for shape
        strides=(stride_vn, stride_vd),
        offsets=(lo.to(tl.int32), 0), # <<< Ensure lo is int32 for offset
        block_shape=(BLOCK_N, BLOCK_DMODEL),
        order=(1, 0),
    )


    K_scale_shift_block_ptr = None
    V_scale_shift_block_ptr = None

    # initialize pointer to m and l
    m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
    l_i = tl.zeros([BLOCK_M], dtype=tl.float32)

    acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)  # noqa: F821

    # scale sm_scale by log_2(e) and use
    # 2^x instead of exp in the loop because CSE and LICM
    # don't work as expected with `exp` in the loop
    qk_scale = sm_scale * 1.44269504
    # load q: it will stay in SRAM throughout
    q = tl.load(  # noqa: F821
        tl.advance(Q_block_ptr, (0, 0)), boundary_check=(0, ))
    q = (q * qk_scale).to(q.dtype)
    if PADDED_HEAD:
        q = tl.where(d_mask[None, :], q, 0.0)

    # loop over k, v and update accumulator
    for start_n in range(lo, hi, BLOCK_N):
        k, v = load_k_v_group(
            K_block_ptr,
            V_block_ptr,
            K_scale_shift_block_ptr,
            V_scale_shift_block_ptr,
            BOUNDS_CHECKS_N,
            1,
            BLOCK_DMODEL,
            ACTUAL_BLOCK_DMODEL,
            Q.dtype.element_ty,
            0,
        )
        if PADDED_HEAD:
            k = tl.where(d_mask[:, None], k, 0.0)
            v = tl.where(d_mask[None, :], v, 0.0)

        # -- compute qk ---
        qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
        qk += tl.dot(q, k)  # noqa: F821

        if USE_ALIBI:
            row_idx = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
            col_idx = start_n + tl.arange(0, BLOCK_N)
            
            # Compute relative positions
            relative_pos = row_idx[:, None] + kv_len - (N_CTX_Q + col_idx[None, :])
            relative_pos = tl.abs(relative_pos)
            
            # Compute ALiBi bias
            alibi_bias = -1 * alibi_slope * relative_pos
            qk += (alibi_bias * 1.44269504)

        # Apply causal mask if IS_CAUSAL is True
        if IS_CAUSAL:
            row_idx = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
            col_idx = start_n + tl.arange(0, BLOCK_N)
            
            # create a N_CTX_Q x kv_len causal mask
            col_offset = N_CTX_Q - kv_len
            causal_mask = row_idx[:, None] >= (col_offset + col_idx[None, :])

            # Apply the mask
            qk = tl.where(causal_mask, qk, float("-inf"))

        # TODO: This is slow, and only needed at the last iteration.
        # Maybe we can unroll the last iteration instead?
        if BOUNDS_CHECKS_N:
            qk = tl.where(tl.arange(0, BLOCK_N) < hi - start_n, qk, float("-inf"))

        # -- compute scaling constant ---
        m_i_new = tl.maximum(m_i, tl.max(qk, 1))
        if IS_CAUSAL:
            alpha = tl.math.exp2(tl.where(m_i > float("-inf"), m_i - m_i_new, float("-inf")))
        else:
            alpha = tl.math.exp2(m_i - m_i_new)
        # cause of nan because subtracting infs
        if IS_CAUSAL:
            qk = tl.where(qk > float("-inf"), qk - m_i_new[:, None], float("-inf"))
        else:
            qk = qk - m_i_new[:, None] 
        
        p = tl.math.exp2(qk)

        # -- update m_i and l_i --
        l_i = l_i * alpha + tl.sum(p, 1)
        m_i = m_i_new
        p = p.to(Q.dtype.element_ty)

        # -- scale and update acc --
        acc *= alpha[:, None]
        acc += tl.dot(p.to(v.dtype), v)
        
        # update pointers
        K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
        V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))

    # write back O
    O_block_ptr = tl.make_block_ptr(
        # base=Out_splitK + logical_off_zhg * stride_osk_zhg + splitk_idx * stride_osk_s,
        base=Out_splitK + off_zhg * stride_osk_zhg + splitk_idx * stride_osk_s,
        shape=(N_CTX_Q, BLOCK_DMODEL),
        strides=(stride_osk_m, 1),
        offsets=(start_m * BLOCK_M, 0),
        block_shape=(BLOCK_M, BLOCK_DMODEL),
        order=(1, 0),
    )
    tl.store(
        tl.advance(O_block_ptr, (0, 0)),
        acc,
        boundary_check=(0, ),
    )
    # Write metadata for split-K reduction
    # Metadata_ptr = (Metadata + logical_off_zhg * stride_mzhg + splitk_idx * stride_ms + start_m * BLOCK_M + tl.arange(0, BLOCK_M))
    Metadata_ptr = (Metadata + off_zhg * stride_mzhg + splitk_idx * stride_ms + start_m * BLOCK_M + tl.arange(0, BLOCK_M))
    
    tl.store(Metadata_ptr, m_i)
    tl.store(Metadata_ptr + stride_m2, l_i)


@triton.jit
def load_k_v_group(
    K_block_ptr,
    V_block_ptr,
    K_scale_shift_block_ptr,
    V_scale_shift_block_ptr,
    BOUNDS_CHECKS_N: tl.constexpr,
    PACKED_PER_VAL: tl.constexpr,
    BLOCK_DMODEL: tl.constexpr,
    ACTUAL_BLOCK_DMODEL: tl.constexpr,
    dtype: tl.constexpr,
    group_id: tl.constexpr,
):
    #Load K/V for a given block

    # Advance to the current quantization group
    K_block_ptr = tl.advance(K_block_ptr, (ACTUAL_BLOCK_DMODEL * group_id, 0))
    V_block_ptr = tl.advance(V_block_ptr, (0, ACTUAL_BLOCK_DMODEL * group_id))

    # -- load k, v --
    k = tl.load(K_block_ptr, boundary_check=(1, ) if BOUNDS_CHECKS_N else ())
    v = tl.load(V_block_ptr, boundary_check=(0, ) if BOUNDS_CHECKS_N else ())

    return k, v


@triton.jit
def cast_uint32_to_half2(scale_shift):
    # Extract two float16 packed into one int32
    scale = scale_shift & 0xFFFF
    shift = scale_shift >> 16
    scale = scale.to(tl.uint16).to(tl.float16, bitcast=True)
    shift = shift.to(tl.uint16).to(tl.float16, bitcast=True)
    return scale, shift


@triton.jit
def dequantize(
    x_,
    scale,
    shift,
    PACKED_PER_VAL: tl.constexpr = 8,
):
    # PACKED_PER_VAL is the number of values packed into
    # each element x_. For example, for int4 quantization
    #and x_ of type int32, PACKED_PER_VAL is 8.

    BLOCK_N: tl.constexpr = x_.shape[0]
    BLOCK_DMODEL_PACKED: tl.constexpr = x_.shape[1]
    offsets = tl.arange(0, PACKED_PER_VAL) * 4
    quant_offset = (x_[:, None, :] >> offsets[None, :, None])  # (BLOCK_N, PACKED_PER_VAL, D // PACKED_PER_VAL)

    quant_offset = tl.view(quant_offset, (BLOCK_N, BLOCK_DMODEL_PACKED * PACKED_PER_VAL))
    # Trick - instead of converting int4 to float16 we view it as float16
    # and then multiply by 32768 * 512 == 2**24
    quant_offset = (quant_offset & 0xF).to(tl.uint16).to(tl.float16, bitcast=True)
    quant_offset = (quant_offset * 32768.0).to(tl.float16)
    scale_512 = scale * 512

    dequant = quant_offset * scale_512 + shift
    return dequant


@triton.jit
def _splitK_reduce(
    Out_splitK,  # [B, H, split_k, Mq, K]
    Metadata,  # [B, H, 2, split_k, M_ceil] contains [mi, li]
    Out,  # [B, H, M, K]
    LSE,  # [B, H, M]
    batch_group_index, # [B, top_k_groups]
    stride_osk_zhg, stride_osk_s, stride_osk_m, stride_osk_k,
    stride_mzhg, stride_m2, stride_ms, stride_mm,
    stride_bgz, stride_bgk,
    stride_oz, stride_oh, stride_og, stride_om, stride_ok,
    stride_lse_zhg, stride_lse_m,
    M_ceil: tl.constexpr,
    BLOCK_SIZE: tl.constexpr,
    H: tl.constexpr,    # Number of query heads
    G: tl.constexpr,    # Number of query groups, fixed to 1
    H_kv: tl.constexpr,  # Number of key/value heads
    TOP_K_GROUPS: tl.constexpr,
    split_k: tl.constexpr,
    splitK_pow2: tl.constexpr,
    use_mask: tl.constexpr,
    IS_CAUSAL: tl.constexpr,
):
    off_zhg = tl.program_id(0).to(tl.int64)
    off_m = tl.program_id(1).to(tl.int64)
    off_k = tl.program_id(2).to(tl.int64)
    HEAD_RATIO: tl.constexpr = H // H_kv  # num of actual heads per group
      
    # off_z = off_zhg // (H * G)
    # off_h = (off_zhg // G) % H
    off_g = off_zhg % G
    
    off_z = off_zhg // (HEAD_RATIO * TOP_K_GROUPS) # index of current batch
    off_bk = (off_zhg //HEAD_RATIO) % TOP_K_GROUPS 
    
    # load the batch group index
    batch_group_index_ptr = batch_group_index + off_z * stride_bgz + off_bk * stride_bgk
    group_idx = tl.load(batch_group_index_ptr, mask=True, other=0).to(tl.int64)
    
    # off_h = off_zhg % H # index of current head within group (0 to H-1)
    off_h = group_idx * HEAD_RATIO + off_zhg % HEAD_RATIO
    
    # logical_off_zhg = off_z * (G * H) + off_g * H + off_h
    
    # read  chunk
    spk_idx = tl.arange(0, splitK_pow2)
    kidx = tl.arange(0, BLOCK_SIZE)

    # Metadata_ptr = (Metadata + stride_mzhg * logical_off_zhg + spk_idx * stride_ms + off_m * stride_mm)
    Metadata_ptr = (Metadata + stride_mzhg * off_zhg + spk_idx * stride_ms + off_m * stride_mm)

    # o_ptr = (Out_splitK + logical_off_zhg * stride_osk_zhg + stride_osk_m * off_m + off_k * BLOCK_SIZE +
            #  stride_osk_s * spk_idx[:, None] + kidx[None, :] * stride_osk_k)
    
    o_ptr = (Out_splitK + off_zhg * stride_osk_zhg + stride_osk_m * off_m + off_k * BLOCK_SIZE +
             stride_osk_s * spk_idx[:, None] + kidx[None, :] * stride_osk_k)

    # read max values of each splitK
    if use_mask:
        spk_mask = spk_idx < split_k
        l_m = tl.load(Metadata_ptr, mask=spk_mask, other=float("-inf"))
        l_sum = tl.load(Metadata_ptr + stride_m2, mask=spk_mask, other=0.0)
        acc = tl.load(o_ptr, mask=spk_mask[:, None], other=0.0)
    else:
        l_m = tl.load(Metadata_ptr)
        l_sum = tl.load(Metadata_ptr + stride_m2)
        acc = tl.load(o_ptr)

    g_m = tl.max(l_m, axis=0)
    
    if IS_CAUSAL:
        l_m_offset = l_m - g_m
        alpha = tl.where(l_m_offset > float("-inf"), tl.math.exp2(l_m_offset), 0.0)
    else:
        alpha = tl.math.exp2(l_m - g_m)

    # read sum
    l_sum *= alpha
    g_sum = tl.sum(l_sum, axis=0)
    acc = acc * alpha[:, None]

    if IS_CAUSAL:
        # Avoid division by zero
        g_sum_safe = tl.where(g_sum > 0, g_sum, 1.0)
        acc_out = tl.sum(acc, axis=0) / g_sum_safe
    else:
        acc_out = tl.sum(acc, axis=0) / g_sum

    # Store output
    Out_ptr = (Out + stride_oz * off_z + stride_oh * off_h + stride_og * off_g + stride_om * off_m +
               off_k * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE))
    tl.store(Out_ptr, acc_out)

    # Store lse
    # l_ptrs = LSE + logical_off_zhg * stride_lse_zhg + off_m
    l_ptrs = LSE + off_zhg * stride_lse_zhg + off_m
    
    if IS_CAUSAL:
        lse = tl.where(g_sum > 0, (g_m + tl.math.log2(g_sum)) / 1.44269504, g_m)
        tl.store(l_ptrs, lse)
    else:
        tl.store(l_ptrs, (g_m + tl.math.log2(g_sum)) / 1.44269504)


def quantize_kv_int4(k: torch.Tensor, num_groups: int = 1) -> torch.Tensor:
    # Scale and shift are such that quantization linearly maps
    # int4 values range [0..15] to input values range min(k)..max(k)
    # individually for every row
    k = k.reshape(*k.shape[:-1], num_groups, k.shape[-1] // num_groups)
    max_vals = torch.max(k, dim=-1, keepdim=True).values
    min_vals = torch.min(k, dim=-1, keepdim=True).values
    scale_k: torch.Tensor = (max_vals - min_vals) / 15

    shift_k = torch.min(k, dim=-1, keepdim=True).values
    scale_k = scale_k.to(torch.float16)
    shift_k = shift_k.to(torch.float16)

    in_bytes = ((k - shift_k.expand(k.shape)) / scale_k.expand(k.shape)) + 0.5
    in_bytes = in_bytes.to(torch.uint8)
    in_int4 = in_bytes & 0xF
    in_int4_packed = in_int4[..., ::2] + (in_int4[..., 1::2] << 4)
    scale_shift = torch.concat([scale_k.view(torch.uint8), shift_k.view(torch.uint8)], dim=-1)
    k_quant = torch.concat(
        [
            scale_shift.flatten(start_dim=-2),
            in_int4_packed.flatten(start_dim=-2),
        ],
        dim=-1,
    ).view(torch.int16)
    return k_quant

def dequantize_kv_fp16(quant_k: torch.Tensor, num_groups: int = 1) -> torch.Tensor:
    k_i16 = quant_k.view(torch.int16)
    k_ui8 = k_i16.view(torch.uint8)

    ss_size = num_groups * 4
    scale_shift_ui8 = k_ui8[..., 0:ss_size]
    scale_shift_ui8 = scale_shift_ui8.reshape(*scale_shift_ui8.shape[:-1], num_groups, 4)
    scale = scale_shift_ui8[..., 0:2].view(torch.float16)
    shift = scale_shift_ui8[..., 2:4].view(torch.float16)

    kv_ui8 = k_ui8[..., ss_size:]
    k_ui8 = kv_ui8.reshape(*kv_ui8.shape[:-1], num_groups, -1)
    k1_i4 = k_ui8 & 0xF
    k2_i4 = (k_ui8 & 0xF0) >> 4
    k_shape = k1_i4.shape
    k1_f16 = k1_i4.to(torch.float16) * scale.expand(k_shape) + shift.expand(k_shape)
    k2_f16 = k2_i4.to(torch.float16) * scale.expand(k_shape) + shift.expand(k_shape)

    out = torch.empty((*k1_f16.shape[:-1], k1_f16.shape[-1] * 2), dtype=torch.float16, device=quant_k.device)
    out[..., ::2] = k1_f16
    out[..., 1::2] = k2_f16
    out = out.reshape(*k_shape[:-2], -1)

    return out


def get_split_k(B: int, G: int, H: int, Mk: int) -> int:
    """Heuristic for the number of splits"""
    bh = max(B * H, 1)  # NOTE: Handle B*h=0 case
    split_k = max(Mk, 1024) // bh
    max_chunk_size = 64
    while split_k > 0 and Mk / split_k < max_chunk_size:
        split_k = split_k // 2
    while B * H * G * split_k >= 1024:
        split_k = split_k // 2
    split_k = min(split_k, 512)
    split_k = max(split_k, 1)
    return split_k

def select_gqa(q, k, v, 
               sm_scale, causal,
               alibi_slopes, layout,
               cache_seqlens, cache_batch_idx,
               new_kv, k_new, v_new,
               batch_group_index):
    
    # batch_group_index:  Index used to select the groups to be activated, shape [batch_size, top_k_groups]
    
    # kernel config
    BLOCK_M = 16
    BLOCK_N = 64
    SPLIT_K = None
    NUM_QUANT_GROUPS = 1

    # kernels expects "bsghd"
    original_layout = layout
    if layout == "bshd":
        q=q.unsqueeze(2)
        k=k.unsqueeze(2)
        v=v.unsqueeze(2)
        if new_kv:
            k_new = k_new.unsqueeze(2)
            v_new = v_new.unsqueeze(2)
        layout = "bsghd"
    elif layout == "bhsd":
        q=q.permute(0, 2, 1, 3).unsqueeze(2)
        k=k.permute(0, 2, 1, 3).unsqueeze(2)
        v=v.permute(0, 2, 1, 3).unsqueeze(2)
        if new_kv:
            k_new = k_new.permute(0, 2, 1, 3).unsqueeze(2)
            v_new = v_new.permute(0, 2, 1, 3).unsqueeze(2)
        layout = "bsghd"
    elif layout == "bsghd":
        pass
    elif layout is None:
        raise ValueError("Layout not given")
    assert layout == "bsghd"

    # get dims
    batch_size, seqlen_q, n_group_q, heads_per_group_q, dim_q = q.shape
    _, seqlen_k, n_group_k, heads_per_group_k, dim_k = k.shape
    _, seqlen_v, n_group_v, heads_per_group_v, dim_v = v.shape
    num_actual_groups = heads_per_group_k
    heads_per_actual_group = heads_per_group_q // num_actual_groups
    
    top_k_groups = batch_group_index.shape[1]
    
    num_actual_groups = heads_per_group_k
    heads_per_actual_group = heads_per_group_q // num_actual_groups

    assert dim_q == dim_k == dim_v, f"Dimensions must match: {dim_q}, {dim_k}, {dim_v}"

    # get padded size
    dim_padded  = get_padded_headsize(dim_k)

    # Handle MQA/GQA case
    if heads_per_group_q > heads_per_group_k:
        is_gqa = True
    elif heads_per_group_q < heads_per_group_k:
        raise ValueError("heads_per_group_q < heads_per_group_k")
    else:
        is_gqa = False

    assert dim_k == dim_q, f"Keys have head dim {dim_k} but queries have head dim {dim_q}"

    if SPLIT_K is not None:
        split_k = SPLIT_K
    else:
        # Use heuristics
        split_k = get_split_k(batch_size, num_actual_groups, heads_per_actual_group, seqlen_k) # NOTE: should the split think about seqlens?

    seqlen_q_ceil = (seqlen_q + BLOCK_M - 1) // BLOCK_M * BLOCK_M
    out_splitk = torch.empty([batch_size * num_actual_groups * heads_per_actual_group, split_k, seqlen_q_ceil, dim_padded], dtype=torch.float32, device=q.device)
    metadata = torch.empty([batch_size * num_actual_groups * heads_per_actual_group, 2, split_k, seqlen_q_ceil], dtype=torch.float32, device=q.device)

    
    lse = torch.empty((batch_size * num_actual_groups * heads_per_actual_group, seqlen_q), device=q.device, dtype=torch.float32)

    # grid = (triton.cdiv(seqlen_q, BLOCK_M), batch_size * n_group_q * heads_per_group_q, split_k)
    grid_fwd = (triton.cdiv(seqlen_q, BLOCK_M), batch_size * top_k_groups * heads_per_actual_group, split_k)
    # print(f"grid_fwd: {grid_fwd}")
    
    num_warps = 1
    split_size = (seqlen_k + split_k - 1) // split_k
    use_cache_seqlens = cache_seqlens is not None
    # just before you launch either kernel
    assert batch_group_index.is_contiguous()
    stride_bgz = batch_group_index.stride(0)   # = top_k_groups
    stride_bgk = 1                             # NOT batch_group_index.stride(1)!
    
    # TODO: enable quantization
    _fwd_kernel_splitK[grid_fwd](
        Q=q,
        K=k,
        V=v,
        sm_scale=sm_scale,
        Out_splitK=out_splitk,
        Metadata=metadata,
        batch_group_index=batch_group_index,
        K_new = k_new,
        V_new = v_new,
        Cache_seqlens=cache_seqlens,
        Cache_batch_idx=cache_batch_idx,
        Alibi_slopes=alibi_slopes,
        **_strides(q, "qz", "qm", "qg", "qh", "qd"),
        **_strides(k, "kz", "kn", "kg", "kh", "kd"),
        **_strides(v, "vz", "vn", "vg", "vh", "vd"),
        **_strides(out_splitk, "osk_zhg", "osk_s", "osk_m", "osk_d"),
        **_strides(metadata, "mzhg", "m2", "ms", "mm"),
        # add strides for batch_group_index
        # **_strides(batch_group_index, "bgz", "bgk"),    # new, bgz: stride for batch, bgk: stride for top_k_groups
        stride_bgz=stride_bgz, stride_bgk=stride_bgk,
        **_strides(k_new, "kn_z", "kn_n", "kn_g", "kn_h", "kn_d"),
        **_strides(v_new, "vn_z", "vn_n", "vn_g", "vn_h", "vn_d"),
        **_strides(alibi_slopes, "az", "ah"),
        Z=batch_size,
        H_q=heads_per_group_q,
        H_kv=heads_per_group_k,
        G_q=n_group_q,
        TOP_K_GROUPS = int(top_k_groups),
        N_CTX_Q=seqlen_q,
        N_CTX_K=seqlen_k,
        N_CTX_NEW=k_new.shape[1] if new_kv else None,
        BLOCK_N_PER_SPLIT=split_size,
        BLOCK_M=BLOCK_M,
        BLOCK_N=BLOCK_N,
        BLOCK_DMODEL=dim_padded,
        ACTUAL_BLOCK_DMODEL=dim_k,
        BOUNDS_CHECKS_N=(split_size % BLOCK_N) > 0 or use_cache_seqlens,
        USE_CACHE_SEQLENs=use_cache_seqlens,
        USE_CACHE_BATCH_IDX=cache_batch_idx is not None,
        NEW_KV=new_kv,
        IS_GQA=is_gqa,
        IS_CAUSAL=causal,
        USE_ALIBI=False if alibi_slopes is None else True,
        num_warps=num_warps,
        num_stages=1,
    )

    # out = torch.empty((batch_size, seqlen_q, n_group_q, heads_per_group_q, dim_padded), device=q.device, dtype=q.dtype)
    out = torch.zeros((batch_size, seqlen_q, n_group_q, heads_per_group_q, dim_padded), device=q.device, dtype=q.dtype)
    
    # Merge together
    splitK_pow2 = triton.next_power_of_2(split_k)
    use_mask = splitK_pow2 > split_k
    if batch_size * num_actual_groups * heads_per_actual_group * seqlen_q >= 512:
        k_block_num = 1
    else:
        k_block_num = 2
    assert dim_padded % k_block_num == 0
    k_block_size = dim_padded // k_block_num
    grid_reduce = (batch_size * top_k_groups * heads_per_actual_group, seqlen_q, k_block_num)

    _splitK_reduce[grid_reduce](
        out_splitk, 
        metadata, 
        out, 
        lse, 
        batch_group_index,
        **_strides(out_splitk, "osk_zhg", "osk_s", "osk_m", "osk_k"),
        **_strides(metadata, "mzhg", "m2", "ms", "mm"), 
        # **_strides(batch_group_index, "bgz", "bgk"),    # new, bgz: stride for batch, bgk: stride for top_k_groups
        stride_bgz=stride_bgz, stride_bgk=stride_bgk,
        **_strides(out, "oz", "om", "og", "oh", "ok"),
        **_strides(lse, "lse_zhg", "lse_m"), 
        M_ceil=seqlen_q_ceil, 
        BLOCK_SIZE=k_block_size, 
        G= n_group_q, 
        TOP_K_GROUPS = int(top_k_groups),
        H=heads_per_group_q,
        H_kv=heads_per_group_k,
        # TODO: Tune num_warps
        split_k=split_k, 
        splitK_pow2=splitK_pow2, 
        use_mask=use_mask,
        IS_CAUSAL=causal,
        num_warps=4)

    lse = lse.reshape([batch_size, n_group_q, heads_per_group_q, seqlen_q])
    if q.ndim == 4:
        # BMGHK -> BMHK
        assert n_group_q == 1
        out = out[:, :, 0]
        lse = lse[:, 0]
    if seqlen_k == 0:
        out.zero_()
    out = out.reshape(batch_size, heads_per_group_q * n_group_q, -1, dim_padded).contiguous()

    # output is batch_size, heads_per_group_q * group_q, seqlen_q, dim_q
    if original_layout == "bshd":
        # out=out.transpose(1, 2).contiguous() # this screws up heads and data.
        # the data is laid out properly. Just need to reshape dims
        out = out.reshape(batch_size, seqlen_q, -1, dim_padded)

    return out.narrow(-1, 0, dim_k), lse