from typing import Optional

import math
import torch
import triton
import triton.language as tl


def forward(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, u: torch.Tensor, qk_scale: float):
    B, C, D = q.size()
    _B, T, _D = k.size()
    assert(B == _B)
    assert(D == _D)
    assert(u.size(1) == T)
    assert(T % C == 0)
    w = torch.empty(B, C, C, device=q.device, dtype=q.dtype)
    lse = torch.empty(B, C, device=q.device, dtype=torch.float)
    delta_flash_attn_compileable(q, k, v, u, w, lse, qk_scale)
    return w, lse


def backward_naive(ndo: torch.Tensor, dw: torch.Tensor, lse: torch.Tensor, q: torch.tensor, k: torch.Tensor, u: torch.Tensor, qk_scale: float):
    r"""
    ndo: [B, C, D] = -do
    dw: [B, C, C]
    lse: [B, C]
    q: [B, C, D]
    k: [B, T, D]
    u: [B, T-C, D]
    """
    B, T, D = k.shape
    _B, C, _D = q.shape
    # p[ B, C, T ] = [ B, C, D ] * [ B, D, T ]
    p = torch.exp2(torch.bmm(q, k.transpose(1, 2)) * qk_scale - lse[:, :, None]).tril(T - C - 1)
    if T > C:
        # du[B, T-C, D] = [B, T-C, C] * [B, C, D]
        du = -torch.bmm(p[:, :, :T - C].transpose(1, 2), ndo)
        # dp[B, C, T-C] = [B, C, D] * [B, D, T-C]
        dp = -torch.bmm(ndo, u.transpose(1, 2))
        # dp[B, C, T] = cat([B, C, T-C], [B, C, C])
        dp = torch.cat([dp, dw], dim=-1)
    else:
        du = None
        dp = dw
    row_dot = torch.sum(p * dp, dim=-1, keepdim=True)
    da = p * (dp - row_dot)
    # dq[B, C, D] = [B, C, T] * [B, T, D]
    dq = torch.bmm(da, k)
    # dk[B, T, D] = [B, T, C] * [B, C, D]
    dk = torch.bmm(da.transpose(1, 2), q)
    return dq, dk, du


def backward_u_chunk(q: torch.Tensor, k: torch.Tensor, lse: torch.Tensor, grad_v: torch.Tensor, fa_scale: float):
    r"""
    Perform softmax on each columns of q kt using lse and then mul grad_v
    q: [B, C, D]
    k: [B, T, D]
    lse: [B, T]
    grad_v: [B, T, D]
    """
    B, C, D = q.size()
    _B, T, _D = k.size()
    grad_u = torch.empty_like(q)

    def grid(META):
        return (
            triton.cdiv(C, META['BLOCK_C']),
            B
        )
    backward_u_chunk_kernel[grid](
        grad_u,
        grad_u.stride(0),
        grad_u.stride(1),
        grad_u.stride(2),
        q,
        q.stride(0),
        q.stride(1),
        q.stride(2),
        k,
        k.stride(0),
        k.stride(1),
        k.stride(2),
        grad_v,
        grad_v.stride(0),
        grad_v.stride(1),
        grad_v.stride(2),
        lse,
        lse.stride(0),
        lse.stride(1),
        B, T, C, D, fa_scale
    )
    return grad_u


def backward_k_naive(k, u, lse, grad_v, qk_scale, fa_scale):
    B, T, D = k.shape
    dp = torch.bmm(-grad_v, u.transpose(1, 2)).tril(-1)
    qk= torch.bmm(k, k.transpose(1, 2)) * fa_scale
    p = torch.exp2(qk - lse[:, :, None]).tril(-1)
    row_dot = torch.sum(p * dp, dim=-1, keepdim=True)
    da = p * (dp - row_dot)
    dq = torch.bmm(da + da.transpose(1, 2), k * qk_scale)
    return dq


def backward_k(
    k: torch.Tensor, u: torch.Tensor, lse: torch.Tensor, 
    grad_v: torch.Tensor, 
    qk_scale: float, fa_scale: float
):
    r"""
    Full flash attention backward qk.
    Because A = K * K^T, grad_k = (grad_a + grad_a^T) * K
    k, u, grad_v: [B, T, D]
    lse: [B, T]
    """
    B, T, D = k.size()
    row_dot_sum = torch.empty_like(lse)
    def grid_bp(META):
        return (
            triton.cdiv(T, META['BLOCK_C']),
            B
        )
    # row_dot_sum = p * dp = (softmax(k * k^t) * tril(-grad_v * u^t))
    backward_p_row_sum_kernel[grid_bp](
        row_dot_sum,
        row_dot_sum.stride(0),
        row_dot_sum.stride(1),
        k,
        k.stride(0),
        k.stride(1),
        k.stride(2),
        grad_v,
        grad_v.stride(0),
        grad_v.stride(1),
        grad_v.stride(2),
        u,
        u.stride(0),
        u.stride(1),
        u.stride(2),
        lse,
        lse.stride(0),
        lse.stride(1),
        B, T, D,
        fa_scale
    )
    grad_k = torch.empty_like(k)
    # grad_k = (grad_a + grad_a^T) * K
    # grad_a = p * (dp - row_dot) = sofmax(k * k^t) * (tril(-grad_v * u^t) - row_dot)
    backward_k_kernel[grid_bp](
        grad_k,
        grad_k.stride(0),
        grad_k.stride(1),
        grad_k.stride(2),
        k,
        k.stride(0),
        k.stride(1),
        k.stride(2),
        grad_v,
        grad_v.stride(0),
        grad_v.stride(1),
        grad_v.stride(2),
        u,
        u.stride(0),
        u.stride(1),
        u.stride(2),
        lse,
        lse.stride(0),
        lse.stride(1),
        row_dot_sum,
        row_dot_sum.stride(0),
        row_dot_sum.stride(1),
        B, T, D,
        fa_scale, qk_scale
    )
    # backward_k_naive(k, u, lse, grad_v, qk_scale, fa_scale)
    return grad_k

# @torch.library.custom_op("deltattn::flash_attn_fwd", mutates_args={"u", "w", "rowmax", "rowsum"})
def delta_flash_attn_compileable(
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        u: torch.Tensor,
        w: torch.Tensor,
        lse: torch.Tensor,
        qk_scale: float,
) -> None:
    r"""
    q and v are of shape [B, C, D]
    k and u are of shape [B, T, D]
    w is [ B, C, C ]
    lse is [ B, C ]
    """
    B, C, D = q.size()
    _B, T, _D = k.size()

    def grid(META):
        return (
            triton.cdiv(C, META['BLOCK_C']),
            B
        )
    flash_attn_kernel[grid](
            q,
            q.stride(0),
            q.stride(1),
            q.stride(2),
            k,
            k.stride(0),
            k.stride(1),
            k.stride(2),
            v,
            v.stride(0),
            v.stride(1),
            v.stride(2),
            u,
            u.stride(0),
            u.stride(1),
            u.stride(2),
            w,
            w.stride(0),
            w.stride(1),
            lse,
            lse.stride(0),
            B, T, C, D, qk_scale)


def _config_delta_flash_attn():
    return [
        triton.Config(
            {'BLOCK_C': BC, 'BLOCK_T': BT}, num_stages=ns, num_warps=nw
        )
        for BC in [128]
        for BT in [64]
        for ns in [3]
        for nw in [8]
    ]



@triton.autotune(
    configs=_config_delta_flash_attn(),
    key=['C', 'D']
)
@triton.jit
def flash_attn_kernel(
    q_ptr,
    stride_qh,
    stride_qc,
    stride_qd,
    k_ptr,
    stride_kh,
    stride_kt,
    stride_kd,
    v_ptr,
    stride_vh,
    stride_vc,
    stride_vd,
    u_ptr,
    stride_uh,
    stride_ut,
    stride_ud,
    w_ptr,
    stride_wh,
    stride_wc,
    lse_ptr,
    stride_lse_r,
    B, T, C, D: tl.constexpr,
    qk_scale: float,
    BLOCK_C: tl.constexpr,
    BLOCK_T: tl.constexpr,
):
    pid_c = tl.program_id(axis=0)
    pid_b = tl.program_id(axis=1)

    rowid_block = tl.arange(0, BLOCK_C) + pid_c * BLOCK_C
    colid_block = tl.arange(0, BLOCK_T)

    rowmax = tl.zeros([BLOCK_C], dtype=tl.float32) - float('inf')
    rowsum = tl.zeros([BLOCK_C], dtype=tl.float32) + 1
    acc = tl.zeros([BLOCK_C, D], dtype=tl.float32)
    q_blk_ptr = tl.make_block_ptr(
        base=q_ptr + pid_b * stride_qh,
        shape=(C, D),
        strides=(stride_qc, stride_qd),
        offsets=(pid_c * BLOCK_C, 0),
        block_shape=(BLOCK_C, D),
        order=(1, 0),
    )  # [BLOCK_C, D]
    q = tl.load(q_blk_ptr)

    for kv_i in range(0, T, BLOCK_T):
        k_blk_ptr = tl.make_block_ptr(
            base=k_ptr + pid_b * stride_kh,
            shape=(D, T),
            strides=(stride_kd, stride_kt),
            offsets=(0, kv_i),
            block_shape=(D, BLOCK_T),
            order=(0, 1),
        )
        k = tl.load(k_blk_ptr)
        qk = tl.dot(q, k) * qk_scale

        if kv_i >= T - C:
            mask = (T - C - kv_i + rowid_block[:, None] - colid_block[None, :] < 1)
            qk = tl.where(mask, -1e6, qk)

        rowmax_i = tl.maximum(rowmax, tl.max(qk, axis=1))
        qk -= rowmax_i[:, None]
        p = tl.math.exp2(qk)

        rowsum_i = tl.sum(p, axis=1)
        alpha = tl.math.exp2(rowmax - rowmax_i)
        rowsum = rowsum * alpha + rowsum_i
        acc = acc * alpha[:, None]
        rowmax = rowmax_i

        if kv_i < T - C:
            u_blk_ptr = tl.make_block_ptr(
                base=u_ptr + pid_b * stride_uh,
                shape=(T, D),
                strides=(stride_ut, stride_ud),
                offsets=(kv_i, 0),
                block_shape=(BLOCK_T, D),
                order=(1, 0),
            )
            u = tl.load(u_blk_ptr)
            acc = tl.dot(p.to(u_ptr.dtype.element_ty), u, acc)

    lse = rowmax + tl.math.log2(rowsum)
    lse_block_ptr = lse_ptr + stride_lse_r * pid_b + rowid_block
    tl.store(lse_block_ptr, lse)

    v_ptr = tl.make_block_ptr(
        base=v_ptr + pid_b * stride_vh,
        shape=(C, D),
        strides=(stride_vc, stride_vd),
        offsets=(pid_c * BLOCK_C, 0),
        block_shape=(BLOCK_C, D),
        order=(1, 0),
    )
    acc = acc / rowsum[:, None]
    v = tl.load(v_ptr)
    u = v - acc.to(v_ptr.dtype.element_ty)
    u_block_ptr = tl.make_block_ptr(
        base=u_ptr + pid_b * stride_uh,
        shape=(T, D),
        strides=(stride_ut, stride_ud),
        offsets=(T - C + pid_c * BLOCK_C, 0),
        block_shape=(BLOCK_C, D),
        order=(1, 0),
    )
    tl.store(u_block_ptr, u)

    for kv_i in range(T - C, T, BLOCK_T):
        k_blk_ptr = tl.make_block_ptr(
            base=k_ptr + pid_b * stride_kh,
            shape=(D, T),
            strides=(stride_kd, stride_kt),
            offsets=(0, kv_i),
            block_shape=(D, BLOCK_T),
            order=(0, 1),
        )
        k = tl.load(k_blk_ptr)
        qk = tl.dot(q, k) * qk_scale

        mask = (T - C - kv_i + rowid_block[:, None] - colid_block[None, :] < 1)
        qk = tl.where(mask, -1e6, qk)
        qk -= rowmax[:, None]
        p = tl.math.exp2(qk) / rowsum[:, None]
        w_blk_ptr = tl.make_block_ptr(
            base=w_ptr + pid_b * stride_wh,
            shape=(C, C),
            strides=(stride_wc, 1),
            offsets=(pid_c * BLOCK_C, kv_i - (T - C)),
            block_shape=(BLOCK_C, BLOCK_T),
            order=(1, 0)
        )
        tl.store(w_blk_ptr, p.to(w_blk_ptr.dtype.element_ty))  # mask.to(tl.float32))


def _config_backward_u_chunk():
    return [
        triton.Config(
            {'BLOCK_C': BC, 'BLOCK_T': BT}, num_stages=ns, num_warps=nw
        )
        for BC in [128]
        for BT in [64]
        for ns in [3]
        for nw in [8]
    ]


@triton.autotune(
    configs=_config_backward_u_chunk(),
    key=['C', 'D']
)
@triton.jit
def backward_u_chunk_kernel(
    o_ptr,
    stride_oh,
    stride_oc,
    stride_od,
    q_ptr,
    stride_qh,
    stride_qc,
    stride_qd,
    k_ptr,
    stride_kh,
    stride_kt,
    stride_kd,
    v_ptr,
    stride_vh,
    stride_vt,
    stride_vd,
    lse_ptr,
    stride_lse_h,
    stride_lse_t,
    B, T, C, D: tl.constexpr,
    fa_scale,
    BLOCK_C: tl.constexpr,
    BLOCK_T: tl.constexpr,
):
    pid_c = tl.program_id(axis=0)
    pid_b = tl.program_id(axis=1)

    rowid_block = tl.arange(0, BLOCK_C) + pid_c * BLOCK_C
    colid_block = tl.arange(0, BLOCK_T)

    acc = tl.zeros([BLOCK_C, D], dtype=tl.float32)

    q_blk_ptr = tl.make_block_ptr(
        base=q_ptr + pid_b * stride_qh,
        shape=(C, D),
        strides=(stride_qc, stride_qd),
        offsets=(pid_c * BLOCK_C, 0),
        block_shape=(BLOCK_C, D),
        order=(1, 0),
    )  # [BLOCK_C, D]
    q = tl.load(q_blk_ptr)

    for kv_i in range(0, T, BLOCK_T):
        k_blk_ptr = tl.make_block_ptr(
            base=k_ptr + pid_b * stride_kh,
            shape=(D, T),
            strides=(stride_kd, stride_kt),
            offsets=(0, kv_i),
            block_shape=(D, BLOCK_T),
            order=(0, 1),
        )
        k = tl.load(k_blk_ptr)
        qk = tl.dot(q, k) * fa_scale

        lse_blk_ptr = tl.make_block_ptr(
            base=lse_ptr + pid_b * stride_lse_h,
            shape=(T,),
            strides=(stride_lse_t,),
            offsets=(kv_i,),
            block_shape=(BLOCK_T,),
            order=(0,),
        )
        lse = tl.load(lse_blk_ptr)
        qk = tl.dot(q, k) * fa_scale
        p = tl.math.exp2(qk - lse[None, :])

        v_blk_ptr = tl.make_block_ptr(
            base=v_ptr + pid_b * stride_vh,
            shape=(T, D),
            strides=(stride_vt, stride_vd),
            offsets=(kv_i, 0),
            block_shape=(BLOCK_T, D),
            order=(1, 0),
        )
        v = tl.load(v_blk_ptr)
        acc = tl.dot(p.to(v_ptr.dtype.element_ty), v, acc)
        
    o_blk_ptr = tl.make_block_ptr(
        base=o_ptr + pid_b * stride_oh,
        shape=(C, D),
        strides=(stride_oc, stride_od),
        offsets=(pid_c * BLOCK_C, 0),
        block_shape=(BLOCK_C, D),
        order=(1, 0),
    )
    tl.store(o_blk_ptr, acc.to(o_ptr.dtype.element_ty))


def _config_backward_p_row_sum():
    return [
        triton.Config(
            {'BLOCK_C': BC, 'BLOCK_T': BT}, num_stages=ns, num_warps=nw
        )
        for BC in [128]
        for BT in [64]
        for ns in [4]
        for nw in [8]
    ]


@triton.autotune(
    configs=_config_backward_p_row_sum(),
    key=['T', 'D']
)
@triton.jit
def backward_p_row_sum_kernel(
    row_dot_ptr,
    stride_row_dot_h,
    stride_row_dot_t,
    k_ptr,
    stride_kh,
    stride_kt,
    stride_kd,
    grad_v_ptr,
    stride_grad_vh,
    stride_grad_vt,
    stride_grad_vd,
    u_ptr,
    stride_uh,
    stride_ut,
    stride_ud,
    lse_ptr,
    stride_lse_h,
    stride_lse_t,
    B, T, D: tl.constexpr,
    fa_scale,
    BLOCK_C: tl.constexpr,
    BLOCK_T: tl.constexpr,
):
    pid_c = tl.program_id(axis=0)
    pid_b = tl.program_id(axis=1)

    rowid_block = tl.arange(0, BLOCK_C) + pid_c * BLOCK_C
    colid_block = tl.arange(0, BLOCK_T)

    acc = tl.zeros([BLOCK_C], dtype=tl.float32)

    k_row_blk_ptr = tl.make_block_ptr(
        base=k_ptr + pid_b * stride_kh,
        shape=(T, D),
        strides=(stride_kt, stride_kd),
        offsets=(pid_c * BLOCK_C, 0),
        block_shape=(BLOCK_C, D),
        order=(1, 0),
    )  # [BLOCK_C, D]
    k_row = tl.load(k_row_blk_ptr)
    lse_blk_ptr = tl.make_block_ptr(
        base=lse_ptr + pid_b * stride_lse_h,
        shape=(T,),
        strides=(stride_lse_t,),
        offsets=(pid_c * BLOCK_C,),
        block_shape=(BLOCK_C,),
        order=(0,),
    )
    lse = tl.load(lse_blk_ptr)
    grad_v_blk_ptr = tl.make_block_ptr(
        base=grad_v_ptr + pid_b * stride_grad_vh,
        shape=(T, D),
        strides=(stride_grad_vt, stride_grad_vd),
        offsets=(pid_c * BLOCK_C, 0),
        block_shape=(BLOCK_C, D),
        order=(1, 0),
    )  # [BLOCK_C, D]
    grad_v_row = -tl.load(grad_v_blk_ptr)

    for kv_i in range(0, (pid_c + 1) * BLOCK_C, BLOCK_T):
        k_blk_ptr = tl.make_block_ptr(
            base=k_ptr + pid_b * stride_kh,
            shape=(D, T),
            strides=(stride_kd, stride_kt),
            offsets=(0, kv_i),
            block_shape=(D, BLOCK_T),
            order=(0, 1),
        )
        k = tl.load(k_blk_ptr)
        qk = tl.dot(k_row, k) * fa_scale
        p = tl.math.exp2(qk - lse[:, None])

        u_blk_ptr = tl.make_block_ptr(
            base=u_ptr + pid_b * stride_uh,
            shape=(D, T),
            strides=(stride_ud, stride_ut),
            offsets=(0, kv_i),
            block_shape=(D, BLOCK_T),
            order=(0, 1),
        )
        ut = tl.load(u_blk_ptr)
        dp = tl.dot(grad_v_row, ut)
        if kv_i + BLOCK_T >= pid_c * BLOCK_C:
            mask = (rowid_block[:, None] <= colid_block[None, :] + kv_i)
            p = tl.where(mask, 0., p)
            dp = tl.where(mask, 0., dp)
        acc += tl.sum(p * dp, axis=1)
    row_dot_block_ptr = tl.make_block_ptr( 
        base=row_dot_ptr + pid_b * stride_row_dot_h,
        shape=(T,),
        strides=(stride_row_dot_t,),
        offsets=(pid_c * BLOCK_C,),
        block_shape=(BLOCK_C,),
        order=(0,),
    )
    tl.store(row_dot_block_ptr, acc)


def _config_backward_k():
    return [
        triton.Config(
            {'BLOCK_C': BC}, num_stages=ns, num_warps=nw
        )
        for BC in [64]
        for ns in [4]
        for nw in [4]
    ]


@triton.autotune(
    configs=_config_backward_k(),
    key=['T', 'D']
)
@triton.jit
def backward_k_kernel(
    grad_k_ptr,
    stride_grad_kh,
    stride_grad_kt,
    stride_grad_kd,
    k_ptr,
    stride_kh,
    stride_kt,
    stride_kd,
    grad_v_ptr,
    stride_grad_vh,
    stride_grad_vt,
    stride_grad_vd,
    u_ptr,
    stride_uh,
    stride_ut,
    stride_ud,
    lse_ptr,
    stride_lse_h,
    stride_lse_t,
    row_dot_ptr,
    stride_row_dot_h,
    stride_row_dot_t,
    B, T,
    D: tl.constexpr,
    fa_scale: tl.constexpr,
    qk_scale: tl.constexpr,
    BLOCK_C: tl.constexpr,
):
    pid_c = tl.program_id(axis=0)
    pid_b = tl.program_id(axis=1)
    block_i = tl.arange(0, BLOCK_C)

    acc = tl.zeros([BLOCK_C, D], dtype=tl.float32)

    k_row_blk_ptr = tl.make_block_ptr(
        base=k_ptr + pid_b * stride_kh,
        shape=(T, D),
        strides=(stride_kt, stride_kd),
        offsets=(pid_c * BLOCK_C, 0),
        block_shape=(BLOCK_C, D),
        order=(1, 0),
    )  # [BLOCK_C, D]
    k_row = tl.load(k_row_blk_ptr)
    lse_blk_ptr = tl.make_block_ptr(
        base=lse_ptr + pid_b * stride_lse_h,
        shape=(T,),
        strides=(stride_lse_t,),
        offsets=(pid_c * BLOCK_C,),
        block_shape=(BLOCK_C,),
        order=(0,),
    )
    lse = tl.load(lse_blk_ptr)
    grad_v_blk_ptr = tl.make_block_ptr(
        base=grad_v_ptr + pid_b * stride_grad_vh,
        shape=(T, D),
        strides=(stride_grad_vt, stride_grad_vd),
        offsets=(pid_c * BLOCK_C, 0),
        block_shape=(BLOCK_C, D),
        order=(1, 0),
    )  # [BLOCK_C, D]
    grad_v_row = -tl.load(grad_v_blk_ptr)
    row_dot_blk_ptr = tl.make_block_ptr(
        base=row_dot_ptr + pid_b * stride_row_dot_h,
        shape=(T,),
        strides=(stride_row_dot_t,),
        offsets=(pid_c * BLOCK_C,),
        block_shape=(BLOCK_C,),
        order=(0,),
    )
    row_dot_row = tl.load(row_dot_blk_ptr).to(k_ptr.dtype.element_ty)

    # Lower-left blocks of A
    for kv_i in range(0, pid_c * BLOCK_C, BLOCK_C):
        k_blk_ptr = tl.make_block_ptr(
            base=k_ptr + pid_b * stride_kh,
            shape=(D, T),
            strides=(stride_kd, stride_kt),
            offsets=(0, kv_i),
            block_shape=(D, BLOCK_C),
            order=(0, 1),
        )
        kt = tl.load(k_blk_ptr)
        qk = tl.dot(k_row, kt) * fa_scale
        p = tl.math.exp2(qk - lse[:, None])

        u_blk_ptr = tl.make_block_ptr(
            base=u_ptr + pid_b * stride_uh,
            shape=(D, T),
            strides=(stride_ud, stride_ut),
            offsets=(0, kv_i),
            block_shape=(D, BLOCK_C),
            order=(0, 1),
        )
        ut = tl.load(u_blk_ptr)
        dp = tl.dot(grad_v_row, ut)
        da = p * (dp - row_dot_row[:, None])
        k = tl.trans(kt, 1, 0)
        acc = tl.dot(da.to(k.dtype), k, acc)

    # The block on diagno
    qk = tl.dot(k_row, tl.trans(k_row, 1, 0)) * fa_scale
    p = tl.math.exp2(qk - lse[:, None])
    u_blk_ptr = tl.make_block_ptr(
        base=u_ptr + pid_b * stride_uh,
        shape=(D, T),
        strides=(stride_ud, stride_ut),
        offsets=(0, pid_c * BLOCK_C),
        block_shape=(D, BLOCK_C),
        order=(0, 1),
    )
    ut = tl.load(u_blk_ptr)
    dp = tl.dot(grad_v_row, ut)
    dpm = dp - row_dot_row[:, None]
    mask = block_i[None, :] < block_i[:, None]
    p = tl.where(mask, p, 0.)
    dpm = tl.where(mask, dpm, 0.)
    da = p * dpm
    daat = da + tl.trans(da, 1, 0)
    acc = tl.dot(daat.to(k_row.dtype), k_row, acc)

    # Upper-right blocks of A^T
    # (K_i * K_r^T)^T = K_r * K_j^T
    # (grad_v * u_T)^T = u * grad_v_t^T
    # Note that lse and row_dot are on column directions
    nu = -tl.trans(ut, 1, 0)
    for kv_i in range((pid_c + 1) * BLOCK_C, T, BLOCK_C):
        k_blk_ptr = tl.make_block_ptr(
            base=k_ptr + pid_b * stride_kh,
            shape=(D, T),
            strides=(stride_kd, stride_kt),
            offsets=(0, kv_i),
            block_shape=(D, BLOCK_C),
            order=(0, 1),
        )
        kt = tl.load(k_blk_ptr)
        lse_blk_ptr = tl.make_block_ptr(
            base=lse_ptr + pid_b * stride_lse_h,
            shape=(T,),
            strides=(stride_lse_t,),
            offsets=(kv_i,),
            block_shape=(BLOCK_C,),
            order=(0,),
        )
        lse = tl.load(lse_blk_ptr)
        qk = tl.dot(k_row, kt) * fa_scale
        p = tl.math.exp2(qk - lse[None, :])

        grad_vt_blk_ptr = tl.make_block_ptr(
            base=grad_v_ptr + pid_b * stride_grad_vh,
            shape=(D, T),
            strides=(stride_grad_vd, stride_grad_vt),
            offsets=(0, kv_i),
            block_shape=(D, BLOCK_C),
            order=(0, 1),
        )
        grad_vt = tl.load(grad_vt_blk_ptr)
        row_dot_blk_ptr = tl.make_block_ptr(
            base=row_dot_ptr + pid_b * stride_row_dot_h,
            shape=(T,),
            strides=(stride_row_dot_t,),
            offsets=(kv_i,),
            block_shape=(BLOCK_C,),
            order=(0,),
        )
        row_dot = tl.load(row_dot_blk_ptr).to(k_ptr.dtype.element_ty)
        dp = tl.dot(nu, grad_vt)
        da = p * (dp - row_dot[None, :])
        k = tl.trans(kt, 1, 0)
        acc = tl.dot(da.to(k.dtype), k, acc)

    grad_k_blk_ptr = tl.make_block_ptr(
        base=grad_k_ptr + pid_b * stride_grad_kh,
        shape=(T, D),
        strides=(stride_grad_kt, stride_grad_kd),
        offsets=(BLOCK_C * pid_c, 0),
        block_shape=(BLOCK_C, D),
        order=(1, 0),
    )  # [BLOCK_C, D]
    acc = acc * qk_scale
    tl.store(grad_k_blk_ptr, acc.to(grad_k_ptr.dtype.element_ty))


def naive_delta_fa(q, k, v, u, scale):
    BS, T, D = k.size()
    BS, C, D = q.size()
    i = T - C
    from tril_softmax import naive_tril_softmax_2
    w = torch.bmm(q, k.transpose(1, 2)) * scale
    w = naive_tril_softmax_2(w)
    wp = w[:, :, :i]
    wi = w[:, :, i:i + C]
    ui = u[:, i:i + C, :]
    ui.copy_(v - torch.bmm(wp, u[:, :i, :]))
    return wi, None, None


def test_delta_fa(device='cuda', dtype=torch.float32):
    from utils import check_close
    torch.set_default_device(device)
    torch.set_default_dtype(dtype)
    torch.cuda.manual_seed(1)
    torch.manual_seed(1)
    B = 2
    T = 4096
    C = 256
    D = 128
    k = torch.randn(B, T, D)
    u = torch.randn(B, T, D)
    u2 = u.clone()
    q = torch.randn(B, C, D)
    v = torch.randn(B, C, D)
    w, rm, rs = delta_flash_attn(q, k, v, u, 1)
    wref, _1, _2 = naive_delta_fa(q, k, v, u2, 1)
    check_close(wref.tril(T - C - 1), w.tril(T - C - 1), name='w')
    check_close(u2, u, name='u')


def tune_delta_fa(device='cuda', dtype=torch.bfloat16):
    from utils import get_time
    torch.set_default_device(device)
    torch.set_default_dtype(dtype)
    B = 2 * 32
    T = 8192
    C = 512
    D = 128
    k = torch.randn(B, T, D)
    u = torch.randn(B, T, D)
    u2 = u.clone()
    q = torch.randn(B, C, D)
    v = torch.randn(B, C, D)
    t = get_time(lambda _: delta_flash_attn(q, k, v, u, 1))
    print(f'fa time {t * 1e3:.3f} ms')


if __name__ == "__main__":
    # test_delta_fa()
    # tune_delta_fa()
    tune_delta_fa()
