import torch
from block_sparse_attn import (
    block_sparse_attn_func,
    flash_attn_varlen_func,
)
from .attn_pooling_kernel_varlen import attn_pooling_qk_varlen
from transformers.models.llama.modeling_llama import repeat_kv
from einops import rearrange

def generate_qkv(
        q, k, v
    ):
    """
    Arguments:
        q: (batch_size, seqlen_q, nheads, d)
        k: (batch_size, seqlen_k, nheads_k, d)
        v: (batch_size, seqlen_k, nheads_k, d)
        query_padding_mask: (batch_size, seqlen), bool
        key_padding_mask: (batch_size, seqlen), bool
    """
    batch_size, seqlen_q, nheads, d = q.shape
    _, seqlen_k, nheads_k, _ = k.shape
    assert k.shape == (batch_size, seqlen_k, nheads_k, d)
    assert v.shape == (batch_size, seqlen_k, nheads_k, d)

    q_unpad = rearrange(q, "b s h d -> (b s) h d")
    cu_seqlens_q = torch.arange(
        0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device
    )
    max_seqlen_q = seqlen_q
    output_pad_fn = lambda output_unpad: rearrange(
        output_unpad, "(b s) h d -> b s h d", b=batch_size
    )

    k_unpad = rearrange(k, "b s h d -> (b s) h d")
    v_unpad = rearrange(v, "b s h d -> (b s) h d")
    cu_seqlens_k = torch.arange(
        0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device
    )
    max_seqlen_k = seqlen_k

    return (
        q_unpad.detach().requires_grad_(),
        k_unpad.detach().requires_grad_(),
        v_unpad.detach().requires_grad_(),
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
    )

def auxhead_attention(query_states, key_states, value_states,
    gamma=0.95, stride=8, causal=True, block_size=128, bias=0.0,
    sparse_file=None):
    # config check
    # assert query_states.shape[2] == key_states.shape[2]
    assert block_size == 128
    assert stride in [1, 2, 4, 8]
    # meta info
    batch_size, query_head_num, q_seq_len, head_size = query_states.shape
    _, key_head_num, _, _= key_states.shape
    scaling = head_size ** -0.5
    group_num = query_head_num // key_head_num

    # budget estimation
    true_attn = (scaling * query_states[:, :, -block_size:, :] @ repeat_kv(key_states, group_num).transpose(-1, -2)).softmax(dim=-1).mean(dim=-2, keepdim=True)
    true_attn = torch.avg_pool1d(true_attn.squeeze(-2), kernel_size=block_size, stride=block_size, padding=0, ceil_mode=True)
    sort_score = true_attn.sort(descending=True, dim=-1).values
    cum_sum = torch.cumsum(sort_score, dim=-1)
    del true_attn
    budget = (torch.sum(cum_sum <= gamma*sort_score.sum(-1, keepdim=True), dim=-1, keepdim=True) + 1) / sort_score.size(-1)

    # obtain estimated score
    aux_query_states = query_states.mean(dim=1, keepdim=True)[..., ::stride, :]
    aux_key_states = key_states.mean(dim=1, keepdim=True)[..., ::stride, :]
    batch_size, _, seq_len, head_size = aux_query_states.shape
    aux_query_states = aux_query_states.transpose(1, 2).reshape(batch_size * seq_len, 1, head_size)
    aux_key_states = aux_key_states.transpose(1, 2).reshape(batch_size * seq_len, 1, head_size)
    cu_seqlens_q = torch.arange(0, (batch_size + 1) * seq_len, seq_len, device=aux_query_states.device, dtype=torch.int32)
    max_seqlen_q = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).max().item()
    aux_score = attn_pooling_qk_varlen(  aux_query_states,
                                            aux_key_states,
                                            cu_seqlens=cu_seqlens_q,
                                            max_seqlen=max_seqlen_q,
                                            sm_scale=scaling,
                                            block_size=block_size//stride)
    
    # get block mask
    threshold_score = torch.zeros([batch_size, query_head_num, aux_score.size(-1), 1], device=aux_score.device, dtype=aux_score.dtype)
    max_ratio = min(torch.max(budget.flatten()) + bias, 1.0)
    max_count = min(int(max_ratio*aux_score.size(-1)) + 1, aux_score.size(-1))
    # max_count = min(int(torch.max(budget.flatten())*aux_score.size(-1)) + 1, aux_score.size(-1))
    # max_count = min(int((torch.max(budget.flatten())+ 0.05)*aux_score.size(-1)) + 1, aux_score.size(-1))
    candi_score = torch.topk(aux_score, k=max_count, dim=-1, sorted=True).values
    for head_i in range(query_head_num):
        # cur_threshold = budget[0, head_i, 0] + 0.05
        cur_threshold = min(budget[0, head_i, 0] + bias, 1.0)
        cur_count = min(int(cur_threshold*aux_score.size(-1)) + 1, aux_score.size(-1))
        threshold_score[:, head_i] = candi_score[..., cur_count-1:cur_count]
    aux_score = aux_score.expand(-1, query_states.size(1), -1, -1).contiguous()
    block_sparse_mask = aux_score >= threshold_score
    del aux_score, aux_query_states, aux_key_states
    block_sparse_mask[..., -2:, :] = True
    block_sparse_mask[..., :, 0] = True
    for z in range(block_sparse_mask.size(-1)):
        block_sparse_mask[..., z, z] = True
    block_sparse_mask.tril_()

    if sparse_file is not None:
        with open(sparse_file, 'a') as f:
            f.write(f"{q_seq_len}\t{block_sparse_mask.flatten().sum() / ((block_sparse_mask.size(-1)+1)*block_sparse_mask.size(-1)/2) / query_head_num}\n")

    # block sparse attention
    query_states = query_states.transpose(1, 2)
    key_states = key_states.transpose(1, 2)
    value_states = value_states.transpose(1, 2)
    (
        q_unpad,
        k_unpad,
        v_unpad,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
    ) = generate_qkv(query_states, key_states, value_states)
    head_mask_type = torch.tensor([1] * query_states.size(2), device=query_states.device, dtype=torch.int32)
    attn_output = block_sparse_attn_func(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, head_mask_type, None,
                            block_sparse_mask,
                            max_seqlen_q,
                            max_seqlen_k,
                            p_dropout=0.0,
                            softmax_scale=scaling,
                            is_causal=True,
                            exact_streaming=False,
                            return_attn_probs=False)
    return attn_output



