import torch
import torch.autograd
import triton
import einx

from ..utils.set_device import SetDevice
from ..common import calc_dims
from .mask_gen import mask_gen_triton
from .sparse_attn import sparse_attn_triton
from .sparse_attn_bwd import sparse_attn_bwd_triton
from .full_forward import attn_flash_triton
from .full_backward import attn_flash_bwd_triton


def mask_gen_orig(
        query_states, key_states,
        top_k: int, block_size_q: int, block_size_k: int, query_offset: int,
        start_sink_tokens: int, end_sink_tokens: int,
        group_size_q: int = 1, block_skip_q: int = 1):
    from hip.models.hip_attention.attention2_draft_prefetch import hip_masking, HiPAttentionArgs

    bsz, num_heads, q_len, head_dim = query_states.size()
    _, _, k_len, _ = key_states.size()

    assert end_sink_tokens >= block_size_q
    assert query_offset == 0
    (q_block_offset, q_blocks, q_start_padding, q_end_padding, k_blocks) = calc_dims(
        q_len, k_len, block_size_q, block_size_k, query_offset
    )
    total_attended_tokens = top_k * block_size_k
    sparse_q_block_begin = min(
        q_blocks,
        triton.cdiv(total_attended_tokens, block_size_q) - 1 - q_block_offset + end_sink_tokens // block_size_q
    )

    (indices,
     ks,
     ks_count,
     ks_start_end,
     key_access_log,
     key_access_count,
     block_access_log,
     block_access_score,
     block_access_count,
     _) = hip_masking(
        einx.rearrange('n h i d -> n i h d', query_states),
        einx.rearrange('n h i d -> n i h d', key_states),
        HiPAttentionArgs(
            mask_k=top_k * block_size_k,
            block_size_q=block_size_q,
            block_stride_q=block_skip_q,
            block_size_k=block_size_k,
            block_stride_k=1,
            sink_token_size=start_sink_tokens,
            sliding_window_size=end_sink_tokens,
            group_size_q=group_size_q,
        )
    )

    indices = indices[:, sparse_q_block_begin:]
    indices //= block_size_k

    indices = einx.rearrange('(n h) q k -> n h q k', indices, n=bsz, h=num_heads)

    return indices


class HipAttention(torch.autograd.Function):
    @staticmethod
    def forward(  # noqa
            ctx, query_states, key_states, value_states,
            hip_block_size_q: int, hip_block_size_k: int, hip_top_k_elems: int, query_offset: int,
            start_sink_tokens: int, end_sink_tokens: int, sparse_indices=None,
            mask_gen_method='mine', group_size_q: int = 1, block_skip_q: int = 1):
        assert hip_top_k_elems % hip_block_size_k == 0
        top_k_blocks = hip_top_k_elems // hip_block_size_k
        if sparse_indices is None:
            if mask_gen_method == 'orig':
                sparse_indices = mask_gen_orig(
                    query_states, key_states, top_k_blocks,
                    hip_block_size_q, hip_block_size_k, query_offset,
                    start_sink_tokens, end_sink_tokens, group_size_q, block_skip_q
                )
                start_sink_tokens = 0
            elif mask_gen_method == 'mine':
                sparse_indices, _ = mask_gen_triton(
                    query_states, key_states, top_k_blocks,
                    hip_block_size_q, hip_block_size_k, query_offset,
                    start_sink_tokens, end_sink_tokens
                )
            else:
                raise ValueError(f"Unknown mask_gen_method: {mask_gen_method}")
        output, L = sparse_attn_triton(
            query_states, key_states, sparse_indices, value_states,
            hip_block_size_q, hip_block_size_k, query_offset,
            start_sink_tokens, end_sink_tokens
        )
        ctx.save_for_backward(query_states, key_states, sparse_indices, value_states, output, L)
        ctx.block_size_q = hip_block_size_q
        ctx.block_size_k = hip_block_size_k
        ctx.query_offset = query_offset
        ctx.start_sink_tokens = start_sink_tokens
        ctx.end_sink_tokens = end_sink_tokens
        return output

    @staticmethod
    def backward(ctx, grad_output):  # noqa
        query_states, key_states, sparse_indices, value_states, output, L = ctx.saved_tensors
        block_size_q = ctx.block_size_q
        block_size_k = ctx.block_size_k
        query_offset = ctx.query_offset
        start_sink_tokens = ctx.start_sink_tokens
        end_sink_tokens = ctx.end_sink_tokens

        grad_query, grad_key, grad_value = sparse_attn_bwd_triton(
            query_states, key_states, sparse_indices, value_states, output, grad_output, L,
            block_size_q, block_size_k, query_offset, start_sink_tokens, end_sink_tokens
        )
        return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None, None, None


class MixedHipAttention(torch.autograd.Function):
    @staticmethod
    def forward(  # noqa
            ctx, query_states, key_states, value_states,
            hip_block_size_q: int, hip_block_size_k: int, hip_top_k_elems: int, query_offset: int,
            split_offset: int, start_sink_tokens: int, end_sink_tokens: int, sparse_indices=None):
        assert hip_top_k_elems % hip_block_size_k == 0
        top_k_blocks = hip_top_k_elems // hip_block_size_k
        if sparse_indices is None:
            sparse_indices, _ = mask_gen_triton(
                query_states, key_states, top_k_blocks,
                hip_block_size_q, hip_block_size_k, query_offset,
                start_sink_tokens, end_sink_tokens
            )
        output, L = sparse_attn_triton(
            query_states, key_states, sparse_indices, value_states,
            hip_block_size_q, hip_block_size_k, query_offset,
            start_sink_tokens, end_sink_tokens
        )
        query_states_latter = query_states[:, :, split_offset:]
        key_states_latter = key_states[:, :, split_offset:]
        value_states_latter = value_states[:, :, split_offset:]
        output_latter = output[:, :, split_offset:]
        L_latter = L[:, :, split_offset:]
        ctx.save_for_backward(
            query_states_latter, key_states_latter, value_states_latter, output_latter, L_latter
        )
        ctx.query_offset = query_offset
        ctx.split_offset = split_offset
        return output

    @staticmethod
    def backward(ctx, grad_output):  # noqa
        query_states, key_states, value_states, output, L = ctx.saved_tensors
        query_offset = ctx.query_offset
        split_offset = ctx.split_offset

        O, L, _ = attn_flash_triton(query_states.unsqueeze(3), key_states, value_states, query_offset)
        (grad_query, grad_key, grad_value, _, _), _ = attn_flash_bwd_triton(
            query_states.unsqueeze(3), key_states, value_states,
            O, grad_output[:, :, split_offset:].unsqueeze(3), L,
            query_offset=query_offset, begin_offset=split_offset
        )
        grad_query = grad_query.squeeze(3)

        grad_query = torch.nn.functional.pad(grad_query, (0, 0, split_offset, 0))
        grad_key = torch.nn.functional.pad(grad_key, (0, 0, split_offset, 0))
        grad_value = torch.nn.functional.pad(grad_value, (0, 0, split_offset, 0))
        return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None


def hip_attn(query_states, key_states, value_states,
             hip_block_size_q: int, hip_block_size_k: int, hip_top_k_elems: int, query_offset: int,
             start_sink_tokens: int, end_sink_tokens: int, sparse_indices=None, mask_gen_method='mine',
             group_size_q: int = 1, block_skip_q: int = 1):
    """
    Perform HiP attention in the forward pass.
    :param query_states: (bsz, num_heads, q_len, head_dim)
    :param key_states: (bsz, num_heads, k_len, head_dim)
    :param value_states: (bsz, num_heads, k_len, value_dim)
    :param hip_block_size_q: block size for the query
    :param hip_block_size_k: block size for the key
    :param hip_top_k_elems: number of top-k elements. Must be a multiple of `hip_block_size_k`.
    :param query_offset: offset of the query
    :param start_sink_tokens: number of sink tokens at the start of the key
    :param end_sink_tokens: number of sink tokens at the end of the key
    :return: output (bsz, num_heads, q_len, value_dim)
    """
    return HipAttention.apply(
        query_states, key_states, value_states,
        hip_block_size_q, hip_block_size_k, hip_top_k_elems, query_offset,
        start_sink_tokens, end_sink_tokens, sparse_indices,
        mask_gen_method, group_size_q, block_skip_q
    )


def mixed_attn(query_states, key_states, value_states,
               hip_block_size_q: int, hip_block_size_k: int, hip_top_k_elems: int, query_offset: int,
               split_offset: int, start_sink_tokens: int, end_sink_tokens: int):
    """
    Perform HiP attention in the forward pass, and partially perform full attention in the backward pass.
    :param query_states: (bsz, num_heads, q_len, head_dim)
    :param key_states: (bsz, num_heads, k_len, head_dim)
    :param value_states: (bsz, num_heads, k_len, value_dim)
    :param hip_block_size_q: block size for the query
    :param hip_block_size_k: block size for the key
    :param hip_top_k_elems: number of top-k elements. Must be a multiple of `hip_block_size_k`.
    :param query_offset: offset of the query
    :param split_offset: offset to split the gradient and the output
    :param start_sink_tokens: number of sink tokens at the start of the key
    :param end_sink_tokens: number of sink tokens at the end of the key
    :return: output (bsz, num_heads, q_len, value_dim)
    """
    return MixedHipAttention.apply(
        query_states, key_states, value_states,
        hip_block_size_q, hip_block_size_k, hip_top_k_elems, query_offset, split_offset,
        start_sink_tokens, end_sink_tokens
    )
