import math
import torch
from torch.func import grad

from ..triton_impl.sparse_attn import sparse_attn_triton
from ..triton_impl.mask_gen import mask_gen_triton
from ..common import calc_dims


def ste_sparse_attn_ref(query_states, key_states, sparse_indices, value_states,
                        block_size_q: int, block_size_k: int, query_offset: int,
                        start_sink_tokens: int, end_sink_tokens: int, alt=False):
    """
    Reference sparse attention
    :param query_states: (bsz, num_heads, q_len, head_dim)
    :param key_states:   (bsz, num_heads, k_len, head_dim)
    :param sparse_indices: (bsz, num_heads, q_blocks, top_k)
    :param value_states: (bsz, num_heads, k_len, value_dim)
    :param block_size_q: query block size
    :param block_size_k: key block size
    :param query_offset: offset of the query
    :param start_sink_tokens: number of sink tokens at the start of the key
    :param end_sink_tokens: number of sink tokens at the end of the key
    :return: output (bsz, num_heads, q_len, value_dim)
    """
    device = query_states.device
    dtype = query_states.dtype

    bsz, num_heads, q_len, head_dim = query_states.size()
    _, _, k_len, value_dim = value_states.size()
    _, _, _, top_k = sparse_indices.size()

    (q_block_offset, q_blocks, q_start_padding, q_end_padding, k_blocks) = calc_dims(
        q_len, k_len, block_size_q, block_size_k, query_offset
    )

    # Make the mask
    sink_end_start_indices = (
        (q_block_offset + 1 + torch.arange(q_blocks, device=device)) * block_size_q + (block_size_k - 1)
    ) // block_size_k * block_size_k - end_sink_tokens
    M = torch.zeros((bsz, num_heads, q_blocks, k_blocks + 1), dtype=torch.int32, device=device)
    M = torch.scatter(M, 3, start_sink_tokens // block_size_k + sparse_indices.long() + 1, 1)[..., 1:]
    M[:, :, :, :start_sink_tokens // block_size_k] = 1
    M[:, :, torch.arange(q_blocks, device=device)[:, None].expand(-1, end_sink_tokens // block_size_k),
      (sink_end_start_indices // block_size_k)[:, None] + torch.arange(end_sink_tokens // block_size_k, device=device)] = 1
    M = M.repeat_interleave(block_size_q, 2).repeat_interleave(block_size_k, 3)
    M = M[:, :, q_block_offset * block_size_q + q_start_padding:q_block_offset * block_size_q + q_start_padding + q_len]

    query_states = query_states / math.sqrt(head_dim)
    # (bsz, num_heads, q_len, head_dim)

    S = torch.einsum('nhrd, nhcd -> nhrc', query_states, key_states)  # (bsz, num_heads, q_len, k_len)

    # Apply causal mask
    k_elem_indices = torch.arange(k_len, device=device)[None, :]
    q_elem_indices = query_offset + torch.arange(q_len, device=device)[:, None]
    causal_mask = (k_elem_indices <= q_elem_indices).expand_as(S)
    S = torch.where(causal_mask, S, float('-inf'))

    #S.register_hook(lambda grad: print(grad[M==0].abs().mean(), grad[M==1].abs().mean(), grad[causal_mask].abs().mean(), grad[~causal_mask].abs().mean()))

    # Apply sparse mask
    if not alt:
        fake_func = compute_P(S)
        masked_S = torch.where(M == 1, S, float('-inf'))
        P = fake_func + (compute_P(masked_S) - fake_func).detach()
        O = torch.einsum('nhrc, nhcd -> nhrd', P, value_states)

    else:
        O = SteAttentionProbs.apply(S, M, value_states)

    return O


def compute_P(S):
    m = S.amax(3)  # r
    P_tilde = torch.exp(S - m[..., None])  # r x c
    l = P_tilde.sum(3)  # r
    P = P_tilde / l[..., None]
    return P


class SteAttentionProbs(torch.autograd.Function):
    @staticmethod
    def forward(ctx, S, M, value_states):  # noqa
        S_hat = torch.where(M == 1, S, float('-inf'))
        P_hat = torch.softmax(S_hat, dim=3)
        O = torch.einsum('nhrc, nhcd -> nhrd', P_hat, value_states)
        ctx.save_for_backward(S, value_states, P_hat, O)
        return O

    @staticmethod
    def backward(ctx, dO):  # noqa
        S, V, P_hat, O = ctx.saved_tensors
        dV = torch.einsum('nhrc, nhrd -> nhcd', P_hat, dO)
        dP = torch.einsum('nhrd, nhcd -> nhrc', dO, V)
        P = torch.softmax(S, dim=3)
        O = torch.einsum('nhrc, nhcd -> nhrd', P, V)
        dS = P * (dP - (dO * O).sum(3, keepdim=True))
        return dS, None, dV


ste_sparse_attn_ref_grad = grad(
    lambda query_states, key_states, sparse_indices, value_states,
           block_size_q, block_size_k, query_offset,
           start_sink_tokens, end_sink_tokens, g: (
            ste_sparse_attn_ref(query_states, key_states, sparse_indices, value_states,
                                block_size_q, block_size_k, query_offset,
                                start_sink_tokens, end_sink_tokens) * g).sum(),
    argnums=(0, 1, 3),
)


class SteHipAttention(torch.autograd.Function):
    @staticmethod
    def forward(  # noqa
            ctx, query_states, key_states, value_states,
            hip_block_size_q: int, hip_block_size_k: int, hip_top_k_elems: int, query_offset: int,
            start_sink_tokens: int, end_sink_tokens: int, sparse_indices=None):
        assert hip_top_k_elems % hip_block_size_k == 0
        top_k_blocks = hip_top_k_elems // hip_block_size_k
        if sparse_indices is None:
            sparse_indices, _ = mask_gen_triton(
                query_states, key_states, top_k_blocks,
                hip_block_size_q, hip_block_size_k, query_offset,
                start_sink_tokens, end_sink_tokens
            )
        output, L, _ = sparse_attn_triton(
            query_states, key_states, sparse_indices, value_states,
            hip_block_size_q, hip_block_size_k, query_offset,
            start_sink_tokens, end_sink_tokens
        )
        ctx.save_for_backward(
            query_states, key_states, sparse_indices, value_states, output, L
        )
        ctx.query_offset = query_offset
        ctx.hip_block_size_q = hip_block_size_q
        ctx.hip_block_size_k = hip_block_size_k
        ctx.start_sink_tokens = start_sink_tokens
        ctx.end_sink_tokens = end_sink_tokens
        return output

    @staticmethod
    def backward(ctx, grad_output):  # noqa
        with torch.no_grad():
            query_states, key_states, sparse_indices, value_states, output, L = ctx.saved_tensors
            grad_query, grad_key, grad_value = ste_sparse_attn_ref_grad(
                query_states, key_states, sparse_indices, value_states,
                ctx.hip_block_size_q, ctx.hip_block_size_k, ctx.query_offset,
                ctx.start_sink_tokens, ctx.end_sink_tokens, grad_output
            )
        return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None


def ste_hip_attn(query_states, key_states, value_states,
                 hip_block_size_q: int, hip_block_size_k: int, hip_top_k_elems: int, query_offset: int,
                 start_sink_tokens: int, end_sink_tokens: int, sparse_indices=None):
    """
    Perform HiP attention in the forward pass.
    :param query_states: (bsz, num_heads, q_len, head_dim)
    :param key_states: (bsz, num_heads, k_len, head_dim)
    :param value_states: (bsz, num_heads, k_len, value_dim)
    :param hip_block_size_q: block size for the query
    :param hip_block_size_k: block size for the key
    :param hip_top_k_elems: number of top-k elements. Must be a multiple of `hip_block_size_k`.
    :param query_offset: offset of the query
    :param start_sink_tokens: number of sink tokens at the start of the key
    :param end_sink_tokens: number of sink tokens at the end of the key
    :param sparse_indices: (bsz, num_heads, q_blocks, top_k)
    :return: output (bsz, num_heads, q_len, value_dim)
    """
    return SteHipAttention.apply(
        query_states, key_states, value_states,
        hip_block_size_q, hip_block_size_k, hip_top_k_elems, query_offset,
        start_sink_tokens, end_sink_tokens, sparse_indices,
    )
