# -*- codeing: utf-8 -*-

import triton
from triton import language as tl
import torch

import os
import math

#####################################################################################################################

def load_chunk_utils():
    from torch.utils.cpp_extension import load
    dir_path = os.path.dirname(os.path.realpath(__file__))
    chunk_utils = load(name='chunk_utils', 
                       extra_include_paths=['include'], 
                       sources=[os.path.join(dir_path, './cuda/chunk_utils.cpp'), os.path.join(dir_path, './cuda/chunk_function.cu'), os.path.join(dir_path, './cuda/stream_manager.cpp')], 
                       verbose=True)
    return chunk_utils


chunk_utils = load_chunk_utils()

#####################################################################################################################3

@triton.jit
def pre_sc_top_1_idx_kernel(Q, C, I, 
                            stride_qz, stride_qh, stride_qm, stride_qd, 
                            stride_cz, stride_ch, stride_cc, stride_cd, 
                            stride_iz, stride_ih, stride_im, 
                            Z, H, N_CTX, N_CEN, 
                            CHUNK_SIZE: tl.constexpr, 
                            BLOCK_M: tl.constexpr, 
                            BLOCK_C: tl.constexpr, 
                            BLOCK_D: tl.constexpr):
    tl.static_assert(BLOCK_M <= CHUNK_SIZE)
    
    pid_m = tl.program_id(0)
    if pid_m * BLOCK_M < CHUNK_SIZE:
        return
    
    pid_zh = tl.program_id(1)
    off_z = pid_zh // H
    off_h = pid_zh % H
    
    offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    mask_m = offs_m < N_CTX
    offs_d = tl.arange(0, BLOCK_D)
    
    Q_ptr = Q + off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh
    C_ptr = C + off_z.to(tl.int64) * stride_cz + off_h.to(tl.int64) * stride_ch
    I_ptr = I + off_z.to(tl.int64) * stride_iz + off_h.to(tl.int64) * stride_ih
    
    q = tl.load(Q_ptr + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd, mask=mask_m[:, None])
    
    m_val = tl.full((BLOCK_M, BLOCK_C), value=-float("inf"), dtype=tl.float32)
    m_idx = tl.full((BLOCK_M, BLOCK_C), value=-1, dtype=tl.int32)
    
    MAX_C = (pid_m * BLOCK_M) // CHUNK_SIZE
    BLOCK_C_RANGE = tl.arange(0, BLOCK_C)
    
    for start_c in range(0, (MAX_C // BLOCK_C) * BLOCK_C, BLOCK_C):
        offs_c = start_c + BLOCK_C_RANGE 
        c = tl.load(C_ptr + offs_c[:, None] * stride_cc + offs_d[None, :] * stride_cd)
        s = tl.dot(q, tl.trans(c)).to(tl.float32)
        m_s = s > m_val
        m_val = tl.maximum(m_val, s)
        m_idx = m_idx * (1 - m_s) + offs_c * m_s
    
    for start_c in range((MAX_C // BLOCK_C) * BLOCK_C, MAX_C, BLOCK_C):
        offs_c = start_c + BLOCK_C_RANGE
        mask_c = offs_c < MAX_C
        c = tl.load(C_ptr + offs_c[:, None] * stride_cc + offs_d[None, :] * stride_cd, mask=mask_c[:, None], other=-float("inf"))
        s = tl.dot(q, tl.trans(c)).to(tl.float32)
        s = tl.where(mask_c[None, :], s, -float("inf"))
        m_s = s > m_val
        m_val = tl.maximum(m_val, s)
        m_idx = m_idx * (1 - m_s) + offs_c * m_s
    
    m_i = tl.argmax(m_val, 1)
    m_s = (tl.expand_dims(m_i, 1) == tl.expand_dims(BLOCK_C_RANGE, 0))
    m_idx = m_idx * m_s
    m_idx = tl.sum(m_idx, 1)
    
    offs_i = offs_m - CHUNK_SIZE
    mask_i = mask_m
    
    tl.store(I_ptr + offs_i * stride_im, m_idx.to(I.dtype.element_ty), mask=mask_i)


def pre_sc_top_1_idx(Q, C, chunk_size):
    assert (len(Q.shape) == 4 and len(C.shape) == 4)
    assert (Q.shape[0] == C.shape[0])
    assert (Q.shape[1] == C.shape[1])
    assert (Q.shape[2] >= C.shape[2] * chunk_size)
    assert (Q.shape[2] >= chunk_size)
    assert (Q.shape[3] == C.shape[3])
    assert (Q.shape[3] in [2, 4, 8, 16, 32, 64, 128, 256])
    
    Z = Q.shape[0]
    H = Q.shape[1]
    N_CTX = Q.shape[2]
    N_CEN = C.shape[2]
    CHUNK_SIZE = chunk_size
    D = Q.shape[3]
    BLOCK_M = chunk_size
    BLOCK_C = chunk_size // 2
    BLOCK_D = D
    
    I = torch.full((Z, H, N_CTX - CHUNK_SIZE), -1, dtype=torch.long, device=Q.device)
    
    grid = (triton.cdiv(N_CTX, BLOCK_M), Z * H, 1)
    
    pre_sc_top_1_idx_kernel[grid](Q, C, I, 
                                  Q.stride(0), Q.stride(1), Q.stride(2), Q.stride(3), 
                                  C.stride(0), C.stride(1), C.stride(2), C.stride(3), 
                                  I.stride(0), I.stride(1), I.stride(2), 
                                  Z, H, N_CTX, N_CEN, CHUNK_SIZE, BLOCK_M, BLOCK_C, BLOCK_D, num_stages=4, num_warps=4)
    
    return I


@triton.jit
def pre_sc_l_chunk_s1_block_fwd_kernel(Q, K, V, offset, index, sm_scale, O, L, M, 
                                       stride_qz, stride_qh, stride_qm, stride_qd, 
                                       stride_kz, stride_kh, stride_kn, stride_kd, 
                                       stride_vz, stride_vh, stride_vn, stride_vd, 
                                       stride_fz, stride_fh, stride_fk, 
                                       stride_iz, stride_ih, stride_im, 
                                       stride_oz, stride_oh, stride_om, stride_od, 
                                       stride_lz, stride_lh, stride_lm, 
                                       stride_mz, stride_mh, stride_mm, 
                                       Z, H, N_CTX, 
                                       BLOCK_M: tl.constexpr, 
                                       BLOCK_N: tl.constexpr, 
                                       BLOCK_D: tl.constexpr):
    pid_n = tl.program_id(0)
    pid_zh = tl.program_id(1)
    off_z = pid_zh // H
    off_h = pid_zh % H
    qkv_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh

    K_block_ptr = tl.make_block_ptr(
        base=K + qkv_offset,
        shape=(BLOCK_D, N_CTX),
        strides=(stride_kd, stride_kn),
        offsets=(0, pid_n * BLOCK_N),
        block_shape=(BLOCK_D, BLOCK_N),
        order=(0, 1)
    )
    V_block_ptr = tl.make_block_ptr(
        base=V + qkv_offset,
        shape=(N_CTX, BLOCK_D),
        strides=(stride_vn, stride_vd),
        offsets=(pid_n * BLOCK_N, 0),
        block_shape=(BLOCK_N, BLOCK_D),
        order=(1, 0)
    )
    
    k = tl.load(K_block_ptr)
    v = tl.load(V_block_ptr)
    
    qk_scale = sm_scale.to(tl.float32) * 1.44269504
    
    offset_ptr = offset + off_z.to(tl.int64) * stride_fz + off_h.to(tl.int64) * stride_fh
    if pid_n == 0:
        i_start = 0
    else:
        i_start = tl.load(offset_ptr + pid_n - 1)
    i_end = tl.load(offset_ptr + pid_n)
    i_numel = i_end - i_start

    if i_numel == 0:
        return
    
    index_ptr = index + off_z.to(tl.int64) * stride_iz + off_h.to(tl.int64) * stride_ih
    L_ptr = L + off_z.to(tl.int64) * stride_lz + off_h.to(tl.int64) * stride_lh
    M_ptr = M + off_z.to(tl.int64) * stride_mz + off_h.to(tl.int64) * stride_mh
    Q_ptr = Q + qkv_offset
    O_ptr = O + qkv_offset
    
    for start_i in range(0, i_numel, BLOCK_M):
        offs_i = start_i + tl.arange(0, BLOCK_M)
        mask_i = offs_i < i_numel
        idx = tl.load(index_ptr + i_start + offs_i, mask=mask_i, other=0)
        offs_d = tl.arange(0, BLOCK_D)
        q = tl.load(Q_ptr + idx[:, None] * stride_qm + offs_d[None, :] * stride_qd, mask=mask_i[:, None], other=-float("inf"))
        qk = tl.dot(q, k).to(tl.float32)
        m = tl.max(qk, 1) * qk_scale
        qk = qk * qk_scale - m[:, None]
        p = tl.math.exp2(qk)
        l = tl.sum(p, 1)
        o = tl.dot(p.to(V.dtype.element_ty), v)
        tl.store(O_ptr + idx[:, None] * stride_om + offs_d[None, :] * stride_od, o.to(O.dtype.element_ty), mask=mask_i[:, None])
        tl.store(L_ptr + idx * stride_lm, l.to(L.dtype.element_ty), mask=mask_i)
        tl.store(M_ptr + idx * stride_mm, m.to(M.dtype.element_ty), mask=mask_i)


@triton.jit
def pre_sc_l_chunk_s1_diag_fwd_kernel(Q, K, V, O, L, M, sm_scale, 
                                      stride_qz, stride_qh, stride_qm, stride_qd, 
                                      stride_kz, stride_kh, stride_kn, stride_kd, 
                                      stride_vz, stride_vh, stride_vn, stride_vd, 
                                      stride_oz, stride_oh, stride_om, stride_od, 
                                      stride_lz, stride_lh, stride_lm, 
                                      stride_mz, stride_mh, stride_mm, 
                                      Z, H, N_CTX, 
                                      BLOCK_M: tl.constexpr, 
                                      BLOCK_M_I: tl.constexpr, 
                                      BLOCK_N: tl.constexpr, 
                                      BLOCK_D: tl.constexpr):
    tl.static_assert(BLOCK_M == BLOCK_N)
    
    pid_mn = tl.program_id(0)
    pid_zh = tl.program_id(1)
    off_z = pid_zh // H
    off_h = pid_zh % H
    qkv_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh
    lm_offset = off_z.to(tl.int64) * stride_lz + off_h.to(tl.int64) * stride_lh
    
    Q_block_ptr = tl.make_block_ptr(
        base=Q + qkv_offset,
        shape=(N_CTX, BLOCK_D),
        strides=(stride_qm, stride_qd),
        offsets=(pid_mn * BLOCK_M, 0),
        block_shape=(BLOCK_M_I, BLOCK_D),
        order=(1, 0)
    )
    K_block_ptr = tl.make_block_ptr(
        base=K + qkv_offset,
        shape=(BLOCK_D, N_CTX),
        strides=(stride_kd, stride_kn),
        offsets=(0, pid_mn * BLOCK_N),
        block_shape=(BLOCK_D, BLOCK_N),
        order=(0, 1)
    )
    V_block_ptr = tl.make_block_ptr(
        base=V + qkv_offset,
        shape=(N_CTX, BLOCK_D),
        strides=(stride_vn, stride_vd),
        offsets=(pid_mn * BLOCK_N, 0),
        block_shape=(BLOCK_N, BLOCK_D),
        order=(1, 0)
    )
    O_block_ptr = tl.make_block_ptr(
        base=O + qkv_offset,
        shape=(N_CTX, BLOCK_D),
        strides=(stride_om, stride_od),
        offsets=(pid_mn * BLOCK_M, 0),
        block_shape=(BLOCK_M_I, BLOCK_D),
        order=(1, 0)
    )
    L_block_ptr = tl.make_block_ptr(
        base=L + lm_offset,
        shape=(N_CTX, ),
        strides=(stride_lm, ),
        offsets=(pid_mn * BLOCK_M, ),
        block_shape=(BLOCK_M_I, ),
        order=(0, )
    )
    M_block_ptr = tl.make_block_ptr(
        base=M + lm_offset,
        shape=(N_CTX, ),
        strides=(stride_mm, ),
        offsets=(pid_mn * BLOCK_M, ),
        block_shape=(BLOCK_M_I, ),
        order=(0, )
    )
    
    qk_scale = sm_scale.to(tl.float32) * 1.44269504
    
    k = tl.load(K_block_ptr, boundary_check=(0, 1), padding_option='zero')
    v = tl.load(V_block_ptr, boundary_check=(0, 1), padding_option='zero')
    
    for start_m in range(0, BLOCK_M, BLOCK_M_I):
        mask = (start_m + tl.arange(0, BLOCK_M_I)[:, None]) > tl.arange(0, BLOCK_N)[None, :]
        neg_mask = -1.0e6 * (1.0 - mask)
    
        q = tl.load(Q_block_ptr, boundary_check=(0, 1), padding_option='zero')
        qk = tl.dot(q, k).to(tl.float32)
        qk = qk * qk_scale + neg_mask
        
        m_i = tl.load(M_block_ptr, boundary_check=(0, ), padding_option='zero')
        m_ij = tl.maximum(m_i, tl.max(qk, 1))
        qk -= m_ij[:, None]
        
        p = tl.math.exp2(qk) * mask
        l_ij = tl.sum(p, 1)
        l_i = tl.load(L_block_ptr, boundary_check=(0, ), padding_option='zero')
        
        alpha = tl.math.exp2(m_i - m_ij)
        l_i = l_i * alpha + l_ij
        
        o = tl.load(O_block_ptr, boundary_check=(0, 1), padding_option='zero')
        o = o * alpha[:, None]
        
        o += tl.dot(p.to(V.dtype.element_ty), v)
        o = o / (l_i[:, None] + 1e-7)
        tl.store(O_block_ptr, o.to(O.dtype.element_ty), boundary_check=(0, 1))
        
        Q_block_ptr = tl.advance(Q_block_ptr, (BLOCK_M_I, 0))
        M_block_ptr = tl.advance(M_block_ptr, (BLOCK_M_I, ))
        L_block_ptr = tl.advance(L_block_ptr, (BLOCK_M_I, ))
        O_block_ptr = tl.advance(O_block_ptr, (BLOCK_M_I, 0))


def pre_sc_l_chunk_s1_fwd(Q, K, V, offset, index, sm_scale, chunk_size):
    assert (len(Q.shape) == 4 and len(K.shape) == 4 and len(V.shape) == 4)
    assert (Q.shape[0] == K.shape[0] and K.shape[0] == V.shape[0])
    assert (Q.shape[1] == K.shape[1] and K.shape[1] == V.shape[1])
    assert (Q.shape[2] == K.shape[2] and K.shape[2] == V.shape[2])
    assert (Q.shape[3] == K.shape[3] and K.shape[3] == V.shape[3])
    assert (Q.shape[3] in [2, 4, 8, 16, 32, 64, 128, 256])
    
    if offset is not None:
        assert (len(offset.shape) == 3)
        assert (V.shape[0] == offset.shape[0])
        assert (V.shape[1] == offset.shape[1])
        assert (Q.shape[2] // chunk_size == offset.shape[2])
        
    if index is not None:
        assert (len(index.shape) == 3)
        assert (V.shape[0] == index.shape[0])
        assert (V.shape[1] == index.shape[1])
        assert (Q.shape[2] == index.shape[2] + chunk_size)
    
    Z = Q.shape[0]
    H = Q.shape[1]
    N_CTX = Q.shape[2]
    DIM = Q.shape[3]
    BLOCK_D = DIM
    BLOCK_MN = chunk_size
    BLOCK_N = chunk_size
    
    BLOCK_M1 = chunk_size // 2
    BLOCK_M2_I = chunk_size // 1
    num_stages_1 = 4
    num_warps_1 = 8
    num_stages_2 = 4
    num_warps_2 = 8                   # H100 / H800时Triton存在Bug, 出现布局错误, num_warps_2应设置为4
    
    O = torch.zeros((Z, H, N_CTX, DIM), dtype=Q.dtype, device=Q.device)
    L = torch.zeros((Z, H, N_CTX), dtype=torch.float32, device=Q.device)
    M = torch.ones((Z, H, N_CTX), dtype=torch.float32, device=Q.device) * -torch.inf
    
    if offset is not None and index is not None:
        if N_CTX % BLOCK_N == 0:
            grid = (N_CTX // BLOCK_N - 1, Z * H, 1)
        else:
            grid = (N_CTX // BLOCK_N, Z * H, 1)
        
        pre_sc_l_chunk_s1_block_fwd_kernel[grid](Q, K, V, offset, index, sm_scale, O, L, M, 
                                                 Q.stride(0), Q.stride(1), Q.stride(2), Q.stride(3), 
                                                 K.stride(0), K.stride(1), K.stride(2), K.stride(3), 
                                                 V.stride(0), V.stride(1), V.stride(2), V.stride(3), 
                                                 offset.stride(0), offset.stride(1), offset.stride(2), 
                                                 index.stride(0), index.stride(1), index.stride(2), 
                                                 O.stride(0), O.stride(1), O.stride(2), O.stride(3), 
                                                 L.stride(0), L.stride(1), L.stride(2), 
                                                 M.stride(0), M.stride(1), M.stride(2), 
                                                 Z, H, N_CTX, BLOCK_M1, BLOCK_N, BLOCK_D, num_stages=num_stages_1, num_warps=num_warps_1)
    
    grid = (triton.cdiv(N_CTX, chunk_size), Z * H, 1)
    
    pre_sc_l_chunk_s1_diag_fwd_kernel[grid](Q, K, V, O, L, M, sm_scale, 
                                            Q.stride(0), Q.stride(1), Q.stride(2), Q.stride(3), 
                                            K.stride(0), K.stride(1), K.stride(2), K.stride(3), 
                                            V.stride(0), V.stride(1), V.stride(2), V.stride(3), 
                                            O.stride(0), O.stride(1), O.stride(2), O.stride(3), 
                                            L.stride(0), L.stride(1), L.stride(2), 
                                            M.stride(0), M.stride(1), M.stride(2), 
                                            Z, H, N_CTX, BLOCK_MN, BLOCK_M2_I, BLOCK_MN, BLOCK_D, num_stages=num_stages_2, num_warps=num_warps_2)
    
    return O


class _pre_attention(torch.autograd.Function):

    @staticmethod
    def forward(ctx, Q, K, V, K_chunk_center, chunk_size, sm_scale):
        assert Q.stride() == K.stride() == V.stride()
    
        Z = Q.shape[0]
        H = Q.shape[1]
        N_CTX = Q.shape[2]
        D = Q.shape[3]
        
        offset, index = None, None
        
        if N_CTX > chunk_size:
            assert (K_chunk_center is not None)
            
            chunk_idx = pre_sc_top_1_idx(Q, K_chunk_center, chunk_size).reshape(Z * H, N_CTX - chunk_size)
            chunk_count = torch.zeros(Z * H, N_CTX // chunk_size, device=chunk_idx.device, dtype=torch.int32)
            chunk_utils.chunk_count(chunk_idx, chunk_count)
            chunk_count = chunk_count.long()
            chunk_cum_count = torch.cumsum(chunk_count, dim=-1).int()
            assert (chunk_cum_count[-1, -1].item() == N_CTX - chunk_size)
            chunk_pos = torch.empty((Z * H, N_CTX - chunk_size), device=chunk_idx.device, dtype=torch.long)
            chunk_utils.chunk_pos(chunk_cum_count.clone(), chunk_idx, chunk_pos)
            chunk_pos = chunk_pos.reshape(Z, H, N_CTX - chunk_size)
            
            offset = chunk_cum_count.reshape(Z, H, N_CTX // chunk_size)
            index = chunk_pos
            index.add_(chunk_size)
        
        O = pre_sc_l_chunk_s1_fwd(Q, K, V, offset, index, sm_scale, chunk_size)

        return O

    @staticmethod
    def backward(ctx, dO):
        return None, None, None, None, None


pre_attention = _pre_attention.apply


#####################################################################################################################


@triton.jit
def small_dim_vec_norm_fwd_kernel(X, Y, stride_xym, stride_xyd, M, BLOCK_M: tl.constexpr, BLOCK_D: tl.constexpr):
    pid_m = tl.program_id(0)
    
    off = pid_m.to(tl.int64) * BLOCK_M + tl.arange(0, BLOCK_M)
    mask = off < M
    off_xy = off[:, None] * stride_xym + tl.arange(0, BLOCK_D)[None, :] * stride_xyd
    
    x = tl.load(X + off_xy, mask=mask[:, None], other=0).to(tl.float32)
    
    c = tl.sqrt(tl.sum(x * x, 1))
    y = x / (c[:, None] + 1e-7)
    
    tl.store(Y + off_xy, y.to(Y.dtype.element_ty), mask=mask[:, None])


def small_dim_vec_norm_fwd(X):
    ori_shape = X.shape
    
    X = X.view(-1, X.shape[-1])
    
    M = X.shape[0]
    DIM = X.shape[1]
    BLOCK_D = DIM
    BLOCK_M = 128
    num_stages = 4
    num_warps = 8
    
    Y = torch.zeros((M, DIM), dtype=X.dtype, device=X.device)
    
    grid = (triton.cdiv(M, BLOCK_M), 1, 1)
    small_dim_vec_norm_fwd_kernel[grid](X, Y, 
                                        X.stride(0), X.stride(1), 
                                        M, BLOCK_M, BLOCK_D, 
                                        num_stages=num_stages, num_warps=num_warps)
    
    return Y.view(ori_shape)


@triton.jit
def sc_top_1_idx_kernel(Q, C, I, 
                        stride_qz, stride_qh, stride_qm, stride_qd, 
                        stride_cz, stride_ch, stride_cc, stride_cd, 
                        stride_iz, stride_ih, stride_im, 
                        Z, H, N_CEN, F_N_CEN, 
                        BLOCK_C: tl.constexpr, 
                        BLOCK_D: tl.constexpr):
    
    pid_zh = tl.program_id(0)
    off_z = pid_zh // H
    off_h = pid_zh % H
    
    offs_m = tl.arange(0, 16)
    mask_m = offs_m < 1
    offs_d = tl.arange(0, BLOCK_D)
    
    Q_ptr = Q + off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh
    C_ptr = C + off_z.to(tl.int64) * stride_cz + off_h.to(tl.int64) * stride_ch
    I_ptr = I + off_z.to(tl.int64) * stride_iz + off_h.to(tl.int64) * stride_ih
    
    q = tl.load(Q_ptr + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd, mask=mask_m[:, None], other=0)
    
    m_val = tl.full((1, BLOCK_C), value=-float("inf"), dtype=tl.float32)
    m_idx = tl.full((1, BLOCK_C), value=-1, dtype=tl.int32)
    
    BLOCK_C_RANGE = tl.arange(0, BLOCK_C)
    
    for start_c in range(0, (N_CEN // BLOCK_C) * BLOCK_C, BLOCK_C):
        offs_c = start_c + BLOCK_C_RANGE 
        c = tl.load(C_ptr + offs_c[:, None] * stride_cc + offs_d[None, :] * stride_cd)
        s = tl.dot(q, tl.trans(c)).to(tl.float32)
        s = tl.sum(s, 0)[None, :]
        m_s = s > m_val
        m_val = tl.maximum(m_val, s)
        m_idx = m_idx * (1 - m_s) + offs_c * m_s
    
    for start_c in range((N_CEN // BLOCK_C) * BLOCK_C, N_CEN, BLOCK_C):
        offs_c = start_c + BLOCK_C_RANGE
        mask_c = offs_c < N_CEN
        c = tl.load(C_ptr + offs_c[:, None] * stride_cc + offs_d[None, :] * stride_cd, mask=mask_c[:, None], other=-float("inf"))
        s = tl.dot(q, tl.trans(c)).to(tl.float32)
        s = tl.sum(s, 0)[None, :]
        s = tl.where(mask_c[None, :], s, -float("inf"))
        m_s = s > m_val
        m_val = tl.maximum(m_val, s)
        m_idx = m_idx * (1 - m_s) + offs_c * m_s
    
    m_i = tl.argmax(m_val, 1)
    m_s = (tl.expand_dims(m_i, 1) == tl.expand_dims(BLOCK_C_RANGE, 0))
    m_idx = m_idx * m_s
    m_idx = tl.sum(m_idx, 1)
    
    tl.store(I_ptr[None, :], m_idx.to(I.dtype.element_ty))


def sc_top_1_idx(Q, C, chunk_size, N_CEN):
    assert (Q.shape[2] == 1)
    assert (len(Q.shape) == 4 and len(C.shape) == 4)
    assert (Q.shape[0] == C.shape[0])
    assert (Q.shape[1] == C.shape[1])
    assert (Q.shape[3] == C.shape[3])
    assert (Q.shape[3] in [2, 4, 8, 16, 32, 64, 128, 256])
    
    Z = Q.shape[0]
    H = Q.shape[1]
    F_N_CEN = C.shape[2]
    D = Q.shape[3]
    BLOCK_C = chunk_size // 2
    BLOCK_D = D
    
    I = torch.full((Z, H, 1), -1, dtype=torch.long, device=Q.device)
    
    grid = (Z * H, 1, 1)
    
    sc_top_1_idx_kernel[grid](Q, C, I, 
                              Q.stride(0), Q.stride(1), Q.stride(2), Q.stride(3), 
                              C.stride(0), C.stride(1), C.stride(2), C.stride(3), 
                              I.stride(0), I.stride(1), I.stride(2), 
                              Z, H, N_CEN, F_N_CEN, BLOCK_C, BLOCK_D, num_stages=4, num_warps=4)
    
    return I


@triton.jit
def sc_l_chunk_s1_block_fwd_kernel(Q, K, V, index, sm_scale, O, L, M, 
                                   stride_qz, stride_qh, stride_qm, stride_qd, 
                                   stride_kz, stride_kh, stride_kn, stride_kd, 
                                   stride_vz, stride_vh, stride_vn, stride_vd, 
                                   stride_iz, stride_ih, stride_im, 
                                   stride_oz, stride_oh, stride_om, stride_od, 
                                   stride_lz, stride_lh, stride_lm, 
                                   stride_mz, stride_mh, stride_mm, 
                                   Z, H, N_CTX, F_N_CTX, 
                                   CHUNK_SIZE: tl.constexpr, 
                                   BLOCK_D: tl.constexpr):
    pid_zh = tl.program_id(0)
    off_z = pid_zh // H
    off_h = pid_zh % H
    
    q_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh
    kv_offset = off_z.to(tl.int64) * stride_kz + off_h.to(tl.int64) * stride_kh
    i_offset = off_z.to(tl.int64) * stride_iz + off_h.to(tl.int64) * stride_ih
    
    offs_m = tl.arange(0, 16)
    mask_m = offs_m < 1
    offs_d = tl.arange(0, BLOCK_D)
    
    Q_ptr = Q + q_offset
    I_ptr = index + i_offset
    
    q = tl.load(Q_ptr + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd, mask=mask_m[:, None], other=0)
    
    idx = tl.load(I_ptr).to(tl.int32)

    K_block_ptr = tl.make_block_ptr(
        base=K + kv_offset.to(tl.int32),
        shape=(BLOCK_D, F_N_CTX),
        strides=(stride_kd, stride_kn),
        offsets=(0, idx * CHUNK_SIZE),
        block_shape=(BLOCK_D, CHUNK_SIZE),
        order=(0, 1)
    )
    V_block_ptr = tl.make_block_ptr(
        base=V + kv_offset.to(tl.int32),
        shape=(F_N_CTX, BLOCK_D),
        strides=(stride_vn, stride_vd),
        offsets=(idx * CHUNK_SIZE, 0),
        block_shape=(CHUNK_SIZE, BLOCK_D),
        order=(1, 0)
    )
    
    k = tl.load(K_block_ptr)
    v = tl.load(V_block_ptr)
    
    qk_scale = sm_scale.to(tl.float32) * 1.44269504
    
    qk = tl.dot(q, k).to(tl.float32)
    m = tl.max(qk, 1) * qk_scale
    qk = qk * qk_scale - m[:, None]
    p = tl.math.exp2(qk)
    l = tl.sum(p, 1)
    o = tl.dot(p.to(V.dtype.element_ty), v)
    
    O_ptr = O + q_offset
    L_ptr = L + i_offset
    M_ptr = M + i_offset
    
    tl.store(O_ptr + offs_m[:, None] * stride_om + offs_d[None, :] * stride_od, o.to(O.dtype.element_ty), mask=mask_m[:, None])
    tl.store(L_ptr + offs_m * stride_lm, l.to(L.dtype.element_ty), mask=mask_m)
    tl.store(M_ptr + offs_m * stride_mm, m.to(M.dtype.element_ty), mask=mask_m)


@triton.jit
def sc_l_chunk_s1_diag_fwd_kernel(Q, K, V, O, L, M, sm_scale, 
                                  stride_qz, stride_qh, stride_qm, stride_qd, 
                                  stride_kz, stride_kh, stride_kn, stride_kd, 
                                  stride_vz, stride_vh, stride_vn, stride_vd, 
                                  stride_oz, stride_oh, stride_om, stride_od, 
                                  stride_lz, stride_lh, stride_lm, 
                                  stride_mz, stride_mh, stride_mm, 
                                  Z, H, N_CTX, F_N_CTX, 
                                  CHUNK_SIZE: tl.constexpr, 
                                  BLOCK_D: tl.constexpr):
    pid_n = tl.cdiv(N_CTX, CHUNK_SIZE) - 1
    pid_zh = tl.program_id(0)
    off_z = pid_zh // H
    off_h = pid_zh % H
    q_offset  = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh
    kv_offset = off_z.to(tl.int64) * stride_kz + off_h.to(tl.int64) * stride_kh
    lm_offset = off_z.to(tl.int64) * stride_lz + off_h.to(tl.int64) * stride_lh
    
    Q_block_ptr = tl.make_block_ptr(
        base=Q + q_offset.to(tl.int32),
        shape=(1, BLOCK_D),
        strides=(stride_qm, stride_qd),
        offsets=(0, 0),
        block_shape=(16, BLOCK_D),
        order=(1, 0)
    )
    K_block_ptr = tl.make_block_ptr(
        base=K + kv_offset.to(tl.int32),
        shape=(BLOCK_D, F_N_CTX),
        strides=(stride_kd, stride_kn),
        offsets=(0, pid_n * CHUNK_SIZE),
        block_shape=(BLOCK_D, CHUNK_SIZE),
        order=(0, 1)
    )
    V_block_ptr = tl.make_block_ptr(
        base=V + kv_offset.to(tl.int32),
        shape=(F_N_CTX, BLOCK_D),
        strides=(stride_vn, stride_vd),
        offsets=(pid_n * CHUNK_SIZE, 0),
        block_shape=(CHUNK_SIZE, BLOCK_D),
        order=(1, 0)
    )
    O_block_ptr = tl.make_block_ptr(
        base=O + q_offset.to(tl.int32),
        shape=(1, BLOCK_D),
        strides=(stride_om, stride_od),
        offsets=(0, 0),
        block_shape=(16, BLOCK_D),
        order=(1, 0)
    )
    L_block_ptr = tl.make_block_ptr(
        base=L + lm_offset.to(tl.int32),
        shape=(1, ),
        strides=(stride_lm, ),
        offsets=(0, ),
        block_shape=(16, ),
        order=(0, )
    )
    M_block_ptr = tl.make_block_ptr(
        base=M + lm_offset.to(tl.int32),
        shape=(1, ),
        strides=(stride_mm, ),
        offsets=(0, ),
        block_shape=(16, ),
        order=(0, )
    )
    
    qk_scale = sm_scale.to(tl.float32) * 1.44269504
    
    k = tl.load(K_block_ptr, boundary_check=(0, 1), padding_option='zero')
    v = tl.load(V_block_ptr, boundary_check=(0, 1), padding_option='zero')
    
    mask = (pid_n * CHUNK_SIZE + tl.arange(0, CHUNK_SIZE)) < N_CTX - 1
    neg_mask = -1.0e6 * (1.0 - mask)

    q = tl.load(Q_block_ptr, boundary_check=(0, 1), padding_option='zero')
    qk = tl.dot(q, k).to(tl.float32)
    qk = qk * qk_scale + neg_mask[None, :]
    
    m_i = tl.load(M_block_ptr, boundary_check=(0, ), padding_option='zero')
    m_ij = tl.maximum(m_i, tl.max(qk, 1))
    qk -= m_ij[:, None]
    
    p = tl.math.exp2(qk) * mask[None, :]
    l_ij = tl.sum(p, 1)
    l_i = tl.load(L_block_ptr, boundary_check=(0, ), padding_option='zero')
    
    alpha = tl.math.exp2(m_i - m_ij)
    l_i = l_i * alpha + l_ij
    
    o = tl.load(O_block_ptr, boundary_check=(0, 1), padding_option='zero')
    o = o * alpha[:, None]
    
    o += tl.dot(p.to(V.dtype.element_ty), v)
    o = o / (l_i[:, None] + 1e-7)
    tl.store(O_block_ptr, o.to(O.dtype.element_ty), boundary_check=(0, 1))


def sc_l_chunk_s1_fwd(Q, K, V, index, sm_scale, chunk_size, N_CTX):
    assert (Q.shape[2] == 1)
    assert (K.shape[2] > 0 and V.shape[2] > 0)
    
    assert (len(Q.shape) == 4 and len(K.shape) == 4 and len(V.shape) == 4)
    assert (Q.shape[0] == K.shape[0] and K.shape[0] == V.shape[0])
    assert (Q.shape[1] == K.shape[1] and K.shape[1] == V.shape[1])
    assert (K.shape[2] == V.shape[2])
    assert (Q.shape[3] == K.shape[3] and K.shape[3] == V.shape[3])
    assert (Q.shape[3] in [2, 4, 8, 16, 32, 64, 128, 256])
    
    if index is not None:
        assert (len(index.shape) == 3)
        assert (Q.shape[2] == index.shape[2])
        assert (V.shape[0] == index.shape[0])
        assert (V.shape[1] == index.shape[1])
    
    Z = Q.shape[0]
    H = Q.shape[1]
    F_N_CTX = K.shape[2]
    DIM = Q.shape[3]
    BLOCK_D = DIM
    CHUNK_SIZE = chunk_size
    
    num_stages_1 = 4
    num_warps_1 = 8
    num_stages_2 = 4
    num_warps_2 = 8                   # H100 / H800时Triton存在Bug, 出现布局错误, num_warps_2应设置为4
    
    O = torch.zeros((Z, H, 1, DIM), dtype=Q.dtype, device=Q.device)
    L = torch.zeros((Z, H, 1), dtype=torch.float32, device=Q.device)
    M = torch.ones((Z, H, 1), dtype=torch.float32, device=Q.device) * -torch.inf
    
    grid = (Z * H, 1, 1)
    
    if index is not None:
        sc_l_chunk_s1_block_fwd_kernel[grid](Q, K, V, index, sm_scale, O, L, M, 
                                             Q.stride(0), Q.stride(1), Q.stride(2), Q.stride(3), 
                                             K.stride(0), K.stride(1), K.stride(2), K.stride(3), 
                                             V.stride(0), V.stride(1), V.stride(2), V.stride(3), 
                                             index.stride(0), index.stride(1), index.stride(2), 
                                             O.stride(0), O.stride(1), O.stride(2), O.stride(3), 
                                             L.stride(0), L.stride(1), L.stride(2), 
                                             M.stride(0), M.stride(1), M.stride(2), 
                                             Z, H, N_CTX, F_N_CTX, CHUNK_SIZE, BLOCK_D, num_stages=num_stages_1, num_warps=num_warps_1)
    
    sc_l_chunk_s1_diag_fwd_kernel[grid](Q, K, V, O, L, M, sm_scale, 
                                        Q.stride(0), Q.stride(1), Q.stride(2), Q.stride(3), 
                                        K.stride(0), K.stride(1), K.stride(2), K.stride(3), 
                                        V.stride(0), V.stride(1), V.stride(2), V.stride(3), 
                                        O.stride(0), O.stride(1), O.stride(2), O.stride(3), 
                                        L.stride(0), L.stride(1), L.stride(2), 
                                        M.stride(0), M.stride(1), M.stride(2), 
                                        Z, H, N_CTX, F_N_CTX, CHUNK_SIZE, BLOCK_D, num_stages=num_stages_2, num_warps=num_warps_2)
    
    return O


class _attention(torch.autograd.Function):

    @staticmethod
    def forward(ctx, Q, K, V, K_chunk_center, chunk_size, sm_scale, N_CTX, N_CEN):
        assert K.stride() == V.stride()
        
        assert (Q.shape[2] == 1)
        assert (K.shape[2] > 0)
        
        index = None
        
        if N_CEN > 0:
            index = sc_top_1_idx(Q, K_chunk_center, chunk_size, N_CEN)
        
        O = sc_l_chunk_s1_fwd(Q, K, V, index, sm_scale, chunk_size, N_CTX)

        return O

    @staticmethod
    def backward(ctx, dO):
        return None, None, None, None, None


attention = _attention.apply


def main():
    return


if __name__ == '__main__':
    main()
