import torch
import torch.nn.functional as F
from torch import nn
import math
from transformers.models.llama.modeling_llama import rotate_half

def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)

def repeat_weight(weight, n_rep):
    num_key_value_heads, head_dim, sparse_dim = weight.shape
    if n_rep == 1:
        return weight
    weight = weight[:, None, :, :].expand(num_key_value_heads, n_rep, head_dim, sparse_dim)
    return weight.reshape(num_key_value_heads * n_rep, head_dim, sparse_dim)

class LlamaRope(nn.Module):
    def __init__(self):
        super(LlamaRope, self).__init__()
    
    def forward(self, x, cos, sin, position_ids, unsqueeze_dim=1):
        # import pdb;pdb.set_trace()
        cos = cos[position_ids].unsqueeze(unsqueeze_dim)
        sin = sin[position_ids].unsqueeze(unsqueeze_dim)
        x_embed = (x * cos) + (rotate_half(x) * sin)
        return x_embed

def generate_sparse_mask_label(lowrank_q, lowrank_k, kv_spasity=16):
    # BNSD
    B,Nq,S,D = lowrank_q.shape
    B,Nk,K,D = lowrank_k.shape
    assert S == 1
    kv_budget = (K + kv_spasity - 1) // kv_spasity
    if K < kv_budget:
        return torch.zeros(B, K, dtype=torch.bool, device=lowrank_q.device)
    scores = torch.matmul(lowrank_q, lowrank_k.transpose(2, 3))
    scores = scores.float() / math.sqrt(D)
    scores = torch.nn.functional.softmax(scores, dim=-1).mean(1).view(B, K)
    # BK
    topk_values, topk_indices = scores.topk(kv_budget, dim=-1)
    # BK
    mask = ~(scores > topk_values[...,-1:])
    return mask

def generate_sparse_mask_label_single_head(lowrank_q, lowrank_k, kv_spasity=16):
    # BNSD
    lowrank_q = lowrank_q.mean(1)

    B,S,D = lowrank_q.shape
    B,K,D = lowrank_k.shape
    assert S == 1
    kv_budget = (K + kv_spasity - 1) // kv_spasity
    if K < kv_budget:
        return torch.zeros(B, K, dtype=torch.bool, device=lowrank_q.device)
    scores = torch.matmul(lowrank_q, lowrank_k.transpose(2, 3))
    scores = scores.float() / math.sqrt(D)
    scores = torch.nn.functional.softmax(scores, dim=-1).view(B, K)
    # BK
    topk_values, topk_indices = scores.topk(kv_budget, dim=-1)
    # BK
    mask = ~(scores > topk_values[...,-1:])
    return mask


class sparse_project(nn.Module):
    def __init__(self, num_head, head_dim, sparse_rank, bias=False):
        super().__init__()
        self.num_heads = num_head
        self.head_dim = head_dim
        self.sparse_rank = sparse_rank
        # self.sparse_rank = 128
        self.hidden_size = self.num_heads * self.head_dim
        self.weight = nn.Parameter(torch.empty(self.num_heads, self.head_dim, self.sparse_rank))
        # print(self.weight.shape)
        if bias:
            self.bias = nn.Parameter(torch.empty(self.num_heads, self.head_dim))
        else:
            self.register_parameter('bias', None)
        # self.reset_parameters()
    def reset_parameters(self):
        nn.init.xavier_uniform_(self.weight)
        if self.bias is not None:
            nn.init.zeros_(self.bias)
    def forward(self, x):
        bsz, num_head, seq_len, head_dim = x.shape
        n_rep = num_head // self.num_heads
        # n_rep = self.num_heads // num_head
        weight_repeated = repeat_weight(self.weight, n_rep)
        # x = repeat_kv(x, n_rep)
        # x = torch.matmul(x, self.weight[..., -self.compute_rank:])
        x = torch.matmul(x, weight_repeated)
        # TODO bias 
        return x
    # def forward(self, x):
    #     # import pdb;pdb.set_trace()
    #     bsz, num_head, seq_len, head_dim = x.shape
    #     # n_rep = num_head // self.num_heads
    #     n_rep = self.num_heads // num_head
    #     # weight_repeated = repeat_weight(self.weight[..., -self.compute_rank:], n_rep)
    #     x = repeat_kv(x, n_rep)
    #     x = torch.matmul(x, self.weight)
    #     # TODO bias 
    #     return x
    
def generate_sparse_indices_label(lowrank_q, lowrank_k, kv_sparsity=16, withsoftmax=True):
    # 获取输入张量的维度：Batch, Num queries, S (应为1), Dimension
    B, Nq, S, D = lowrank_q.shape
    B, Nk, K, D = lowrank_k.shape
    assert S == 1, f"low rank q shape:{lowrank_q.shape}, low_rank_k shape:{lowrank_k.shape}"

    # 计算每个查询的预算
    if kv_sparsity <= 32:
        if kv_sparsity == 0:
            kv_budget = K
        else:
            kv_budget = (K + kv_sparsity - 1) // kv_sparsity
    else:
        kv_budget = kv_sparsity
    if K < kv_budget:
        return torch.arange(0, K, device=lowrank_q.device), None  # 如果K小于预算，则无需处理，返回None或适当的空结构

    # 计算query和key之间的得分（相似度）
    lowrank_k = repeat_kv(lowrank_k, Nq // Nk)
    scores = torch.matmul(lowrank_q, lowrank_k.transpose(2, 3))
    scores = scores.float() / math.sqrt(D)  # 缩放因子，通常为特征维度的平方根
    if withsoftmax:
        scores = torch.nn.functional.softmax(scores, dim=-1).mean(1).view(B, K)
    else:
        scores = scores.mean(1).view(B, K)

    # 获取每个批次中分数最高的kv_budget个索引
    topk_values, topk_indices = scores.topk(kv_budget, dim=-1)

    # 返回稀疏索引
    return topk_indices.sort()[0][0], scores

def generate_sparse_indices_label_single_head(lowrank_q, lowrank_k, kv_sparsity=16, withsoftmax=True):
    # 获取输入张量的维度：Batch, Num queries, S (应为1), Dimension
    B, Nq, S, D = lowrank_q.shape
    B, Nk, K, D = lowrank_k.shape
    assert S == 1, f"low rank q shape:{lowrank_q.shape}, low_rank_k shape:{lowrank_k.shape}"

    # 计算每个查询的预算
    if kv_sparsity <= 32:
        if kv_sparsity == 0:
            kv_budget = K
        else:
            kv_budget = (K + kv_sparsity - 1) // kv_sparsity
    else:
        kv_budget = kv_sparsity
    if K < kv_budget:
        return torch.arange(0, K, device=lowrank_q.device), None  # 如果K小于预算，则无需处理，返回None或适当的空结构

    # 计算query和key之间的得分（相似度）
    # lowrank_k = repeat_kv(lowrank_k, Nq // Nk)
    scores = torch.matmul(lowrank_q, lowrank_k.transpose(2, 3))
    scores = scores.float() / math.sqrt(D)  # 缩放因子，通常为特征维度的平方根
    
    scores = scores.view(B, K)

    # 获取每个批次中分数最高的kv_budget个索引
    topk_values, topk_indices = scores.topk(kv_budget, dim=-1)

    # 返回稀疏索引
    return topk_indices.sort()[0][0], scores

def generate_sparse_indices_label_per_head(lowrank_q, lowrank_k, kv_sparsity=16, withsoftmax=True):
    # 获取输入张量的维度：Batch, Num queries, S (应为1), Dimension
    B, Nq, S, D = lowrank_q.shape
    B, Nk, K, D = lowrank_k.shape
    assert S == 1, f"low rank q shape:{lowrank_q.shape}, low_rank_k shape:{lowrank_k.shape}"

    # 计算每个查询的预算
    if kv_sparsity <= 32:
        if kv_sparsity == 0:
            kv_budget = K
        else:
            kv_budget = (K + kv_sparsity - 1) // kv_sparsity
    else:
        kv_budget = kv_sparsity
    if K < kv_budget:
        return torch.arange(0, K, device=lowrank_q.device), None  # 如果K小于预算，则无需处理，返回None或适当的空结构

    # 计算query和key之间的得分（相似度）
    lowrank_k = repeat_kv(lowrank_k, Nq // Nk)
    scores = torch.matmul(lowrank_q, lowrank_k.transpose(2, 3))
    scores = scores.float() / math.sqrt(D)  # 缩放因子，通常为特征维度的平方根
    
    scores = scores.view(B, Nq // Nk, K)

    # 获取每个批次中分数最高的kv_budget个索引
    topk_values, topk_indices = scores.topk(kv_budget, dim=-1)

    # 返回稀疏索引
    return topk_indices.sort()[0][0], scores

def generate_sparse_indices_label_withrope(lowrank_q, lowrank_k, kv_sparsity=16):
    # 获取输入张量的维度：Batch, Num queries, S (应为1), Dimension
    B, Nq, S, D = lowrank_q.shape
    B, Nk, K, D = lowrank_k.shape
    assert S == 1, f"low rank q shape:{lowrank_q.shape}, low_rank_k shape:{lowrank_k.shape}"

    # 计算每个查询的预算
    if kv_sparsity == 0:
        kv_budget = K
    else:
        kv_budget = (K + kv_sparsity - 1) // kv_sparsity
    if K < kv_budget:
        return None  # 如果K小于预算，则无需处理，返回None或适当的空结构

    # 计算query和key之间的得分（相似度）
    scores = torch.matmul(lowrank_q, lowrank_k.transpose(2, 3))
    scores = scores.float() / math.sqrt(D)  # 缩放因子，通常为特征维度的平方根
    scores = torch.nn.functional.softmax(scores, dim=-1).view(B, K)

    # 获取每个批次中分数最高的kv_budget个索引
    topk_values, topk_indices = scores.topk(kv_budget, dim=-1)

    # 返回稀疏索引
    return topk_indices.sort()[0]

def generate_sparse_indices_label_tp_head(lowrank_q, lowrank_k, kv_sparsity=16, bias=None, withsoftmax=True):
    # 获取输入张量的维度：Batch, Num queries, S (应为1), Dimension
    B, Nq, S, D = lowrank_q.shape
    B, Nk, K, D = lowrank_k.shape
    assert S == 1, f"low rank q shape:{lowrank_q.shape}, low_rank_k shape:{lowrank_k.shape}"

    # 计算每个查询的预算
    if kv_sparsity == 0:
        kv_budget = K
    else:
        kv_budget = (K + kv_sparsity - 1) // kv_sparsity
    if K < kv_budget:
        return None  # 如果K小于预算，则无需处理，返回None或适当的空结构

    # 计算query和key之间的得分（相似度）
    # lowrank_k = repeat_kv(lowrank_k, Nq // Nk)
    # scores = torch.matmul(lowrank_q, lowrank_k.transpose(2, 3))
    # scores = scores.float() / math.sqrt(D)  # 缩放因子，通常为特征维度的平方根
    # # import pdb;pdb.set_trace()
    # if bias != None:
    #     scores = scores + bias
    # Nk = 8
    # scores = scores.view(B, -1, Nk, K)
    # # scores = torch.nn.functional.softmax(scores, dim=-1).mean(1).view(B, Nk, K)
    # # scores = torch.nn.functional.softmax(scores, dim=-1).amax(1).view(B, Nk, K)
    # scores = scores.mean(1).view(B, Nk, K)
    scores = origin_attention(lowrank_q, lowrank_k, None, bias, withsoftmax)
    scores = scores.view(B, Nk, K)

    # 获取每个批次中分数最高的kv_budget个索引
    topk_values, topk_indices = scores.topk(kv_budget, dim=-1)

    # 返回稀疏索引
    return topk_indices.sort()[0][0]
def origin_attention(q, k, mask, bias=None, withsoftmax=True):
    B, Nq, S, D = q.shape
    B, Nk, K, D = k.shape
    k = repeat_kv(k, Nq // Nk)
    attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(D)
    if bias is not None:
        attn_weights += bias
    if mask is not None:
        attn_weights = attn_weights + mask
    # attn_weights = attn_weights.view(B, Nk, -1, S, K).transpose(1, 2)
    attn_weights = attn_weights.view(B, -1, Nk, S, K)
    if withsoftmax:
        attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1).mean(1)
    else:
        attn_weights = attn_weights.mean(1)
    # attn_weights.squeeze(-2)
    # attn_weights = attn_weights.softmax(dim=-1).mean(1)
    # attn_weights = attn_weights.mean(1)
    return attn_weights

def overlap_ratio(tensor1, tensor2):
    overlap_sum = 0
    num_heads = tensor1.shape[0]
    for i in range(num_heads):
        set1 = set(tensor1[i].tolist())
        set2 = set(tensor2[i].tolist())
        intersection = set1 & set2
        overlap_ratio = len(intersection) / len(set1)
        overlap_sum += overlap_ratio
    avg_overlap_ratio = overlap_sum / num_heads
    return avg_overlap_ratio


def overlap_ratio_by_tensor(tensor1, tensor2, kv_sparsity=16): 
    tensor1 = tensor1.squeeze(0)    
    tensor2 = tensor2.squeeze(0)
    overlap_sum = 0
    num_heads = tensor1.shape[0]
    tensor1 = tensor1.softmax(dim=-1)
    tensor2 = tensor2.softmax(dim=-1)
    kv_len = tensor1.shape[1]
    kv_budget = (kv_len + kv_sparsity - 1) // kv_sparsity
    _, topk_indices_1 = tensor1.topk(kv_budget, dim=-1)
    _, topk_indices_2 = tensor2.topk(kv_budget, dim=-1)
    for i in range(num_heads):
        set1 = set(topk_indices_1[i].tolist())
        set2 = set(topk_indices_2[i].tolist())
        intersection = set1 & set2
        # import pdb;pdb.set_trace()
        indices = torch.Tensor(list(intersection)).to(torch.int64)
        overlap_ratio = sum(tensor1[i, indices]) / sum(tensor2[i, indices])
        overlap_sum += overlap_ratio
    avg_overlap_ratio = overlap_sum / (num_heads)
    return avg_overlap_ratio
def overlap_ratio_single_head(tensor1, tensor2):
    set1 = set(tensor1.tolist())
    set2 = set(tensor2.tolist())
    intersection = set1 & set2
    overlap_ratio = len(intersection) / len(set1)
    return overlap_ratio

def percent_scores(indices, true_indices, true_scores):
    # import pdb;pdb.set_trace()
    set1 = set(indices.tolist())
    set2 = set(true_indices.tolist())
    intersection = set1
    true_scores = torch.nn.functional.softmax(true_scores.squeeze(0))
    intersection_scores = true_scores[list(intersection)]
    sum_1 = sum(intersection_scores)
    sum2 = sum(true_scores)
    return sum_1 / sum2

def generate_sparse_indices_label_mean(lowrank_q: torch.Tensor, lowrank_k: torch.Tensor, kv_sparsity=16):
    # 获取输入张量的维度：Batch, Num queries, S (应为1), Dimension
    B, Nq, S, D = lowrank_q.shape
    B, Nk, K, D = lowrank_k.shape
    assert S == 1, f"low rank q shape:{lowrank_q.shape}, low_rank_k shape:{lowrank_k.shape}"

    # 计算每个查询的预算
    if kv_sparsity == 0:
        kv_budget = K
    else:
        kv_budget = (K + kv_sparsity - 1) // kv_sparsity
    if K < kv_budget:
        return None  # 如果K小于预算，则无需处理，返回None或适当的空结构
    
    # lowrank_q = torch.tanh(lowrank_q)
    # lowrank_k = torch.tanh(lowrank_k)
    # import pdb;pdb.set_trace()
    lowrank_q = lowrank_q.view(B, Nk, -1, S, D).mean(2)
    # lowrank_q = lowrank_q.view(B, -1, Nk, S, D).mean(1)

    # lowrank_k = repeat_kv(lowrank_k, Nq // Nk)

    # 计算query和key之间的得分（相似度）
    scores = torch.matmul(lowrank_q, lowrank_k.transpose(2, 3))
    scores = scores.float() / math.sqrt(D)  # 缩放因子，通常为特征维度的平方根
    
    # scores = scores.view(B, Nk, K)
    # scores = scores.reshape(B, Nk, -1, S, K).mean(2).view(B, Nk, K)
    # scores = torch.nn.functional.softmax(scores, dim=-1).view(B, Nk, K)
    # scores = scores.mean(1).view(B, K)
    # scores = torch.nn.functional.softmax(scores, dim=-1).mean(1).view(B, K)
    scores = torch.nn.functional.softmax(scores, dim=-1).amax(1).view(B, K)
# 
    # 获取每个批次中分数最高的kv_budget个索引
    topk_values, topk_indices = scores.topk(kv_budget, dim=-1)

    # 返回稀疏索引
    return topk_indices.sort()[0]

def generate_sparse_indices_label_all_head(lowrank_q, lowrank_k, kv_sparsity=16):
    # 获取输入张量的维度：Batch, Num queries, S (应为1), Dimension
    B, S, D = lowrank_q.shape
    B, K, D = lowrank_k.shape
    assert S == 1, f"low rank q shape:{lowrank_q.shape}, low_rank_k shape:{lowrank_k.shape}"

    # 计算每个查询的预算
    if kv_sparsity == 0:
        kv_budget = K
    else:
        kv_budget = (K + kv_sparsity - 1) // kv_sparsity
    if K < kv_budget:
        return None  # 如果K小于预算，则无需处理，返回None或适当的空结构

    # 计算query和key之间的得分（相似度）
    scores = torch.matmul(lowrank_q, lowrank_k.transpose(1, 2))
    scores = scores.float() / math.sqrt(D)  # 缩放因子，通常为特征维度的平方根

    # 获取每个批次中分数最高的kv_budget个索引
    topk_values, topk_indices = scores.view(B, K).topk(kv_budget, dim=-1)

    # 返回稀疏索引
    return topk_indices.sort()[0]

    
# def generate_sparse_indices_label_single_head(lowrank_q, lowrank_k, kv_sparsity=16):
#     # 获取输入张量的维度：Batch, Num queries, S (应为1), Dimension
#     B, Nq, S, D = lowrank_q.shape
#     B, K, D = lowrank_k.shape
#     assert S == 1, f"low rank q shape:{lowrank_q.shape}, low_rank_k shape:{lowrank_k.shape}"
#     # import pdb;pdb.set_trace()
#     lowrank_k = lowrank_k[:, None, :, :].expand(B, Nq, K, D)

#     # 计算每个查询的预算
#     if kv_sparsity == 0:
#         kv_budget = K
#     else:
#         kv_budget = (K + kv_sparsity - 1) // kv_sparsity
#     if K < kv_budget:
#         return None  # 如果K小于预算，则无需处理，返回None或适当的空结构

#     # 计算query和key之间的得分（相似度）
#     scores = torch.matmul(lowrank_q, lowrank_k.transpose(2, 3))
#     scores = scores.float() / math.sqrt(D)  # 缩放因子，通常为特征维度的平方根
#     scores = torch.nn.functional.softmax(scores, dim=-1).mean(1).view(B, K)

#     # 获取每个批次中分数最高的kv_budget个索引
#     topk_values, topk_indices = scores.topk(kv_budget, dim=-1)

#     # 返回稀疏索引
#     return topk_indices.sort()[0]