#### from https://github.com/FMInference/H2O/blob/main/h2o_hf/utils_hh/modify_llama.py
# from https://github1s.com/FMInference/H2O/blob/main/h2o_hf/utils_real_drop/modify_llama.py
from typing import Optional, Tuple, Union

import copy
import torch
from torch import nn
import torch.nn.functional as F

from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, LlamaLinearScalingRotaryEmbedding, apply_rotary_pos_emb, repeat_kv

from algorithms.utils import apply_rotary_pos_emb_single
# from algorithms.kcentergreedy import kcenter_greedy, get_tensors_from_ids
from algorithms.kmeans import MultiKMeans, _kpp
from algorithms.kcenter import kcenter_greedy


class KVCacheHyper:
    def __init__(self, num_clusters, recent_size, cache_size, layer_idx=-1, method='kcenter'):
        # self.hh_size = hh_size
        self.recent_size = recent_size
        # self.cache_size = hh_size + recent_size
        # self.cache_size = recent_size
        self.layer_idx = layer_idx
        # self.thres = 0.3
        self.num_clusters = num_clusters
        self.cache_size = cache_size
        # assert self.num_clusters + self.recent_size == self.cache_size
        self.method = method
        self.cluster_sizes = None
        self.arranged_mask = None
        self.is_preprocessed = False

    def __call__(self, past_key_values):
        if past_key_values is None:
            return None
        seq_len = past_key_values[0].size(2)
        if seq_len <= self.cache_size:
            return past_key_values

        if self.is_preprocessed:
            return past_key_values
        # if self.arranged_mask is None:
        #     print(f"[layer_id : {self.layer_idx}]=====sdfdsfasdfsadfdsf=====")
        
        bsz, num_heads, _, head_dim = past_key_values[0].shape

        key_cached, value_cached = past_key_values[0], past_key_values[1]

        key_recent = key_cached[:,:,-self.recent_size:,:]
        value_recent = value_cached[:,:,-self.recent_size:,:]

        if not self.is_preprocessed:

            if self.method == 'kcenter':
                key_centers, center_ids = kcenter_greedy(key_cached[:,:,:-self.recent_size,:], self.num_clusters, return_radius=False)
                value_centers = torch.gather(value_cached[:,:,:-self.recent_size], 2, center_ids.unsqueeze(-1).expand(bsz, num_heads, self.num_clusters, head_dim))

            elif self.method == 'weighted_kcenter':
                key_centers, center_ids, cluster_sizes, radius = kcenter_greedy(key_cached[:,:,:-self.recent_size,:], self.num_clusters, return_radius=True)
                value_centers = torch.gather(value_cached[:,:,:-self.recent_size], 2, center_ids.unsqueeze(-1).expand(bsz, num_heads, self.num_clusters, head_dim))
                self.cluster_sizes = cluster_sizes
                self.radius = radius

            elif self.method == 'kmeans':

                kmeans = MultiKMeans(n_clusters=self.num_clusters, mode='euclidean', init_method='kmeans++')
                labels = kmeans.fit_predict(key_cached[0,:,:-self.recent_size,:].float())
                # num_points_in_clusters = torch.stack([labels[i].unique(return_counts=True)[1] for i in range(labels.shape[0])])
                key_centers = kmeans.centroids.unsqueeze(0).to(dtype=key_cached.dtype)

                expanded_closest = labels[:, None].expand(-1, self.num_clusters, -1)
                self.arranged_mask = torch.arange(self.num_clusters, device=value_cached.device)[None, :, None]
                mask = (expanded_closest==self.arranged_mask).to(value_cached.dtype)
                mask = torch.nn.functional.normalize(mask, p=1, dim=-1)
                torch.nan_to_num_(mask)
                value_centers = mask @ value_cached[:,:,:-self.recent_size,:]
            else:
                raise NotImplementedError

            self.is_preprocessed = True
        else:
           
            key_centers, center_ids = kcenter_greedy(key_cached[:,:,:-self.recent_size,:], self.num_clusters)
            value_centers = torch.gather(value_cached[:,:,:-self.recent_size], 2, center_ids.unsqueeze(-1).expand(bsz, num_heads, self.num_clusters, head_dim))

        return (torch.cat((key_centers, key_recent), axis=2), torch.cat((value_centers, value_recent), axis=2))

    def _clean(self):
        self.is_preprocessed = False
        self.arranged_mask = None

    def __repr__(self):
        return f"KVCacheHyper(\n"+\
            f"  layer_id: {self.layer_idx}\n"+\
            f"  is_preprocessed: {self.is_preprocessed}\n"+\
            f"  method: {self.method}\n"+\
            f"  num_clusters: {self.num_clusters}\n"+\
            f"  recent_size: {self.recent_size}\n"+\
            f")"


class LlamaAttentionHyper(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(self, config: LlamaConfig, layer_idx: int):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.hidden_size // self.num_heads
        self.num_key_value_heads = config.num_key_value_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
        self.max_position_embeddings = config.max_position_embeddings
        self.rope_theta = config.rope_theta
        self.is_causal = True
        self.layer_idx = layer_idx

        if (self.head_dim * self.num_heads) != self.hidden_size:
            raise ValueError(
                f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
                f" and `num_heads`: {self.num_heads})."
            )
        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
        self._init_rope()

        self.kv_cache = KVCacheHyper(
            num_clusters=config.hh_size,
            recent_size=config.recent_size,
            cache_size=config.cache_size,
            layer_idx=layer_idx,
            method=config.method
        )

    def _init_rope(self):
        if self.config.rope_scaling is None:
            self.rotary_emb = LlamaRotaryEmbedding(
                self.head_dim,
                max_position_embeddings=self.max_position_embeddings,
                base=self.rope_theta,
            )
        else:
            scaling_type = self.config.rope_scaling["type"]
            scaling_factor = self.config.rope_scaling["factor"]
            if scaling_type == "linear":
                self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
                    self.head_dim,
                    max_position_embeddings=self.max_position_embeddings,
                    scaling_factor=scaling_factor,
                    base=self.rope_theta,
                )
            elif scaling_type == "dynamic":
                self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
                    self.head_dim,
                    max_position_embeddings=self.max_position_embeddings,
                    scaling_factor=scaling_factor,
                    base=self.rope_theta,
                )
            else:
                raise ValueError(f"Unknown RoPE scaling type {scaling_type}")


    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

    def _clean_cache(self):
        self.kv_cache._clean()

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        **kwargs,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        if "padding_mask" in kwargs:
            warnings.warn(
                "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
            )

        bsz, q_len, _ = hidden_states.size()

        if self.config.pretraining_tp > 1:
            key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
            query_slices = self.q_proj.weight.split(
                (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
            )
            key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
            value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)

            query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
            query_states = torch.cat(query_states, dim=-1)

            key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
            key_states = torch.cat(key_states, dim=-1)

            value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
            value_states = torch.cat(value_states, dim=-1)

        else:
            query_states = self.q_proj(hidden_states)
            key_states = self.k_proj(hidden_states)
            value_states = self.v_proj(hidden_states)

        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        kv_seq_len = key_states.shape[-2]
        if past_key_value is not None:
            kv_seq_len += past_key_value[0].shape[-2]

        if not position_ids.nelement() > 1:
            position_ids[0][0] = kv_seq_len - 1

        try:
            cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
        except:
            import pdb;pdb.set_trace();
        ### Shift Pos: query pos is min(cache_size, idx)
        query_states = apply_rotary_pos_emb_single(query_states, cos, sin, position_ids)
        ###

        if past_key_value is not None:
            # reuse k, v, self_attention
            key_states = torch.cat([past_key_value[0], key_states], dim=2)
            value_states = torch.cat([past_key_value[1], value_states], dim=2)

        past_key_value = (key_states, value_states) if use_cache else None

        ### Shift Pos: key pos is the pos in cache (Rolling KV Cache and using relative pos emb)
        key_position_ids = torch.arange(kv_seq_len, device=position_ids.device).unsqueeze(0)
        key_states = apply_rotary_pos_emb_single(key_states, cos, sin, key_position_ids)
        ###

        # repeat k/v heads if n_kv_heads < n_heads
        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / self.head_dim**0.5

        if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
            raise ValueError(
                f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
                f" {attn_weights.size()}"
            )

        if attention_mask is not None:
            if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
                raise ValueError(
                    f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
                )
            attn_weights = attn_weights + attention_mask

        if self.config.method == 'weighted_kcenter' and self.kv_cache.cluster_sizes is not None:
            attn_weights[:,:,:,:self.kv_cache.num_clusters] += self.kv_cache.cluster_sizes.log().to(query_states.dtype).unsqueeze(2)
            
        # upcast attention to fp32
        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
        past_key_value = self.kv_cache(past_key_value)
        attn_output = torch.matmul(attn_weights, value_states)

        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
            raise ValueError(
                f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
                f" {attn_output.size()}"
            )

        attn_output = attn_output.transpose(1, 2).contiguous()

        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

        if self.config.pretraining_tp > 1:
            attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
            o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
            attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
        else:
            attn_output = self.o_proj(attn_output)

        if not output_attentions:
            attn_weights = None

        return attn_output, attn_weights, past_key_value