import torch
import torch.nn.functional as F
import torch.nn as nn
import math


# perform qk calculation and get indices
# this version will not update in inference mode
def ema(input_tensor, dim, alpha=None):
    if alpha is None:
        alpha = 2. / (input_tensor.size(dim) + 1)

    cumulative_tensor = input_tensor.select(dim, 0).unsqueeze(dim).clone()
    for i in range(1, input_tensor.size(dim)):
        current_slice = input_tensor.select(dim, i).unsqueeze(dim)
        cumulative_tensor = alpha * current_slice + (1 - alpha) * cumulative_tensor

    return cumulative_tensor.squeeze(dim)


# Copied from transformers.models.llama.modeling_llama.repeat_kv
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)

def merge_kv(key_states, value_states, indices, window_size, merge):
    # merge methods in LOOK-M

    bsz, num_heads, k_len, head_dim = key_states.shape

    # kv-selected
    selected_keys = key_states.gather(dim=2, index=indices)  # [bsz, num_heads, topk_len, head_dim]
    selected_values = value_states.gather(dim=2, index=indices)  # [bsz, num_heads, topk_len, head_dim]

    # kv-drop
    all_indices = torch.arange(k_len, device=key_states.device).unsqueeze(0).unsqueeze(0).expand(bsz, num_heads, k_len)
    all_indices_flattened = all_indices.flatten()  # [bsz * num_heads * (k_len-window_size)]
    selected_indices_flattened = indices.flatten()  # [bsz * num_heads * topk_len]
    is_selected = torch.isin(all_indices_flattened, selected_indices_flattened)
    drop_indices_flattened = all_indices_flattened[~is_selected]
    drop_len = drop_indices_flattened.shape[0] // (all_indices.shape[0] * all_indices.shape[1])
    drop_indices = drop_indices_flattened.reshape(all_indices.shape[0], all_indices.shape[1], drop_len) # [bsz * num_heads * (k_len-window_size-topk_len)]
    drop_indices = drop_indices.unsqueeze(-1).expand(-1, -1, -1, head_dim)  # [bsz, num_heads, (k_len-window_size-topk_len), head_dim]
    drop_keys = key_states.gather(dim=2, index=drop_indices)
    drop_values = value_states.gather(dim=2, index=drop_indices)

    # kv-recent
    recent_keys = key_states[:, :, -window_size:, :]

    ##### apply merge #####
    # prepare for merge
    k_hh_pruned = drop_keys  # [bsz, num_heads, k_len-topk_len-window_size, head_dim]
    k_hh_recent = torch.cat([recent_keys, selected_keys], dim=2)  # [bsz, num_heads, topk_len+window_size, head_dim]
    v_hh_pruned = drop_values  # [bsz, num_heads, k_len-topk_len-window_size, head_dim]
    v_hh_recent = torch.cat([selected_values, value_states[:, :, -window_size:, :]], dim=2)  # [bsz, num_heads, topk_len+window_size, head_dim]
    # similarity matrix
    similarity = (k_hh_pruned / torch.norm(k_hh_pruned, dim=-1).unsqueeze(-1).repeat(1, 1, 1, 128)) @ ((k_hh_recent / (torch.norm(k_hh_recent, dim=-1).unsqueeze(-1).repeat(1, 1, 1, 128))).transpose(-1, -2)) # cosin
    max_values, max_indices = similarity.max(dim=-1)

    # pivot merge
    if merge=="pivot":
        print("Pivot merge")
        merged_indices = max_indices.unsqueeze(-1).repeat(1, 1, 1, 128)
        k_hh_selected = torch.gather(input=k_hh_recent, dim=2, index=merged_indices)
        k_hh_merged = (k_hh_pruned + k_hh_selected)/2
        k_hh_recent = torch.scatter_reduce(input=k_hh_recent, dim=2, index=merged_indices, src=k_hh_merged, reduce='mean', include_self=True) # include_self=True seems decrease the performance
        v_hh_selected = torch.gather(input=v_hh_recent, dim=2, index=merged_indices)
        v_hh_merged = (v_hh_pruned + v_hh_selected)/2
        v_hh_recent = torch.scatter_reduce(input=v_hh_recent, dim=2, index=merged_indices, src=v_hh_merged, reduce='mean', include_self=True)
    else:
        raise ValueError('Merge method not supported')

    # TODO: other merge strategies
    # average merge
    # weight merge

    return k_hh_recent, v_hh_recent


class RestKVCluster():
    def __init__(self, num_hidden_layers = 32, window_size=64, max_capacity_prompt=256 + 64, kernel_size=5, pooling='avgpool', beta = 20, num_layers = 80, layer_idx=None, merge=None, use_wo=True, use_norm=False, use_ema=False, alpha=0.5, metric_mode="after", tau=1.0, scale=2000, use_pyramid=False):
        # SnapKV
        self.window_size = window_size
        self.max_capacity_prompt = max_capacity_prompt
        assert self.max_capacity_prompt - self.window_size > 0
        
        # PyramidKV
        self.layer_idx = layer_idx
        self.num_hidden_layers = num_hidden_layers
        self.steps = -1
        self.beta = beta
        
        # restkv
        self.kernel_size = kernel_size
        self.pooling = pooling
        self.merge = merge
        self.use_wo = use_wo
        self.use_norm = use_norm
        self.use_ema = use_ema
        self.use_pyramid = use_pyramid
        self.alpha = alpha
        self.tau = tau
        self.metric_mode = metric_mode
        self.scale = scale

    def reset(self, window_size = 64, max_capacity_prompt = 256 + 64, kernel_size = 5, pooling = 'avgpool', merge = None):
        self.window_size = window_size
        self.max_capacity_prompt = max_capacity_prompt
        assert self.max_capacity_prompt - self.window_size > 0
        self.kernel_size = kernel_size
        self.pooling = pooling
        self.merge = merge

    def update_kv(self, key_states, query_states, value_states, attention_mask, num_key_value_groups, Wo):
        # check if prefix phase
        assert key_states.shape[-2] == query_states.shape[-2]
        bsz, num_heads, q_len, head_dim = query_states.shape

        
        # PyramidKV
        if self.use_pyramid:
            min_num = (self.max_capacity_prompt - self.window_size) // self.beta
            max_num = (self.max_capacity_prompt - self.window_size) * 2 - min_num


            if max_num >= q_len - self.window_size:
                max_num = q_len - self.window_size
                min_num = (self.max_capacity_prompt - self.window_size) * 2 - max_num


            steps = (max_num - min_num) // (self.num_hidden_layers - 1)
            max_capacity_prompt = max_num - self.layer_idx * steps

        if q_len < self.max_capacity_prompt:
            return key_states, value_states
        else:
            attn_weights = torch.matmul(query_states[..., -self.window_size:, :], key_states.transpose(2, 3)) / math.sqrt(head_dim) # [bsz, num_heads, window_size, q_len]
            mask = torch.full((self.window_size, self.window_size), torch.finfo(attn_weights.dtype).min, device=attn_weights.device)    # [window_size, window_size]
            mask_cond = torch.arange(mask.size(-1), device=attn_weights.device) 
            mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)    # [window_size, window_size]
            mask = mask.to(attn_weights.device)
            attention_mask = mask[None, None, :, :] # [1, 1, window_size, window_size]

            attn_weights[:, :, -self.window_size:, -self.window_size:] += attention_mask    # mask the diagonal part
            attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)  # [bsz, num_heads, window_size, q_len]
            
            # First, calculate the metrics of each query, then calculate the average.
            if self.metric_mode == "before":
                attn_outputs = torch.matmul(attn_weights, value_states) # [bsz, num_heads, window_size, head_dim]
                important_score = attn_weights / (1 - attn_weights) # [bsz, num_heads, window_size, q_len]
                if self.use_wo:
                    Wo = Wo.reshape(-1, num_heads, head_dim)
                    projection_outputs = torch.einsum("bhwd,chd->bhwc", attn_outputs, Wo) # [bsz, num_heads, window_size, C_out]
                    projection_values = torch.einsum("bhld,chd->bhlc", value_states, Wo) # [bsz, num_heads, q_len, C_out]
                    # Customize the size of each block, which can be adjusted according to available video memory.
                    chunk_size = 1024 * 4
                    norm_result = []
                    for i in range(0, q_len, chunk_size):
                        q_len_chunk = slice(i, min(i + chunk_size, q_len))
                        norm_chunk = torch.norm(projection_outputs.unsqueeze(3) - projection_values[:, :, q_len_chunk].unsqueeze(2), p=2, dim=-1)  # Calculate the L2 norm of each block.
                        norm_result.append(norm_chunk)

                    important_score *= torch.cat(norm_result, dim=-1)
                    if self.use_norm:
                        important_score /= torch.norm(projection_outputs, p=2, dim=-1).unsqueeze(-1) # [bsz, num_heads, window_size, q_len]
                else:
                    important_score *= torch.norm((attn_outputs.unsqueeze(3) - value_states.unsqueeze(2)), p=2, dim=-1) # [bsz, num_heads, window_size, q_len]
                    if self.use_norm:
                        important_score /= torch.norm(attn_outputs, p=2, dim=-1).unsqueeze(-1) # [bsz, num_heads, window_size, q_len]
                if self.use_ema:
                    important_score = ema(important_score[:, :, -self.window_size:, : -self.window_size], dim=-2, alpha=self.alpha) # [bsz, num_heads, q_len - window_size]
                else:
                    important_score = important_score[:, :, -self.window_size:, : -self.window_size].mean(dim = -2) # [bsz, num_heads, q_len - window_size]
            else: # First, calculate the average attention score, and then compute the metric in one go.
                if self.use_ema:
                    attn_weights_mean = ema(attn_weights[:, :, -self.window_size:, : -self.window_size], dim=-2, alpha=self.alpha)  # [bsz, num_heads, q_len - window_size]
                else:
                    attn_weights_mean = attn_weights[:, :, -self.window_size:, : -self.window_size].mean(dim=-2)    # [bsz, num_heads, q_len - window_size]
                attn_outputs_mean = torch.einsum("bhl,bhld->bhd", attn_weights_mean, value_states[:, :, :-self.window_size]) # [bsz, num_heads, head_dim]
                important_score = attn_weights_mean / (1 - attn_weights_mean)   # [bsz, num_heads, q_len - window_size]
                if self.use_wo:
                    Wo = Wo.reshape(-1, num_heads, head_dim)
                    projection_outputs = torch.einsum("bhd,chd->bhc", attn_outputs_mean, Wo) # [bsz, num_heads, C_out]
                    projection_values = torch.einsum("bhld,chd->bhlc", value_states, Wo) # [bsz, num_heads, q_len, C_out]
                    important_score *= torch.pow(torch.norm((projection_outputs.unsqueeze(2) - projection_values[:, :, :-self.window_size]), p=2, dim=-1), self.tau)   # [bsz, num_heads, q_len - window_size]
                else: # Only consider value states
                    important_score *= torch.norm((attn_outputs_mean.unsqueeze(2) - value_states[:, :, :-self.window_size]), p=2, dim=-1)   # [bsz, num_heads, q_len - window_size]
   
            if self.pooling == 'adaptive':
                t_shift_score = (attn_weights[0,:,:self.window_size//2].sort()[1][:,:,self.window_size-self.max_capacity_prompt:] - attn_weights[0,:,self.window_size//2:].sort()[1][:,:,self.window_size-self.max_capacity_prompt:]).float().mean(dim=(1,2))
                kernel_sizes = 2 * (t_shift_score.abs() // self.scale) + 1 
                shifts = torch.where(t_shift_score > 0, t_shift_score // self.scale, t_shift_score // self.scale + 1)
                tmp = []
                for i in range(num_heads):
                    kernel_size = max(self.kernel_size, int(kernel_sizes[i]))
                    shift = int(shifts[i])
                    padding_left = kernel_size//2 - shift
                    padding_right = kernel_size//2 + shift
                    padded = F.pad(important_score[:,i], (padding_left, padding_right), mode='replicate')
                    tmp.append(F.avg_pool1d(padded, kernel_size=kernel_size, stride=1))
                attn_cache = torch.stack(tmp, dim=1)
            elif self.pooling == 'avgpool':
                attn_cache = F.avg_pool1d(important_score, kernel_size=self.kernel_size, padding=self.kernel_size//2, stride=1)  # [bsz, num_heads, q_len - window_size]
            elif self.pooling == 'maxpool':
                attn_cache = F.max_pool1d(important_score, kernel_size=self.kernel_size, padding=self.kernel_size//2, stride=1)  # [bsz, num_heads, q_len - window_size]
            else:
                attn_cache = important_score
            
            if self.use_pyramid and q_len >= self.max_capacity_prompt:
                indices = attn_cache.topk(max_capacity_prompt, dim=-1).indices
            else:
                indices = attn_cache.topk(self.max_capacity_prompt - self.window_size, dim=-1).indices  # [bsz, num_heads, max_capacity_prompt - window_size]
            indices = indices.unsqueeze(-1).expand(-1, -1, -1, head_dim)    # [bsz, num_heads, max_capacity_prompt - window_size, head_dim]

            if self.merge is not None:
                key_states, value_states = merge_kv(key_states, value_states, indices, self.window_size, self.merge)
                return key_states, value_states

            k_past_compress = key_states[:, :, :-self.window_size, :].gather(dim = 2, index = indices)  # [bsz, num_heads, max_capacity_prompt - window_size, head_dim]
            v_past_compress = value_states[:, :, :-self.window_size, :].gather(dim = 2, index = indices)    # [bsz, num_heads, max_capacity_prompt - window_size, head_dim]
            k_cur = key_states[:, :, -self.window_size:, :] # [bsz, num_heads, window_size, head_dim]
            v_cur = value_states[:, :, -self.window_size:, :]   # [bsz, num_heads, window_size, head_dim]
            key_states = torch.cat([k_past_compress, k_cur], dim = 2)   # [bsz, num_heads, max_capacity_prompt, head_dim]
            value_states = torch.cat([v_past_compress, v_cur], dim = 2) # [bsz, num_heads, max_capacity_prompt, head_dim]
            return key_states, value_states

    
def init_restkv(self, num_hidden_layers):
    if not hasattr(self, "kv_cluster"):
        if not hasattr(self.config, 'window_size'):
            self.config.window_size = 32
        if not hasattr(self.config, 'max_capacity_prompt'):
            self.config.max_capacity_prompt = 4096
        if not hasattr(self.config, 'merge'):
            self.config.merge = None


    self.kv_cluster = RestKVCluster(
        num_hidden_layers = num_hidden_layers,
        layer_idx = self.layer_idx,
        window_size = self.config.window_size,
        max_capacity_prompt = self.config.max_capacity_prompt,
        kernel_size = self.config.kernel_size,
        pooling = self.config.pooling,
        merge = self.config.merge,
        use_wo = self.config.use_wo,
        use_ema = self.config.use_ema,
        use_pyramid = self.config.use_pyramid,
        alpha = self.config.alpha,
        metric_mode = "after",
        use_norm = False,
        scale = 2000,
        tau = 1.0,
    )

