import torch
import triton
import triton.language as tl


def is_hip():
    return triton.runtime.driver.active.get_current_target().backend == "hip"

@triton.jit
def _attn_fwd_inner(acc, l_i, m_i, q,  #
                    K_block_ptr, V_block_ptr,  #
                    R_block_ptr,  #
                    A_block_ptr,  #
                    start_m, qk_scale,  #
                    BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr,  #
                    STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr,  #
                    N_CTX: tl.constexpr, fp8_v: tl.constexpr):
    # range of values handled by this stage
    if STAGE == 1:
        lo, hi = 0, start_m * BLOCK_M
    elif STAGE == 2:
        lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
        lo = tl.multiple_of(lo, BLOCK_M)
    # causal = False
    else:
        lo, hi = 0, N_CTX
    K_block_ptr = tl.advance(K_block_ptr, (0, lo))
    # V_block_ptr = tl.advance(V_block_ptr, (lo, 0))
    # loop over k, v and update accumulator
    for start_n in range(lo, hi, BLOCK_N):
        start_n = tl.multiple_of(start_n, BLOCK_N)
        is_last_block = start_n + BLOCK_N >= hi  # 判断是否是最后一个块
        remaining = hi - start_n
        mask = tl.arange(0, BLOCK_N) < remaining
        k = tl.load(K_block_ptr)

        # 计算qk时应用相同逻辑
        qk = tl.dot(q, k)
        if is_last_block:
            qk = tl.where(mask, qk, -float('inf'))  # 掩码无效位置
        if STAGE == 2:
            mask = offs_m[:, None] >= (start_n + offs_n[None, :])
            qk += tl.where(mask, 0, -1.0e6)
        max = tl.max(qk, 1) * qk_scale 
        m_ij = tl.maximum(m_i, max)
        qk = qk * qk_scale - m_ij[:, None]
        tl.store(tl.advance(R_block_ptr, (0, start_n // BLOCK_N)), max[:, None].to(q.dtype))
        m_i = m_ij
        K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
    
    # -- update Po --
    if STAGE == 2:
        for start_n in range(0, (start_m + 1) * BLOCK_N, BLOCK_N):
            start_n = tl.multiple_of(start_n, BLOCK_N)
            row_max = tl.load(R_block_ptr)
            xi = row_max - m_i[:, None]
            row_max = tl.exp2(xi)/l_i[:, None]
            col_max = tl.max(row_max, 0)
            col_max = col_max[:, None].to(q.dtype)
            tl.store(A_block_ptr, col_max)
            A_block_ptr = tl.advance(A_block_ptr, (0, 1))
            R_block_ptr = tl.advance(R_block_ptr, (0, 1))

    elif STAGE == 3: 
        for start_n in range(lo, hi, BLOCK_N):
            start_n = tl.multiple_of(start_n, BLOCK_N)
            row_max = tl.load(R_block_ptr)
            xi = row_max - m_i[:, None]
            row_max = tl.exp2(xi)/l_i[:, None]
            col_max = tl.max(row_max, 0)
            col_max = col_max[:, None].to(q.dtype)
            tl.store(A_block_ptr, col_max)
            A_block_ptr = tl.advance(A_block_ptr, (0, 1))
            R_block_ptr = tl.advance(R_block_ptr, (0, 1))

    return acc, l_i, m_i


@triton.jit
def _attn_fwd(Q, K, V, sm_scale, M, Out,  #
              R, Po,
              stride_qz, stride_qh, stride_qm, stride_qk,  #
              stride_kz, stride_kh, stride_kn, stride_kk,  #
              stride_vz, stride_vh, stride_vk, stride_vn,  #
              stride_oz, stride_oh, stride_om, stride_on,  #
              stride_rz, stride_rh, stride_rm, stride_rn,  #
              stride_poz, stride_poh, stride_pom, stride_pon,  #
              Z, H, N_CTX,  #
              n_rep,  #
              HEAD_DIM: tl.constexpr,  #
              BLOCK_M: tl.constexpr,  #
              BLOCK_N: tl.constexpr,  #
              N_DOWNSAMPLE: tl.constexpr,  #
              STAGE: tl.constexpr  #
              ):
    # tl.static_assert(BLOCK_N <= HEAD_DIM)
    start_m = tl.program_id(0)
    off_hz = tl.program_id(1)
    off_z = off_hz // H
    off_h = off_hz % H
    off_kvh = off_h // n_rep
    q_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh
    k_offset = off_z.to(tl.int64) * stride_kz + off_kvh.to(tl.int64) * stride_kh
    v_offset = off_z.to(tl.int64) * stride_vz + off_kvh.to(tl.int64) * stride_vh
    r_offset = off_z.to(tl.int64) * stride_rz + off_h.to(tl.int64) * stride_rh
    po_offset = off_z.to(tl.int64) * stride_poz + off_h.to(tl.int64) * stride_poh

    # block pointers
    Q_block_ptr = tl.make_block_ptr(
        base=Q + q_offset,
        shape=(N_CTX, HEAD_DIM),
        strides=(stride_qm, stride_qk),
        offsets=(start_m * BLOCK_M, 0),
        block_shape=(BLOCK_M, HEAD_DIM),
        order=(1, 0),
    )

    K_block_ptr = tl.make_block_ptr(
        base=K + k_offset,
        shape=(HEAD_DIM, N_CTX),
        strides=(stride_kk, stride_kn),
        offsets=(0, 0),
        block_shape=(HEAD_DIM, BLOCK_N),
        order=(0, 1),
    )
    O_block_ptr = tl.make_block_ptr(
        base=Out + q_offset,
        shape=(N_CTX, HEAD_DIM),
        strides=(stride_om, stride_on),
        offsets=(start_m * BLOCK_M, 0),
        block_shape=(BLOCK_M, HEAD_DIM),
        order=(1, 0),
    )
    
    R_block_ptr = tl.make_block_ptr(
        base=R + r_offset,
        shape=(N_CTX, N_DOWNSAMPLE),
        strides=(stride_rm, stride_rn),
        offsets=(start_m * BLOCK_M, 0),
        block_shape=(BLOCK_M, 1),
        order=(0, 1),

    )
    A_block_ptr = tl.make_block_ptr(
        base=Po + po_offset,
        shape=(N_DOWNSAMPLE, N_DOWNSAMPLE),
        strides=(stride_pom, stride_pon),
        offsets=(start_m, 0),
        block_shape=(1, 1),
        order=(0, 1),
    )
    # initialize offsets
    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = tl.arange(0, BLOCK_N)
    # initialize pointer to m and l
    m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
    l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
    acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
    # load scales
    qk_scale = sm_scale
    qk_scale *= 1.44269504  # 1/log(2)
    # load q: it will stay in SRAM throughout
    q = tl.load(Q_block_ptr)
    # stage 1: off-band
    # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE
    # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE
    if STAGE & 1:
        acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, None,  #
                                        R_block_ptr,  #
                                        A_block_ptr,  #
                                        start_m, qk_scale,  #
                                        BLOCK_M, HEAD_DIM, BLOCK_N,  #
                                        4 - STAGE, offs_m, offs_n, N_CTX, V.dtype.element_ty == tl.float8e5  #
                                        )
    # stage 2: on-band
    if STAGE & 2:
        # barrier makes it easier for compielr to schedule the
        # two loops independently
        acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, None,  #
                                        R_block_ptr,  #
                                        A_block_ptr,  #
                                        start_m, qk_scale,  #
                                        BLOCK_M, HEAD_DIM, BLOCK_N,  #
                                        2, offs_m, offs_n, N_CTX, V.dtype.element_ty == tl.float8e5  #
                                        )
    # epilogue
    m_i += tl.math.log2(l_i)
    acc = acc / l_i[:, None]
    m_ptrs = M + off_hz * N_CTX + offs_m
    tl.store(m_ptrs, m_i)
    tl.store(O_block_ptr, acc.to(Out.type.element_ty))

class _attention_pooling(torch.autograd.Function):

    @staticmethod
    def forward(ctx, q, k, v, causal, sm_scale, block_size):
        assert block_size in {16, 32, 64, 128}
        orig_dtype = q.dtype
        q, k, v = q.contiguous(), k.contiguous(), v.contiguous()
        HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1]
        # when v is in float8_e5m2 it is transposed.
        HEAD_DIM_V = v.shape[-1]
        assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V
        assert HEAD_DIM_K in {16, 32, 64, 128, 256}
        NUM_HEADS_Q, NUM_HEADS_K, NUM_HEADS_V = q.shape[1], k.shape[1], v.shape[1]
        assert NUM_HEADS_K == NUM_HEADS_V
        n_rep = NUM_HEADS_Q // NUM_HEADS_K
        o = torch.empty_like(q)
        BLOCK_N = block_size
        n_d = triton.cdiv(q.shape[2], BLOCK_N)
        R = torch.full((q.shape[0], q.shape[1], q.shape[2], n_d), -65504.0, device=q.device, dtype=q.dtype)
        Po = torch.zeros((q.shape[0], q.shape[1], n_d, n_d), device=q.device, dtype=q.dtype)
        stage = 3 if causal else 1
        extra_kern_args = {}
        # Tuning for AMD target
        if is_hip():
            waves_per_eu = 3 if HEAD_DIM_K <= 64 else 2
            extra_kern_args = {"waves_per_eu": waves_per_eu, "allow_flush_denorm": True}

        grid = lambda args: (triton.cdiv(q.shape[2], args["BLOCK_M"]), q.shape[0] * q.shape[1], 1)
        M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
        _attn_fwd[grid](
            q, k, v, sm_scale, M, o,  #
            R, Po,  #
            q.stride(0), q.stride(1), q.stride(2), q.stride(3),  #
            k.stride(0), k.stride(1), k.stride(2), k.stride(3),  #
            v.stride(0), v.stride(1), v.stride(2), v.stride(3),  #
            o.stride(0), o.stride(1), o.stride(2), o.stride(3),  #
            R.stride(0), R.stride(1), R.stride(2), R.stride(3),  #
            Po.stride(0), Po.stride(1), Po.stride(2), Po.stride(3),  #
            q.shape[0], q.shape[1],  #
            N_CTX=q.shape[2],  #
            n_rep=n_rep,  #
            HEAD_DIM=HEAD_DIM_K,  #
            STAGE=stage,  #
            BLOCK_M=block_size,
            BLOCK_N=block_size,
            N_DOWNSAMPLE=n_d,
            num_stages=3,
            num_warps=4,
            **extra_kern_args)
        Sum = torch.sum(Po, dim=-1, keepdim=True)
        Po.div_(Sum)
        o=o.to(orig_dtype)
        return o, Po

attn_with_pooling = _attention_pooling.apply