import torch
import math
from torch.onnx.symbolic_opset9 import true_divide
from transformers.models.llama.modeling_llama import repeat_kv, rotate_half
from fused_recover import recover_rope, gen_embeds
from unpack_dequant import unpack_dequant
import triton
import pca_topk as G

from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input  # noqa
from new_pack import *


import torch
from fused_recover import recover_rope
from flash_attn import flash_attn_func

def generate_sparse_indices_label_single_head(lowrank_q, lowrank_k, kv_budget, withsoftmax=True):
    # 获取输入张量的维度：Batch, Num queries, S (应为1), Dimension
    B, Nk, K, D = lowrank_k.shape
    # 计算query和key之间的得分（相似度）
    # lowrank_k = repeat_kv(lowrank_k, Nq // Nk)
    scores = torch.matmul(lowrank_q, lowrank_k.transpose(2, 3))
    
    # scores = G.topr_bmv_optimized(A=lowrank_q.view(B, 1, D), B=lowrank_k.view(B, K, D).transpose(1, 2), 
                                                        # r=64)
    # scores = torch.randn(B, Nq, S, K, device=lowrank_q.device)
    
    scores = scores.view(B, K)

    # 获取每个批次中分数最高的kv_budget个索引
    topk_values, topk_indices = torch.topk(scores, kv_budget, dim=-1)

    # 返回稀疏索引
    return topk_indices

import torch.utils.benchmark as benchmark

def fused_recover_fa(
    query_states:torch.Tensor, 
    key_cache_high:torch.Tensor,
    key_cache_low:torch.Tensor,
    value_cache_high:torch.Tensor,
    value_cache_low:torch.Tensor,
    value_scale:torch.Tensor,
    value_mn:torch.Tensor,
    recover_weight:torch.Tensor, 
    recover_bias:torch.Tensor,
    cos:torch.Tensor, 
    sin:torch.Tensor, 
    high_index:torch.Tensor,
    sparsity:int,
    num_kv_head:int,
    vbit:int,
    group_size:int,
    attn_mask:torch.Tensor,
):
    bsz, q_len, num_head, head_dim = query_states.shape
    _, _, key_high_length, key_high_rank = key_cache_high.shape
    _, _, key_low_length, key_low_rank = key_cache_low.shape
    topk_rank = key_low_rank // 2
    query_states_to_sparse = query_states.view(bsz, num_head // num_kv_head, -1).mean(dim=1)
# 首先调整 query_states_to_sparse 的形状以适应矩阵乘法
    query_sparse = torch.matmul(query_states_to_sparse, recover_weight[-topk_rank:, :].transpose(0, 1)).view(bsz, -1, q_len, topk_rank)
    # 恢复到原始的形状
    # query_sparse = torch.einsum("bshd,rd->bhsr", query_states_to_sparse, recover_weight[-topk_rank:, :])

    key_sparse = key_cache_low[..., -topk_rank:]
    # key_sparse = key_cache_low
    key_select_length = (key_high_length + key_low_length) // sparsity - key_high_length
    topk_index = generate_sparse_indices_label_single_head(query_sparse, key_sparse,  key_select_length)
    k_recover = torch.empty(bsz, num_kv_head, key_select_length + key_high_length, head_dim, dtype=query_states.dtype, device=query_states.device)
    recover_rope(key_cache_high, high_index, recover_weight, sin, cos, num_kv_head, head_dim, k_recover, key_select_length)
    recover_rope(key_cache_low, topk_index, recover_weight.narrow(0, key_high_rank - key_low_rank, key_low_rank), sin, cos, num_kv_head, head_dim, k_recover, 0)
    # k_high_recover = torch.randn((bsz, num_kv_head, key_high_length, head_dim), device=query_states.device, dtype=torch.float16)
    # k_low_recover = torch.randn((bsz, num_kv_head, topk_index.shape[-1], head_dim), device=query_states.device, dtype=torch.float16)
    # unpack_value_cache_low = unpack_tensor(value_cache_low[:, :, topk_index[0], :], vbit, pack_dim=3)
    # v_dequant_low = dequant_weight_outer(unpack_value_cache_low, value_scale[:, :, topk_index[0], :], value_mn[:, :, topk_index[0], :], group_size)
    v_dequant = torch.empty(bsz, num_kv_head, key_select_length + key_high_length, head_dim, dtype=query_states.dtype, device=query_states.device)
    v_dequant_high = v_dequant.narrow(2, key_select_length, key_high_length) 
    v_dequant_high = value_cache_high
    unpack_dequant(value_cache_low, topk_index, value_scale, value_mn, head_dim, v_dequant, 0)
    # v_dequant = torch.cat([v_dequant_low, value_cache_high], dim=2)
    # print(value_cache_low.shape, v_dequant_low.shape)
    # v_dequant = torch.randn((bsz, num_kv_head, topk_index.shape[-1] + key_high_length, head_dim), dtype=torch.float16, device=query_states.device)
    # _flash_attention_forward(
    #         query_states, k_recover, v_dequant, attn_mask, query_length=1, dropout=0.0, use_sliding_windows=False
    # )

    # print(query_states.shape)
    # print(k_recover.shape)
    # print(v_dequant.shape)
    attn_output = flash_attn_func(
                query_states,
                k_recover.transpose(1, 2),
                v_dequant.transpose(1, 2),
                causal=True,
            )
def fused_recover_fa_attention(
    query_states:torch.Tensor, 
    key_cache_high:torch.Tensor,
    key_cache_low:torch.Tensor,
    value_cache_high:torch.Tensor,
    value_cache_low:torch.Tensor,
    value_scale:torch.Tensor,
    value_mn:torch.Tensor,
    recover_weight:torch.Tensor, 
    recover_bias:torch.Tensor,
    cos:torch.Tensor, 
    sin:torch.Tensor, 
    high_index:torch.Tensor,
    sparsity:int,
    num_kv_head:int,
    vbit:int,
    group_size:int,
    attn_mask:torch.Tensor,
    high_length:int,
    low_length: int
):
    # import pdb;pdb.set_trace()
    bsz, q_len, num_head, head_dim = query_states.shape
    key_high_rank = key_cache_high.shape[-1]
    key_low_rank = key_cache_low.shape[-1]
    topk_rank = key_low_rank // 2
    query_states_to_sparse = query_states.view(bsz, num_head // num_kv_head, -1).mean(dim=1)
# 首先调整 query_states_to_sparse 的形状以适应矩阵乘法
    query_sparse = torch.matmul(query_states_to_sparse, recover_weight[-topk_rank:, :].transpose(0, 1)).view(bsz, -1, q_len, topk_rank)
    # 恢复到原始的形状
    # query_sparse = torch.einsum("bshd,rd->bhsr", query_states_to_sparse, recover_weight[-topk_rank:, :])
    # import pdb;pdb.set_trace()

    key_sparse = key_cache_low[..., -topk_rank:]
    # key_sparse = key_cache_low
    key_select_length = (high_length + low_length) // sparsity
    topk_index = generate_sparse_indices_label_single_head(query_sparse, key_sparse,  key_select_length)
    k_recover = torch.empty(bsz, num_kv_head, key_select_length + high_length, head_dim, dtype=query_states.dtype, device=query_states.device)
    recover_rope(key_cache_high, high_index, recover_weight, sin, cos, num_kv_head, head_dim, k_recover, key_select_length)
    recover_rope(key_cache_low, topk_index, recover_weight.narrow(0, key_high_rank - key_low_rank, key_low_rank), sin, cos, num_kv_head, head_dim, k_recover, 0)
    # k_high_recover = torch.randn((bsz, num_kv_head, key_high_length, head_dim), device=query_states.device, dtype=torch.float16)
    # k_low_recover = torch.randn((bsz, num_kv_head, topk_index.shape[-1], head_dim), device=query_states.device, dtype=torch.float16)
    # unpack_value_cache_low = unpack_tensor(value_cache_low[:, :, topk_index[0], :], vbit, pack_dim=3)
    # v_dequant_low = dequant_weight_outer(unpack_value_cache_low, value_scale[:, :, topk_index[0], :], value_mn[:, :, topk_index[0], :], group_size)
    v_dequant = torch.empty(bsz, num_kv_head, key_select_length + high_length, head_dim, dtype=query_states.dtype, device=query_states.device)
    v_dequant_high = v_dequant.narrow(2, key_select_length, high_length) 
    v_dequant_high = value_cache_high
    unpack_dequant(value_cache_low, topk_index, value_scale, value_mn, head_dim, v_dequant, 0)
    # v_dequant = torch.cat([v_dequant_low, value_cache_high], dim=2)
    # print(value_cache_low.shape, v_dequant_low.shape)
    # v_dequant = torch.randn((bsz, num_kv_head, topk_index.shape[-1] + key_high_length, head_dim), dtype=torch.float16, device=query_states.device)
    # _flash_attention_forward(
    #         query_states, k_recover, v_dequant, attn_mask, query_length=1, dropout=0.0, use_sliding_windows=False
    # )

    # print(query_states.shape)
    # print(k_recover.shape)
    # print(v_dequant.shape)
    attn_output = flash_attn_func(
                query_states,
                k_recover.transpose(1, 2),
                v_dequant.transpose(1, 2),
                causal=True,
            )
    return attn_output
