import math
import torch
import triton

from .full_backward import reference_attn_flash_bwd
from ..common import calc_dims


def sparse_attn_bwd_ref_impl(
        query_states, key_states, sparse_indices, value_states, output, grad_output, L,
        block_size_q: int, block_size_k: int, B_r: int, B_c: int, query_offset: int,
        start_sink_tokens: int, end_sink_tokens: int):
    """
    Reference sparse flashattention backward pass
    :param query_states: (bsz, num_heads, q_len, head_dim)
    :param key_states:   (bsz, num_heads, k_len, head_dim)
    :param sparse_indices: (0, num_heads, q_blocks, top_k)
    :param value_states: (bsz, num_heads, k_len, value_dim)
    :param output: (bsz, num_heads, q_len, value_dim)
    :param grad_output: (bsz, num_heads, q_len, value_dim)
    :param L: (bsz, num_heads, q_len)
    :param block_size_q: query block size
    :param block_size_k: key block size
    :param B_r: number of rows in the flash block
    :param B_c: number of columns in the flash block
    :param query_offset: offset of the query
    :param start_sink_tokens: number of sink tokens at the start of the key
    :param end_sink_tokens: number of sink tokens at the end of the key
    :return: (dQ, dK, dV)
    """
    device = query_states.device
    dtype = query_states.dtype

    bsz, num_heads, q_len, head_dim = query_states.size()
    _, _, k_len, value_dim = value_states.size()
    _, _, _, top_k = sparse_indices.size()

    assert k_len >= q_len
    assert B_r % block_size_q == 0
    assert B_c % block_size_k == 0

    (q_block_offset, q_blocks, q_start_padding, q_end_padding, k_blocks) = calc_dims(
        q_len, k_len, block_size_q, block_size_k, query_offset
    )

    T_r = triton.cdiv(q_blocks * block_size_q, B_r)
    T_c = triton.cdiv(top_k * block_size_k, B_c)

    query_states = torch.nn.functional.pad(query_states, (0, 0, q_start_padding, q_end_padding))
    query_states = query_states.reshape(bsz, num_heads, q_blocks, block_size_q, head_dim)
    key_states = torch.nn.functional.pad(key_states, (0, 0, 0, k_blocks * block_size_k - k_len))
    key_states = key_states.reshape(bsz, num_heads, k_blocks, block_size_k, head_dim)
    value_states = torch.nn.functional.pad(value_states, (0, 0, 0, k_blocks * block_size_k - k_len))
    value_states = value_states.reshape(bsz, num_heads, k_blocks, block_size_k, value_dim)

    sparse_indices = sparse_indices.to(torch.int64)
    gathered_keys = torch.gather(
        key_states.unsqueeze(2).expand(-1, -1, q_blocks, -1, -1, -1), 3,
        start_sink_tokens // block_size_k
        + sparse_indices[:, :, :, :, None, None].expand(-1, -1, -1, -1, block_size_k, head_dim)
    )  # (bsz, num_heads, q_blocks, top_k, block_size_k, head_dim)
    gathered_values = torch.gather(
        value_states.unsqueeze(2).expand(-1, -1, q_blocks, -1, -1, -1), 3,
        start_sink_tokens // block_size_k
        + sparse_indices[:, :, :, :, None, None].expand(-1, -1, -1, -1, block_size_k, value_dim)
    )  # (bsz, num_heads, q_blocks, top_k, block_size_k, value_dim)

    output = torch.nn.functional.pad(output, (0, 0, q_start_padding, q_end_padding))
    output = output.reshape(bsz, num_heads, q_blocks, block_size_q, value_dim)
    grad_output = torch.nn.functional.pad(grad_output, (0, 0, q_start_padding, q_end_padding))
    grad_output = grad_output.reshape(bsz, num_heads, q_blocks, block_size_q, value_dim)
    L = torch.nn.functional.pad(L, (q_start_padding, q_end_padding))
    L = L.reshape(bsz, num_heads, q_blocks, block_size_q)
    D = (grad_output * output).sum(dim=4)  # (bsz, num_heads, q_blocks, block_size_q)

    dQ = torch.zeros(bsz, num_heads, q_blocks, block_size_q, head_dim, device=device, dtype=dtype)
    dK = torch.zeros(bsz, num_heads, k_blocks, block_size_k, head_dim, device=device, dtype=dtype)
    dV = torch.zeros(bsz, num_heads, k_blocks, block_size_k, value_dim, device=device, dtype=dtype)

    # Handle sink tokens
    for j in range(triton.cdiv(start_sink_tokens, B_c)):
        c_block_begin = j * B_c // block_size_k
        c_block_end = min(start_sink_tokens // block_size_k, c_block_begin + B_c // block_size_k)
        c_block_len = c_block_end - c_block_begin

        for i in range(T_r):
            r_block_begin = i * B_r // block_size_q
            r_block_end = min(q_blocks, r_block_begin + B_r // block_size_q)
            r_block_len = r_block_end - r_block_begin

            K_j = key_states[:, :, c_block_begin:c_block_end]
            V_j = value_states[:, :, c_block_begin:c_block_end]

            Q_i = query_states[:, :, r_block_begin:r_block_end] / math.sqrt(head_dim)
            dQ_i = dQ[:, :, r_block_begin:r_block_end]           # (bsz, num_heads, r_block_len, block_size_q, head_dim)
            dO_i = grad_output[:, :, r_block_begin:r_block_end]  # (bsz, num_heads, r_block_len, block_size_q, value_dim)
            L_i = L[:, :, r_block_begin:r_block_end]             # (bsz, num_heads, r_block_len, block_size_q)
            D_i = D[:, :, r_block_begin:r_block_end]             # (bsz, num_heads, r_block_len, block_size_q)

            S_ij = torch.einsum('bhrqd, bhckd -> bhrqck', Q_i, K_j)

            # Apply causal mask
            k_elem_indices = (
                c_block_begin * block_size_k + torch.arange(c_block_len * block_size_k, device=device)
            ).reshape(1, 1, 1, 1, c_block_len, block_size_k)
            q_elem_indices = (
                (q_block_offset + r_block_begin) * block_size_q
                + torch.arange(r_block_len * block_size_q, device=device)
            ).reshape(1, 1, r_block_len, block_size_q, 1, 1)
            S_ij = torch.where(k_elem_indices > q_elem_indices, float('-inf'), S_ij)

            sink_end_start_indices = (
                (q_block_offset + r_block_begin + 1 + torch.arange(r_block_len, device=device)) * block_size_q
                + (block_size_k - 1)
            ) // block_size_k * block_size_k - end_sink_tokens
            S_ij = torch.where(k_elem_indices >= sink_end_start_indices[None, None, :, None, None, None], float('-inf'), S_ij)

            P_ij = torch.exp(S_ij - L_i[:, :, :, :, None, None])
            dV_j = torch.einsum('bhrqck, bhrqv -> bhckv', P_ij, dO_i)
            dP_ij = torch.einsum('bhrqv, bhckv -> bhrqck', dO_i, V_j)
            dS_ij = P_ij * (dP_ij - D_i[:, :, :, :, None, None])
            dQ_i += torch.einsum('bhrqck, bhckd -> bhrqd', dS_ij, K_j)
            dQ[:, :, r_block_begin:r_block_end].copy_(dQ_i)
            dK_j = torch.einsum('bhrqck, bhrqd -> bhckd', dS_ij, Q_i)

            dK[:, :, c_block_begin:c_block_end] += dK_j
            dV[:, :, c_block_begin:c_block_end] += dV_j

    for j in range(triton.cdiv(end_sink_tokens, B_c)):
        c_block_begin = j * B_c // block_size_k
        c_block_end = min(end_sink_tokens // block_size_k, c_block_begin + B_c // block_size_k)
        c_block_len = c_block_end - c_block_begin

        for i in range(T_r):
            r_block_begin = i * B_r // block_size_q
            r_block_end = min(q_blocks, r_block_begin + B_r // block_size_q)
            r_block_len = r_block_end - r_block_begin

            sink_end_start_indices = (
                (q_block_offset + r_block_begin + 1 + torch.arange(r_block_len, device=device)) * block_size_q
                + (block_size_k - 1)
            ) // block_size_k * block_size_k - end_sink_tokens

            keys_sink_indices = (
                sink_end_start_indices[:, None] // block_size_k + c_block_begin
                + torch.arange(c_block_len, device=device)
            )  # (r_block_len, c_block_len)
            keys_sink_indices_valid = (keys_sink_indices >= 0)  # (r_block_len, c_block_len)
            keys_sink_indices = torch.where(keys_sink_indices_valid, keys_sink_indices, 0)

            K_j = key_states[:, :, keys_sink_indices]
            K_j = torch.where(keys_sink_indices_valid[None, None, :, :, None, None], K_j, 0)
            V_j = value_states[:, :, keys_sink_indices]
            V_j = torch.where(keys_sink_indices_valid[None, None, :, :, None, None], V_j, 0)

            Q_i = query_states[:, :, r_block_begin:r_block_end] / math.sqrt(head_dim)
            dQ_i = dQ[:, :, r_block_begin:r_block_end]           # (bsz, num_heads, r_block_len, block_size_q, head_dim)
            dO_i = grad_output[:, :, r_block_begin:r_block_end]  # (bsz, num_heads, r_block_len, block_size_q, value_dim)
            L_i = L[:, :, r_block_begin:r_block_end]             # (bsz, num_heads, r_block_len, block_size_q)
            D_i = D[:, :, r_block_begin:r_block_end]             # (bsz, num_heads, r_block_len, block_size_q)

            S_ij = torch.einsum('bhrqd, bhrckd -> bhrqck', Q_i, K_j)

            # Apply causal mask
            k_elem_indices = (
                (keys_sink_indices * block_size_k)[None, None, :, None, :, None]
                + torch.arange(block_size_k, device=device)
            )
            q_elem_indices = (
                (q_block_offset + r_block_begin) * block_size_q
                + torch.arange(r_block_len * block_size_q, device=device)
            ).reshape(1, 1, r_block_len, block_size_q, 1, 1)
            S_ij = torch.where(k_elem_indices > q_elem_indices, float('-inf'), S_ij)
            S_ij[(~keys_sink_indices_valid)[None, None, :, None, :, None].expand_as(S_ij)] = float('-inf')

            P_ij = torch.exp(S_ij - L_i[:, :, :, :, None, None])
            dV_j = torch.einsum('bhrqck, bhrqv -> bhrckv', P_ij, dO_i)
            dP_ij = torch.einsum('bhrqv, bhrckv -> bhrqck', dO_i, V_j)
            dS_ij = P_ij * (dP_ij - D_i[:, :, :, :, None, None])
            dQ_i += torch.einsum('bhrqck, bhrckd -> bhrqd', dS_ij, K_j)
            dQ[:, :, r_block_begin:r_block_end].copy_(dQ_i)
            dK_j = torch.einsum('bhrqck, bhrqd -> bhrckd', dS_ij, Q_i)

            dK_j = dK_j.reshape(bsz, num_heads, r_block_len * c_block_len, block_size_k, head_dim)
            dV_j = dV_j.reshape(bsz, num_heads, r_block_len * c_block_len, block_size_k, value_dim)

            indices = keys_sink_indices.reshape(1, 1, r_block_len * c_block_len, 1, 1)
            dK.scatter_reduce_(2, indices.expand(bsz, num_heads, -1, block_size_k, head_dim), dK_j, "sum")
            dV.scatter_reduce_(2, indices.expand(bsz, num_heads, -1, block_size_k, value_dim), dV_j, "sum")

    # Handle top-k
    for j in range(T_c):
        c_block_begin = j * B_c // block_size_k
        c_block_end = min(top_k, c_block_begin + B_c // block_size_k)
        c_block_len = c_block_end - c_block_begin

        for i in range(T_r):
            r_block_begin = i * B_r // block_size_q
            r_block_end = min(q_blocks, r_block_begin + B_r // block_size_q)
            r_block_len = r_block_end - r_block_begin

            K_j = gathered_keys[:, :, r_block_begin:r_block_end, c_block_begin:c_block_end]    # (bsz, num_heads, r_block_len, c_block_len, block_size_k, head_dim)
            V_j = gathered_values[:, :, r_block_begin:r_block_end, c_block_begin:c_block_end]  # (bsz, num_heads, r_block_len, c_block_len, block_size_k, value_dim)

            Q_i = query_states[:, :, r_block_begin:r_block_end] / math.sqrt(head_dim)
            dQ_i = dQ[:, :, r_block_begin:r_block_end]           # (bsz, num_heads, r_block_len, block_size_q, head_dim)
            dO_i = grad_output[:, :, r_block_begin:r_block_end]  # (bsz, num_heads, r_block_len, block_size_q, value_dim)
            L_i = L[:, :, r_block_begin:r_block_end]             # (bsz, num_heads, r_block_len, block_size_q)
            D_i = D[:, :, r_block_begin:r_block_end]             # (bsz, num_heads, r_block_len, block_size_q)

            S_ij = torch.einsum('bhrqd, bhrckd -> bhrqck', Q_i, K_j)

            P_ij = torch.exp(S_ij - L_i[:, :, :, :, None, None])
            dV_j = torch.einsum('bhrqck, bhrqv -> bhrckv', P_ij, dO_i)
            dP_ij = torch.einsum('bhrqv, bhrckv -> bhrqck', dO_i, V_j)
            dS_ij = P_ij * (dP_ij - D_i[:, :, :, :, None, None])
            dQ_i += torch.einsum('bhrqck, bhrckd -> bhrqd', dS_ij, K_j)
            dQ[:, :, r_block_begin:r_block_end].copy_(dQ_i)
            dK_j = torch.einsum('bhrqck, bhrqd -> bhrckd', dS_ij, Q_i)

            dK_j = dK_j.reshape(bsz, num_heads, r_block_len * c_block_len, block_size_k, head_dim)
            dV_j = dV_j.reshape(bsz, num_heads, r_block_len * c_block_len, block_size_k, value_dim)

            indices = (
                start_sink_tokens // block_size_k
                + sparse_indices[:, :, r_block_begin:r_block_end, c_block_begin:c_block_end]
            ).reshape(bsz, num_heads, r_block_len * c_block_len, 1, 1)
            dK.scatter_reduce_(2, indices.expand(-1, -1, -1, block_size_k, head_dim), dK_j, "sum")
            dV.scatter_reduce_(2, indices.expand(-1, -1, -1, block_size_k, value_dim), dV_j, "sum")

    dQ = dQ.reshape(bsz, num_heads, q_blocks * block_size_q, head_dim)
    dQ = dQ[:, :, q_start_padding:q_start_padding + q_len, :] / math.sqrt(head_dim)
    dK = dK.reshape(bsz, num_heads, k_blocks * block_size_k, head_dim)[:, :, :k_len]
    dV = dV.reshape(bsz, num_heads, k_blocks * block_size_k, value_dim)[:, :, :k_len]

    return dQ, dK, dV


def sparse_attn_bwd_ref(
        query_states, key_states, sparse_indices, value_states, output, grad_output, L,
        block_size_q: int, block_size_k: int, B_r: int, B_c: int, query_offset: int,
        start_sink_tokens: int, end_sink_tokens: int):
    """
    Reference sparse flashattention backward pass
    :param query_states: (bsz, num_heads, q_len, head_dim)
    :param key_states:   (bsz, num_heads, k_len, head_dim)
    :param sparse_indices: (0, num_heads, q_blocks, top_k)
    :param value_states: (bsz, num_heads, k_len, value_dim)
    :param output: (bsz, num_heads, q_len, value_dim)
    :param grad_output: (bsz, num_heads, q_len, value_dim)
    :param L: (bsz, num_heads, q_len)
    :param block_size_q: query block size
    :param block_size_k: key block size
    :param B_r: number of rows in the flash block
    :param B_c: number of columns in the flash block
    :param query_offset: offset of the query
    :param start_sink_tokens: number of sink tokens at the start of the key
    :param end_sink_tokens: number of sink tokens at the end of the key
    :return: (dQ, dK, dV)
    """
    bsz, num_heads, q_len, head_dim = query_states.size()
    _, _, k_len, value_dim = value_states.size()
    _, _, _, top_k = sparse_indices.size()

    assert end_sink_tokens >= block_size_q
    (q_block_offset, q_blocks, q_start_padding, q_end_padding, k_blocks) = calc_dims(
        q_len, k_len, block_size_q, block_size_k, query_offset
    )
    total_attended_tokens = start_sink_tokens + top_k * block_size_k + end_sink_tokens
    sparse_q_block_begin = min(q_blocks, triton.cdiv(total_attended_tokens, block_size_q) - 1 - q_block_offset)

    cutoff = max(0, sparse_q_block_begin * block_size_q - q_start_padding)

    # Perform dense backward for the initial part
    dQ_initial, dK_initial, dV_initial = reference_attn_flash_bwd(
        query_states[:, :, :cutoff], key_states, value_states,
        output[:, :, :cutoff], grad_output[:, :, :cutoff], L[:, :, :cutoff],
        B_r=B_r, B_c=B_c, query_offset=query_offset,
    )

    dQ, dK, dV = sparse_attn_bwd_ref_impl(
        query_states[:, :, cutoff:], key_states, sparse_indices, value_states,
        output[:, :, cutoff:], grad_output[:, :, cutoff:], L[:, :, cutoff:],
        block_size_q, block_size_k, B_r, B_c, cutoff + query_offset,
        start_sink_tokens, end_sink_tokens,
    )

    dQ = torch.cat([dQ_initial, dQ], dim=2)
    dK = dK + dK_initial
    dV = dV + dV_initial

    return dQ, dK, dV
