
from transformers.models.llama.modeling_llama import repeat_kv
import torch
from .lsh_utils import indexing
from .select_utils import standard_dis_index


def knn(self, query_states, key_states, value_states, use_repeat_kv=True):
    # b, h, n, d = key_states.shape
    if use_repeat_kv:
        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)
    _, indices = standard_dis_index(
        key_states[:, :, self.sink_size:-self.recent_size, :], query_states, self.topk, pool=True)

    # indices = indices[:,:,0,:] # b, h, s
    indices = indices + self.sink_size
    key_sub_states = indexing(key_states, indices)
    value_sub_states = indexing(value_states, indices)

    key_states = torch.cat(
        (key_states[:, :, :self.sink_size, :], key_sub_states, key_states[:, :, -self.recent_size:, :]), dim=-2)
    value_states = torch.cat(
        (value_states[:, :, :self.sink_size, :], value_sub_states, value_states[:, :, -self.recent_size:, :]), dim=-2)

    return key_states, value_states
