#### from https://github.com/FMInference/H2O/blob/main/h2o_hf/utils_hh/modify_llama.py
# from https://github1s.com/FMInference/H2O/blob/main/h2o_hf/utils_real_drop/modify_llama.py
from typing import Optional, Tuple, Union

import copy
import torch
from torch import nn
import torch.nn.functional as F

from transformers.models.llama.configuration_llama import LlamaConfig
# from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, LlamaLinearScalingRotaryEmbedding, apply_rotary_pos_emb, repeat_kv
from transformers.models.llama.modeling_llama import (
    LlamaAttention,
    rotate_half,
    apply_rotary_pos_emb,
    LlamaRotaryEmbedding,
    apply_rotary_pos_emb,
    LlamaForCausalLM,
    LlamaLinearScalingRotaryEmbedding
)
from flash_attn.flash_attn_interface import flash_attn_func

# from algorithms.utils import apply_rotary_pos_emb_single
# from algorithms.kcentergreedy import kcenter_greedy, get_tensors_from_ids
# from algorithms.kmeans import MultiKMeans, _kpp
from algorithms.kcenter import kcenter_greedy


class KVCacheHyper:
    def __init__(self, num_clusters, recent_size, cache_size, layer_idx=-1, method='kcenter', dist='euclidean', seed_select='greedy'):
        # self.hh_size = hh_size
        self.recent_size = recent_size
        # self.cache_size = hh_size + recent_size
        # self.cache_size = recent_size
        self.layer_idx = layer_idx
        # self.thres = 0.3
        self.num_clusters = num_clusters
        self.cache_size = cache_size
        # assert self.num_clusters + self.recent_size == self.cache_size
        self.method = method
        self.cluster_sizes = None
        self.arranged_mask = None
        self.is_preprocessed = False
        self.dist = dist
        self.seed_select = seed_select

    def __call__(self, past_key_values):
        if past_key_values is None:
            return None
        seq_len = past_key_values[0].size(2)
        if seq_len <= self.cache_size:
            return past_key_values

        if self.is_preprocessed:
            return past_key_values
        # if self.arranged_mask is None:
        #     print(f"[layer_id : {self.layer_idx}]=====sdfdsfasdfsadfdsf=====")
        
        bsz, num_heads, _, head_dim = past_key_values[0].shape

        key_cached, value_cached = past_key_values[0], past_key_values[1]

        key_recent = key_cached[:,:,-self.recent_size:,:]
        value_recent = value_cached[:,:,-self.recent_size:,:]

        if not self.is_preprocessed:

            if self.method == 'kcenter':
                key_centers, center_ids = kcenter_greedy(key_cached[:,:,:-self.recent_size,:], self.num_clusters, return_radius=False, dist=self.dist, seed_select=self.seed_select)
                value_centers = torch.gather(value_cached[:,:,:-self.recent_size], 2, center_ids.unsqueeze(-1).expand(bsz, num_heads, self.num_clusters, head_dim))

            else:
                raise NotImplementedError

            self.is_preprocessed = True
        else:
            raise NotImplementedError

        return (torch.cat((key_centers, key_recent), axis=2), torch.cat((value_centers, value_recent), axis=2))

    def _clean(self):
        self.is_preprocessed = False
        self.arranged_mask = None

    def __repr__(self):
        return f"KVCacheHyper(\n"+\
            f"  layer_id: {self.layer_idx}\n"+\
            f"  is_preprocessed: {self.is_preprocessed}\n"+\
            f"  method: {self.method}\n"+\
            f"  num_clusters: {self.num_clusters}\n"+\
            f"  recent_size: {self.recent_size}\n"+\
            f")"



def apply_rotary_pos_emb_single(x, cos, sin, position_ids):
    # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
    cos = cos.squeeze(1).squeeze(0)  # [seq_len, dim]
    sin = sin.squeeze(1).squeeze(0)  # [seq_len, dim]
    cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
    sin = sin[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
    x_embed = (x * cos) + (rotate_half(x) * sin)
    return x_embed


def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


class ClustergenLlamaAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(self, config: LlamaConfig, layer_idx: int):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.hidden_size // self.num_heads
        self.num_key_value_heads = config.num_key_value_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
        self.max_position_embeddings = config.max_position_embeddings
        self.rope_theta = config.rope_theta
        self.is_causal = True
        self.layer_idx = layer_idx

        if (self.head_dim * self.num_heads) != self.hidden_size:
            raise ValueError(
                f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
                f" and `num_heads`: {self.num_heads})."
            )
        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
        self._init_rope()

        self.kv_cache = KVCacheHyper(
            num_clusters=config.hh_size,
            recent_size=config.recent_size,
            cache_size=config.cache_size,
            layer_idx=layer_idx,
            method=config.method,
            dist=config.dist,
            seed_select=config.seed_select
        )

    def _init_rope(self):
        if self.config.rope_scaling is None:
            self.rotary_emb = LlamaRotaryEmbedding(
                self.head_dim,
                max_position_embeddings=self.max_position_embeddings,
                base=self.rope_theta,
            )
        else:
            scaling_type = self.config.rope_scaling["type"]
            scaling_factor = self.config.rope_scaling["factor"]
            if scaling_type == "linear":
                self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
                    self.head_dim,
                    max_position_embeddings=self.max_position_embeddings,
                    scaling_factor=scaling_factor,
                    base=self.rope_theta,
                )
            elif scaling_type == "dynamic":
                self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
                    self.head_dim,
                    max_position_embeddings=self.max_position_embeddings,
                    scaling_factor=scaling_factor,
                    base=self.rope_theta,
                )
            else:
                raise ValueError(f"Unknown RoPE scaling type {scaling_type}")


    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

    def _clean_cache(self):
        self.kv_cache._clean()

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        **kwargs,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        # if "padding_mask" in kwargs:
        #     warnings.warn(
        #         "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
        #     )

        bsz, q_len, _ = hidden_states.size()

        if self.config.pretraining_tp > 1:
            key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
            query_slices = self.q_proj.weight.split(
                (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
            )
            key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
            value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)

            query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
            query_states = torch.cat(query_states, dim=-1)

            key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
            key_states = torch.cat(key_states, dim=-1)

            value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
            value_states = torch.cat(value_states, dim=-1)

        else:
            query_states = self.q_proj(hidden_states)
            key_states = self.k_proj(hidden_states)
            value_states = self.v_proj(hidden_states)

        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        kv_seq_len = key_states.shape[-2]
        if past_key_value is not None:
            kv_seq_len += past_key_value[0].shape[-2]

        if not position_ids.nelement() > 1:
            position_ids[0][0] = kv_seq_len - 1

        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
        
        ### Shift Pos: query pos is min(cache_size, idx)
        query_states = apply_rotary_pos_emb_single(query_states, cos, sin, position_ids)
        ###

        if past_key_value is not None:
            # reuse k, v, self_attention
            key_states = torch.cat([past_key_value[0], key_states], dim=2)
            value_states = torch.cat([past_key_value[1], value_states], dim=2)

        past_key_value = (key_states, value_states) if use_cache else None

        ### Shift Pos: key pos is the pos in cache (Rolling KV Cache and using relative pos emb)
        key_position_ids = torch.arange(kv_seq_len, device=position_ids.device).unsqueeze(0)
        key_states = apply_rotary_pos_emb_single(key_states, cos, sin, key_position_ids)
        ###

        # repeat k/v heads if n_kv_heads < n_heads
        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        if q_len < 1:

            attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / self.head_dim**0.5

            if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
                raise ValueError(
                    f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
                    f" {attn_weights.size()}"
                )

            if attention_mask is not None:
                if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
                    raise ValueError(
                        f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
                    )
                attn_weights = attn_weights + attention_mask

            if self.config.method == 'weighted_kcenter' and self.kv_cache.cluster_sizes is not None:
                attn_weights[:,:,:,:self.kv_cache.num_clusters] += self.kv_cache.cluster_sizes.log().to(query_states.dtype).unsqueeze(2)
                
            # upcast attention to fp32
            attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
            attn_output = torch.matmul(attn_weights, value_states)

            if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
                raise ValueError(
                    f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
                    f" {attn_output.size()}"
                )

            attn_output = attn_output.transpose(1, 2).contiguous()
        else:
            attn_output = flash_attn_func(query_states.transpose(1,2), key_states.transpose(1,2), value_states.transpose(1,2), causal=True)

        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
        past_key_value = self.kv_cache(past_key_value)

        if self.config.pretraining_tp > 1:
            attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
            o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
            attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
        else:
            attn_output = self.o_proj(attn_output)

        if not output_attentions:
            attn_weights = None

        return attn_output, attn_weights, past_key_value



class ClustergenLlamaForCausalLM(LlamaForCausalLM):
    def __init__(self, config):
        super().__init__(config)
        num_layers = len(self.model.layers)
        for layer_idx in range(num_layers):
            self.model.layers[layer_idx].self_attn = ClustergenLlamaAttention(config, layer_idx)

    
    def prepare_inputs_for_generation(
        self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
    ):
        if past_key_values:
            input_ids = input_ids[:, -1:]

        position_ids = kwargs.get("position_ids", None)

        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
        if inputs_embeds is not None and past_key_values is None:
            model_inputs = {"inputs_embeds": inputs_embeds}
        else:
            model_inputs = {"input_ids": input_ids}

        model_inputs.update(
            {
                "position_ids": position_ids,
                "past_key_values": past_key_values,
                "use_cache": kwargs.get("use_cache"),
                "attention_mask": attention_mask,
            }
        )
        return model_inputs