# -*- codeing: utf-8 -*-

import triton
from triton import language as tl
import torch

import os
import math


@triton.jit
def sc_l_chunk_s1_block_fwd_kernel(Q, K, V, offset, index, sm_scale, O, L, M, 
                                   stride_qz, stride_qh, stride_qm, stride_qd, 
                                   stride_kz, stride_kh, stride_kn, stride_kd, 
                                   stride_vz, stride_vh, stride_vn, stride_vd, 
                                   stride_fz, stride_fh, stride_fk, 
                                   stride_iz, stride_ih, stride_im, 
                                   stride_oz, stride_oh, stride_om, stride_od, 
                                   stride_lz, stride_lh, stride_lm, 
                                   stride_mz, stride_mh, stride_mm, 
                                   Z, H, 
                                   N_CTX: tl.constexpr, 
                                   BLOCK_M: tl.constexpr, 
                                   BLOCK_N: tl.constexpr, 
                                   BLOCK_D: tl.constexpr):
    pid_n = tl.program_id(0)
    pid_zh = tl.program_id(1)
    off_z = pid_zh // H
    off_h = pid_zh % H
    qkv_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh

    K_block_ptr = tl.make_block_ptr(
        base=K + qkv_offset,
        shape=(BLOCK_D, N_CTX),
        strides=(stride_kd, stride_kn),
        offsets=(0, pid_n * BLOCK_N),
        block_shape=(BLOCK_D, BLOCK_N),
        order=(0, 1)
    )
    V_block_ptr = tl.make_block_ptr(
        base=V + qkv_offset,
        shape=(N_CTX, BLOCK_D),
        strides=(stride_vn, stride_vd),
        offsets=(pid_n * BLOCK_N, 0),
        block_shape=(BLOCK_N, BLOCK_D),
        order=(1, 0)
    )
    
    k = tl.load(K_block_ptr)
    v = tl.load(V_block_ptr)
    
    qk_scale = sm_scale.to(tl.float32) * 1.44269504
    
    offset_ptr = offset + off_z.to(tl.int64) * stride_fz + off_h.to(tl.int64) * stride_fh
    if pid_n == 0:
        i_start = 0
    else:
        i_start = tl.load(offset_ptr + pid_n - 1)
    i_end = tl.load(offset_ptr + pid_n)
    i_numel = i_end - i_start

    if i_numel == 0:
        return
    
    index_ptr = index + off_z.to(tl.int64) * stride_iz + off_h.to(tl.int64) * stride_ih
    L_ptr = L + off_z.to(tl.int64) * stride_lz + off_h.to(tl.int64) * stride_lh
    M_ptr = M + off_z.to(tl.int64) * stride_mz + off_h.to(tl.int64) * stride_mh
    Q_ptr = Q + qkv_offset
    O_ptr = O + qkv_offset
    
    for start_i in range(0, i_numel, BLOCK_M):
        offs_i = start_i + tl.arange(0, BLOCK_M)
        mask_i = offs_i < i_numel
        idx = tl.load(index_ptr + i_start + offs_i, mask=mask_i, other=0)
        offs_d = tl.arange(0, BLOCK_D)
        q = tl.load(Q_ptr + idx[:, None] * stride_qm + offs_d[None, :] * stride_qd, mask=mask_i[:, None], other=-float("inf"))
        qk = tl.dot(q, k).to(tl.float32)
        m = tl.max(qk, 1) * qk_scale
        qk = qk * qk_scale - m[:, None]
        p = tl.math.exp2(qk)
        l = tl.sum(p, 1)
        o = tl.dot(p.to(V.dtype.element_ty), v)
        tl.store(O_ptr + idx[:, None] * stride_om + offs_d[None, :] * stride_od, o.to(O.dtype.element_ty), mask=mask_i[:, None])
        tl.store(L_ptr + idx * stride_lm, l.to(L.dtype.element_ty), mask=mask_i)
        tl.store(M_ptr + idx * stride_mm, m.to(M.dtype.element_ty), mask=mask_i)


@triton.jit
def sc_l_chunk_s1_diag_fwd_kernel(Q, K, V, O, L, M, sm_scale, R, 
                                  stride_qz, stride_qh, stride_qm, stride_qd, 
                                  stride_kz, stride_kh, stride_kn, stride_kd, 
                                  stride_vz, stride_vh, stride_vn, stride_vd, 
                                  stride_oz, stride_oh, stride_om, stride_od, 
                                  stride_lz, stride_lh, stride_lm, 
                                  stride_mz, stride_mh, stride_mm, 
                                  stride_rz, stride_rh, stride_rm, 
                                  Z, H, 
                                  N_CTX: tl.constexpr, 
                                  BLOCK_M: tl.constexpr, 
                                  BLOCK_M_I: tl.constexpr, 
                                  BLOCK_N: tl.constexpr, 
                                  BLOCK_D: tl.constexpr):
    tl.static_assert(BLOCK_M == BLOCK_N)
    
    pid_mn = tl.program_id(0)
    pid_zh = tl.program_id(1)
    off_z = pid_zh // H
    off_h = pid_zh % H
    qkv_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh
    lmr_offset = off_z.to(tl.int64) * stride_lz + off_h.to(tl.int64) * stride_lh
    
    Q_block_ptr = tl.make_block_ptr(
        base=Q + qkv_offset,
        shape=(N_CTX, BLOCK_D),
        strides=(stride_qm, stride_qd),
        offsets=(pid_mn * BLOCK_M, 0),
        block_shape=(BLOCK_M_I, BLOCK_D),
        order=(1, 0)
    )
    K_block_ptr = tl.make_block_ptr(
        base=K + qkv_offset,
        shape=(BLOCK_D, N_CTX),
        strides=(stride_kd, stride_kn),
        offsets=(0, pid_mn * BLOCK_N),
        block_shape=(BLOCK_D, BLOCK_N),
        order=(0, 1)
    )
    V_block_ptr = tl.make_block_ptr(
        base=V + qkv_offset,
        shape=(N_CTX, BLOCK_D),
        strides=(stride_vn, stride_vd),
        offsets=(pid_mn * BLOCK_N, 0),
        block_shape=(BLOCK_N, BLOCK_D),
        order=(1, 0)
    )
    O_block_ptr = tl.make_block_ptr(
        base=O + qkv_offset,
        shape=(N_CTX, BLOCK_D),
        strides=(stride_om, stride_od),
        offsets=(pid_mn * BLOCK_M, 0),
        block_shape=(BLOCK_M_I, BLOCK_D),
        order=(1, 0)
    )
    L_block_ptr = tl.make_block_ptr(
        base=L + lmr_offset,
        shape=(N_CTX, ),
        strides=(stride_lm, ),
        offsets=(pid_mn * BLOCK_M, ),
        block_shape=(BLOCK_M_I, ),
        order=(0, )
    )
    M_block_ptr = tl.make_block_ptr(
        base=M + lmr_offset,
        shape=(N_CTX, ),
        strides=(stride_mm, ),
        offsets=(pid_mn * BLOCK_M, ),
        block_shape=(BLOCK_M_I, ),
        order=(0, )
    )
    R_block_ptr = tl.make_block_ptr(
        base=R + lmr_offset,
        shape=(N_CTX, ),
        strides=(stride_rm, ),
        offsets=(pid_mn * BLOCK_M, ),
        block_shape=(BLOCK_M_I, ),
        order=(0, )
    )
    
    qk_scale = sm_scale.to(tl.float32) * 1.44269504
    
    k = tl.load(K_block_ptr, boundary_check=(0, 1), padding_option='zero')
    v = tl.load(V_block_ptr, boundary_check=(0, 1), padding_option='zero')
    
    for start_m in range(0, BLOCK_M, BLOCK_M_I):
        mask = (start_m + tl.arange(0, BLOCK_M_I)[:, None]) > tl.arange(0, BLOCK_N)[None, :]
        neg_mask = -1.0e6 * (1.0 - mask)
    
        q = tl.load(Q_block_ptr, boundary_check=(0, 1), padding_option='zero')
        qk = tl.dot(q, k).to(tl.float32)
        qk = qk * qk_scale + neg_mask
        
        m_i = tl.load(M_block_ptr, boundary_check=(0, ), padding_option='zero')
        m_ij = tl.maximum(m_i, tl.max(qk, 1))
        qk -= m_ij[:, None]
        
        p = tl.math.exp2(qk) * mask
        l_ij = tl.sum(p, 1)
        l_i = tl.load(L_block_ptr, boundary_check=(0, ), padding_option='zero')
        
        alpha = tl.math.exp2(m_i - m_ij)
        l_i = l_i * alpha + l_ij
        
        o = tl.load(O_block_ptr, boundary_check=(0, 1), padding_option='zero')
        o = o * alpha[:, None]
        
        o += tl.dot(p.to(V.dtype.element_ty), v)
        o = o / (l_i[:, None] + 1e-7)
        tl.store(O_block_ptr, o.to(O.dtype.element_ty), boundary_check=(0, 1))
        
        m_i = m_ij
        r_i = m_i + tl.math.log2(l_i)
        tl.store(R_block_ptr, r_i.to(R.dtype.element_ty), boundary_check=(0, ))
        
        Q_block_ptr = tl.advance(Q_block_ptr, (BLOCK_M_I, 0))
        M_block_ptr = tl.advance(M_block_ptr, (BLOCK_M_I, ))
        L_block_ptr = tl.advance(L_block_ptr, (BLOCK_M_I, ))
        R_block_ptr = tl.advance(R_block_ptr, (BLOCK_M_I, ))
        O_block_ptr = tl.advance(O_block_ptr, (BLOCK_M_I, 0))


def sc_l_chunk_s1_fwd(Q, K, V, offset, index, sm_scale, chunk_size):
    assert (len(Q.shape) == 4 and len(K.shape) == 4 and len(V.shape) == 4 and len(offset.shape) == 3 and len(index.shape) == 3)
    assert (Q.shape[0] == K.shape[0] and K.shape[0] == V.shape[0] and V.shape[0] == offset.shape[0] and offset.shape[0] == index.shape[0])
    assert (Q.shape[1] == K.shape[1] and K.shape[1] == V.shape[1] and V.shape[1] == offset.shape[1] and offset.shape[1] == index.shape[1])
    assert (Q.shape[2] == K.shape[2] and K.shape[2] == V.shape[2])
    assert (Q.shape[2] >= chunk_size)
    assert (Q.shape[2] // chunk_size == offset.shape[2])
    assert (Q.shape[2] == index.shape[2] + chunk_size)
    assert (Q.shape[3] == K.shape[3] and K.shape[3] == V.shape[3])
    assert (Q.shape[3] in [2, 4, 8, 16, 32, 64, 128, 256])
    
    Z = Q.shape[0]
    H = Q.shape[1]
    N_CTX = Q.shape[2]
    DIM = Q.shape[3]
    BLOCK_D = DIM
    BLOCK_MN = chunk_size
    BLOCK_N = chunk_size
    
    BLOCK_M1 = chunk_size // 1
    BLOCK_M2_I = chunk_size // 1
    num_stages_1 = 4
    num_warps_1 = 8
    num_stages_2 = 4
    num_warps_2 = 8                   # H100 / H800时Triton存在Bug, 出现布局错误, num_warps_2应设置为4
    
    O = torch.zeros((Z, H, N_CTX, DIM), dtype=Q.dtype, device=Q.device)
    L = torch.zeros((Z, H, N_CTX), dtype=torch.float32, device=Q.device)
    M = torch.ones((Z, H, N_CTX), dtype=torch.float32, device=Q.device) * -torch.inf
    R = torch.zeros((Z, H, N_CTX), dtype=torch.float32, device=Q.device)
    
    if N_CTX % BLOCK_N == 0:
        grid = (N_CTX // BLOCK_N - 1, Z * H, 1)
    else:
        grid = (N_CTX // BLOCK_N, Z * H, 1)
    
    sc_l_chunk_s1_block_fwd_kernel[grid](Q, K, V, offset, index, sm_scale, O, L, M, 
                                         Q.stride(0), Q.stride(1), Q.stride(2), Q.stride(3), 
                                         K.stride(0), K.stride(1), K.stride(2), K.stride(3), 
                                         V.stride(0), V.stride(1), V.stride(2), V.stride(3), 
                                         offset.stride(0), offset.stride(1), offset.stride(2), 
                                         index.stride(0), index.stride(1), index.stride(2), 
                                         O.stride(0), O.stride(1), O.stride(2), O.stride(3), 
                                         L.stride(0), L.stride(1), L.stride(2), 
                                         M.stride(0), M.stride(1), M.stride(2), 
                                         Z, H, N_CTX, BLOCK_M1, BLOCK_N, BLOCK_D, num_stages=num_stages_1, num_warps=num_warps_1)
    
    grid = (triton.cdiv(N_CTX, chunk_size), Z * H, 1)
    
    sc_l_chunk_s1_diag_fwd_kernel[grid](Q, K, V, O, L, M, sm_scale, R, 
                                        Q.stride(0), Q.stride(1), Q.stride(2), Q.stride(3), 
                                        K.stride(0), K.stride(1), K.stride(2), K.stride(3), 
                                        V.stride(0), V.stride(1), V.stride(2), V.stride(3), 
                                        O.stride(0), O.stride(1), O.stride(2), O.stride(3), 
                                        L.stride(0), L.stride(1), L.stride(2), 
                                        M.stride(0), M.stride(1), M.stride(2), 
                                        R.stride(0), R.stride(1), R.stride(2), 
                                        Z, H, N_CTX, BLOCK_MN, BLOCK_M2_I, BLOCK_MN, BLOCK_D, num_stages=num_stages_2, num_warps=num_warps_2)
    
    return O, R


@triton.jit
def sc_l_chunk_s1_block_bwd_kernel(Q, K, V, O, dO, R, offset, index, sm_scale, dQ, dK, dV, 
                                   stride_qz,  stride_qh,  stride_qm,  stride_qd, 
                                   stride_kz,  stride_kh,  stride_kn,  stride_kd, 
                                   stride_vz,  stride_vh,  stride_vn,  stride_vd, 
                                   stride_oz,  stride_oh,  stride_om,  stride_od, 
                                   stride_doz, stride_doh, stride_dom, stride_dod, 
                                   stride_rz,  stride_rh,  stride_rm, 
                                   stride_fz,  stride_fh,  stride_fk, 
                                   stride_iz,  stride_ih,  stride_im, 
                                   stride_dqz, stride_dqh, stride_dqm, stride_dqd, 
                                   stride_dkz, stride_dkh, stride_dkn, stride_dkd, 
                                   stride_dvz, stride_dvh, stride_dvn, stride_dvd, 
                                   Z, H, 
                                   N_CTX: tl.constexpr, 
                                   BLOCK_M: tl.constexpr, 
                                   BLOCK_N: tl.constexpr, 
                                   BLOCK_D: tl.constexpr):
    pid_n = tl.program_id(0)
    pid_zh = tl.program_id(1)
    off_z = pid_zh // H
    off_h = pid_zh % H
    qkv_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh

    K_block_ptr = tl.make_block_ptr(
        base=K + qkv_offset,
        shape=(BLOCK_D, N_CTX),
        strides=(stride_kd, stride_kn),
        offsets=(0, pid_n * BLOCK_N),
        block_shape=(BLOCK_D, BLOCK_N),
        order=(0, 1)
    )
    V_block_ptr = tl.make_block_ptr(
        base=V + qkv_offset,
        shape=(BLOCK_D, N_CTX),
        strides=(stride_vd, stride_vn),
        offsets=(0, pid_n * BLOCK_N),
        block_shape=(BLOCK_D, BLOCK_N),
        order=(0, 1)
    )
    dK_block_ptr = tl.make_block_ptr(
        base=dK + qkv_offset,
        shape=(N_CTX, BLOCK_D),
        strides=(stride_dkn, stride_dkd),
        offsets=(pid_n * BLOCK_N, 0),
        block_shape=(BLOCK_N, BLOCK_D),
        order=(1, 0)
    )
    dV_block_ptr = tl.make_block_ptr(
        base=dV + qkv_offset,
        shape=(N_CTX, BLOCK_D),
        strides=(stride_dvn, stride_dvd),
        offsets=(pid_n * BLOCK_N, 0),
        block_shape=(BLOCK_N, BLOCK_D),
        order=(1, 0)
    )
    
    k = tl.load(K_block_ptr)
    v = tl.load(V_block_ptr)
    
    qk_scale = sm_scale.to(tl.float32) * 1.44269504
    
    offset_ptr = offset + off_z.to(tl.int64) * stride_fz + off_h.to(tl.int64) * stride_fh
    if pid_n == 0:
        i_start = 0
    else:
        i_start = tl.load(offset_ptr + pid_n - 1)
    i_end = tl.load(offset_ptr + pid_n)
    i_numel = i_end - i_start

    if i_numel == 0:
        return
    
    index_ptr = index + off_z.to(tl.int64) * stride_iz + off_h.to(tl.int64) * stride_ih
    R_ptr = R + off_z.to(tl.int64) * stride_rz + off_h.to(tl.int64) * stride_rh
    Q_ptr = Q + qkv_offset
    O_ptr = O + qkv_offset
    dQ_ptr = dQ + qkv_offset
    dO_ptr = dO + qkv_offset
    
    dk = tl.zeros([BLOCK_N, BLOCK_D], dtype=tl.float32)
    dv = tl.zeros([BLOCK_N, BLOCK_D], dtype=tl.float32)
    
    for start_i in range(0, i_numel, BLOCK_M):
        offs_i = start_i + tl.arange(0, BLOCK_M)
        mask_i = offs_i < i_numel
        idx = tl.load(index_ptr + i_start + offs_i, mask=mask_i, other=0)
        offs_d = tl.arange(0, BLOCK_D)
        q = tl.load(Q_ptr + idx[:, None] * stride_qm + offs_d[None, :] * stride_qd, mask=mask_i[:, None], other=0)
        qk = tl.dot(q, k).to(tl.float32) * qk_scale
        r = tl.load(R_ptr + idx * stride_rm, mask=mask_i, other=0)
        p = tl.math.exp2(qk - r[:, None]) * mask_i[:, None]
        do = tl.load(dO_ptr + idx[:, None] * stride_dom + offs_d[None, :] * stride_dod, mask=mask_i[:, None], other=0)
        dv += tl.dot(tl.trans(p).to(V.dtype.element_ty), do)
        dp = tl.dot(do, v).to(tl.float32)
        o = tl.load(O_ptr + idx[:, None] * stride_om + offs_d[None, :] * stride_od, mask=mask_i[:, None], other=0)
        d = tl.sum(do * o, 1)
        ds = p * (dp - d[:, None])
        ds = ds.to(V.dtype.element_ty)
        dq = tl.dot(ds, tl.trans(k)).to(tl.float32) * sm_scale
        dk += tl.dot(tl.trans(ds), q)
        tl.store(dQ_ptr + idx[:, None] * stride_dqm + offs_d[None, :] * stride_dqd, dq.to(dQ.dtype.element_ty), mask=mask_i[:, None])
    
    dk *= sm_scale
    tl.store(dK_block_ptr, dk.to(dK.dtype.element_ty))
    tl.store(dV_block_ptr, dv.to(dV.dtype.element_ty))


@triton.jit
def sc_l_chunk_s1_diag_bwd_kernel(Q, K, V, O, dO, R, sm_scale, dQ, dK, dV, 
                                  stride_qz,  stride_qh,  stride_qm,  stride_qd, 
                                  stride_kz,  stride_kh,  stride_kn,  stride_kd, 
                                  stride_vz,  stride_vh,  stride_vn,  stride_vd, 
                                  stride_oz,  stride_oh,  stride_om,  stride_od, 
                                  stride_doz, stride_doh, stride_dom, stride_dod, 
                                  stride_rz,  stride_rh,  stride_rm, 
                                  stride_dqz, stride_dqh, stride_dqm, stride_dqd, 
                                  stride_dkz, stride_dkh, stride_dkn, stride_dkd, 
                                  stride_dvz, stride_dvh, stride_dvn, stride_dvd, 
                                  Z, H, 
                                  N_CTX: tl.constexpr, 
                                  BLOCK_M: tl.constexpr, 
                                  BLOCK_M_I: tl.constexpr, 
                                  BLOCK_N: tl.constexpr, 
                                  BLOCK_D: tl.constexpr):
    tl.static_assert(BLOCK_M == BLOCK_N)
    
    pid_mn = tl.program_id(0)
    pid_zh = tl.program_id(1)
    off_z = pid_zh // H
    off_h = pid_zh % H
    qkv_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh
    r_offset = off_z.to(tl.int64) * stride_rz + off_h.to(tl.int64) * stride_rh
    
    Q_block_ptr = tl.make_block_ptr(
        base=Q + qkv_offset,
        shape=(N_CTX, BLOCK_D),
        strides=(stride_qm, stride_qd),
        offsets=(pid_mn * BLOCK_M, 0),
        block_shape=(BLOCK_M_I, BLOCK_D),
        order=(1, 0)
    )
    K_block_ptr = tl.make_block_ptr(
        base=K + qkv_offset,
        shape=(BLOCK_D, N_CTX),
        strides=(stride_kd, stride_kn),
        offsets=(0, pid_mn * BLOCK_N),
        block_shape=(BLOCK_D, BLOCK_N),
        order=(0, 1)
    )
    V_block_ptr = tl.make_block_ptr(
        base=V + qkv_offset,
        shape=(BLOCK_D, N_CTX),
        strides=(stride_vd, stride_vn),
        offsets=(0, pid_mn * BLOCK_N),
        block_shape=(BLOCK_D, BLOCK_N),
        order=(0, 1)
    )
    O_block_ptr = tl.make_block_ptr(
        base=O + qkv_offset,
        shape=(N_CTX, BLOCK_D),
        strides=(stride_om, stride_od),
        offsets=(pid_mn * BLOCK_M, 0),
        block_shape=(BLOCK_M_I, BLOCK_D),
        order=(1, 0)
    )
    dO_block_ptr = tl.make_block_ptr(
        base=dO + qkv_offset,
        shape=(N_CTX, BLOCK_D),
        strides=(stride_dom, stride_dod),
        offsets=(pid_mn * BLOCK_M, 0),
        block_shape=(BLOCK_M_I, BLOCK_D),
        order=(1, 0)
    )
    R_block_ptr = tl.make_block_ptr(
        base=R + r_offset,
        shape=(N_CTX, ),
        strides=(stride_rm, ),
        offsets=(pid_mn * BLOCK_M, ),
        block_shape=(BLOCK_M_I, ),
        order=(0, )
    )
    dQ_block_ptr = tl.make_block_ptr(
        base=dQ + qkv_offset,
        shape=(N_CTX, BLOCK_D),
        strides=(stride_dqm, stride_dqd),
        offsets=(pid_mn * BLOCK_M, 0),
        block_shape=(BLOCK_M_I, BLOCK_D),
        order=(1, 0)
    )
    dK_block_ptr = tl.make_block_ptr(
        base=dK + qkv_offset,
        shape=(N_CTX, BLOCK_D),
        strides=(stride_dkn, stride_dkd),
        offsets=(pid_mn * BLOCK_N, 0),
        block_shape=(BLOCK_N, BLOCK_D),
        order=(1, 0)
    )
    dV_block_ptr = tl.make_block_ptr(
        base=dV + qkv_offset,
        shape=(N_CTX, BLOCK_D),
        strides=(stride_dvn, stride_dvd),
        offsets=(pid_mn * BLOCK_N, 0),
        block_shape=(BLOCK_N, BLOCK_D),
        order=(1, 0)
    )
    
    qk_scale = sm_scale.to(tl.float32) * 1.44269504
    
    k = tl.load(K_block_ptr, boundary_check=(0, 1), padding_option='zero')
    v = tl.load(V_block_ptr, boundary_check=(0, 1), padding_option='zero')
    dk = tl.load(dK_block_ptr, boundary_check=(0, 1), padding_option='zero').to(tl.float32)
    dv = tl.load(dV_block_ptr, boundary_check=(0, 1), padding_option='zero').to(tl.float32)
    
    for start_m in range(0, BLOCK_M, BLOCK_M_I):
        mask = (start_m + tl.arange(0, BLOCK_M_I)[:, None]) > tl.arange(0, BLOCK_N)[None, :]
        neg_mask = -1.0e6 * (1.0 - mask)
    
        q = tl.load(Q_block_ptr, boundary_check=(0, 1), padding_option='zero')
        qk = tl.dot(q, k).to(tl.float32) * qk_scale + neg_mask
        
        r = tl.load(R_block_ptr, boundary_check=(0, ), padding_option='zero')
        p = tl.math.exp2(qk - r[:, None]) * mask
        
        do = tl.load(dO_block_ptr, boundary_check=(0, 1), padding_option='zero')
        dv += tl.dot(tl.trans(p).to(V.dtype.element_ty), do)

        dp = tl.dot(do, v).to(tl.float32)
        
        o = tl.load(O_block_ptr, boundary_check=(0, 1), padding_option='zero')
        d = tl.sum(do * o, 1)
        ds = p * (dp - d[:, None])
        ds = ds.to(V.dtype.element_ty)
        
        dq = tl.load(dQ_block_ptr, boundary_check=(0, 1), padding_option='zero').to(tl.float32)
        dq += tl.dot(ds, tl.trans(k)).to(tl.float32) * sm_scale
        
        dk += tl.dot(tl.trans(ds), q).to(tl.float32) * sm_scale
        
        tl.store(dQ_block_ptr, dq.to(dQ.dtype.element_ty), boundary_check=(0, 1))
        
        Q_block_ptr = tl.advance(Q_block_ptr, (BLOCK_M_I, 0))
        R_block_ptr = tl.advance(R_block_ptr, (BLOCK_M_I, ))
        dO_block_ptr = tl.advance(dO_block_ptr, (BLOCK_M_I, 0))
        O_block_ptr = tl.advance(O_block_ptr, (BLOCK_M_I, 0))
        dQ_block_ptr = tl.advance(dQ_block_ptr, (BLOCK_M_I, 0))
    
    tl.store(dK_block_ptr, dk.to(dK.dtype.element_ty), boundary_check=(0, 1))
    tl.store(dV_block_ptr, dv.to(dV.dtype.element_ty), boundary_check=(0, 1))


def sc_l_chunk_s1_bwd(Q, K, V, O, dO, R, offset, index, sm_scale, chunk_size):
    assert (len(Q.shape) == 4 and len(K.shape) == 4 and len(V.shape) == 4 and len(O.shape) == 4 and len(dO.shape) == 4)
    assert (len(R.shape) == 3 and len(offset.shape) == 3 and len(index.shape) == 3)
    assert (Q.shape[0] == K.shape[0] and K.shape[0] == V.shape[0] and V.shape[0] == O.shape[0] and O.shape[0] == dO.shape[0])
    assert (dO.shape[0] == R.shape[0] and R.shape[0] == offset.shape[0] and offset.shape[0] == index.shape[0])
    assert (Q.shape[1] == K.shape[1] and K.shape[1] == V.shape[1] and V.shape[1] == O.shape[1] and O.shape[1] == dO.shape[1])
    assert (dO.shape[1] == R.shape[1] and R.shape[1] == offset.shape[1] and offset.shape[1] == index.shape[1])
    assert (Q.shape[2] == K.shape[2] and K.shape[2] == V.shape[2] and V.shape[2] == O.shape[2] and O.shape[2] == dO.shape[2] and dO.shape[2] == R.shape[2])
    assert (Q.shape[2] >= chunk_size)
    assert (Q.shape[2] // chunk_size == offset.shape[2])
    assert (Q.shape[2] == index.shape[2] + chunk_size)
    assert (Q.shape[3] == K.shape[3] and K.shape[3] == V.shape[3] and V.shape[3] == O.shape[3] and O.shape[3] == dO.shape[3])
    assert (Q.shape[3] in [2, 4, 8, 16, 32, 64, 128, 256])
    
    Z = Q.shape[0]
    H = Q.shape[1]
    N_CTX = Q.shape[2]
    DIM = Q.shape[3]
    BLOCK_D = DIM
    BLOCK_MN = chunk_size
    BLOCK_N = chunk_size
    
    BLOCK_M1 = chunk_size // 2       # H100 / H800 / Triton 3.1.0时, 应设置为整除2(dim_64_chunk_128)
    BLOCK_M2_I = chunk_size // 4     # H100 / H800 / Triton 3.1.0时, 应设置为整除4(dim_64_chunk_128)
    num_stages_1 = 4
    num_warps_1 = 8
    num_stages_2 = 4
    num_warps_2 = 8                  # triton 2.3.1存在bug(3.1.0已修复), 当非N_CTX非整除chunk_size时, num_warps_2应设置为4
    
    dQ = torch.zeros((Z, H, N_CTX, DIM), dtype=Q.dtype, device=Q.device)
    dK = torch.zeros((Z, H, N_CTX, DIM), dtype=Q.dtype, device=Q.device)
    dV = torch.zeros((Z, H, N_CTX, DIM), dtype=Q.dtype, device=Q.device)
    
    if N_CTX % BLOCK_N == 0:
        grid = (N_CTX // BLOCK_N - 1, Z * H, 1)
    else:
        grid = (N_CTX // BLOCK_N, Z * H, 1)
    
    sc_l_chunk_s1_block_bwd_kernel[grid](Q, K, V, O, dO, R, offset, index, sm_scale, dQ, dK, dV, 
                                         Q.stride(0), Q.stride(1), Q.stride(2), Q.stride(3), 
                                         K.stride(0), K.stride(1), K.stride(2), K.stride(3), 
                                         V.stride(0), V.stride(1), V.stride(2), V.stride(3), 
                                         O.stride(0), O.stride(1), O.stride(2), O.stride(3), 
                                         dO.stride(0), dO.stride(1), dO.stride(2), dO.stride(3), 
                                         R.stride(0), R.stride(1), R.stride(2), 
                                         offset.stride(0), offset.stride(1), offset.stride(2), 
                                         index.stride(0), index.stride(1), index.stride(2), 
                                         dQ.stride(0), dQ.stride(1), dQ.stride(2), dQ.stride(3), 
                                         dK.stride(0), dK.stride(1), dK.stride(2), dK.stride(3), 
                                         dV.stride(0), dV.stride(1), dV.stride(2), dV.stride(3), 
                                         Z, H, N_CTX, BLOCK_M1, BLOCK_N, BLOCK_D, num_stages=num_stages_1, num_warps=num_warps_1)
    
    grid = (triton.cdiv(N_CTX, chunk_size), Z * H, 1)
    
    sc_l_chunk_s1_diag_bwd_kernel[grid](Q, K, V, O, dO, R, sm_scale, dQ, dK, dV, 
                                        Q.stride(0), Q.stride(1), Q.stride(2), Q.stride(3), 
                                        K.stride(0), K.stride(1), K.stride(2), K.stride(3), 
                                        V.stride(0), V.stride(1), V.stride(2), V.stride(3), 
                                        O.stride(0), O.stride(1), O.stride(2), O.stride(3), 
                                        dO.stride(0), dO.stride(1), dO.stride(2), dO.stride(3), 
                                        R.stride(0), R.stride(1), R.stride(2), 
                                        dQ.stride(0), dQ.stride(1), dQ.stride(2), dQ.stride(3), 
                                        dK.stride(0), dK.stride(1), dK.stride(2), dK.stride(3), 
                                        dV.stride(0), dV.stride(1), dV.stride(2), dV.stride(3), 
                                        Z, H, N_CTX, BLOCK_MN, BLOCK_M2_I, BLOCK_MN, BLOCK_D, num_stages=num_stages_2, num_warps=num_warps_2)
    
    return dQ, dK, dV


def load_chunk_utils():
    from torch.utils.cpp_extension import load
    dir_path = os.path.dirname(os.path.realpath(__file__))
    chunk_utils = load(name='chunk_utils', 
                       extra_include_paths=['include'], 
                       sources=[os.path.join(dir_path, './cuda/chunk_utils.cpp'), os.path.join(dir_path, './cuda/chunk_function.cu'), os.path.join(dir_path, './cuda/stream_manager.cpp')], 
                       verbose=True)
    return chunk_utils


chunk_utils = load_chunk_utils()


class _attention(torch.autograd.Function):

    @staticmethod
    def forward(ctx, Q, K, V, chunk_size, sm_scale):
        assert Q.stride() == K.stride() == V.stride()
    
        Z = Q.shape[0]
        H = Q.shape[1]
        N_CTX = Q.shape[2]
        D = Q.shape[3]
        
        K_n = (N_CTX // chunk_size) * chunk_size
        K_floor = K[:, :, 0:K_n, :].contiguous()
        
        chunk_idx = sc_top_1_idx(Q, K_floor, sm_scale, chunk_size).reshape(Z * H, N_CTX - chunk_size)
        chunk_count = torch.zeros(Z * H, N_CTX // chunk_size, device=chunk_idx.device, dtype=torch.int32)
        chunk_utils.chunk_count(chunk_idx, chunk_count)
        chunk_count = chunk_count.long()
        chunk_cum_count = torch.cumsum(chunk_count, dim=-1).int()
        assert (chunk_cum_count[-1, -1].item() == N_CTX - chunk_size)
        chunk_pos = torch.empty((Z * H, N_CTX - chunk_size), device=chunk_idx.device, dtype=torch.long)
        chunk_utils.chunk_pos(chunk_cum_count.clone(), chunk_idx, chunk_pos)
        chunk_pos = chunk_pos.reshape(Z, H, N_CTX - chunk_size)
        
        offset = chunk_cum_count.reshape(Z, H, N_CTX // chunk_size)
        index = chunk_pos
        index.add_(chunk_size)
        
        O, R = sc_l_chunk_s1_fwd(Q, K, V, offset, index, sm_scale, chunk_size)
        R[:, :, 0] = 0

        ctx.save_for_backward(Q, K, V, O, R, offset, index)
        ctx.sm_scale = sm_scale
        ctx.chunk_size = chunk_size
        return O

    @staticmethod
    def backward(ctx, dO):
        Q, K, V, O, R, offset, index = ctx.saved_tensors
        dO = dO.contiguous()
        assert Q.stride() == K.stride() == V.stride() == O.stride() == dO.stride()
        dQ, dK, dV = sc_l_chunk_s1_bwd(Q, K, V, O, dO, R, offset, index, ctx.sm_scale, ctx.chunk_size)
        return dQ, dK, dV, None, None


attention = _attention.apply


@triton.jit
def small_dim_vec_norm_fwd_kernel(X, Y, C, stride_xym, stride_xyd, stride_cm, M, BLOCK_M: tl.constexpr, BLOCK_D: tl.constexpr):
    pid_m = tl.program_id(0)
    
    off = pid_m.to(tl.int64) * BLOCK_M + tl.arange(0, BLOCK_M)
    mask = off < M
    off_xy = off[:, None] * stride_xym + tl.arange(0, BLOCK_D)[None, :] * stride_xyd
    off_c = off * stride_cm
    
    x = tl.load(X + off_xy, mask=mask[:, None], other=0).to(tl.float32)
    
    c = tl.sqrt(tl.sum(x * x, 1))
    y = x / (c[:, None] + 1e-7)
    
    tl.store(Y + off_xy, y.to(Y.dtype.element_ty), mask=mask[:, None])
    tl.store(C + off_c, c.to(C.dtype.element_ty), mask=mask)


@triton.jit
def small_dim_vec_norm_bwd_kernel(X, C, dY, dX, stride_xym, stride_xyd, stride_cm, M, BLOCK_M: tl.constexpr, BLOCK_D: tl.constexpr):
    pid_m = tl.program_id(0)
    
    off = pid_m.to(tl.int64) * BLOCK_M + tl.arange(0, BLOCK_M)
    mask = off < M
    off_xy = off[:, None] * stride_xym + tl.arange(0, BLOCK_D)[None, :] * stride_xyd
    off_c = off * stride_cm
    
    x = tl.load(X + off_xy, mask=mask[:, None], other=0).to(tl.float32)
    c = tl.load(C + off_c, mask=mask, other=0).to(tl.float32)
    dy = tl.load(dY + off_xy, mask=mask[:, None], other=0).to(tl.float32)
    
    dx = dy / (c[:, None] + 1e-7) - tl.sum(dy * x, 1)[:, None] * x / (c * c * c + 1e-7)[:, None]
    
    tl.store(dX + off_xy, dx.to(dX.dtype.element_ty), mask=mask[:, None])


def small_dim_vec_norm_fwd(X):
    assert (len(X.shape) == 2)
    
    M = X.shape[0]
    DIM = X.shape[1]
    BLOCK_D = DIM
    BLOCK_M = 128
    num_stages = 4
    num_warps = 8
    
    Y = torch.zeros((M, DIM), dtype=X.dtype, device=X.device)
    C = torch.zeros(M, dtype=torch.float32, device=X.device)
    
    grid = (triton.cdiv(M, BLOCK_M), 1, 1)
    small_dim_vec_norm_fwd_kernel[grid](X, Y, C, 
                                        X.stride(0), X.stride(1), 
                                        C.stride(0), 
                                        M, BLOCK_M, BLOCK_D, 
                                        num_stages=num_stages, num_warps=num_warps)
    
    return Y, C


def small_dim_vec_norm_bwd(X, C, dY):
    assert (len(X.shape) == 2)
    
    M = X.shape[0]
    DIM = X.shape[1]
    BLOCK_D = DIM
    BLOCK_M = 128
    num_stages = 4
    num_warps = 8
    
    dX = torch.zeros((M, DIM), dtype=X.dtype, device=X.device)
    
    grid = (triton.cdiv(M, BLOCK_M), 1, 1)
    small_dim_vec_norm_bwd_kernel[grid](X, C, dY, dX, 
                                        X.stride(0), X.stride(1), 
                                        C.stride(0), 
                                        M, BLOCK_M, BLOCK_D, 
                                        num_stages=num_stages, num_warps=num_warps)
    
    return dX


class _small_dim_vec_norm(torch.autograd.Function):

    @staticmethod
    def forward(ctx, X):
        assert (X.is_contiguous())
        ori_shape = X.shape
        ori_stride = X.stride()
        Y, C = small_dim_vec_norm_fwd(X.view(-1, X.shape[-1]))
        ctx.save_for_backward(X, C)
        ctx.ori_shape = ori_shape
        ctx.ori_stride = ori_stride
        return Y.view(ori_shape)

    @staticmethod
    def backward(ctx, dY):
        X, C = ctx.saved_tensors
        dY = dY.contiguous()
        assert (dY.shape == ctx.ori_shape and dY.stride() == ctx.ori_stride)
        dX = small_dim_vec_norm_bwd(X.view(-1, dY.shape[-1]), C, dY.view(-1, dY.shape[-1]))
        return dX.view(ctx.ori_shape)


small_dim_vec_norm = _small_dim_vec_norm.apply


class _attention_with_cos_norm(torch.autograd.Function):

    @staticmethod
    def forward(ctx, Q_, K_, V, chunk_size, sm_scale):
        assert (Q_.is_contiguous() and K_.is_contiguous())
        assert (Q_.shape == K_.shape)
        
        ori_shape = Q_.shape
        ori_stride = Q_.stride()
        Q, _ = small_dim_vec_norm_fwd(Q_.view(-1, Q_.shape[-1]))
        K, _ = small_dim_vec_norm_fwd(K_.view(-1, K_.shape[-1]))
        Q = Q.view(ori_shape)
        K = K.view(ori_shape)
        
        assert Q.stride() == K.stride() == V.stride()
    
        Z = Q.shape[0]
        H = Q.shape[1]
        N_CTX = Q.shape[2]
        D = Q.shape[3]
        
        K_n = (N_CTX // chunk_size) * chunk_size
        K_floor = K[:, :, 0:K_n, :].contiguous()
        
        chunk_idx = sc_top_1_idx(Q, K_floor, sm_scale, chunk_size).reshape(Z * H, N_CTX - chunk_size)
        chunk_count = torch.zeros(Z * H, N_CTX // chunk_size, device=chunk_idx.device, dtype=torch.int32)
        chunk_utils.chunk_count(chunk_idx, chunk_count)
        chunk_count = chunk_count.long()
        chunk_cum_count = torch.cumsum(chunk_count, dim=-1).int()
        assert (chunk_cum_count[-1, -1].item() == N_CTX - chunk_size)
        chunk_pos = torch.empty((Z * H, N_CTX - chunk_size), device=chunk_idx.device, dtype=torch.long)
        chunk_utils.chunk_pos(chunk_cum_count.clone(), chunk_idx, chunk_pos)
        chunk_pos = chunk_pos.reshape(Z, H, N_CTX - chunk_size)
        
        offset = chunk_cum_count.reshape(Z, H, N_CTX // chunk_size)
        index = chunk_pos
        index.add_(chunk_size)
        
        O, R = sc_l_chunk_s1_fwd(Q, K, V, offset, index, sm_scale, chunk_size)
        R[:, :, 0] = 0

        ctx.save_for_backward(Q_, K_, V, O, R, offset, index)
        ctx.ori_shape = ori_shape
        ctx.ori_stride = ori_stride
        ctx.sm_scale = sm_scale
        ctx.chunk_size = chunk_size
        return O

    @staticmethod
    def backward(ctx, dO):
        Q_, K_, V, O, R, offset, index = ctx.saved_tensors
        
        Q, Q_C = small_dim_vec_norm_fwd(Q_.view(-1, Q_.shape[-1]))
        K, K_C = small_dim_vec_norm_fwd(K_.view(-1, K_.shape[-1]))
        Q = Q.view(ctx.ori_shape)
        K = K.view(ctx.ori_shape)
        
        dO = dO.contiguous()
        assert Q.stride() == K.stride() == V.stride() == O.stride() == dO.stride()
        dQ, dK, dV = sc_l_chunk_s1_bwd(Q, K, V, O, dO, R, offset, index, ctx.sm_scale, ctx.chunk_size)
        
        dQ = dQ.contiguous()
        dK = dK.contiguous()
        assert (dQ.shape == ctx.ori_shape and dQ.stride() == ctx.ori_stride)
        assert (dK.shape == ctx.ori_shape and dK.stride() == ctx.ori_stride)
        dQ_ = small_dim_vec_norm_bwd(Q_.view(-1, dQ.shape[-1]), Q_C, dQ.view(-1, dQ.shape[-1]))
        dK_ = small_dim_vec_norm_bwd(K_.view(-1, dK.shape[-1]), K_C, dK.view(-1, dK.shape[-1]))
        dQ_ = dQ_.view(ctx.ori_shape)
        dK_ = dK_.view(ctx.ori_shape)
        
        return dQ_, dK_, dV, None, None


attention_with_cos_norm = _attention_with_cos_norm.apply


@triton.jit
def sc_top_1_idx_kernel(Q, K, I, sm_scale, 
                        stride_qz, stride_qh, stride_qm, stride_qd, 
                        stride_kz, stride_kh, stride_kc, stride_kd, 
                        stride_iz, stride_ih, stride_im, 
                        Z, H, 
                        N_CTX: tl.constexpr, 
                        CHUNK_SIZE: tl.constexpr, 
                        BLOCK_M: tl.constexpr, 
                        BLOCK_N: tl.constexpr, 
                        BLOCK_D: tl.constexpr):
    tl.static_assert(BLOCK_N == CHUNK_SIZE)
    tl.static_assert(BLOCK_M <= CHUNK_SIZE)
    
    pid_m = tl.program_id(0)
    if pid_m * BLOCK_M < CHUNK_SIZE:
        return
    
    pid_zh = tl.program_id(1)
    off_z = pid_zh // H
    off_h = pid_zh % H
    
    offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    mask_m = offs_m < N_CTX
    offs_d = tl.arange(0, BLOCK_D)
    
    Q_ptr = Q + off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh
    K_ptr = K + off_z.to(tl.int64) * stride_kz + off_h.to(tl.int64) * stride_kh
    I_ptr = I + off_z.to(tl.int64) * stride_iz + off_h.to(tl.int64) * stride_ih
    
    q = tl.load(Q_ptr + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd, mask=mask_m[:, None])
    
    m_val = tl.full((BLOCK_M, ), value=-float("inf"), dtype=tl.float32)
    m_idx = tl.full((BLOCK_M, ), value=-1, dtype=tl.int32)
    
    MAX_K = (pid_m * BLOCK_M) // CHUNK_SIZE
    
    for start_k in range(0, MAX_K):
        offs_k = start_k * CHUNK_SIZE + tl.arange(0, CHUNK_SIZE)
        k = tl.load(K_ptr + offs_k[:, None] * stride_kc + offs_d[None, :] * stride_kd)
        s = tl.dot(q, tl.trans(k)).to(tl.float32) * sm_scale.to(tl.float32)
        m = tl.max(s, 1)
        e = tl.exp(s - m[:, None])
        es = tl.sum(e, 1)
        les = tl.log(es)
        l = les + m
        m_lm = l > m_val
        m_val = tl.maximum(m_val, l)
        m_idx = m_idx * (1 - m_lm) + start_k * m_lm
    
    offs_i = offs_m - CHUNK_SIZE
    mask_i = mask_m
    
    tl.store(I_ptr + offs_i * stride_im, m_idx.to(I.dtype.element_ty), mask=mask_i)


def sc_top_1_idx(Q, K, sm_scale, chunk_size):
    assert (len(Q.shape) == 4 and len(K.shape) == 4)
    assert (Q.shape[0] == K.shape[0])
    assert (Q.shape[1] == K.shape[1])
    assert (Q.shape[2] == K.shape[2])
    assert (Q.shape[3] == K.shape[3])
    assert (Q.shape[3] in [2, 4, 8, 16, 32, 64, 128, 256])
    
    Z = Q.shape[0]
    H = Q.shape[1]
    N_CTX = Q.shape[2]
    CHUNK_SIZE = chunk_size
    D = Q.shape[3]
    BLOCK_M = chunk_size // 1
    BLOCK_N = chunk_size
    BLOCK_D = D
    
    I = torch.full((Z, H, N_CTX - CHUNK_SIZE), -1, dtype=torch.long, device=Q.device)
    
    grid = (triton.cdiv(N_CTX, BLOCK_M), Z * H, 1)
    
    sc_top_1_idx_kernel[grid](Q, K, I, sm_scale, 
                              Q.stride(0), Q.stride(1), Q.stride(2), Q.stride(3), 
                              K.stride(0), K.stride(1), K.stride(2), K.stride(3), 
                              I.stride(0), I.stride(1), I.stride(2), 
                              Z, H, N_CTX, CHUNK_SIZE, BLOCK_M, BLOCK_N, BLOCK_D, num_stages=4, num_warps=4)
    
    return I


def main():
    return


if __name__ == '__main__':
    main()
