from asym_kv.streaming_llm.kv_cache import StartRecentKVCache
import torch
from asym_kv.streaming_llm.modify_llama import enable_llama_pos_shift_attention
from  asym_kv.util.pred_utils import greedy_generate, post_process
from tqdm import tqdm
import json
import torch.distributed as dist
import torch.nn.functional as F
from typing import Optional, Tuple
from transformers.models.llama.modeling_llama import (
    LlamaAttention,
    repeat_kv,
    rotate_half
)
import types
import math
import torch.nn as nn
from  asym_kv.method.compress_forward.compress_mistral import enable_mistral_compress_attention
from  asym_kv.method.compress_forward.compress_llama import enable_llama_pos_shift_compress_attention_442
from  asym_kv.method.compress_forward.compress_qwen2 import enable_qwen2_compress_attention
from  asym_kv.method.compress_forward.compress_llama_433 import enable_llama_pos_shift_compress_attention_433
from  asym_kv.method.compress_forward.compress_gemma import enable_gemma_compress_attention
def enable_compress_attention(model_name,model):
    if "llama" in model_name.lower():
        enable_llama_pos_shift_compress_attention_442(model)
    elif "mistral" in model_name.lower():
        enable_mistral_compress_attention(model)
    elif "qwen" in model_name.lower():
        enable_qwen2_compress_attention(model)
    elif "gemma" in model_name.lower():
        enable_gemma_compress_attention(model)
    else:
        raise ValueError(f"Unsupported model: {model_name}")

def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :].expand(batch, num_key_value_heads, n_rep, slen)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen)


def compress_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This function compresses the hidden states from (batch, num_attention_heads, seqlen, head_dim)
    to (batch, num_key_value_heads, seqlen, head_dim) by taking the mean of every n_rep heads.
    """
    batch, num_attention_heads, slen,head_dim= hidden_states.shape
    num_key_value_heads = num_attention_heads // n_rep
    
    # Reshape the hidden states to (batch, num_key_value_heads, n_rep, slen, head_dim)
    hidden_states = hidden_states.view(batch, num_key_value_heads, n_rep, slen,head_dim)
    
    # Take the mean along the n_rep dimension
    compressed_hidden_states = hidden_states.mean(dim=2)
    
    return compressed_hidden_states

class KVCompressCache(StartRecentKVCache):
    def formalize_past_key_values(self, past_key_values):
        ret = []
        for i, (k,v) in enumerate(past_key_values):
            l = torch.ones(past_key_values[i][0].size()[:-1], device=past_key_values[i][0].device, 
                           dtype=past_key_values[i][0].dtype)
            ret.append((k,v,l))
        return tuple(ret)
    
    def __call__(self, past_key_values, attns, num_key_value_groups, hessian_diagnoal = None, return_Cache=False):
        batch_size=512
        if past_key_values is None:
            return None
        if len(past_key_values[0]) == 2:
            past_key_values = self.formalize_past_key_values(past_key_values)

        seq_len = past_key_values[0][0].size(self.k_seq_dim)
        if seq_len <= self.cache_size:
            return past_key_values,hessian_diagnoal
        
        new_mid = []
        new_hessian_diagnoal_mid=[]
        for i, (k,v,l) in enumerate(past_key_values):
            mid_k=self.k_slice(k,self.start_size,seq_len)
            mid_v=self.v_slice(v,self.start_size,seq_len)
            mid_l= l[:,:,self.start_size:seq_len]#self.l_slice(l,self.start_size,seq_len)
            
            attns_i = attns[i][:,:,:,self.start_size:]
            attns_i = attns_i.sum(dim=-2)
            repeat_mid_l=repeat_kv(mid_l,int(attns_i.shape[1]/mid_l.shape[1]))
            attns_i=attns_i*repeat_mid_l[:,:,:attns_i.shape[2]]
            attns_i = attns_i[:,:,:-1] + attns_i[:,:,1:]

            weight_i = attns_i.sum(dim=1).squeeze(0)
            
            l_i=mid_l[:,:,:].sum(dim=-2)
            l_i=l_i[:,:-1] + l_i[:,1:]
            l_i=l_i[0]
            gamma=4096/math.log(512)
            weight_idx=torch.arange(1, len(weight_i)+1).float()
            sqrt_indices=torch.exp(weight_idx/gamma).to(weight_i.device)
            weight_i=weight_i/sqrt_indices*l_i
           
            mink_indices = weight_i.topk(seq_len - self.cache_size, dim=-1, largest=False).indices
           
            if hessian_diagnoal is None:
                ke = (mid_k[:, :, :-1, :] + mid_k[:, :, 1:, :]) / 2
                ve = (mid_v[:, :, :-1, :] + mid_v[:, :, 1:, :]) / 2
                le = mid_l[:, :, :-1] + mid_l[:, :, 1:]
            else:
                hessian_diagnoal_mid_k = self.k_slice(hessian_diagnoal[i][0].pow(2), self.start_size, seq_len)
                hessian_diagnoal_mid_v = self.v_slice(hessian_diagnoal[i][1].pow(2), self.start_size, seq_len)
                epsilon = 1e-21

                k1 = mid_k[:, :, :-1, :]
                k2 = mid_k[:, :, 1:, :]
                hk1 = hessian_diagnoal_mid_k[:, :, :-1, :]
                hk2 = hessian_diagnoal_mid_k[:, :, 1:, :]
                ke = 1/(hk1+hk2+epsilon) * (k1*hk1 + k2*hk2)
                # v1 = mid_v[:, :, :-1, :]
                # v2 = mid_v[:, :, 1:, :]
                # hv1 = hessian_diagnoal_mid_v[:, :, :-1, :]
                # hv2 = hessian_diagnoal_mid_v[:, :, 1:, :]
                # ve = 1/(hv1+hv2+epsilon) * (v1*hv1 + v2*hv2)
                ve = (mid_v[:, :, :-1, :] + mid_v[:, :, 1:, :])
                # ve = (mid_v[:, :, :-1, :]*mid_l[:, :, :-1].unsqueeze(-1) + mid_v[:, :, 1:, :]*mid_l[:, :, 1:].unsqueeze(-1))/(mid_l[:, :, :-1] + mid_l[:, :, 1:]).unsqueeze(-1)
                le = mid_l[:, :, :-1] + mid_l[:, :, 1:]
                
                hessian_diagnoal_mid_k = self.k_slice(hessian_diagnoal[i][0], self.start_size, seq_len)
                hessian_diagnoal_mid_v = self.v_slice(hessian_diagnoal[i][1], self.start_size, seq_len)
                hke=(hessian_diagnoal_mid_k[:, :, :-1, :] + hessian_diagnoal_mid_k[:, :, 1:, :])
                hve=(hessian_diagnoal_mid_v[:, :, :-1, :] + hessian_diagnoal_mid_v[:, :, 1:, :])
            
            mask = torch.ones(mid_k.shape[2], dtype=torch.bool, device=mid_k.device)
            mask[mink_indices + 1] = False
            
            mid_k[:, :, mink_indices, :] = ke[:, :, mink_indices, :]
            mid_v[:, :, mink_indices, :] = ve[:, :, mink_indices, :]

            mid_l[:, :, mink_indices] = le[:, :, mink_indices]

            new_mid_k = mid_k[:, :, mask, :]
            new_mid_v = mid_v[:, :, mask, :]
            new_mid_l = mid_l[:, :, mask]
            
            new_mid_l = torch.clip(new_mid_l, max=5)
            if hessian_diagnoal is not None:
                hessian_diagnoal_mid_k[:, :, mink_indices, :]= hke[:, :, mink_indices, :]
                hessian_diagnoal_mid_v[:, :, mink_indices, :]= hve[:, :, mink_indices, :]
                new_hessian_diagnoal_mid_k=hessian_diagnoal_mid_k[:, :, mask, :]
                new_hessian_diagnoal_mid_v=hessian_diagnoal_mid_v[:, :, mask, :]
                new_hessian_diagnoal_mid.append((new_hessian_diagnoal_mid_k,new_hessian_diagnoal_mid_v))

            new_mid.append((new_mid_k, new_mid_v, new_mid_l))

        return [
            [
                torch.cat([self.k_slice(k, 0, self.start_size),new_k],dim=self.k_seq_dim,),
                torch.cat([self.v_slice(v, 0, self.start_size),new_v],dim=self.v_seq_dim,),
                torch.cat([l[:,:,:self.start_size],new_l],dim=2),
            ]
            for (k,v,l), (new_k, new_v, new_l) in zip(past_key_values, new_mid)
        ],[
            [
                torch.cat([hessian_diagnoal[i][0][:,:,:self.start_size],hessian_diagnoal_mid_k],dim=2),
                torch.cat([hessian_diagnoal[i][1][:,:,:self.start_size],hessian_diagnoal_mid_v],dim=2),
            ]
            for i, (hessian_diagnoal_mid_k, hessian_diagnoal_mid_v) in enumerate(new_hessian_diagnoal_mid)
        ]

    def min_indices(self, past_key_values, attns, num_key_value_groups, hessian_diagnoal = None, return_Cache=False):
        batch_size=512
        if past_key_values is None:
            return None
        if len(past_key_values[0]) == 2:
            past_key_values = self.formalize_past_key_values(past_key_values)

        seq_len = past_key_values[0][0].size(self.k_seq_dim)
        if seq_len <= self.cache_size:
            return [[] for _ in past_key_values]
        
        indices = []
        for i, (k,v,l) in enumerate(past_key_values):
            mid_k=self.k_slice(k,self.start_size,seq_len)
            mid_v=self.v_slice(v,self.start_size,seq_len)
            mid_l= l[:,:,self.start_size:seq_len]#self.l_slice(l,self.start_size,seq_len)
            attns_i = attns[i][:,:,:,self.start_size:]
            attns_i = attns_i.sum(dim=-2)
            repeat_mid_l=repeat_kv(mid_l,int(attns_i.shape[1]/mid_l.shape[1]))
            attns_i=attns_i*repeat_mid_l
            attns_i = attns_i[:,:,:-1] + attns_i[:,:,1:]
            weight_i = attns_i.sum(dim=1).squeeze(0)
            gamma=4096/math.log(512)

            weight_idx=torch.arange(1, len(weight_i)+1).float()
            sqrt_indices=torch.exp(weight_idx/gamma).to(weight_i.device)
            weight_i=weight_i
            mink_indices = weight_i.topk(seq_len - self.cache_size, dim=-1, largest=False).indices

            indices.append(mink_indices)

        return indices
    
    def fs(self, past_key_values, attns, q, num_key_value_groups,attn,position):
        if past_key_values is None:
            return None
        if len(past_key_values[0]) == 2:
            past_key_values = self.formalize_past_key_values(past_key_values)

        seq_len = past_key_values[0][0].size(self.k_seq_dim)
        if seq_len <= self.cache_size:
            return past_key_values
        new_mid = []
        query_position_ids = torch.arange(position-q[0].shape[-2], position, device=q[0].device, dtype=torch.long).unsqueeze(0).expand(q[0].shape[0], -1)
        for i, (k,v,l) in enumerate(past_key_values):
            mid_k=self.k_slice(k,self.start_size,seq_len)
            mid_v=self.v_slice(v,self.start_size,seq_len)
            mid_l= l[:,:,self.start_size:seq_len]#self.l_slice(l,self.start_size,seq_len)
            
            #rope
            
            # qi=compress_kv(q[i], int(q[i].shape[1]/mid_l.shape[1]))
            attns_i = attns[i][:,:,:,self.start_size:]
            attns_i = attns_i.sum(dim=-2)
            repeat_mid_l=repeat_kv(mid_l,int(attns_i.shape[1]/mid_l.shape[1]))
            attns_i=attns_i*repeat_mid_l[:,:,:attns_i.shape[2]]
            # attns_i = compress_kv(attns_i, int(attns_i.shape[1]/mid_l.shape[1]))
            attns_i = attns_i[:,:,:-1] + attns_i[:,:,1:]

            weight_i = attns_i.sum(dim=1).squeeze(0)
            l_i=mid_l.sum(dim=-2)
            l_i=l_i[:,:-1] + l_i[:,1:]
            l_i=l_i[0]
            # overcompress_i=torch.where(l_i > 32)[0]
            gamma=4096/math.log(512)

            weight_idx=torch.arange(1, len(weight_i)+1).float()
            sqrt_indices=torch.exp(weight_idx/gamma).to(weight_i.device)
            weight_i=weight_i/sqrt_indices*l_i
            mink_indices = weight_i.topk(seq_len - self.cache_size, dim=-1, largest=False).indices
            qi = qi.float()
            # RoPe？
            mid_k = mid_k.float()
            
            k1 = torch.exp(torch.matmul(qi,mid_k[:, :, :-1, :].transpose(-2, -1))) # 1,8,512,n 
            k2 = torch.exp(torch.matmul(qi,mid_k[:, :, 1: , :].transpose(-2, -1)))
            yi=torch.log((k1+k2))-torch.log(torch.tensor(2)) # 1,8,512,n
            qp=torch.linalg.inv(qi.transpose(-2,-1)@qi)@qi.transpose(-2,-1)
            ke=torch.matmul(qp,yi).transpose(-2, -1)
            ke = ke.to(torch.bfloat16)
            mid_k=mid_k.to(torch.bfloat16)
            
            # ke
            # ke = (mid_k[:, :, :-1, :] + mid_k[:, :, 1:, :]) / 2
            ve = (mid_v[:, :, :-1, :] + mid_v[:, :, 1:, :])
            le = mid_l[:, :, :-1] + mid_l[:, :, 1:]

            mask = torch.ones(mid_k.shape[2], dtype=torch.bool, device=mid_k.device)
            mask[mink_indices + 1] = False
            
            mid_k[:, :, mink_indices, :] = ke[:, :, mink_indices, :]
            mid_v[:, :, mink_indices, :] = ve[:, :, mink_indices, :]
            mid_l[:, :, mink_indices] = le[:, :, mink_indices]
            
            new_mid_k = mid_k[:, :, mask, :]
            new_mid_v = mid_v[:, :, mask, :]
            new_mid_l = mid_l[:, :, mask]

            new_mid_l = torch.clip(new_mid_l, max=5)
            new_mid.append((new_mid_k, new_mid_v, new_mid_l))
        return [
            [
                torch.cat([self.k_slice(k, 0, self.start_size),new_k],dim=self.k_seq_dim,),
                torch.cat([self.v_slice(v, 0, self.start_size),new_v],dim=self.v_seq_dim,),
                torch.cat([l[:,:,:self.start_size],new_l],dim=2),
            ]
            for (k,v,l), (new_k, new_v, new_l) in zip(past_key_values, new_mid)
        ]
    
    def gaussian_kernel(x1, x2, sigma):
        
        distance_sq = torch.sum((x1 - x2) ** 2, dim=-1) 
        return torch.exp(-distance_sq / (2 * sigma ** 2))

    def key_adjacent(self, past_key_values, attns):
        batch_size=512
        if past_key_values is None:
            return None
        if len(past_key_values[0]) == 2:
            past_key_values = self.formalize_past_key_values(past_key_values)

        seq_len = past_key_values[0][0].size(self.k_seq_dim)
        if seq_len <= self.cache_size:
            return past_key_values
        
        new_mid = []
        new_hessian_diagnoal_mid=[]
        for i, (k,v,l) in enumerate(past_key_values):
            mid_k=self.k_slice(k,self.start_size,seq_len)
            mid_v=self.v_slice(v,self.start_size,seq_len)
            mid_l= l[:,:,self.start_size:seq_len]#self.l_slice(l,self.start_size,seq_len)
            
            attns_i = attns[i][:,:,:,self.start_size:]
            attns_i = attns_i.sum(dim=-2)
            repeat_mid_l=repeat_kv(mid_l,int(attns_i.shape[1]/mid_l.shape[1]))
            attns_i=attns_i*repeat_mid_l[:,:,:attns_i.shape[2]]
            attns_i = attns_i[:,:,:-1] + attns_i[:,:,1:]
            weight_i = attns_i.sum(dim=1).squeeze(0)
            attns_i=compress_kv(attns[i][:,:,:,self.start_size:],int(attns_i.shape[1]/mid_l.shape[1]))
            attns_i=attns_i.sum(dim=-2)
            l_i=mid_l[:,:,:].sum(dim=-2)
            l_i=l_i[:,:-1] + l_i[:,1:]
            l_i=l_i[0]
            gamma=4096/math.log(512)
            weight_idx=torch.arange(1, len(weight_i)+1).float()
            sqrt_indices=torch.exp(weight_idx/gamma).to(weight_i.device)
            weight_i=weight_i/sqrt_indices*l_i

            mink_indices = weight_i.topk(seq_len - self.cache_size, dim=-1, largest=False).indices
            
            ke = (mid_k[:, :, :-1, :]*attns_i[:, :, :-1].unsqueeze(-1) + mid_k[:, :, 1:, :]*attns_i[:, :, 1:].unsqueeze(-1))/(attns_i[:, :, :-1] + attns_i[:, :, 1:]).unsqueeze(-1)
            ve = (mid_v[:, :, :-1, :]*attns_i[:, :, :-1].unsqueeze(-1) + mid_v[:, :, 1:, :]*attns_i[:, :, 1:].unsqueeze(-1))/(attns_i[:, :, :-1] + attns_i[:, :, 1:]).unsqueeze(-1)
            le = mid_l[:, :, :-1] + mid_l[:, :, 1:]
                
            mask = torch.ones(mid_k.shape[2], dtype=torch.bool, device=mid_k.device)
            mask[mink_indices + 1] = False
            
            mid_k[:, :, mink_indices, :] = ke[:, :, mink_indices, :]
            mid_v[:, :, mink_indices, :] = ve[:, :, mink_indices, :]
            mid_l[:, :, mink_indices] = le[:, :, mink_indices]

            new_mid_k = mid_k[:, :, mask, :]
            new_mid_v = mid_v[:, :, mask, :]
            new_mid_l = mid_l[:, :, mask]
            
            new_mid_l = torch.clip(new_mid_l, max=5)
            new_mid.append((new_mid_k, new_mid_v, new_mid_l))

        return [
            [
                torch.cat([self.k_slice(k, 0, self.start_size),new_k],dim=self.k_seq_dim,),
                torch.cat([self.v_slice(v, 0, self.start_size),new_v],dim=self.v_seq_dim,),
                torch.cat([l[:,:,:self.start_size],new_l],dim=2),
            ]
            for (k,v,l), (new_k, new_v, new_l) in zip(past_key_values, new_mid)
        ]
