import math
import torch
import triton


def reference_attn_flash(
        query_states, key_states, value_states, B_r: int, B_c: int, query_offset: int):
    """
    Reference flashattention
    :param query_states: (bsz, num_heads, q_len, head_dim)
    :param key_states:   (bsz, num_heads, k_len, head_dim)
    :param value_states: (bsz, num_heads, k_len, value_dim)
    :param B_r: number of rows in the flash block
    :param B_c: number of columns in the flash block
    :param query_offset: offset of the query
    :return: output (bsz, num_heads, q_len, value_dim)
    """
    device = query_states.device
    dtype = query_states.dtype

    bsz, num_heads, q_len, head_dim = query_states.size()
    _, _, k_len, value_dim = value_states.size()

    T_r = triton.cdiv(q_len, B_r)
    T_c = triton.cdiv(k_len, B_c)

    O = torch.zeros(bsz, num_heads, q_len, value_dim, dtype=dtype, device=device)
    L = torch.zeros(bsz, num_heads, q_len, dtype=dtype, device=device)

    for i in range(T_r):
        r_block_begin = i * B_r
        r_block_end = min(q_len, r_block_begin + B_r)
        r_block_len = r_block_end - r_block_begin

        Q_i = query_states[:, :, r_block_begin:r_block_end] / math.sqrt(head_dim)
        # (bsz, num_heads, r_block_len, head_dim)

        O_ij = torch.zeros(bsz, num_heads, r_block_len, value_dim, dtype=dtype, device=device)
        l_ij = torch.ones(bsz, num_heads, r_block_len, dtype=dtype, device=device)
        m_ij = torch.full((bsz, num_heads, r_block_len), float('-inf'), dtype=dtype, device=device)

        for j in range(T_c):
            c_block_begin = j * B_c
            c_block_end = min(k_len, c_block_begin + B_c)
            c_block_len = c_block_end - c_block_begin

            K_j = key_states[:, :, c_block_begin:c_block_end]
            V_j = value_states[:, :, c_block_begin:c_block_end]
            # (bsz, num_heads, c_block_len, head_dim)

            S_ij = torch.einsum('nhrd, nhcd -> nhrc', Q_i, K_j)

            # Apply causal mask
            k_elem_indices = torch.arange(c_block_begin, c_block_end, device=device)[None, :]
            q_elem_indices = query_offset + torch.arange(r_block_begin, r_block_end, device=device)[:, None]
            S_ij[(k_elem_indices > q_elem_indices).expand_as(S_ij)] = float('-inf')

            m_ijm1 = m_ij
            m_ij = torch.maximum(m_ij, S_ij.amax(3))  # r
            P_tilde_ij = torch.exp(S_ij - m_ij[..., None])  # r x c
            m_diff_exp = torch.exp(m_ijm1 - m_ij)  # r
            l_ij = l_ij * m_diff_exp + P_tilde_ij.sum(3)  # r
            O_ij = O_ij * m_diff_exp[..., None] + torch.einsum('nhrc, nhcd -> nhrd', P_tilde_ij, V_j)

        O_i = (1 / l_ij)[..., None] * O_ij
        L_i = m_ij + torch.log(l_ij)

        O[:, :, r_block_begin:r_block_end] = O_i
        L[:, :, r_block_begin:r_block_end] = L_i

    return O, L
