import torch
import torch.nn.functional as F
import math
from transformers.models.llama.modeling_llama import repeat_kv

# def standard_dis_index(self, data, queries, k, norm=1, pool=False, kernel_size=5, sum_over_heads=False):
#     inner_product = torch.matmul(queries, data.transpose(-1, -2))
#     inner_product = inner_product[:, :, 0, :]
#     if self.probe_context is not None:
#         norm_scores = F.softmax(inner_product, -1)
#         self.attention_socres = norm_scores[:, :,self.probe_context[0]:self.probe_context[0]+self.probe_context[1]].sum(dim=-1).squeeze()
#     if sum_over_heads:
#         inner_product = torch.sum(inner_product, dim=1, keepdim=True)
#     if pool:
#         inner_product = F.avg_pool1d(
#             inner_product, kernel_size=kernel_size, padding=kernel_size//2, stride=1)
#     top_k = torch.topk(inner_product, k, dim=-1)
#     indices = top_k.indices
#     distances = top_k.values
#     if norm != 1:
#         distances = distances / norm
#     return distances, indices


# def find_context(self, query_states, key_states, print_idx_dis=False):
#     b, h, n, d = key_states.shape
#     # if self.indecies is None and self.layer_idx == self.select_layer_idx:
#     if True:
#         assert b == 1
#         key_states_repeat = repeat_kv(
#             key_states, self.num_key_value_groups)
#         query_last_states = query_states[:, :, -1:, :]
#         _, indices = standard_dis_index(self, key_states_repeat, query_last_states, min(
#             self.topk, n), pool=True, sum_over_heads=True)
#         self.indecies = indices
#         if print_idx_dis:
#             print(self.layer_idx, torch.min(torch.abs(indices-62383)))
#     return

def new_standard_dis_index(self, data, queries, k,  pool, norm=1, kernel_size=7, sum_over_heads=False, window_size=1):
    inner_product = torch.matmul(queries, data.transpose(-1, -2)) / math.sqrt(queries.shape[-1])
    inner_product = torch.nn.functional.softmax(inner_product, dim=-1)
    inner_product = torch.mean(inner_product, dim=-2, keepdim=True)[:, :, 0, :]
    if self.probe_context is not None:
        a = inner_product[:, :,self.probe_context[0]:self.probe_context[0]+self.probe_context[1]].sum(dim=-1).squeeze()
        self.attention_socres = a
    self.all_attention_scores = inner_product
    if sum_over_heads:
        if self.config['heads']:
            indices = self.config['heads']
            # retreival_heads = [18, 13, 21, 8]
            # retreival_scores = torch.sum(inner_product[:,retreival_heads,:], dim=1, keepdim=True)
            # r_value, r_index = torch.topk(retreival_scores, 128)
            inner_product = torch.sum(inner_product[:,indices,:], dim=1, keepdim=True)
            # inner_product[:,:,r_index] = 1
        else:
            inner_product = torch.sum(inner_product, dim=1, keepdim=True)
    if pool == 'avg_pool':
        inner_product = F.avg_pool1d(inner_product, kernel_size=kernel_size, padding=kernel_size//2, stride=1)
    elif pool == 'max_pool':
        inner_product = F.max_pool1d(inner_product, kernel_size=kernel_size, padding=kernel_size//2, stride=1)

#            indices = attn_cache.topk(self.max_capacity_prompt - self.window_size, dim=-1).indices
#            key_states = torch.cat([k_past_compress, k_cur], dim = 2)
        # inner_product = F.max_pool1d(inner_product, kernel_size = kernel_size, padding=kernel_size//2, stride=1)
    top_k = torch.topk(inner_product, k-window_size, dim=-1)
    indices = top_k.indices
    distances = top_k.values
    if norm != 1:
        distances = distances / norm
    return distances, indices, inner_product


def new_find_context(self, query_states, key_states, print_idx_dis=False):
    window_size = self.config['window_size'] if self.config else 16
    kernel_size = self.config['kernel_size'] if self.config else 16
    b, h, n, d = key_states.shape
    # if self.indecies is None and self.layer_idx == self.select_layer_idx:
    if True:
        assert b == 1
        key_states_repeat = repeat_kv(
            key_states, self.num_key_value_groups)
        select_key_states = key_states_repeat[:, :, :-window_size, :]
        select_query_states = query_states[:, :, -window_size:, :]
        _, indices, inner_product = standard_dis_index(self, select_key_states, select_query_states, min(self.topk, n), pool=self.config['pool'], kernel_size = kernel_size, sum_over_heads=True, window_size=window_size)
        self.indecies = indices
        self.pooled_attention =inner_product.squeeze()
        if print_idx_dis:
            print(self.layer_idx, torch.min(torch.abs(indices-62383)))
    return

def standard_dis_index(self, data, queries, k,  pool, norm=1, kernel_size=7, sum_over_heads=False, window_size=1):
    window_size = self.config['window_size'] if self.config else 16
    attn_weights = torch.matmul(queries, data.transpose(-1, -2)) / math.sqrt(queries.shape[-1])
    mask = torch.full((window_size, window_size), torch.finfo(attn_weights.dtype).min, device=attn_weights.device)
    mask_cond = torch.arange(mask.size(-1), device=attn_weights.device)
    mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
    mask = mask.to(attn_weights.device)
    attention_mask = mask[None, None, :, :]
    attn_weights[:, :, -window_size:, -window_size:] += attention_mask

    inner_product = torch.nn.functional.softmax(attn_weights, dim=-1)
    inner_product = (inner_product[:, :, -window_size:, : -window_size]).mean(dim=-2) # align to snap kv
    if self.probe_context is not None:
        a = inner_product[:, :,self.probe_context[0]:self.probe_context[0]+self.probe_context[1]].sum(dim=-1).squeeze()
        self.attention_socres = a
    self.all_attention_scores = inner_product
    if sum_over_heads:
        if self.config and self.config['heads']:
            indices = self.config['heads']
            # retreival_heads = [18, 13, 21, 8]
            # retreival_scores = torch.sum(inner_product[:,retreival_heads,:], dim=1, keepdim=True)
            # r_value, r_index = torch.topk(retreival_scores, 128)
            inner_product = torch.sum(inner_product[:,indices,:], dim=1, keepdim=True)
            # inner_product[:,:,r_index] = 1
        else:
            inner_product = torch.sum(inner_product, dim=1, keepdim=True)
    if pool == 'avg_pool':
        inner_product = F.avg_pool1d(inner_product, kernel_size=kernel_size, padding=kernel_size//2, stride=1)
    elif pool == 'max_pool':
        inner_product = F.max_pool1d(inner_product, kernel_size=kernel_size, padding=kernel_size//2, stride=1)

#            indices = attn_cache.topk(self.max_capacity_prompt - self.window_size, dim=-1).indices
#            key_states = torch.cat([k_past_compress, k_cur], dim = 2)
        # inner_product = F.max_pool1d(inner_product, kernel_size = kernel_size, padding=kernel_size//2, stride=1)
    top_k = torch.topk(inner_product[-window_size:], k-window_size, dim=-1) # align to snapkv
    indices = top_k.indices
    distances = top_k.values
    if norm != 1:
        distances = distances / norm
    return distances, indices, inner_product


def find_context(self, query_states, key_states, print_idx_dis=False):
    window_size = self.config['window_size'] if self.config else 16
    kernel_size = self.config['kernel_size'] if self.config else 16
    pool = self.config['pool'] if self.config else 'avg_pool'
    b, h, n, d = key_states.shape
    # if self.indecies is None and self.layer_idx == self.select_layer_idx:
    if True:
        assert b == 1
        key_states_repeat = repeat_kv(
            key_states, self.num_key_value_groups)
        select_key_states = key_states_repeat
        select_query_states = query_states[:, :, -window_size:, :]
        _, indices, inner_product = standard_dis_index(self, select_key_states, select_query_states, min(self.topk, n), pool=pool, kernel_size = kernel_size, sum_over_heads=True, window_size=window_size)
        self.indecies = indices
        self.pooled_attention =inner_product.squeeze()
        if print_idx_dis:
            print(self.layer_idx, torch.min(torch.abs(indices-62383)))
    return
