import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Optional, Tuple, Union
import torch.nn.functional as F
import warnings
from src.cache_utils import Cache, DynamicCache
from transformers.models.llama.modeling_llama import (
    apply_rotary_pos_emb,
    repeat_kv,
)
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.utils import (
    logging,
)
from src.kv_pruning_utils import init_snapkv
import numpy as np


logger = logging.get_logger(__name__)


def recover_cache(key_states_pruned, mask, layer_idx, ratio=0.4):
    _, heads, head_dim  = mask.shape
    sqlen = int(int(key_states_pruned.shape[-1])/int(mask.sum()))
    mask = mask.unsqueeze(2).expand(-1, -1, sqlen, -1)
    recovered_key_states = torch.zeros(1, heads, sqlen, head_dim, dtype=key_states_pruned.dtype, device = key_states_pruned.device)
    recovered_key_states[mask] = key_states_pruned
    del mask
    return recovered_key_states


def calculate_high_dim_contribution_score(
    q_pairs: torch.Tensor, 
    k_pairs: torch.Tensor,
    high_dim_ratio: float = 0.5
) -> torch.Tensor:
    head_dim = q_pairs.shape[-1]

    h = q_pairs * k_pairs
    g = h.view(*h.shape[:-1], head_dim // 2, 2).sum(dim=-1)

    num_dim_groups = g.shape[-1]
    split_index = int(num_dim_groups * (1 - high_dim_ratio))
    
    g_high = g[..., split_index:]

    high_dim_contribution = g_high.sum(dim=-1)
    
    total_abs_contribution = g.sum(dim=-1)

    epsilon = 1e-8
    
    score = high_dim_contribution / (total_abs_contribution + epsilon)
    
    return score

def spearman_correlation(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    x = x.flatten().float()
    y = y.flatten().float()
    x_rank = x.argsort().argsort().float()
    y_rank = y.argsort().argsort().float()
    x_rank_mean = torch.mean(x_rank)
    y_rank_mean = torch.mean(y_rank)
    x_centered = x_rank - x_rank_mean
    y_centered = y_rank - y_rank_mean
    cov = torch.sum(x_centered * y_centered)
    denom = torch.sqrt(torch.sum(x_centered ** 2) * torch.sum(y_centered ** 2))
    if denom == 0:
        return torch.tensor(0.0, device=x.device)
    return cov / denom

def calculate_high_dim_correlation(
    query_states: torch.Tensor,
    key_states: torch.Tensor,
    top_k: int = 100
) -> torch.Tensor:
    bsz, num_heads, seq_len, head_dim = query_states.shape
    k_len = key_states.shape[2]
    if top_k > k_len:
        top_k = k_len

    attn_scores = torch.matmul(query_states, key_states.transpose(-1, -2))
    _, topk_indices = torch.topk(attn_scores, k=top_k, dim=-1, sorted=False)
    
    expanded_indices = topk_indices.unsqueeze(-1).expand(-1, -1, -1, -1, head_dim)
    expanded_keys = key_states.unsqueeze(2).expand(-1, -1, seq_len, -1, -1)
    k_pairs = torch.gather(expanded_keys, 3, expanded_indices)
    
    q_pairs = query_states.unsqueeze(3).expand(-1, -1, -1, top_k, -1)

    high_dim_scores = calculate_high_dim_contribution_score(q_pairs, k_pairs)

    m_positions = torch.arange(seq_len, device=query_states.device).view(1, 1, seq_len, 1)
    relative_distances = m_positions - topk_indices

    correlation_scores = torch.zeros((bsz, num_heads), device=query_states.device)

    for b in range(bsz):
        for h in range(num_heads):
            distances_for_head = relative_distances[b, h].flatten()
            scores_for_head = high_dim_scores[b, h].flatten()
            
            correlation_scores[b, h] = spearman_correlation(
                distances_for_head, scores_for_head
            )

    return correlation_scores


def llama_flash_attn2_forward_SnapKV(
    self,
    hidden_states: torch.Tensor,
    attention_mask: Optional[torch.LongTensor] = None,
    position_ids: Optional[torch.LongTensor] = None,
    past_key_value: Optional[Cache] = None,
    output_attentions: bool = False,
    use_cache: bool = False,
    **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
    init_snapkv(self)
    if "padding_mask" in kwargs:
        warnings.warn(
            "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
        )

        # overwrite attention_mask with padding_mask
        attention_mask = kwargs.pop("padding_mask")

    output_attentions = False

    bsz, q_len, _ = hidden_states.size()

    query_states = self.q_proj(hidden_states)
    key_states = self.k_proj(hidden_states)
    value_states = self.v_proj(hidden_states)

    query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
    key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
    value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
    
    kv_seq_len = key_states.shape[-2]
    if past_key_value is not None:
        if self.layer_idx is None:
            raise ValueError(
                f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
                "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
                "with a layer index."
            )
        if hasattr(self, "kv_seq_len"):
            if self.kv_seq_len != 0:
                kv_seq_len += self.kv_seq_len
            else:
                kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
        else:
            kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)

    cos, sin = self.rotary_emb(value_states, position_ids)
    query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)


    key_states = repeat_kv(key_states, self.num_key_value_groups)
    value_states = repeat_kv(value_states, self.num_key_value_groups)


    if past_key_value is not None:
        cache_kwargs = {"sin": sin, "cos": cos}

        if key_states.shape[-2] == kv_seq_len:
            self.kv_seq_len = kv_seq_len
            positional_head_scores = calculate_high_dim_correlation(query_states[..., -32:, :], key_states, top_k=100)

            importance_scores = positional_head_scores.squeeze(0).tolist()
            c = self.head_dim/2
            p = 0.7
            Lamda = 0.2
            all_layer_ratio=np.array(importance_scores)
            all_layer_ratio = ((all_layer_ratio - all_layer_ratio.min()) * (1/(all_layer_ratio.max() - all_layer_ratio.min()) * Lamda*2))
            all_layer_ratio=all_layer_ratio-np.mean(all_layer_ratio)+(1-p)
            self.keep_channels = [int(x*c)*2 for x in all_layer_ratio]

            kv_pruned, kv_recent, mask, value_states_compress = self.kv_cluster.update_thinkv(key_states, query_states, value_states, attention_mask, self.num_key_value_groups, self.keep_channels)
            past_key_value.update_think(kv_pruned, kv_recent, mask, value_states_compress, self.layer_idx, cache_kwargs)
        else:
            self.kv_seq_len += q_len
            key_states, value_states, key_pruned, mask = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
            recovered_key_states = recover_cache(key_pruned, mask, self.layer_idx, self.kv_cluster.ratio)
            key_states = torch.cat([recovered_key_states, key_states], dim=-2)


    query_states = query_states.transpose(1, 2)
    key_states = key_states.transpose(1, 2)
    value_states = value_states.transpose(1, 2)

    dropout_rate = self.attention_dropout if self.training else 0.0


    input_dtype = query_states.dtype
    if input_dtype == torch.float32:
        if torch.is_autocast_enabled():
            target_dtype = torch.get_autocast_gpu_dtype()
        elif hasattr(self.config, "_pre_quantization_dtype"):
            target_dtype = self.config._pre_quantization_dtype
        else:
            target_dtype = self.q_proj.weight.dtype

        logger.warning_once(
            f"The input hidden states seems to be silently casted in float32, this might be related to"
            f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
            f" {target_dtype}."
        )

        query_states = query_states.to(target_dtype)
        key_states = key_states.to(target_dtype)
        value_states = value_states.to(target_dtype)

    attn_output = self._flash_attention_forward(
        query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
    )

    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
    attn_output = self.o_proj(attn_output)

    if not output_attentions:
        attn_weights = None

    return attn_output, attn_weights, past_key_value

def prepare_inputs_for_generation_llama(
    self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
    if past_key_values is None:
        for layer in self.model.layers:
            layer.self_attn.kv_seq_len = 0
    if past_key_values is not None:
        if isinstance(past_key_values, Cache):
            cache_length = past_key_values.get_seq_length()
            past_length = past_key_values.seen_tokens
            max_cache_length = past_key_values.get_max_length()
        else:

            cache_length = past_length = self.model.layers[0].self_attn.kv_seq_len
            max_cache_length = None

        if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
            input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]

        elif past_length < input_ids.shape[1]:
            input_ids = input_ids[:, past_length:]
            

        if (
            max_cache_length is not None
            and attention_mask is not None
            and cache_length + input_ids.shape[1] > max_cache_length
        ):
            attention_mask = attention_mask[:, -max_cache_length:]

    position_ids = kwargs.get("position_ids", None)
    if attention_mask is not None and position_ids is None:
        position_ids = attention_mask.long().cumsum(-1) - 1
        position_ids.masked_fill_(attention_mask == 0, 1)
        if past_key_values:
            
            
            
            position_ids = position_ids[:, -input_ids.shape[1] :]

    if inputs_embeds is not None and past_key_values is None:
        model_inputs = {"inputs_embeds": inputs_embeds}
    else:
        model_inputs = {"input_ids": input_ids}

    model_inputs.update(
        {
            "position_ids": position_ids,
            "past_key_values": past_key_values,
            "use_cache": kwargs.get("use_cache"),
            "attention_mask": attention_mask,
        }
    )
    return model_inputs


def llama_model_forward(
    self,
    input_ids: torch.LongTensor = None,
    attention_mask: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.LongTensor] = None,
    past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
    inputs_embeds: Optional[torch.FloatTensor] = None,
    use_cache: Optional[bool] = None,
    output_attentions: Optional[bool] = None,
    output_hidden_states: Optional[bool] = None,
    return_dict: Optional[bool] = None,
    cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:

    init_snapkv(self)

    output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
    output_hidden_states = (
        output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
    )
    use_cache = use_cache if use_cache is not None else self.config.use_cache
    return_dict = return_dict if return_dict is not None else self.config.use_return_dict

    if (input_ids is None) ^ (inputs_embeds is not None):
        raise ValueError(
            "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
        )

    if self.gradient_checkpointing and self.training and use_cache:
        logger.warning_once(
            "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
        )
        use_cache = False

    if inputs_embeds is None:
        inputs_embeds = self.embed_tokens(input_ids)

    return_legacy_cache = False
    if use_cache and not isinstance(past_key_values, Cache):  # kept for BC (non `Cache` `past_key_values` inputs)
        return_legacy_cache = True
        past_key_values = DynamicCache.from_legacy_cache(past_key_values)

    if cache_position is None:
        past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
        cache_position = torch.arange(
            past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
        )
    if position_ids is None:
        position_ids = cache_position.unsqueeze(0)

    causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values)

    hidden_states = inputs_embeds

    all_hidden_states = () if output_hidden_states else None
    all_self_attns = () if output_attentions else None
    next_decoder_cache = None

    for layer_idx, decoder_layer in enumerate(self.layers):
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        past_seen_tokens = past_key_values.get_seq_length(layer_idx)
        
        

        if self.gradient_checkpointing and self.training:
            layer_outputs = self._gradient_checkpointing_func(
                decoder_layer.__call__,
                hidden_states,
                causal_mask,
                position_ids,
                past_key_values,
                output_attentions,
                use_cache,
                cache_position,
            )
        else:
            layer_outputs = decoder_layer(
                hidden_states,
                attention_mask=causal_mask,
                position_ids=position_ids,
                past_key_value=past_key_values,
                output_attentions=output_attentions,
                use_cache=use_cache,
                cache_position=cache_position,
            )

        hidden_states = layer_outputs[0]

        if use_cache:
            next_decoder_cache = layer_outputs[2 if output_attentions else 1]

        if output_attentions:
            all_self_attns += (layer_outputs[1],)

    hidden_states = self.norm(hidden_states)

    if output_hidden_states:
        all_hidden_states += (hidden_states,)

    next_cache = next_decoder_cache if use_cache else None
    if return_legacy_cache:
        next_cache = next_cache.to_legacy_cache()

    if not return_dict:
        return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
    return BaseModelOutputWithPast(
        last_hidden_state=hidden_states,
        past_key_values=next_cache,
        hidden_states=all_hidden_states,
        attentions=all_self_attns,
    )