import math
import torch
import triton


def reference_attn_flash_bwd(
        query_states, key_states, value_states, output, grad_output, L,
        B_r: int, B_c: int, query_offset: int, begin_offset: int = 0):
    """
    Reference flashattention backward pass
    :param query_states: (bsz, num_heads, q_len, head_dim)
    :param key_states:   (bsz, num_heads, k_len, head_dim)
    :param value_states: (bsz, num_heads, k_len, value_dim)
    :param output: (bsz, num_heads, q_len, value_dim)
    :param grad_output: (bsz, num_heads, q_len, value_dim)
    :param L: (bsz, num_heads, q_len)
    :param B_r: number of rows in the flash block
    :param B_c: number of columns in the flash block
    :param query_offset: offset of the query
    :param begin_offset: offset of both the query and key
    :return: (dQ, dK, dV)
    """
    device = query_states.device
    dtype = query_states.dtype

    bsz, num_heads, q_len, head_dim = query_states.size()
    _, _, k_len, value_dim = value_states.size()

    T_r = triton.cdiv(q_len, B_r)
    T_c = triton.cdiv(k_len, B_c)

    D = (grad_output * output).sum(dim=3)  # (bsz, num_heads, q_blocks, block_size_q)

    dQ = torch.zeros(bsz, num_heads, q_len, head_dim, device=device, dtype=dtype)
    dK = torch.zeros(bsz, num_heads, k_len, head_dim, device=device, dtype=dtype)
    dV = torch.zeros(bsz, num_heads, k_len, value_dim, device=device, dtype=dtype)

    for j in range(T_c):
        c_block_begin = j * B_c
        c_block_end = min(k_len, c_block_begin + B_c)
        c_block_len = c_block_end - c_block_begin

        K_j = key_states[:, :, c_block_begin:c_block_end]  # (bsz, num_heads, c_block_len, head_dim)
        V_j = value_states[:, :, c_block_begin:c_block_end]   # (bsz, num_heads, c_block_len, value_dim)

        dK_j = torch.zeros(bsz, num_heads, c_block_len, head_dim, device=device, dtype=dtype)
        dV_j = torch.zeros(bsz, num_heads, c_block_len, value_dim, device=device, dtype=dtype)

        for i in range(T_r):
            r_block_begin = i * B_r
            r_block_end = min(q_len, r_block_begin + B_r)
            r_block_len = r_block_end - r_block_begin

            Q_i = query_states[:, :, r_block_begin:r_block_end] / math.sqrt(head_dim)  # (bsz, num_heads, r_block_len, head_dim)
            dQ_i = dQ[:, :, r_block_begin:r_block_end]           # (bsz, num_heads, r_block_len, head_dim)
            dO_i = grad_output[:, :, r_block_begin:r_block_end]  # (bsz, num_heads, r_block_len, value_dim)
            L_i = L[:, :, r_block_begin:r_block_end]             # (bsz, num_heads, r_block_len)
            D_i = D[:, :, r_block_begin:r_block_end]             # (bsz, num_heads, r_block_len)

            S_ij = torch.einsum('bhrd, bhcd -> bhrc', Q_i, K_j)

            # Apply causal mask
            k_elem_indices = begin_offset + torch.arange(c_block_begin, c_block_end, device=device)[None, :]
            q_elem_indices = begin_offset + query_offset + torch.arange(r_block_begin, r_block_end, device=device)[:, None]
            S_ij[(k_elem_indices > q_elem_indices).expand_as(S_ij)] = float('-inf')

            P_ij = torch.exp(S_ij - L_i[..., None])
            dV_j += torch.einsum('bhrc, bhrv -> bhcv', P_ij, dO_i)
            dP_ij = torch.einsum('bhrv, bhcv -> bhrc', dO_i, V_j)
            dS_ij = P_ij * (dP_ij - D_i[..., None])
            dQ_i += torch.einsum('bhrc, bhcd -> bhrd', dS_ij, K_j)
            dQ[:, :, r_block_begin:r_block_end].copy_(dQ_i)
            dK_j += torch.einsum('bhrc, bhrd -> bhcd', dS_ij, Q_i)

        dK[:, :, c_block_begin:c_block_end] = dK_j
        dV[:, :, c_block_begin:c_block_end] = dV_j

    dQ /= math.sqrt(head_dim)
    return dQ, dK, dV
