import math
import torch
import triton
import einx
import torch.utils.checkpoint

from .full_forward import reference_attn_flash
from ..common import calc_dims


def sparse_attn_ref_impl(query_states, key_states, sparse_indices, value_states,
                         block_size_q: int, block_size_k: int, query_offset: int,
                         start_sink_tokens: int, end_sink_tokens: int):
    """
    Reference sparse attention
    :param query_states: (bsz, num_heads, q_len, head_dim)
    :param key_states:   (bsz, num_heads, k_len, head_dim)
    :param sparse_indices: (bsz, num_heads, q_blocks, top_k) or
        tuple(tensor(bsz, num_heads, q_blocks, top_k, n_candidates), tensor(bsz, num_heads, q_blocks, n_candidates))
        for soft masking (used for training)
    :param value_states: (bsz, num_heads, k_len, value_dim)
    :param block_size_q: query block size
    :param block_size_k: key block size
    :param query_offset: offset of the query
    :param start_sink_tokens: number of sink tokens at the start of the key
    :param end_sink_tokens: number of sink tokens at the end of the key
    :return: output (bsz, num_heads, q_len, value_dim)
    """
    device = query_states.device
    dtype = query_states.dtype
    NEG_INF = torch.finfo(dtype).min

    bsz, num_heads, q_len, head_dim = query_states.size()
    _, _, k_len, value_dim = value_states.size()
    is_soft_mask = isinstance(sparse_indices, tuple)
    if is_soft_mask:
        permute, sparse_indices = sparse_indices
        _, _, _, top_k, n_candidates = permute.size()
    else:
        _, _, _, top_k = sparse_indices.size()

    (q_block_offset, q_blocks, q_start_padding, q_end_padding, k_blocks) = calc_dims(
        q_len, k_len, block_size_q, block_size_k, query_offset
    )

    query_states = torch.nn.functional.pad(query_states, (0, 0, q_start_padding, q_end_padding))
    query_states = query_states.reshape(bsz, num_heads, q_blocks, block_size_q, head_dim) / math.sqrt(head_dim)
    key_states = torch.nn.functional.pad(key_states, (0, 0, 0, k_blocks * block_size_k - k_len))
    key_states = key_states.reshape(bsz, num_heads, k_blocks, block_size_k, head_dim)
    value_states = torch.nn.functional.pad(value_states, (0, 0, 0, k_blocks * block_size_k - k_len))
    value_states = value_states.reshape(bsz, num_heads, k_blocks, block_size_k, value_dim)

    q_elem_indices = (
        q_block_offset * block_size_q + torch.arange(q_blocks * block_size_q, device=device)
    ).reshape(1, 1, q_blocks, block_size_q, 1, 1)
    sink_end_start_indices = (
        (q_block_offset + 1 + torch.arange(q_blocks, device=device)) * block_size_q + (block_size_k - 1)
    ) // block_size_k * block_size_k - end_sink_tokens

    # Handle sink tokens - start
    keys_sink_start = key_states[:, :, :start_sink_tokens // block_size_k]
    attn_weights_sink_start = torch.einsum('nhiqd, nhjkd -> nhiqjk', query_states, keys_sink_start)
    start_sink_k = torch.arange(start_sink_tokens, device=device).reshape(1, 1, 1, 1, start_sink_tokens // block_size_k, block_size_k)
    attn_weights_sink_start = torch.where(start_sink_k <= q_elem_indices, attn_weights_sink_start, NEG_INF)
    attn_weights_sink_start = torch.where(start_sink_k < sink_end_start_indices[None, None, :, None, None, None], attn_weights_sink_start, NEG_INF)

    # Handle sink tokens - end
    keys_sink_indices = sink_end_start_indices[:, None] // block_size_k + torch.arange(end_sink_tokens // block_size_k, device=device)  # (q_blocks, end_sink_tokens // block_size_k)
    keys_sink_indices_valid = (keys_sink_indices >= 0)  # (r_block_len, c_block_len)
    keys_sink_indices = torch.where(keys_sink_indices_valid, keys_sink_indices, 0)

    keys_sink_end = key_states[:, :, keys_sink_indices]  # (bsz, num_heads, q_blocks, end_sink_tokens // block_size_k, block_size_k, head_dim)
    keys_sink_end = torch.where(keys_sink_indices_valid[None, None, :, :, None, None], keys_sink_end, 0)
    attn_weights_sink_end = torch.einsum('nhiqd, nhijkd -> nhiqjk', query_states, keys_sink_end)
    end_sink_k = (
        (keys_sink_indices * block_size_k)[None, None, :, None, :, None]
        + torch.arange(block_size_k, device=device)
    )  # (1, 1, q_blocks, 1, end_sink_tokens // block_size_k, block_size_k)
    attn_weights_sink_end = torch.where(end_sink_k <= q_elem_indices, attn_weights_sink_end, NEG_INF)
    attn_weights_sink_end = torch.where(keys_sink_indices_valid[None, None, :, None, :, None], attn_weights_sink_end, NEG_INF)

    # Handle top-k keys
    if is_soft_mask:
        invalid_indices = sparse_indices < 0  # (bsz, num_heads, q_blocks, n_candidates)
        sparse_indices = torch.masked_fill(sparse_indices, invalid_indices, 0)
        sparse_indices = sparse_indices.to(torch.int64)
        gathered_keys = einx.get_at(
            'n h [i] k d, n h q c -> n h q c k d',
            key_states, start_sink_tokens // block_size_k + sparse_indices)
        gathered_keys = torch.where(~invalid_indices[:, :, :, :, None, None], gathered_keys, 0)

        attn_weights = einx.dot(
            '... q d, ... t c, ... c k d -> ... q t k',
            query_states, permute, gathered_keys,
        )  # (bsz, num_heads, q_blocks, block_size_q, top_k, block_size_k)

        top_k_indices = permute.argmax(dim=-1)  # (bsz, num_heads, q_blocks, top_k)
        sparse_indices = sparse_indices.gather(3, top_k_indices)

    else:
        sparse_indices = sparse_indices.to(torch.int64)
        gathered_keys = einx.get_at(
            'n h [i] k d, n h q t -> n h q t k d',
            key_states, start_sink_tokens // block_size_k + sparse_indices)

        attn_weights = torch.einsum('nhiqd, nhitkd -> nhiqtk', query_states, gathered_keys)
        # (bsz, num_heads, q_blocks, block_size_q, top_k, block_size_k)

    attn_weights = torch.cat([attn_weights_sink_start, attn_weights, attn_weights_sink_end], dim=4)

    # Upcast to float32 for softmax
    attn_weights = attn_weights.reshape(
        bsz, num_heads, q_blocks, block_size_q,
        (start_sink_tokens // block_size_k + top_k + end_sink_tokens // block_size_k) * block_size_k)
    attn_weights = attn_weights.softmax(dim=-1, dtype=torch.float32).to(query_states.dtype)

    gathered_values = einx.get_at(
        'n h [i] k d, n h q t -> n h q t k d',
        value_states, start_sink_tokens // block_size_k + sparse_indices)

    values_sink_start = value_states[:, :, None, :start_sink_tokens // block_size_k].expand(-1, -1, q_blocks, -1, -1, -1)
    values_sink_end = value_states[:, :, keys_sink_indices]
    values_sink_end = torch.where(keys_sink_indices_valid[None, None, :, :, None, None], values_sink_end, 0)
    gathered_values = torch.cat([values_sink_start, gathered_values, values_sink_end], dim=3)

    gathered_values = gathered_values.reshape(
        bsz, num_heads, q_blocks,
        (start_sink_tokens // block_size_k + top_k + end_sink_tokens // block_size_k) * block_size_k, value_dim)

    attn_output = torch.einsum('nhiqk, nhikd -> nhiqd', attn_weights, gathered_values)

    attn_output = attn_output.reshape(bsz, num_heads, q_blocks * block_size_q, value_dim)
    attn_output = attn_output[:, :, q_start_padding:q_start_padding + q_len, :]

    return attn_output


def sparse_attn_ref(query_states, key_states, sparse_indices, value_states,
                    block_size_q: int, block_size_k: int, query_offset: int,
                    start_sink_tokens: int, end_sink_tokens: int):
    """
    Reference sparse attention
    :param query_states: (bsz, num_heads, q_len, head_dim)
    :param key_states:   (bsz, num_heads, k_len, head_dim)
    :param sparse_indices: (bsz, num_heads, q_blocks, top_k) or
        tuple(tensor(bsz, num_heads, q_blocks, top_k, n_candidates), tensor(bsz, num_heads, q_blocks, n_candidates))
        for soft masking (used for training)
    :param value_states: (bsz, num_heads, k_len, value_dim)
    :param block_size_q: query block size
    :param block_size_k: key block size
    :param query_offset: offset of the query
    :param start_sink_tokens: number of sink tokens at the start of the key
    :param end_sink_tokens: number of sink tokens at the end of the key
    :return: output (bsz, num_heads, q_len, value_dim)
    """
    bsz, num_heads, q_len, head_dim = query_states.size()
    _, _, k_len, value_dim = value_states.size()
    is_soft_mask = isinstance(sparse_indices, tuple)
    if is_soft_mask:
        _, _, _, top_k, n_candidates = sparse_indices[0].size()
    else:
        _, _, _, top_k = sparse_indices.size()

    assert end_sink_tokens >= block_size_q
    (q_block_offset, q_blocks, q_start_padding, q_end_padding, k_blocks) = calc_dims(
        q_len, k_len, block_size_q, block_size_k, query_offset
    )
    total_attended_tokens = start_sink_tokens + top_k * block_size_k + end_sink_tokens
    sparse_q_block_begin = min(q_blocks, triton.cdiv(total_attended_tokens, block_size_q) - 1 - q_block_offset)

    cutoff = max(0, sparse_q_block_begin * block_size_q - q_start_padding)

    # Perform dense attention for the initial part
    attn_output_initial, _ = reference_attn_flash(
        query_states[:, :, :cutoff], key_states, value_states,
        32, 32, query_offset
    )

    attn_output = sparse_attn_ref_impl(
        query_states[:, :, cutoff:], key_states, sparse_indices, value_states,
        block_size_q, block_size_k, cutoff + query_offset,
        start_sink_tokens, end_sink_tokens,
    )

    return torch.cat([attn_output_initial, attn_output], dim=2)


def sparse_attn_flash_ref_impl(
        query_states, key_states, sparse_indices, value_states,
        block_size_q: int, block_size_k: int, B_r: int, B_c: int, query_offset: int,
        start_sink_tokens: int, end_sink_tokens: int):
    """
    Reference sparse flashattention
    :param query_states: (bsz, num_heads, q_len, head_dim)
    :param key_states:   (bsz, num_heads, k_len, head_dim)
    :param sparse_indices: (bsz, num_heads, q_blocks, top_k)
    :param value_states: (bsz, num_heads, k_len, value_dim)
    :param block_size_q: query block size
    :param block_size_k: key block size
    :param B_r: number of rows in the flash block
    :param B_c: number of columns in the flash block
    :param query_offset: offset of the query
    :param start_sink_tokens: number of sink tokens at the start of the key
    :param end_sink_tokens: number of sink tokens at the end of the key
    :return: output (bsz, num_heads, q_len, value_dim)
    """
    device = query_states.device
    dtype = query_states.dtype
    NEG_INF = torch.finfo(dtype).min

    bsz, num_heads, q_len, head_dim = query_states.size()
    _, _, k_len, value_dim = value_states.size()
    is_soft_mask = isinstance(sparse_indices, tuple)
    if is_soft_mask:
        permute, sparse_indices = sparse_indices
        _, _, _, top_k, n_candidates = permute.size()
    else:
        permute = None
        _, _, _, top_k = sparse_indices.size()

    assert B_r % block_size_q == 0
    assert B_c % block_size_k == 0

    (q_block_offset, q_blocks, q_start_padding, q_end_padding, k_blocks) = calc_dims(
        q_len, k_len, block_size_q, block_size_k, query_offset
    )

    T_r = triton.cdiv(q_blocks * block_size_q, B_r)
    T_c = triton.cdiv(top_k * block_size_k, B_c)

    query_states = torch.nn.functional.pad(query_states, (0, 0, q_start_padding, q_end_padding))
    query_states = query_states.reshape(bsz, num_heads, q_blocks, block_size_q, head_dim)
    key_states = torch.nn.functional.pad(key_states, (0, 0, 0, k_blocks * block_size_k - k_len))
    key_states = key_states.reshape(bsz, num_heads, k_blocks, block_size_k, head_dim)
    value_states = torch.nn.functional.pad(value_states, (0, 0, 0, k_blocks * block_size_k - k_len))
    value_states = value_states.reshape(bsz, num_heads, k_blocks, block_size_k, value_dim)

    sparse_indices = sparse_indices.to(torch.int64)

    O = torch.zeros(bsz, num_heads, q_blocks, block_size_q, value_dim, dtype=dtype, device=device)
    L = torch.zeros(bsz, num_heads, q_blocks, block_size_q, 1, dtype=dtype, device=device)

    for i in range(T_r):
        r_block_begin = i * B_r // block_size_q
        r_block_end = min(q_blocks, r_block_begin + B_r // block_size_q)
        r_block_len = r_block_end - r_block_begin
        Q_i = query_states[:, :, r_block_begin:r_block_end] / math.sqrt(head_dim)
        # (bsz, num_heads, B_r // block_size_q, block_size_q, head_dim)

        q_elem_indices = (
            (q_block_offset + r_block_begin) * block_size_q
            + torch.arange(r_block_len * block_size_q, device=device)
        ).reshape(1, 1, r_block_len, block_size_q, 1, 1)
        sink_end_start_indices = (
            (q_block_offset + r_block_begin + 1 + torch.arange(r_block_len, device=device)) * block_size_q
            + (block_size_k - 1)
        ) // block_size_k * block_size_k - end_sink_tokens

        O_ij = torch.zeros(bsz, num_heads, r_block_len, block_size_q, value_dim, dtype=dtype, device=device)
        l_ij = torch.ones(bsz, num_heads, r_block_len, block_size_q, 1, dtype=dtype, device=device)
        m_ij = torch.full((bsz, num_heads, r_block_len, block_size_q, 1), NEG_INF, dtype=dtype, device=device)

        # Handle sink tokens
        for j in range(triton.cdiv(start_sink_tokens, B_c)):
            c_block_begin = j * B_c // block_size_k
            c_block_end = min(start_sink_tokens // block_size_k, c_block_begin + B_c // block_size_k)
            c_block_len = c_block_end - c_block_begin

            K_j = key_states[:, :, c_block_begin:c_block_end]
            V_j = value_states[:, :, c_block_begin:c_block_end]
            # (bsz, num_heads, c_block_len, block_size_k, head_dim)
            V_j = V_j.reshape(bsz, num_heads, c_block_len * block_size_k, value_dim)

            S_ij = torch.einsum('nhiqd, nhjkd -> nhiqjk', Q_i, K_j)

            # Apply causal mask
            k_elem_indices = (
                c_block_begin * block_size_k + torch.arange(c_block_len * block_size_k, device=device)
            ).reshape(1, 1, 1, 1, c_block_len, block_size_k)
            S_ij[(k_elem_indices > q_elem_indices).expand_as(S_ij)] = NEG_INF
            S_ij[(k_elem_indices >= sink_end_start_indices[None, None, :, None, None, None]).expand_as(S_ij)] = NEG_INF

            S_ij = S_ij.reshape(bsz, num_heads, r_block_len, block_size_q, c_block_len * block_size_k)
            # (bsz, num_heads, r_block_len, block_size_q, B_c)

            m_ijm1 = m_ij
            m_ij = torch.maximum(m_ij, S_ij.amax(4, keepdim=True))
            P_tilde_ij = torch.exp(S_ij - m_ij)  # (bsz, num_heads, r_block_len, block_size_q, B_c)
            m_diff_exp = torch.exp(m_ijm1 - m_ij)
            l_ij = m_diff_exp * l_ij + P_tilde_ij.sum(4, keepdim=True)
            O_ij = m_diff_exp * O_ij + torch.einsum('nhirc, nhcd -> nhird', P_tilde_ij, V_j)

        for j in range(triton.cdiv(end_sink_tokens, B_c)):
            c_block_begin = j * B_c // block_size_k
            c_block_end = min(end_sink_tokens // block_size_k, c_block_begin + B_c // block_size_k)
            c_block_len = c_block_end - c_block_begin

            keys_sink_indices = (
                sink_end_start_indices[:, None] // block_size_k + c_block_begin
                + torch.arange(c_block_len, device=device)
            )
            keys_sink_indices_valid = (keys_sink_indices >= 0)  # (r_block_len, c_block_len)
            keys_sink_indices = torch.where(keys_sink_indices_valid, keys_sink_indices, 0)

            K_j = key_states[:, :, keys_sink_indices]
            K_j = torch.where(keys_sink_indices_valid[None, None, :, :, None, None], K_j, 0)
            V_j = value_states[:, :, keys_sink_indices]
            V_j = torch.where(keys_sink_indices_valid[None, None, :, :, None, None], V_j, 0)
            # (bsz, num_heads, r_block_len, c_block_len, block_size_k, head_dim)
            V_j = V_j.reshape(bsz, num_heads, r_block_len, c_block_len * block_size_k, value_dim)

            S_ij = torch.einsum('nhiqd, nhijkd -> nhiqjk', Q_i, K_j)

            # Apply causal mask
            k_elem_indices = (
                (keys_sink_indices * block_size_k)[None, None, :, None, :, None]
                + torch.arange(block_size_k, device=device)
            )
            S_ij[(k_elem_indices > q_elem_indices).expand_as(S_ij)] = NEG_INF
            S_ij[(~keys_sink_indices_valid)[None, None, :, None, :, None].expand_as(S_ij)] = NEG_INF

            S_ij = S_ij.reshape(bsz, num_heads, r_block_len, block_size_q, c_block_len * block_size_k)
            # (bsz, num_heads, r_block_len, block_size_q, B_c)

            m_ijm1 = m_ij
            m_ij = torch.maximum(m_ij, S_ij.amax(4, keepdim=True))
            P_tilde_ij = torch.exp(S_ij - m_ij)  # (bsz, num_heads, r_block_len, block_size_q, B_c)
            m_diff_exp = torch.exp(m_ijm1 - m_ij)
            l_ij = m_diff_exp * l_ij + P_tilde_ij.sum(4, keepdim=True)
            O_ij = m_diff_exp * O_ij + torch.einsum('nhirc, nhicd -> nhird', P_tilde_ij, V_j)

        # Handle top-k
        for j in range(T_c):
            O_ij, l_ij, m_ij = torch.utils.checkpoint.checkpoint(
                sparse_attn_inner,
                O_ij, l_ij, m_ij,
                Q_i, key_states, permute, sparse_indices, value_states,
                j, r_block_begin, r_block_end, is_soft_mask,
                start_sink_tokens, block_size_k, B_c, top_k,
                use_reentrant=True,
            )

        O_i = (1 / l_ij) * O_ij
        L_i = m_ij + torch.log(l_ij)

        O[:, :, r_block_begin:r_block_end] = O_i
        L[:, :, r_block_begin:r_block_end] = L_i

    O = O.reshape(bsz, num_heads, q_blocks * block_size_q, value_dim)
    O = O[:, :, q_start_padding:q_start_padding + q_len, :]

    L = L.reshape(bsz, num_heads, q_blocks * block_size_q, 1)
    L = L[:, :, q_start_padding:q_start_padding + q_len, 0]

    return O, L


def sparse_attn_inner(
        O_ij, l_ij, m_ij,
        Q_i, key_states, permute, sparse_indices, value_states,
        j, r_block_begin, r_block_end, is_soft_mask: bool,
        start_sink_tokens: int, block_size_k: int, B_c: int, top_k: int,
):
    c_block_begin = j * B_c // block_size_k
    c_block_end = min(top_k, c_block_begin + B_c // block_size_k)

    if is_soft_mask:
        sparse_indices_ij = sparse_indices[:, :, r_block_begin:r_block_end, :]
        invalid_indices_ij = sparse_indices_ij < 0
        sparse_indices_ij = torch.masked_fill(sparse_indices_ij, invalid_indices_ij, 0)
        K_j = einx.get_at(
            'n h [i] k d, n h q c -> n h q c k d',
            key_states,
            start_sink_tokens // block_size_k + sparse_indices_ij
        )  # (bsz, num_heads, r_block_len, n_candidates, block_size_k, head_dim)
        K_j = torch.where(~invalid_indices_ij[:, :, :, :, None, None], K_j, 0)

        permute_ij = permute[:, :, r_block_begin:r_block_end, c_block_begin:c_block_end, :]
        S_ij = einx.dot('... q d, ... t c, ... c k d -> ... q (t k)', Q_i, permute_ij, K_j)

        top_k_indices = permute_ij.argmax(dim=-1)
        sparse_indices_ij = sparse_indices_ij.gather(3, top_k_indices)

    else:
        sparse_indices_ij = sparse_indices[:, :, r_block_begin:r_block_end, c_block_begin:c_block_end]
        K_j = einx.get_at(
            'n h [i] k d, n h q t -> n h q t k d',
            key_states,
            start_sink_tokens // block_size_k + sparse_indices_ij
        )  # (bsz, num_heads, r_block_len, c_block_len, block_size_k, head_dim)

        S_ij = einx.dot('n h i q d, n h i t k d -> n h i q (t k)', Q_i, K_j)
        # (bsz, num_heads, r_block_len, block_size_q, B_c)

    V_j = einx.get_at(
        'n h [i] k d, n h q c -> n h q (c k) d',
        value_states,
        start_sink_tokens // block_size_k + sparse_indices_ij
    )  # (bsz, num_heads, r_block_len, c_block_len * block_size_k, value_dim)

    m_ijm1 = m_ij
    m_ij = torch.maximum(m_ij, S_ij.amax(4, keepdim=True))
    P_tilde_ij = torch.exp(S_ij - m_ij)  # (bsz, num_heads, r_block_len, block_size_q, B_c)
    m_diff_exp = torch.exp(m_ijm1 - m_ij)
    l_ij = m_diff_exp * l_ij + P_tilde_ij.sum(4, keepdim=True)
    O_ij = m_diff_exp * O_ij + torch.einsum('nhirc, nhicd -> nhird', P_tilde_ij, V_j)

    return O_ij, l_ij, m_ij


def sparse_attn_flash_ref(
        query_states, key_states, sparse_indices, value_states,
        block_size_q: int, block_size_k: int, B_r: int, B_c: int, query_offset: int,
        start_sink_tokens: int, end_sink_tokens: int):
    """
    Reference sparse flashattention
    :param query_states: (bsz, num_heads, q_len, head_dim)
    :param key_states:   (bsz, num_heads, k_len, head_dim)
    :param sparse_indices: (bsz, num_heads, q_blocks, top_k)
    :param value_states: (bsz, num_heads, k_len, value_dim)
    :param block_size_q: query block size
    :param block_size_k: key block size
    :param B_r: number of rows in the flash block
    :param B_c: number of columns in the flash block
    :param query_offset: offset of the query
    :param start_sink_tokens: number of sink tokens at the start of the key
    :param end_sink_tokens: number of sink tokens at the end of the key
    :return: output (bsz, num_heads, q_len, value_dim)
    """
    bsz, num_heads, q_len, head_dim = query_states.size()
    _, _, k_len, value_dim = value_states.size()
    is_soft_mask = isinstance(sparse_indices, tuple)
    if is_soft_mask:
        _, _, _, top_k, n_candidates = sparse_indices[0].size()
    else:
        _, _, _, top_k = sparse_indices.size()

    assert end_sink_tokens >= block_size_q
    (q_block_offset, q_blocks, q_start_padding, q_end_padding, k_blocks) = calc_dims(
        q_len, k_len, block_size_q, block_size_k, query_offset
    )
    total_attended_tokens = start_sink_tokens + top_k * block_size_k + end_sink_tokens
    sparse_q_block_begin = min(q_blocks, triton.cdiv(total_attended_tokens, block_size_q) - 1 - q_block_offset)

    cutoff = max(0, sparse_q_block_begin * block_size_q - q_start_padding)

    # Perform dense attention for the initial part
    attn_output_initial, L_initial = reference_attn_flash(
        query_states[:, :, :cutoff], key_states, value_states,
        32, 32, query_offset
    )

    attn_output, L = sparse_attn_flash_ref_impl(
        query_states[:, :, cutoff:], key_states, sparse_indices, value_states,
        block_size_q, block_size_k, B_r, B_c, cutoff + query_offset,
        start_sink_tokens, end_sink_tokens,
    )

    return (
        torch.cat([attn_output_initial, attn_output], dim=2),
        torch.cat([L_initial, L], dim=2),
    )
