from typing import Any, Dict, Optional, Tuple
import torch
import os
import torch.nn.functional as F
import torch.nn as nn
import math
from infomap import Infomap
import numpy as np
import pandas as pd
import csv
from accelerate import Accelerator
from scipy import stats
from scipy.spatial.distance import cosine
accelerator = Accelerator()
import sys
pi = math.pi

def get_head_pattern_attn_entropy(attn_weights,key_states,aerfa=0.5,beta=0.65,relative=0,reletive_entropy=0):
    attn_weights = attn_weights.squeeze()[:,-1,:]
    attn_weights_std = attn_weights.std(dim=-1)

    if relative == 1:
        attn_max = attn_weights_std.max(dim=-1)[0]
        attn_min = attn_weights_std.min(dim=-1)[0]
        mid = attn_min + (attn_max - attn_min)*aerfa 
        mask = attn_weights_std >= mid 
    elif relative == 2:
        top16 = attn_weights_std.topk(16, dim=-1, largest=True)[0][-1]
        mask = attn_weights_std >= top16
    elif relative == 0:
        mask = attn_weights_std >= aerfa 
    if mask.dim() != 1:
        raise ValueError("mask.dim() != 1")
    true_positions = torch.nonzero(mask, as_tuple=False).squeeze(-1)
    if true_positions.shape[0] == 0:
        return mask

    key_new_states = key_states[:,mask,:,:].view(key_states.shape[0],true_positions.shape[0],-1,key_states.shape[-1])
    if key_new_states.shape[1] != true_positions.shape[0]:
        raise ValueError("key_new_states.shape[1] != true_positions.shape[0]")

    if reletive_entropy[0] == 0:
        entropy = get_entropy_no_group_head_svd(key_new_states,beta)
        # entropy = torch.from_numpy(entropy).to(key_states.device)
        topn = entropy.topk(reletive_entropy[1], dim=-1, largest=True)[0][-1]
        mask_new = entropy >= topn
    elif reletive_entropy[0] == 1:
        entropy = get_entropy_no_group_head_svd(key_new_states,beta)
        entropy = torch.from_numpy(entropy).to(key_states.device)
        return entropy
    elif reletive_entropy[0] == 2:
        entropy = get_entropy_no_group_head_svd_draw_picture(key_new_states,beta,reletive_entropy[1],reletive_entropy[2],reletive_entropy[3],reletive_entropy[4])
        entropy = torch.from_numpy(entropy).to(key_states.device)
        top16 = entropy.topk(16, dim=-1, largest=True)[0][-1]
        mask_new = entropy >= top16
    
    mask[true_positions] = mask_new
    return mask

def normalize(R):
    mean = R.mean(dim=-2,keepdim=True)
    R = R - mean
    norms = torch.norm(R, p=2, dim=-1, keepdim=True)
    R = R/norms
    return R

def cal_cov_no_group(R):
    Z = torch.nn.functional.normalize(R, dim=-1)
    A = torch.einsum('bhji,bhjk->bhik',Z,Z) / Z.shape[-2]
    return A

def cal_entropy_no_group_svd(A,topk):
    A = A.contiguous()
    A_np = A.cpu().numpy().astype(np.float64)
    traces_np = np.trace(A_np, axis1=-2, axis2=-1)
    traces_np = traces_np[:, np.newaxis, np.newaxis]
    
    epsilon = 1e-10  
    eig_val_np = np.linalg.svd(A_np / traces_np + epsilon * np.eye(A_np.shape[-1]), compute_uv=False)
    entropy_np = -np.nansum(eig_val_np[:,:topk] * np.log(eig_val_np[:,:topk]), axis=-1)
    normalized_entropy = entropy_np/math.log(A.shape[-1])
    return normalized_entropy

def cal_entropy_no_group_svd_draw_picture(A,svd_n,draw_model,layer_idx,dataname,my_type):
    A = A.contiguous()
    A_np = A.cpu().numpy().astype(np.float64)
    traces_np = np.trace(A_np, axis1=-2, axis2=-1)
    traces_np = traces_np[:, np.newaxis, np.newaxis]
    
    epsilon = 1e-10  
    eig_val_np = np.linalg.svd(A_np / traces_np + epsilon * np.eye(A_np.shape[-1]), compute_uv=False)
    entropy_np = -np.nansum(eig_val_np[:,:svd_n] * np.log(eig_val_np[:,:svd_n]), axis=-1)
    normalized_entropy = entropy_np / math.log(A.shape[-1])
    
    if draw_model == "draw_svd_trend":    
        datapath = os.path.join(os.getcwd(), 'draw_picture', 'svdn_trend')
        datapath = datapath + f"/llama2-13B/{my_type}/csv/" + dataname + "_eigenvalues.csv"
        with open(datapath, 'w', newline='') as csvfile:
            csv_writer = csv.writer(csvfile)
            
            for i in range(40):
                row = [f"{x:.20f}" for x in eig_val_np[i, :]]
                csv_writer.writerow(row)
    elif draw_model == "draw_thermodynamic_chart":
        entropy_np = -np.nansum(eig_val_np[:,:svd_n] * np.log(eig_val_np[:,:svd_n]), axis=-1)
        normalized_entropy = entropy_np/math.log(A.shape[-1])
        
        datapath = os.path.join(os.getcwd(), 'draw_picture', 'thermodynamic_chart')
        datapath = datapath + f"/llama2-13B/{my_type}/csv/" 
        os.makedirs(datapath, exist_ok=True)
        datapath = datapath+ dataname + "_eigenvalues.csv"
        mode = 'a' if layer_idx > 0 else 'w'
        with open(datapath, mode, newline='') as csvfile:
            csv_writer = csv.writer(csvfile)
            
            row = [f"{x:.20f}" for x in normalized_entropy]
            csv_writer.writerow(row)
    elif draw_model == "draw_accumulated_energy":
        if layer_idx == 0:
            dim = (-2,-1)
            raw_matrix = A_np / traces_np + epsilon * np.eye(A_np.shape[-1])
            datapath = os.path.join(os.getcwd(), 'draw_picture', 'accumulated_energy')
            datapath = datapath + "/" + dataname + "_eigenvalues.csv"
            with open(datapath, 'w', newline='') as csvfile:
                csv_writer = csv.writer(csvfile)
                
                for i in range(raw_matrix.shape[0]):
                    U, S, V = np.linalg.svd(raw_matrix[i],full_matrices=False)
                    cumulative_energy = np.cumsum(S**2) / np.sum(S**2)
                    csv_writer.writerow(cumulative_energy)        
    
    return normalized_entropy

def get_entropy_no_group_svd(key,topk):   
    R = normalize(key)
    A = cal_cov_no_group(R)
    Entropy1 = cal_entropy_no_group_svd(A,topk)
    Entropy1=Entropy1.sum(axis=0)
    return Entropy1

def get_entropy_no_group_head_svd(key,topk):    
    R = normalize(key)
    A = cal_cov_no_group(R)
    Entropy1 = cal_entropy_no_group_svd(A,topk)
    return Entropy1

def get_entropy_no_group_head_svd_draw_picture(key,svd_n,draw_model,layer_idx,dataname,my_type):    
    R = normalize(key)
    A = cal_cov_no_group(R)
    Entropy1 = cal_entropy_no_group_svd_draw_picture(A,svd_n,draw_model,layer_idx,dataname,my_type)
    return Entropy1

class UncompCluster():
    def __init__(self, num_hidden_layers = 32, window_size = 8, max_capacity_prompt = 512, 
                 kernel_size = 5, pooling = 'avgpool', beta = 20,  layer_idx=None , manager = None):
        self.manager = manager
        self.bsz = manager.bsz 
        self.layer_idx = layer_idx
        self.num_hidden_layers = manager.num_hidden_layers
        self.num_attention_heads = manager.num_attention_heads
        self.window_size = window_size
            
        if manager.method_name in manager.delet_head_set:
            self.select_topk = manager.select_topk 
            self.head_indices1 = manager.head_datas[self.layer_idx]
            self.recent_indices_generate=[torch.arange(-self.window_size,0,device='cuda').view(1, 1, -1).expand(self.bsz,self.select_topk ,-1)]
        elif "group3" in manager.method_name:
            self.head_indices1 = manager.head_datas[self.layer_idx]
            self.recent_indices = [torch.arange(-self.window_size,0,device='cuda').view(1, 1, -1).expand(self.bsz,self.num_hidden_layers,-1)]
            self.recent_indices_generate = [torch.arange(-self.window_size,0,device='cuda').view(1, 1, -1).expand(self.bsz,11,-1),
                                            torch.arange(-self.window_size,0,device='cuda').view(1, 1, -1).expand(self.bsz,10,-1),
                                            torch.arange(-self.window_size,0,device='cuda').view(1, 1, -1).expand(self.bsz,11,-1)
                                            ]
        elif "group4" in manager.method_name:
            self.head_indices1 = manager.head_datas[self.layer_idx]
            self.recent_indices = [torch.arange(-self.window_size,0,device='cuda').view(1, 1, -1).expand(self.bsz,self.num_hidden_layers,-1)]
            self.recent_indices_generate = [torch.arange(-self.window_size,0,device='cuda').view(1, 1, -1).expand(self.bsz,8,-1),
                                            torch.arange(-self.window_size,0,device='cuda').view(1, 1, -1).expand(self.bsz,8,-1),
                                            torch.arange(-self.window_size,0,device='cuda').view(1, 1, -1).expand(self.bsz,8,-1),
                                            torch.arange(-self.window_size,0,device='cuda').view(1, 1, -1).expand(self.bsz,8,-1),
                                            ]
        elif "group5" in manager.method_name:
            self.head_indices1 = manager.head_datas[self.layer_idx]
            self.recent_indices = [torch.arange(-self.window_size,0,device='cuda').view(1, 1, -1).expand(self.bsz,self.num_hidden_layers,-1)]
            self.recent_indices_generate = [torch.arange(-self.window_size,0,device='cuda').view(1, 1, -1).expand(self.bsz,7,-1),
                                            torch.arange(-self.window_size,0,device='cuda').view(1, 1, -1).expand(self.bsz,6,-1),
                                            torch.arange(-self.window_size,0,device='cuda').view(1, 1, -1).expand(self.bsz,6,-1),
                                            torch.arange(-self.window_size,0,device='cuda').view(1, 1, -1).expand(self.bsz,6,-1),
                                            torch.arange(-self.window_size,0,device='cuda').view(1, 1, -1).expand(self.bsz,7,-1)
                                            ]
        elif "group8" in manager.method_name:
            self.head_indices1 = manager.head_datas[self.layer_idx]
            self.recent_indices = [torch.arange(-self.window_size,0,device='cuda').view(1, 1, -1).expand(self.bsz,self.num_hidden_layers,-1)]
            self.recent_indices_generate = [torch.arange(-self.window_size,0,device='cuda').view(1, 1, -1).expand(self.bsz,4,-1),
                                            torch.arange(-self.window_size,0,device='cuda').view(1, 1, -1).expand(self.bsz,4,-1),
                                            torch.arange(-self.window_size,0,device='cuda').view(1, 1, -1).expand(self.bsz,4,-1),
                                            torch.arange(-self.window_size,0,device='cuda').view(1, 1, -1).expand(self.bsz,4,-1),
                                            torch.arange(-self.window_size,0,device='cuda').view(1, 1, -1).expand(self.bsz,4,-1),
                                            torch.arange(-self.window_size,0,device='cuda').view(1, 1, -1).expand(self.bsz,4,-1),
                                            torch.arange(-self.window_size,0,device='cuda').view(1, 1, -1).expand(self.bsz,4,-1),
                                            torch.arange(-self.window_size,0,device='cuda').view(1, 1, -1).expand(self.bsz,4,-1),
                                            ]
        else:
            self.head_indices1 = manager.head_datas[self.layer_idx]
            self.recent_indices = [torch.arange(-self.window_size,0,device='cuda').view(1, 1, -1).expand(self.bsz,self.num_hidden_layers,-1)]
            self.recent_indices_generate = [torch.arange(-self.window_size,0,device='cuda').view(1, 1, -1).expand(self.bsz,self.num_attention_heads//2,-1),torch.arange(-self.window_size,0,device='cuda').view(1, 1, -1).expand(self.bsz,self.num_hidden_layers//2,-1)]

        self.beta = beta
        self.max_capacity_prompt = max_capacity_prompt
        assert self.max_capacity_prompt - self.window_size > 0
        self.kernel_size = kernel_size
        self.pooling = pooling


    def update_kv(self, key_states, query_states, value_states, attn_weights_now, attn_weights_now_all):
        assert key_states.shape[-2] == query_states.shape[-2]
        bsz, num_heads, q_len, head_dim = query_states.shape
        
        manager = self.manager
        num_hidden_layers = manager.num_hidden_layers
        num_attention_heads = manager.num_attention_heads
        num_group_heads_num = num_attention_heads//2
        method_name = self.manager.method_name
        min_num = (self.max_capacity_prompt - self.window_size) // self.beta
        max_num = (self.max_capacity_prompt - self.window_size) * 2 - min_num
        
        if max_num >= q_len - self.window_size:
            max_num = q_len - self.window_size 
            min_num = (self.max_capacity_prompt - self.window_size) * 2 - max_num
        steps = (max_num - min_num) // self.num_hidden_layers 
        max_capacity_prompt = max_num - self.layer_idx * steps
        # draw picture
        if "draw_svd_trend" in method_name:
                attn_weights = attn_weights_now
                params = {0:2,1:"draw_svd_trend",2:self.layer_idx,3:manager.dataset,4:my_type}
                head_pattern = get_head_pattern_attn_entropy(attn_weights,states,0,32,0,params)
                revise_key_states = key_states
                revise_value_states = value_states
        elif "draw_thermodynamic_chart" in method_name:
            attn_weights = attn_weights_now
            params = {0:2,1:"draw_thermodynamic_chart",2:self.layer_idx,3:manager.dataset,4:my_type}
            head_pattern = get_head_pattern_attn_entropy(attn_weights,query_states,0,32,0,params)
            revise_key_states = key_states
            revise_value_states = value_states
        elif method_name == "draw_accumulated_energy":
            attn_weights = attn_weights_now
            params = {0:2,1:"draw_accumulated_energy",2:self.layer_idx,3:manager.dataset}
            head_pattern = get_head_pattern_attn_entropy(attn_weights,key_states,0,128,0,params)
            revise_key_states = key_states
            revise_value_states = value_states
        elif method_name == "draw_entropy":
            if self.layer_idx == 0:
                manager.entropys = []
            attn_weights = attn_weights_now
            entropy = get_entropy_no_group_svd(query_states,128)
            revise_key_states = key_states
            revise_value_states = value_states
            manager.entropys.append(entropy)
            if self.layer_idx == num_hidden_layers-1:
                datapath = os.path.join(os.getcwd(), 'draw_picture', 'draw_layer_trend')
                datapath = datapath + "/" + manager.dataset + "_eigenvalues.csv"
                with open(datapath, 'w', newline='') as csvfile:
                    csv_writer = csv.writer(csvfile)
                    csv_writer.writerow(manager.entropys)
        elif "draw_effective_rank" in method_name:
            if self.layer_idx == 0:
                manager.entropys = []
            attn_weights = attn_weights_now
            entropy = get_entropy_no_group_head_svd(states,128)
            entropy=np.exp(entropy).sum(axis=0)
            revise_key_states = key_states
            revise_value_states = value_states
            manager.entropys.append(entropy)
            if self.layer_idx == num_hidden_layers-1:
                datapath = os.path.join(os.getcwd(), 'draw_picture', 'draw_layer_draw_effective_rank_trend')
                datapath = datapath + f"/llama2-13B/{my_type}/csv/" + manager.dataset + "_eigenvalues.csv"
                with open(datapath, 'w', newline='') as csvfile:
                    csv_writer = csv.writer(csvfile)
                    csv_writer.writerow(manager.entropys)
        # type search
        elif method_name == "head_type_search_2":
            attn_weights = attn_weights_now
            select_topk = num_group_heads_num
            svdn = 32 
            head_pattern = get_head_pattern_attn_entropy(attn_weights,query_states,0,svdn,0,[0,select_topk])
            svdn = "svd" + str(svdn)
            filename = "Your save path"
            os.makedirs(os.path.dirname(filename), exist_ok=True)
            filename = filename + "head_type_search_layer" + str(self.layer_idx) + ".csv"
            if manager.sample_time == 0:
                mode = 'w'
            else:
                mode = 'a'
            with open(filename, mode, newline='') as csvfile:
                writer = csv.writer(csvfile)
                writer.writerow(head_pattern.to(torch.int8).tolist())
            return key_states, value_states
        elif method_name == "head_type_search_3":
            attn_weights = attn_weights_now
            group_num = 3
            select_topk = num_attention_heads//group_num
            svdn = 32
            entropy = get_head_pattern_attn_entropy(attn_weights,query_states,0,svdn,0,[1,select_topk])
            sorted_indices = torch.argsort(entropy)
            labels = torch.empty_like(entropy, dtype=torch.long)
            labels[sorted_indices[:11]] = 0   
            labels[sorted_indices[11:21]] = 1  
            labels[sorted_indices[21:]] = 2    
            svdn = "svd" + str(svdn)
            filename = "Your save path"
            os.makedirs(os.path.dirname(filename), exist_ok=True)
            filename = filename + "head_type_search_layer" + str(self.layer_idx) + ".csv"
            if manager.sample_time == 0:
                mode = 'w'
            else:
                mode = 'a'
            with open(filename, mode, newline='') as csvfile:
                writer = csv.writer(csvfile)
                writer.writerow(labels.to(torch.int8).tolist())
            return key_states, value_states
        elif method_name == "head_type_search_4":
            attn_weights = attn_weights_now
            group_num = 4
            select_topk = num_attention_heads//group_num
            svdn = 32
            entropy = get_head_pattern_attn_entropy(attn_weights,query_states,0,svdn,0,[1,select_topk])
            sorted_indices = torch.argsort(entropy)
            labels = torch.empty_like(entropy, dtype=torch.long)
            labels[sorted_indices[:8]] = 0   
            labels[sorted_indices[8:16]] = 1  
            labels[sorted_indices[16:24]] = 2    
            labels[sorted_indices[24:]] = 3
            
            svdn = "svd" + str(svdn)
            filename = "Your save path"
            os.makedirs(os.path.dirname(filename), exist_ok=True)
            filename = filename + "head_type_search_layer" + str(self.layer_idx) + ".csv"
            if manager.sample_time == 0:
                mode = 'w'
            else:
                mode = 'a'
            with open(filename, mode, newline='') as csvfile:
                writer = csv.writer(csvfile)
                writer.writerow(labels.to(torch.int8).tolist())
            return key_states, value_states
        elif method_name == "head_type_search_5":
            attn_weights = attn_weights_now
            group_num = 5
            select_topk = num_attention_heads//group_num
            svdn = 32
            entropy = get_head_pattern_attn_entropy(attn_weights,query_states,0,svdn,0,[1,select_topk])
            sorted_indices = torch.argsort(entropy)
            labels = torch.empty_like(entropy, dtype=torch.long)
            labels[sorted_indices[:7]] = 0   
            labels[sorted_indices[7:13]] = 1  
            labels[sorted_indices[13:19]] = 2    
            labels[sorted_indices[19:25]] = 3    
            labels[sorted_indices[25:]] = 4 
            
            svdn = "svd" + str(svdn)
            filename = "Your save path"
            os.makedirs(os.path.dirname(filename), exist_ok=True)
            filename = filename + "head_type_search_layer" + str(self.layer_idx) + ".csv"
            if manager.sample_time == 0:
                mode = 'w'
            else:
                mode = 'a'
            with open(filename, mode, newline='') as csvfile:
                writer = csv.writer(csvfile)
                writer.writerow(labels.to(torch.int8).tolist())
            return key_states, value_states
        elif method_name == "head_type_search_8":
            attn_weights = attn_weights_now
            group_num = 8
            select_topk = num_attention_heads//group_num
            svdn = 32
            entropy = get_head_pattern_attn_entropy(attn_weights,query_states,0,svdn,0,[1,select_topk])
            sorted_indices = torch.argsort(entropy)
            labels = torch.empty_like(entropy, dtype=torch.long)
            for i in range(8):
                labels[sorted_indices[0+4*i:4*(i+1)]] = i

            svdn = "svd" + str(svdn)
            filename = "Your save path"
            os.makedirs(os.path.dirname(filename), exist_ok=True)
            filename = filename + "head_type_search_layer" + str(self.layer_idx) + ".csv"
            if manager.sample_time == 0:
                mode = 'w'
            else:
                mode = 'a'
            with open(filename, mode, newline='') as csvfile:
                writer = csv.writer(csvfile)
                writer.writerow(labels.to(torch.int8).tolist())
            return key_states, value_states
        # uncomp
        elif method_name in manager.hidden_delete_stage_and_ours: 
            max_capacity_prompt = self.max_capacity_prompt
            if q_len < max_capacity_prompt:
                max_capacity_prompt = q_len
            self.cache_size = max_capacity_prompt
            attn_weights = attn_weights_now
            attn_weights_sum = attn_weights[:, :, -self.window_size:, :-self.window_size].sum(dim = -2)
            if self.pooling == 'avgpool':
                attn_cache = F.avg_pool1d(attn_weights_sum, kernel_size = self.kernel_size, padding=self.kernel_size//2, stride=1)
            elif self.pooling == 'maxpool':
                attn_cache = F.max_pool1d(attn_weights_sum, kernel_size = self.kernel_size, padding=self.kernel_size//2, stride=1)
            else:
                raise ValueError('Pooling method not supported')
            top_k = max_capacity_prompt - self.window_size
            indices = attn_cache.topk(top_k, dim=-1).indices
            indices1 = indices.sort(dim=-1).values
            recent_indices = self.recent_indices_generate[0]+q_len
            indices_1 = torch.cat([indices1[:,self.head_indices1[-num_group_heads_num:],:],recent_indices],dim=-1)
            indices_expanded  = indices_1.unsqueeze(-1).expand(-1, -1, -1, head_dim)
            indices_attn = indices_1.unsqueeze(-2).expand(-1, -1, self.window_size, -1)
            revise_key_states=key_states[:,self.head_indices1[-num_group_heads_num:],:,:].gather(dim = 2, index = indices_expanded)
            revise_value_states=value_states[:,self.head_indices1[-num_group_heads_num:],:,:].gather(dim = 2, index = indices_expanded)
            self.attn_weights_1=attn_weights[:, self.head_indices1[-num_group_heads_num:], :,:].gather(dim = -1, index = indices_attn)

            max_capacity_prompt_2 = max_capacity_prompt//2
            top_k2 = max_capacity_prompt_2 - self.window_size
            indices_2 = indices[:,:,:top_k2].sort(dim=-1).values
            indices_2 = torch.cat([indices_2[:,self.head_indices1[:num_group_heads_num],:],recent_indices],dim=-1)
            indices_expanded_2  = indices_2.unsqueeze(-1).expand(-1, -1, -1, head_dim)
            indices_attn_2 = indices_2.unsqueeze(-2).expand(-1, -1, self.window_size, -1)
            revise_value_states_2 = value_states[:,self.head_indices1[:num_group_heads_num],:,:].gather(dim=2,index=indices_expanded_2)
            revise_key_states_2 = key_states[:,self.head_indices1[:num_group_heads_num],:,:].gather(dim=2, index=indices_expanded_2)
            self.attn_weights_2 = attn_weights[:, self.head_indices1[:num_group_heads_num], :,:].gather(dim = -1, index = indices_attn_2)
            key1 = revise_key_states
            key2 = revise_key_states_2
            value1 = revise_value_states
            value2 = revise_value_states_2
            revise_key_states = [key1, key2]
            revise_value_states = [value1, value2]
            self.head_pattern = [self.head_indices1[-num_group_heads_num:], self.head_indices1[:num_group_heads_num]]
            self.attn_weights = [self.attn_weights_1, self.attn_weights_2]
            self.cache_size = [max_capacity_prompt, max_capacity_prompt_2]              
        elif method_name in manager.ahead_500_equal_code: 
            max_capacity_prompt = self.max_capacity_prompt
            if q_len < max_capacity_prompt:
                max_capacity_prompt = q_len
            if self.window_size > q_len:
                self.window_size = q_len
            if self.max_capacity_prompt == 86:
                max_capacity_prompts = [32,96]
            else:
                max_capacity_prompts = [max_capacity_prompt,max_capacity_prompt//2]
            if self.layer_idx == num_hidden_layers-1:
                print("max_capacity_prompts",max_capacity_prompts)
            attn_weights = attn_weights_now
            attn_weights_sum = attn_weights[:, :, -self.window_size:, :-self.window_size].sum(dim = -2)
            if self.pooling == 'avgpool':
                attn_cache = F.avg_pool1d(attn_weights_sum, kernel_size = self.kernel_size, padding=self.kernel_size//2, stride=1)
            elif self.pooling == 'maxpool':
                attn_cache = F.max_pool1d(attn_weights_sum, kernel_size = self.kernel_size, padding=self.kernel_size//2, stride=1)
            else:
                raise ValueError('Pooling method not supported')

            top_k = max_capacity_prompts[1] - self.window_size
            indices = attn_cache.topk(top_k, dim=-1).indices
            indices1 = indices.sort(dim=-1).values
            for i in range(len(self.recent_indices_generate)):
                self.recent_indices_generate[i] = self.recent_indices_generate[i].to(indices1.device)
            self.head_indices1 = self.head_indices1.to(indices1.device)
            recent_indices = self.recent_indices_generate[0]+q_len
            num_heads = num_attention_heads // 2
            indices_1 = torch.cat([indices1[:,self.head_indices1[-num_heads:],:],recent_indices],dim=-1)
            indices_expanded  = indices_1.unsqueeze(-1).expand(-1, -1, -1, head_dim)
            indices_attn = indices_1.unsqueeze(-2).expand(-1, -1, self.window_size, -1)
            revise_key_states=key_states[:,self.head_indices1[-num_heads:],:,:].gather(dim = 2, index = indices_expanded)
            revise_value_states=value_states[:,self.head_indices1[-num_heads:],:,:].gather(dim = 2, index = indices_expanded)
            self.attn_weights_1=attn_weights[:, self.head_indices1[-num_heads:], :,:].gather(dim = -1, index = indices_attn)

            max_capacity_prompt_2 = max_capacity_prompts[0]
            top_k2 = max_capacity_prompt_2 - self.window_size
            indices = attn_cache.topk(top_k2, dim=-1).indices
            indices_2 = indices.sort(dim=-1).values
            indices_2 = torch.cat([indices_2[:,self.head_indices1[:num_heads],:],recent_indices],dim=-1)
            indices_expanded_2  = indices_2.unsqueeze(-1).expand(-1, -1, -1, head_dim)
            indices_attn_2 = indices_2.unsqueeze(-2).expand(-1, -1, self.window_size, -1)
            revise_value_states_2 = value_states[:,self.head_indices1[:num_heads],:,:].gather(dim=2,index=indices_expanded_2)
            revise_key_states_2 = key_states[:,self.head_indices1[:num_heads],:,:].gather(dim=2, index=indices_expanded_2)
            self.attn_weights_2 = attn_weights[:, self.head_indices1[:num_heads], :,:].gather(dim = -1, index = indices_attn_2)
            key1 = revise_key_states
            key2 = revise_key_states_2
            value1 = revise_value_states
            value2 = revise_value_states_2
            revise_key_states = [key1, key2]
            revise_value_states = [value1, value2]
            self.head_pattern = [self.head_indices1[-num_heads:], self.head_indices1[:num_heads]]
            self.attn_weights = [self.attn_weights_1, self.attn_weights_2]
            self.cache_size = [max_capacity_prompts[1], max_capacity_prompts[0]]         
        # extreme_compressibility
        elif method_name in manager.extreme_compressibility_equal_code :
            max_capacity_prompt = self.max_capacity_prompt
            if q_len < max_capacity_prompt:
                max_capacity_prompt = q_len
            self.cache_size = max_capacity_prompt
            attn_weights = attn_weights_now
            attn_weights_sum = attn_weights[:, :, -self.window_size:, :-self.window_size].sum(dim = -2)
            if self.pooling == 'avgpool':
                attn_cache = F.avg_pool1d(attn_weights_sum, kernel_size = self.kernel_size, padding=self.kernel_size//2, stride=1)
            elif self.pooling == 'maxpool':
                attn_cache = F.max_pool1d(attn_weights_sum, kernel_size = self.kernel_size, padding=self.kernel_size//2, stride=1)
            else:
                raise ValueError('Pooling method not supported')
            top_k = max_capacity_prompt - self.window_size
            indices = attn_cache.topk(top_k, dim=-1).indices
            indices1 = indices.sort(dim=-1).values
            recent_indices = self.recent_indices_generate[0]+q_len
            indices_1 = torch.cat([indices1[:,self.head_indices1[-num_group_heads_num:],:],recent_indices],dim=-1)
            indices_expanded  = indices_1.unsqueeze(-1).expand(-1, -1, -1, head_dim)
            indices_attn = indices_1.unsqueeze(-2).expand(-1, -1, self.window_size, -1)
            revise_key_states=key_states[:,self.head_indices1[-num_group_heads_num:],:,:].gather(dim = 2, index = indices_expanded)
            revise_value_states=value_states[:,self.head_indices1[-num_group_heads_num:],:,:].gather(dim = 2, index = indices_expanded)
            self.attn_weights_1=attn_weights[:, self.head_indices1[-num_group_heads_num:], :,:].gather(dim = -1, index = indices_attn)
            
            if "128" in method_name:
                nums = 4
            elif "64" in method_name:
                nums = 8
            elif "32" in method_name:
                nums = 16
            elif "16" in method_name:
                nums = 32
            elif "12" in method_name:
                nums = 42
            elif "10" in method_name:
                nums = 51
            max_capacity_prompt_2 = max_capacity_prompt//nums
            if q_len < max_capacity_prompt_2:
                max_capacity_prompt = q_len
            if max_capacity_prompt_2 < self.window_size:
                max_capacity_prompt_2 = self.window_size
            top_k2 = max_capacity_prompt_2 - self.window_size
            indices_2 = indices[:,:,:top_k2].sort(dim=-1).values
            indices_2 = torch.cat([indices_2[:,self.head_indices1[:num_group_heads_num],:],recent_indices],dim=-1)
            indices_expanded_2  = indices_2.unsqueeze(-1).expand(-1, -1, -1, head_dim)
            indices_attn_2 = indices_2.unsqueeze(-2).expand(-1, -1, self.window_size, -1)
            revise_value_states_2 = value_states[:,self.head_indices1[:num_group_heads_num],:,:].gather(dim=2,index=indices_expanded_2)
            revise_key_states_2 = key_states[:,self.head_indices1[:num_group_heads_num],:,:].gather(dim=2, index=indices_expanded_2)
            self.attn_weights_2 = attn_weights[:, self.head_indices1[:num_group_heads_num], :,:].gather(dim = -1, index = indices_attn_2)
            key1 = revise_key_states
            key2 = revise_key_states_2
            value1 = revise_value_states
            value2 = revise_value_states_2
            revise_key_states = [key1, key2]
            revise_value_states = [value1, value2]
            self.head_pattern = [self.head_indices1[-num_group_heads_num:], self.head_indices1[:num_group_heads_num]]
            self.attn_weights = [self.attn_weights_1, self.attn_weights_2]
            self.cache_size = [max_capacity_prompt, max_capacity_prompt_2] 
        elif method_name in manager.delete_head_equal_code:
            max_capacity_prompt = self.max_capacity_prompt
            if q_len < max_capacity_prompt:
                max_capacity_prompt = q_len
            self.cache_size = max_capacity_prompt
            attn_weights = attn_weights_now
            attn_weights_sum = attn_weights[:, :, -self.window_size:, :-self.window_size].sum(dim = -2)
            if self.pooling == 'avgpool':
                attn_cache = F.avg_pool1d(attn_weights_sum, kernel_size = self.kernel_size, padding=self.kernel_size//2, stride=1)
            elif self.pooling == 'maxpool':
                attn_cache = F.max_pool1d(attn_weights_sum, kernel_size = self.kernel_size, padding=self.kernel_size//2, stride=1)
            else:
                raise ValueError('Pooling method not supported')
            top_k = max_capacity_prompt - self.window_size
            indices = attn_cache.topk(top_k, dim=-1).indices
            recent_indices = self.recent_indices_generate[0]+q_len
            select_topk = self.select_topk
            
            indices_1 = torch.cat([indices[:,self.head_indices1[-select_topk:],:],recent_indices],dim=-1)
            indices_expanded  = indices_1.unsqueeze(-1).expand(-1, -1, -1, head_dim)
            indices_attn = indices_1.unsqueeze(-2).expand(-1, -1, self.window_size, -1)
            revise_key_states=key_states[:,self.head_indices1[-select_topk:],:,:].gather(dim = 2, index = indices_expanded)
            revise_value_states=value_states[:,self.head_indices1[-select_topk:],:,:].gather(dim = 2, index = indices_expanded)
            self.attn_weights_1=attn_weights[:, self.head_indices1[-select_topk:], :,:].gather(dim = -1, index = indices_attn)
            self.head_pattern = self.head_indices1[-select_topk:]
            self.attn_weights = self.attn_weights_1
            self.cache_size = max_capacity_prompt
            attn_1 = key_states[:,self.head_indices1[:32-select_topk],:,:].squeeze(0)[:,-8:,:].view(32-select_topk,-1)
            attn_2 = key_states[:,self.head_indices1[-select_topk:],:,:].squeeze(0)[:,-8:,:].view(select_topk,-1)
            similarity_matrix = torch.nn.functional.cosine_similarity(attn_1.unsqueeze(1), attn_2.unsqueeze(0), dim=2)
            max_similarity_indices = torch.argmax(similarity_matrix, dim=1)
            self.similarity = max_similarity_indices
        # multi_group
        elif method_name in manager.ahead_500_equal_code_group3: 
            max_capacity_prompt = self.max_capacity_prompt
            if q_len < max_capacity_prompt:
                max_capacity_prompt = q_len
            attn_weights = attn_weights_now
            attn_weights_sum = attn_weights[:, :, -self.window_size:, :-self.window_size].sum(dim = -2)
            if self.pooling == 'avgpool':
                attn_cache = F.avg_pool1d(attn_weights_sum, kernel_size = self.kernel_size, padding=self.kernel_size//2, stride=1)
            elif self.pooling == 'maxpool':
                attn_cache = F.max_pool1d(attn_weights_sum, kernel_size = self.kernel_size, padding=self.kernel_size//2, stride=1)
            else:
                raise ValueError('Pooling method not supported')
            
            if self.max_capacity_prompt == 86:
                max_capacity_prompts = [32,64,96]
            
            if self.max_capacity_prompt == 512:
                max_capacity_prompts = [32,384,736]
                
            max_capacity_prompts = [min(prompt, q_len) for prompt in max_capacity_prompts]
            if self.layer_idx == 0:
                print("max_capacity_prompts",max_capacity_prompts)
            max_capacity_prompt = max_capacity_prompts[0]
            top_k = max_capacity_prompt - self.window_size
            indices = attn_cache.topk(top_k, dim=-1).indices
            indices1 = indices.sort(dim=-1).values
            recent_indices = self.recent_indices_generate[0]+q_len
            indices_1 = torch.cat([indices1[:,self.head_indices1[-11:],:],recent_indices],dim=-1)
            indices_expanded  = indices_1.unsqueeze(-1).expand(-1, -1, -1, head_dim)
            indices_attn = indices_1.unsqueeze(-2).expand(-1, -1, self.window_size, -1)
            revise_key_states=key_states[:,self.head_indices1[-11:],:,:].gather(dim = 2, index = indices_expanded)
            revise_value_states=value_states[:,self.head_indices1[-11:],:,:].gather(dim = 2, index = indices_expanded)
            attn_weights_1=attn_weights[:, self.head_indices1[-11:], :,:].gather(dim = -1, index = indices_attn)

            max_capacity_prompt = max_capacity_prompts[1]
            top_k2 = max_capacity_prompt - self.window_size
            indices_2 = indices[:,:,:top_k2].sort(dim=-1).values
            recent_indices_1 = self.recent_indices_generate[1]+q_len
            indices_2 = torch.cat([indices_2[:,self.head_indices1[11:-11],:],recent_indices_1],dim=-1)
            indices_expanded_2  = indices_2.unsqueeze(-1).expand(-1, -1, -1, head_dim)
            indices_attn_2 = indices_2.unsqueeze(-2).expand(-1, -1, self.window_size, -1)
            revise_value_states_2 = value_states[:,self.head_indices1[11:-11],:,:].gather(dim=2,index=indices_expanded_2)
            revise_key_states_2 = key_states[:,self.head_indices1[11:-11],:,:].gather(dim=2, index=indices_expanded_2)
            attn_weights_2 = attn_weights[:, self.head_indices1[11:-11], :,:].gather(dim = -1, index = indices_attn_2)
            
            max_capacity_prompt = max_capacity_prompts[2]
            top_k3 = max_capacity_prompt - self.window_size
            indices_3 = indices[:,:,:top_k3].sort(dim=-1).values
            indices_3 = torch.cat([indices_3[:,self.head_indices1[:11],:],recent_indices],dim=-1)
            indices_expanded_3  = indices_3.unsqueeze(-1).expand(-1, -1, -1, head_dim)
            indices_attn_3 = indices_3.unsqueeze(-2).expand(-1, -1, self.window_size, -1)
            revise_value_states_3 = value_states[:,self.head_indices1[:11],:,:].gather(dim=2,index=indices_expanded_3)
            revise_key_states_3 = key_states[:,self.head_indices1[:11],:,:].gather(dim=2, index=indices_expanded_3)
            attn_weights_3 = attn_weights[:, self.head_indices1[:11], :,:].gather(dim = -1, index = indices_attn_3)
            
            key1 = revise_key_states
            key2 = revise_key_states_2
            key3 = revise_key_states_3
            value1 = revise_value_states
            value2 = revise_value_states_2
            value3 = revise_value_states_3
            revise_key_states = [key1, key2, key3]
            revise_value_states = [value1, value2, value3]
            self.head_pattern = [self.head_indices1[-11:], self.head_indices1[11:-11], self.head_indices1[:11]]
            self.attn_weights = [attn_weights_1, attn_weights_2, attn_weights_3]
            self.cache_size = max_capacity_prompts   
        elif method_name in manager.ahead_500_equal_code_group4: 
            max_capacity_prompt = self.max_capacity_prompt
            if q_len < max_capacity_prompt:
                max_capacity_prompt = q_len
            attn_weights = attn_weights_now
            attn_weights_sum = attn_weights[:, :, -self.window_size:, :-self.window_size].sum(dim = -2)
            
            if self.pooling == 'avgpool':
                attn_cache = F.avg_pool1d(attn_weights_sum, kernel_size = self.kernel_size, padding=self.kernel_size//2, stride=1)
            elif self.pooling == 'maxpool':
                attn_cache = F.max_pool1d(attn_weights_sum, kernel_size = self.kernel_size, padding=self.kernel_size//2, stride=1)
            else:
                raise ValueError('Pooling method not supported')
            
            if self.max_capacity_prompt == 86:
                max_capacity_prompts = [96,75,53,32]
            elif self.max_capacity_prompt == 512:
                max_capacity_prompts = [736,502,266,32]
            else:
                max_capacity_prompts = []
                max_cap = max_capacity_prompt
                min_cap = max_capacity_prompt // 2
                allowance = (max_cap-min_cap) // 3
                for i in range(4):
                    max_capacity_prompts.append(max_cap - allowance*i)
            max_capacity_prompts = [min(prompt, q_len) for prompt in max_capacity_prompts]
            if self.layer_idx == 0:
                print("max_capacity_prompts",max_capacity_prompts)
            max_capacity_prompt = max_capacity_prompts[0]
            top_k = max_capacity_prompt - self.window_size
            indices = attn_cache.topk(top_k, dim=-1).indices
            indices1 = indices.sort(dim=-1).values
            recent_indices = self.recent_indices_generate[0]+q_len
            indices_1 = torch.cat([indices1[:,self.head_indices1[-8:],:],recent_indices],dim=-1)
            indices_expanded  = indices_1.unsqueeze(-1).expand(-1, -1, -1, head_dim)
            indices_attn = indices_1.unsqueeze(-2).expand(-1, -1, self.window_size, -1)
            revise_key_states=key_states[:,self.head_indices1[-8:],:,:].gather(dim = 2, index = indices_expanded)
            revise_value_states=value_states[:,self.head_indices1[-8:],:,:].gather(dim = 2, index = indices_expanded)
            attn_weights_1=attn_weights[:, self.head_indices1[-8:], :,:].gather(dim = -1, index = indices_attn)

            max_capacity_prompt = max_capacity_prompts[1]
            top_k2 = max_capacity_prompt - self.window_size
            indices_2 = indices[:,:,:top_k2].sort(dim=-1).values
            recent_indices = self.recent_indices_generate[1]+q_len
            recent_indices_1 = self.recent_indices_generate[1]+q_len
            indices_2 = torch.cat([indices_2[:,self.head_indices1[16:-8],:],recent_indices_1],dim=-1)
            indices_expanded_2  = indices_2.unsqueeze(-1).expand(-1, -1, -1, head_dim)
            indices_attn_2 = indices_2.unsqueeze(-2).expand(-1, -1, self.window_size, -1)
            revise_value_states_2 = value_states[:,self.head_indices1[16:-8],:,:].gather(dim=2,index=indices_expanded_2)
            revise_key_states_2 = key_states[:,self.head_indices1[16:-8],:,:].gather(dim=2, index=indices_expanded_2)
            attn_weights_2 = attn_weights[:, self.head_indices1[16:-8], :,:].gather(dim = -1, index = indices_attn_2)
            
            max_capacity_prompt = max_capacity_prompts[2]
            top_k3 = max_capacity_prompt - self.window_size
            recent_indices = self.recent_indices_generate[2]+q_len
            indices_3 = indices[:,:,:top_k3].sort(dim=-1).values
            indices_3 = torch.cat([indices_3[:,self.head_indices1[8:16],:],recent_indices],dim=-1)
            indices_expanded_3  = indices_3.unsqueeze(-1).expand(-1, -1, -1, head_dim)
            indices_attn_3 = indices_3.unsqueeze(-2).expand(-1, -1, self.window_size, -1)
            revise_value_states_3 = value_states[:,self.head_indices1[8:16],:,:].gather(dim=2,index=indices_expanded_3)
            revise_key_states_3 = key_states[:,self.head_indices1[8:16],:,:].gather(dim=2, index=indices_expanded_3)
            attn_weights_3 = attn_weights[:, self.head_indices1[8:16], :,:].gather(dim = -1, index = indices_attn_3)
            
            max_capacity_prompt = max_capacity_prompts[3]
            top_k4 = max_capacity_prompt - self.window_size
            indices_4 = indices[:,:,:top_k4].sort(dim=-1).values
            recent_indices = self.recent_indices_generate[3]+q_len
            indices_4 = torch.cat([indices_4[:,self.head_indices1[:8],:],recent_indices],dim=-1)
            indices_expanded_4  = indices_4.unsqueeze(-1).expand(-1, -1, -1, head_dim)
            indices_attn_4 = indices_4.unsqueeze(-2).expand(-1, -1, self.window_size, -1)
            revise_value_states_4 = value_states[:,self.head_indices1[:8],:,:].gather(dim=2,index=indices_expanded_4)
            revise_key_states_4 = key_states[:,self.head_indices1[:8],:,:].gather(dim=2, index=indices_expanded_4)
            attn_weights_4 = attn_weights[:, self.head_indices1[:8], :,:].gather(dim = -1, index = indices_attn_4)
            
            key1,key2,key3,key4 = revise_key_states,revise_key_states_2,revise_key_states_3,revise_key_states_4
            value1,value2,value3,value4 = revise_value_states,revise_value_states_2,revise_value_states_3,revise_value_states_4
            revise_key_states = [key1, key2, key3, key4]
            revise_value_states = [value1, value2, value3, value4]
            self.head_pattern = [self.head_indices1[-8:], self.head_indices1[16:-8], self.head_indices1[8:16], self.head_indices1[:8]]
            self.attn_weights = [attn_weights_1, attn_weights_2, attn_weights_3, attn_weights_4]
            self.cache_size = max_capacity_prompts  
        elif method_name in manager.ahead_500_equal_code_group5: 
            max_capacity_prompt = self.max_capacity_prompt
            if q_len < max_capacity_prompt:
                max_capacity_prompt = q_len
            attn_weights = attn_weights_now
            attn_weights_sum = attn_weights[:, :, -self.window_size:, :-self.window_size].sum(dim = -2)
            if self.pooling == 'avgpool':
                attn_cache = F.avg_pool1d(attn_weights_sum, kernel_size = self.kernel_size, padding=self.kernel_size//2, stride=1)
            elif self.pooling == 'maxpool':
                attn_cache = F.max_pool1d(attn_weights_sum, kernel_size = self.kernel_size, padding=self.kernel_size//2, stride=1)
            else:
                raise ValueError('Pooling method not supported')
            if self.max_capacity_prompt == 512:
                max_capacity_prompts = [736,560,384,208,32]
            if self.max_capacity_prompt == 86:
                max_capacity_prompts = [96,80,64,48,32]    
            max_capacity_prompts = [min(prompt, q_len) for prompt in max_capacity_prompts]
            if self.layer_idx == 0:
                print("max_capacity_prompts",max_capacity_prompts)  
            max_capacity_prompt = max_capacity_prompts[0]
            top_k = max_capacity_prompt - self.window_size
            indices = attn_cache.topk(top_k, dim=-1).indices
            indices1 = indices.sort(dim=-1).values
            recent_indices = self.recent_indices_generate[0]+q_len
            indices_1 = torch.cat([indices1[:,self.head_indices1[-7:],:],recent_indices],dim=-1)
            indices_expanded  = indices_1.unsqueeze(-1).expand(-1, -1, -1, head_dim)
            indices_attn = indices_1.unsqueeze(-2).expand(-1, -1, self.window_size, -1)
            revise_key_states=key_states[:,self.head_indices1[-7:],:,:].gather(dim = 2, index = indices_expanded)
            revise_value_states=value_states[:,self.head_indices1[-7:],:,:].gather(dim = 2, index = indices_expanded)
            attn_weights_1=attn_weights[:, self.head_indices1[-7:], :,:].gather(dim = -1, index = indices_attn)

            max_capacity_prompt = max_capacity_prompts[1]
            recent_indices = self.recent_indices_generate[1]+q_len
            top_k2 = max_capacity_prompt - self.window_size
            indices_2 = indices[:,:,:top_k2].sort(dim=-1).values
            recent_indices_1 = self.recent_indices_generate[1]+q_len
            indices_2 = torch.cat([indices_2[:,self.head_indices1[19:-7],:],recent_indices_1],dim=-1)
            indices_expanded_2  = indices_2.unsqueeze(-1).expand(-1, -1, -1, head_dim)
            indices_attn_2 = indices_2.unsqueeze(-2).expand(-1, -1, self.window_size, -1)
            revise_value_states_2 = value_states[:,self.head_indices1[19:-7],:,:].gather(dim=2,index=indices_expanded_2)
            revise_key_states_2 = key_states[:,self.head_indices1[19:-7],:,:].gather(dim=2, index=indices_expanded_2)
            attn_weights_2 = attn_weights[:, self.head_indices1[19:-7], :,:].gather(dim = -1, index = indices_attn_2)
            
            max_capacity_prompt = max_capacity_prompts[2]
            recent_indices = self.recent_indices_generate[2]+q_len
            top_k3 = max_capacity_prompt - self.window_size
            indices_3 = indices[:,:,:top_k3].sort(dim=-1).values
            indices_3 = torch.cat([indices_3[:,self.head_indices1[13:19],:],recent_indices],dim=-1)
            indices_expanded_3  = indices_3.unsqueeze(-1).expand(-1, -1, -1, head_dim)
            indices_attn_3 = indices_3.unsqueeze(-2).expand(-1, -1, self.window_size, -1)
            revise_value_states_3 = value_states[:,self.head_indices1[13:19],:,:].gather(dim=2,index=indices_expanded_3)
            revise_key_states_3 = key_states[:,self.head_indices1[13:19],:,:].gather(dim=2, index=indices_expanded_3)
            attn_weights_3 = attn_weights[:, self.head_indices1[13:19], :,:].gather(dim = -1, index = indices_attn_3)
            
            max_capacity_prompt = max_capacity_prompts[3]
            recent_indices = self.recent_indices_generate[3]+q_len
            top_k4 = max_capacity_prompt - self.window_size
            indices_4 = indices[:,:,:top_k4].sort(dim=-1).values
            indices_4 = torch.cat([indices_4[:,self.head_indices1[7:13],:],recent_indices],dim=-1)
            indices_expanded_4  = indices_4.unsqueeze(-1).expand(-1, -1, -1, head_dim)
            indices_attn_4 = indices_4.unsqueeze(-2).expand(-1, -1, self.window_size, -1)
            revise_value_states_4 = value_states[:,self.head_indices1[7:13],:,:].gather(dim=2,index=indices_expanded_4)
            revise_key_states_4 = key_states[:,self.head_indices1[7:13],:,:].gather(dim=2, index=indices_expanded_4)
            attn_weights_4 = attn_weights[:, self.head_indices1[7:13], :,:].gather(dim = -1, index = indices_attn_4)
            
            max_capacity_prompt = max_capacity_prompts[4]
            recent_indices = self.recent_indices_generate[4]+q_len
            top_k5 = max_capacity_prompt - self.window_size
            indices_5 = indices[:,:,:top_k5].sort(dim=-1).values
            indices_5 = torch.cat([indices_5[:,self.head_indices1[:7],:],recent_indices],dim=-1)
            indices_expanded_5  = indices_5.unsqueeze(-1).expand(-1, -1, -1, head_dim)
            indices_attn_5 = indices_5.unsqueeze(-2).expand(-1, -1, self.window_size, -1)
            revise_value_states_5 = value_states[:,self.head_indices1[:7],:,:].gather(dim=2,index=indices_expanded_5)
            revise_key_states_5 = key_states[:,self.head_indices1[:7],:,:].gather(dim=2, index=indices_expanded_5)
            attn_weights_5 = attn_weights[:, self.head_indices1[:7], :,:].gather(dim = -1, index = indices_attn_5)
            
            key1,key2,key3,key4,key5 = revise_key_states,revise_key_states_2,revise_key_states_3,revise_key_states_4,revise_key_states_5
            value1,value2,value3,value4,value5 = revise_value_states,revise_value_states_2,revise_value_states_3,revise_value_states_4,revise_value_states_5
            revise_key_states = [key1, key2, key3, key4, key5]
            revise_value_states = [value1, value2, value3, value4, value5]
            self.head_pattern = [self.head_indices1[-7:], self.head_indices1[19:-7], self.head_indices1[13:19], self.head_indices1[7:13], self.head_indices1[:7]]
            self.attn_weights = [attn_weights_1, attn_weights_2, attn_weights_3, attn_weights_4, attn_weights_5]
            self.cache_size = max_capacity_prompts         
        elif method_name in manager.ahead_500_equal_code_group8: 
            max_capacity_prompt = self.max_capacity_prompt
            if q_len < max_capacity_prompt:
                max_capacity_prompt = q_len
            attn_weights = attn_weights_now
            attn_weights_sum = attn_weights[:, :, -self.window_size:, :-self.window_size].sum(dim = -2)
            if self.pooling == 'avgpool':
                attn_cache = F.avg_pool1d(attn_weights_sum, kernel_size = self.kernel_size, padding=self.kernel_size//2, stride=1)
            elif self.pooling == 'maxpool':
                attn_cache = F.max_pool1d(attn_weights_sum, kernel_size = self.kernel_size, padding=self.kernel_size//2, stride=1)
            else:
                raise ValueError('Pooling method not supported')
            
            
            if self.max_capacity_prompt and "new" in method_name == 512:
                max_capacity_prompts = [32,132,232,332,436,536,636,736]
            elif self.max_capacity_prompt == 86:
                max_capacity_prompts = [32,41,50,59,69,78,87,96]
            else:
                max_capacity_prompts = []
                max_cap = max_capacity_prompt
                min_cap = max_capacity_prompt // 2
                allowance = (max_cap-min_cap) // 7
                for i in range(7,-1,-1):
                    max_capacity_prompts.append(max_cap - allowance*i)
            
            max_capacity_prompts = [min(prompt, q_len) for prompt in max_capacity_prompts]
            
            revise_key_states = []
            revise_value_states = []
            self.head_pattern = []
            self.attn_weights = []
            if self.layer_idx == 0:
                print("max_capacity_prompts",max_capacity_prompts)
            new_max_capacity_prompts = []
            for i in range(7,-1,-1):
                max_capacity_prompt = max_capacity_prompts[i]
                new_max_capacity_prompts.append(max_capacity_prompt)
                top_k = max_capacity_prompt - self.window_size
                indices = attn_cache.topk(top_k, dim=-1).indices
                indices1 = indices.sort(dim=-1).values
                recent_indices = self.recent_indices_generate[i]+q_len
                indices_1 = torch.cat([indices1[:,self.head_indices1[4*i:4*(i+1)],:],recent_indices],dim=-1)
                indices_expanded  = indices_1.unsqueeze(-1).expand(-1, -1, -1, head_dim)
                indices_attn = indices_1.unsqueeze(-2).expand(-1, -1, self.window_size, -1)
                revise_key_states_single=key_states[:,self.head_indices1[4*i:4*(i+1)],:,:].gather(dim = 2, index = indices_expanded)
                revise_value_states_single=value_states[:,self.head_indices1[4*i:4*(i+1)],:,:].gather(dim = 2, index = indices_expanded)
                attn_weights_1=attn_weights[:, self.head_indices1[4*i:4*(i+1)], :,:].gather(dim = -1, index = indices_attn)

                revise_key_states.append(revise_key_states_single)
                revise_value_states.append(revise_value_states_single)
                self.attn_weights.append(attn_weights_1)
                self.head_pattern.append(self.head_indices1[4*i:4*(i+1)])
            self.cache_size = new_max_capacity_prompts  
        # other methods
        elif method_name in manager.chai:
            revise_key_states = key_states
            revise_value_states = value_states
        elif "pyramidkv_generate" in method_name:
            self.beta = 20
            my_max = int(self.max_capacity_prompt * 1.5)
            min_num = self.max_capacity_prompt // self.beta
            max_num = my_max - min_num
            if max_num >= q_len:
                max_num = q_len
                min_num = my_max - max_num
            steps = (max_num - min_num) // self.num_hidden_layers 
            max_capacity_prompt = min_num + (num_hidden_layers-self.layer_idx) * steps
            if q_len < max_capacity_prompt:
                max_capacity_prompt = q_len
            if max_capacity_prompt < self.window_size:
                max_capacity_prompt = self.window_size
            self.cache_size = max_capacity_prompt
            attn_weights = attn_weights_now
            attn_weights_sum = attn_weights[:, :, -self.window_size:, :-self.window_size].sum(dim = -2)
            if self.pooling == 'avgpool':
                attn_cache = F.avg_pool1d(attn_weights_sum, kernel_size = self.kernel_size, padding=self.kernel_size//2, stride=1)
            elif self.pooling == 'maxpool':
                attn_cache = F.max_pool1d(attn_weights_sum, kernel_size = self.kernel_size, padding=self.kernel_size//2, stride=1)
            else:
                raise ValueError('Pooling method not supported')
            top_k = max_capacity_prompt - self.window_size
            indices = attn_cache.topk(top_k, dim=-1).indices
            indices = indices.sort(dim=-1).values
            indices_expanded  = indices.unsqueeze(-1).expand(-1, -1, -1, head_dim)
            indices_attn = indices.unsqueeze(-2).expand(-1, -1, self.window_size, -1)
            attn_weights_compress = attn_weights[:, :, -self.window_size:, :-self.window_size].gather(dim = -1, index = indices_attn)
            k_past_compress = key_states[:, :, :-self.window_size, :].gather(dim = 2, index = indices_expanded)
            v_past_compress = value_states[:, :, :-self.window_size, :].gather(dim = 2, index = indices_expanded)
            k_cur = key_states[:, :, -self.window_size:, :]
            v_cur = value_states[:, :, -self.window_size:, :]
            revise_key_states = torch.cat([k_past_compress, k_cur], dim = 2)
            revise_value_states = torch.cat([v_past_compress, v_cur], dim = 2)
            attn_weights_cur= attn_weights[:,:,-self.window_size:,-self.window_size:]
            self.attn_weights = torch.cat([attn_weights_compress,attn_weights_cur],dim=-1)
            if revise_key_states.shape != (bsz, num_heads, max_capacity_prompt, head_dim):
                print("revise_key_states.shape",revise_key_states.shape)
                print("max_capacity_prompt",max_capacity_prompt)
                raise ValueError("revise_key_states.shape != (bsz, num_heads, max_capacity_prompt, head_dim)")
        elif "snapkv" in method_name:
            max_capacity_prompt = self.max_capacity_prompt
            if q_len < max_capacity_prompt:
                max_capacity_prompt = q_len
            self.cache_size = max_capacity_prompt
            attn_weights = attn_weights_now_all
            attn_weights_sum = attn_weights[:, :, :, :-self.window_size].sum(dim = -2)
            if self.pooling == 'avgpool':
                attn_cache = F.avg_pool1d(attn_weights_sum, kernel_size = self.kernel_size, padding=self.kernel_size//2, stride=1)
            elif self.pooling == 'maxpool':
                attn_cache = F.max_pool1d(attn_weights_sum, kernel_size = self.kernel_size, padding=self.kernel_size//2, stride=1)
            else:
                raise ValueError('Pooling method not supported')
            top_k = max_capacity_prompt - self.window_size
            indices = attn_cache.topk(top_k, dim=-1).indices
            indices = indices.sort(dim=-1).values
            indices_expanded  = indices.unsqueeze(-1).expand(-1, -1, -1, head_dim)
            indices_attn = indices.unsqueeze(-2).expand(-1, -1, self.window_size, -1)
            attn_weights_compress = attn_weights[:, :, -self.window_size:, :-self.window_size].gather(dim = -1, index = indices_attn)
            k_past_compress = key_states[:, :, :-self.window_size, :].gather(dim = 2, index = indices_expanded)
            v_past_compress = value_states[:, :, :-self.window_size, :].gather(dim = 2, index = indices_expanded)
            k_cur = key_states[:, :, -self.window_size:, :]
            v_cur = value_states[:, :, -self.window_size:, :]
            revise_key_states = torch.cat([k_past_compress, k_cur], dim = 2)
            revise_value_states = torch.cat([v_past_compress, v_cur], dim = 2)
            attn_weights_cur= attn_weights[:,:,-self.window_size:,-self.window_size:]
            self.attn_weights = torch.cat([attn_weights_compress,attn_weights_cur],dim=-1)
            if revise_key_states.shape != (bsz, num_heads, max_capacity_prompt, head_dim):
                print("revise_key_states.shape",revise_key_states.shape)
                print("max_capacity_prompt",max_capacity_prompt)
                raise ValueError("revise_key_states.shape != (bsz, num_heads, max_capacity_prompt, head_dim)")
        elif "H2O" in method_name:
            attn_weights = attn_weights_now_all
            recent_size = self.recent_size = self.max_capacity_prompt*508//512
            hh_size = self.hh_size = self.max_capacity_prompt*4//512
            if hh_size==0:
                hh_size = self.hh_size = 1
                recent_size = self.recent_size = self.max_capacity_prompt - hh_size
            if q_len <= recent_size+hh_size:
                recent_size = self.recent_size = q_len - hh_size
            self.cache_size = recent_size + hh_size
            self.attn_weights_sum = attn_weights.sum(0).sum(1)
            select_hh_scores = self.attn_weights_sum[:, :q_len - recent_size]
            _, keep_topk = torch.topk(select_hh_scores, hh_size, dim=-1)
            keep_topk = keep_topk.sort().values 
            keep_recent = torch.arange(q_len - recent_size, q_len, device=keep_topk.device).repeat(keep_topk.shape[0], 1)
            keep_idx = torch.cat([keep_topk, keep_recent], dim=-1)
            mask = torch.zeros(self.attn_weights_sum.shape, dtype=torch.bool).to(attn_weights.device)
            mask = mask.scatter(-1, keep_idx, 1)
            revise_key_states = key_states.squeeze()[mask].view(bsz, num_heads, -1, head_dim)
            revise_value_states = value_states.squeeze()[mask].view(bsz, num_heads, -1, head_dim)
            self.attn_weights_sum = self.attn_weights_sum[mask].view(num_heads, self.cache_size)
        return revise_key_states, revise_value_states

    def update_kv_generate(
        self,
        past_key_value: Tuple[torch.Tensor, torch.Tensor],
        new_attn_weights: torch.Tensor,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        layer_idx: int,
        cache_kwargs: Optional[Dict[str, Any]] = None,
    ):
        manager = self.manager
        if manager.method_name in manager.group_sampling:
            bsz,num_heads,q_len,head_dim = key_states.shape
            attn_weights = self.attn_weights
            now_attn_weights = torch.zeros((attn_weights.shape[0],attn_weights.shape[1],attn_weights.shape[2]+1,attn_weights.shape[3]+1),device=attn_weights.device)
            now_attn_weights[:,:,:-1,:-1] = attn_weights
            now_attn_weights[:,:,:,-1] = 0
            now_attn_weights[:,:,-1:,:] = new_attn_weights
            attn_weights_sum = now_attn_weights[...,-self.window_size:,:-self.window_size].sum(-2)
            cache_size = self.cache_size
            if self.pooling == 'avgpool':
                attn_cache = F.avg_pool1d(attn_weights_sum, kernel_size = self.kernel_size, padding=self.kernel_size//2, stride=1)
            elif self.pooling == 'maxpool':
                attn_cache = F.max_pool1d(attn_weights_sum, kernel_size = self.kernel_size, padding=self.kernel_size//2, stride=1)
            else:
                raise ValueError('Pooling method not supported')
            indices  = attn_cache.topk(cache_size-self.window_size, dim=-1).indices
            indices_expanded  = indices.unsqueeze(-1).expand(-1, -1, -1, head_dim)
            indices_attn = indices.unsqueeze(-2).expand(-1, -1, self.window_size, -1)
            attn_weights_compress = now_attn_weights[:, :, -self.window_size:, :-self.window_size].gather(dim = -1, index = indices_attn)
            k_past_compress = key_states[:, :, :-self.window_size, :].gather(dim = 2, index = indices_expanded)
            v_past_compress = value_states[:, :, :-self.window_size, :].gather(dim = 2, index = indices_expanded)
            k_cur = key_states[:, :, -self.window_size:, :]
            v_cur = value_states[:, :, -self.window_size:, :]
            attn_weights_cur = now_attn_weights[:, :, -self.window_size:, -self.window_size:]
            revise_key_states = torch.cat([k_past_compress, k_cur], dim = 2)
            revise_value_states = torch.cat([v_past_compress, v_cur], dim = 2)
            self.attn_weights = torch.cat([attn_weights_compress,attn_weights_cur],dim=-1)
            past_key_value.key_cache[layer_idx] = revise_key_states
            past_key_value.value_cache[layer_idx] = revise_value_states
            
            if revise_key_states.shape != (bsz, num_heads, cache_size, head_dim):
                    print("revise_key_states.shape",revise_key_states.shape)
                    print("max_capacity_prompt",cache_size)
                    raise ValueError("revise_key_states.shape != (bsz, num_heads, max_capacity_prompt, head_dim)")
            return None
        else:
            recent_size,hh_size = self.recent_size,self.hh_size

            bsz,num_heads,q_len,head_dim = key_states.shape
            new_attn_weights = new_attn_weights.sum(0).sum(1)  
            new_attn_weights[:,:-1] += self.attn_weights_sum
            attn_weights_sum = self.attn_weights_sum = new_attn_weights
            
            select_hh_scores = self.attn_weights_sum[:, :q_len - recent_size]
            _, keep_topk = torch.topk(select_hh_scores, hh_size, dim=-1)
            
            keep_topk = keep_topk.sort().values 
            keep_recent = torch.arange(q_len - recent_size, q_len, device=keep_topk.device).repeat(keep_topk.shape[0], 1)
            keep_idx = torch.cat([keep_topk, keep_recent], dim=-1)
            mask = torch.zeros(attn_weights_sum.shape, dtype=torch.bool).to(attn_weights_sum.device)
            mask = mask.scatter_(-1, keep_idx, 1)
            past_key_value.key_cache[layer_idx] = key_states.squeeze()[mask].view(bsz, num_heads, -1, head_dim)
            past_key_value.value_cache[layer_idx] = value_states.squeeze()[mask].view(bsz, num_heads, -1, head_dim)
            self.attn_weights_sum= self.attn_weights_sum[mask].view(num_heads, self.cache_size)
        
    def update_head_kv_generate(
        self,
        past_key_value: Tuple[torch.Tensor, torch.Tensor],
        new_attn_weights_all: torch.Tensor,
        key_states_all: torch.Tensor,
        value_states_all: torch.Tensor,
        layer_idx: int,
        padding: list,
    ):
        method_name = self.manager.method_name
        len1,len2,len3 = len(key_states_all),len(value_states_all),len(new_attn_weights_all)
        if len1 != len2 or len2 != len3  : 
            print(f"len1:{len1} len2:{len2} len3:{len3}")
            raise ValueError("The length of key_states_all, value_states_all, new_attn_weights_all should be the same.")
        if method_name in self.manager.delet_head_set:
            attn_weights = self.attn_weights
            new_attn_weights,key_states,value_states = new_attn_weights_all,key_states_all,value_states_all
            bsz,num_heads,q_len,head_dim = key_states.size()
            now_attn_weights = torch.nn.functional.pad(attn_weights, (0, 1, 0, 1), mode='constant', value=0).to(torch.float32)
            now_attn_weights[:,:,-1:,:] = new_attn_weights[...,:now_attn_weights.shape[-1]] 
            attn_weights_sum = now_attn_weights[...,-self.window_size:,:-self.window_size].sum(-2)
            
            cache_size = self.cache_size
            if self.pooling == 'avgpool':
                attn_cache = F.avg_pool1d(attn_weights_sum, kernel_size = self.kernel_size, padding=self.kernel_size//2, stride=1)
            elif self.pooling == 'maxpool':
                attn_cache = F.max_pool1d(attn_weights_sum, kernel_size = self.kernel_size, padding=self.kernel_size//2, stride=1)
            else:
                raise ValueError('Pooling method not supported')
            indices  = attn_cache.topk(cache_size-self.window_size, dim=-1).indices
            bsz,num_heads,q_len,_= key_states.shape
            recent_indices = self.recent_indices_generate[0]+q_len
            indices = torch.cat([indices,recent_indices],dim=-1)
            indices_expanded  = indices.unsqueeze(-1).expand(-1, -1, -1, head_dim)
            indices_attn = indices.unsqueeze(-2).expand(-1, -1, self.window_size, -1)
            attn_weights_compress = now_attn_weights[...,-self.window_size:,:].gather(dim = -1, index = indices_attn)
            revise_key_states = key_states.gather(dim = 2, index = indices_expanded)
            revise_value_states = value_states.gather(dim = 2, index = indices_expanded)

            self.attn_weights = attn_weights_compress
            past_key_value.key_cache[layer_idx] = revise_key_states
            past_key_value.value_cache[layer_idx] = revise_value_states
            
            if revise_key_states.shape != (bsz, num_heads, cache_size, head_dim):
                print("revise_key_states.shape",revise_key_states.shape)
                print("max_capacity_prompt",cache_size)
                raise ValueError("revise_key_states.shape != (bsz, num_heads, max_capacity_prompt, head_dim)")
        elif method_name in self.manager.head_set:
            for i,(key_states,value_states,new_attn_weights) in enumerate(zip(key_states_all,value_states_all,new_attn_weights_all)): 
                bsz,num_heads,q_len,head_dim = key_states.size()
                attn_weights = self.attn_weights[i]
                now_attn_weights = torch.nn.functional.pad(attn_weights, (0, 1, 0, 1), mode='constant', value=0).to(torch.float32)
                now_attn_weights[:,:,-1:,:] = new_attn_weights[...,:now_attn_weights.shape[-1]] 
                attn_weights_sum = now_attn_weights[...,-self.window_size:,:-self.window_size].sum(-2)
                if attn_weights_sum.shape[1] == 0:
                    self.attn_weights[i] = self.attn_weights[i] 
                    past_key_value.key_cache[layer_idx][i] = past_key_value.key_cache[layer_idx][i][:,:,:-1,:]
                    past_key_value.value_cache[layer_idx][i] = past_key_value.value_cache[layer_idx][i][:,:,:-1,:]
                    continue
                cache_size = self.cache_size[i]
                if self.pooling == 'avgpool':
                    attn_cache = F.avg_pool1d(attn_weights_sum, kernel_size = self.kernel_size, padding=self.kernel_size//2, stride=1)
                elif self.pooling == 'maxpool':
                    attn_cache = F.max_pool1d(attn_weights_sum, kernel_size = self.kernel_size, padding=self.kernel_size//2, stride=1)
                else:
                    raise ValueError('Pooling method not supported')
                indices  = attn_cache.topk(cache_size-self.window_size, dim=-1).indices
                bsz,num_heads,q_len,_= key_states.shape
                recent_indices = self.recent_indices_generate[i]+q_len
                indices = torch.cat([indices,recent_indices],dim=-1)
                indices_expanded  = indices.unsqueeze(-1).expand(-1, -1, -1, head_dim)
                indices_attn = indices.unsqueeze(-2).expand(-1, -1, self.window_size, -1)
                attn_weights_compress = now_attn_weights[...,-self.window_size:,:].gather(dim = -1, index = indices_attn)
                revise_key_states = key_states.gather(dim = 2, index = indices_expanded)
                revise_value_states = value_states.gather(dim = 2, index = indices_expanded)
                self.attn_weights[i] = attn_weights_compress
                past_key_value.key_cache[layer_idx][i] = revise_key_states
                past_key_value.value_cache[layer_idx][i] = revise_value_states
                
                if revise_key_states.shape != (bsz, num_heads, cache_size, head_dim):
                        print("revise_key_states.shape",revise_key_states.shape)
                        print("max_capacity_prompt",cache_size)
                        raise ValueError("revise_key_states.shape != (bsz, num_heads, max_capacity_prompt, head_dim)")
        
        return None
    
    def update_past_key_value(
        self,
        past_key_value,
        key_states: tuple,
        value_states: tuple,
        layer_idx: int,
        mode: int,
    ):
        if len(past_key_value.key_cache) <= layer_idx:
            past_key_value.key_cache.append(key_states)
            past_key_value.value_cache.append(value_states)
        else:
            if mode == 1:
                past_key_value.key_cache[layer_idx] = torch.cat([past_key_value.key_cache[layer_idx], key_states[:,self.head_pattern,:,:]], dim=-2)
                past_key_value.value_cache[layer_idx] = torch.cat([past_key_value.value_cache[layer_idx], value_states[:,self.head_pattern,:,:]], dim=-2)
            else:
                groups_num = len(past_key_value.key_cache[layer_idx])
                for i in range(groups_num):
                    past_key_value.key_cache[layer_idx][i] = torch.cat([past_key_value.key_cache[layer_idx][i], key_states[:,self.head_pattern[i],:,:]], dim=-2)
                    past_key_value.value_cache[layer_idx][i] = torch.cat([past_key_value.value_cache[layer_idx][i], value_states[:,self.head_pattern[i],:,:]], dim=-2)
        return past_key_value.key_cache[layer_idx], past_key_value.value_cache[layer_idx]

def init_uncomp(self, num_hidden_layers,manager):
    if not hasattr(self, "kv_cluster"):
        if not hasattr(self.config, 'window_size'):
            self.config.window_size = 32
        if not hasattr(self.config, 'max_capacity_prompt'):
            self.config.max_capacity_prompt = 512
        if not hasattr(self.config, 'kernel_size'):
            self.config.kernel_size = 5
        if not hasattr(self.config, 'pooling'):
            self.config.pooling = 'avgpool'
    
    
    self.kv_cluster = UncompCluster( 
        num_hidden_layers = num_hidden_layers,
        layer_idx = self.layer_idx,
        window_size = self.config.window_size, 
        max_capacity_prompt = self.config.max_capacity_prompt, 
        kernel_size = self.config.kernel_size,
        pooling = self.config.pooling,
        manager = manager,
        )
