import enum
import math
from typing import List, Optional, Tuple, Union

import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
from transformers.generation import GenerationMixin
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from transformers.modeling_outputs import (
    BaseModelOutputWithPast,
    CausalLMOutputWithPast,
    QuestionAnsweringModelOutput,
    SequenceClassifierOutputWithPast,
    TokenClassifierOutput,
)
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import (
    add_code_sample_docstrings,
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    is_flash_attn_2_available,
    is_flash_attn_greater_or_equal_2_10,
    logging,
    replace_return_docstrings,
)
from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
from transformers.models.qwen2.modeling_qwen2 import Qwen2RotaryEmbedding, apply_rotary_pos_emb, repeat_kv, Qwen2MLP, Qwen2RMSNorm, Qwen2DecoderLayer, Qwen2Model, Qwen2ForCausalLM

from itertools import combinations

if is_flash_attn_2_available():
    from transformers.modeling_flash_attention_utils import _flash_attention_forward
from scipy.stats import spearmanr

import numpy as np

logger = logging.get_logger(__name__)

class AblatedQwen2Attention(nn.Module):
    """
    Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
    and "Generating Long Sequences with Sparse Transformers".
    """

    def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None):
        super().__init__()
        self.config = config
        self.layer_idx = layer_idx
        if layer_idx is None:
            logger.warning_once(
                f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
                "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
                "when creating this class."
            )

        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.hidden_size // self.num_heads
        self.num_key_value_heads = config.num_key_value_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
        self.max_position_embeddings = config.max_position_embeddings
        self.rope_theta = config.rope_theta
        self.is_causal = True
        self.attention_dropout = config.attention_dropout

        if (self.head_dim * self.num_heads) != self.hidden_size:
            raise ValueError(
                f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
                f" and `num_heads`: {self.num_heads})."
            )
        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)

        self.rotary_emb = Qwen2RotaryEmbedding(config=self.config)

        self.ignore_start = None
        self.ignore_end = None

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # will become mandatory in v4.46
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:

        bsz, q_len, _ = hidden_states.size()

        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)

        if position_embeddings is None:
            logger.warning_once(
                "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
                "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
                "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
                "removed and `position_embeddings` will be mandatory."
            )
            cos, sin = self.rotary_emb(value_states, position_ids)
        else:
            cos, sin = position_embeddings
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        if past_key_value is not None:
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}  # Specific to RoPE models
            
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
        # repeat k/v heads if n_kv_heads < n_heads
        key_states = repeat_kv(key_states, self.num_key_value_groups)

        value_states = repeat_kv(value_states, self.num_key_value_groups)

        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

        if attention_mask is not None:  # no matter the length, we just slice it
            causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
            if self.ignore_start is not None and self.ignore_end is not None:
                #print("crop")
                causal_mask[:, :, :, self.ignore_start:self.ignore_end] = -3.4e+38
            attn_weights = attn_weights + causal_mask

        # upcast attention to fp32
        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
        attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
        attn_output = torch.matmul(attn_weights, value_states)

        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
            raise ValueError(
                f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
                f" {attn_output.size()}"
            )

        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

        attn_output = self.o_proj(attn_output)

        if not output_attentions:
            attn_weights = None

        return attn_output, attn_weights, past_key_value
    
    
class AblatedQwen2DecoderLayer(Qwen2DecoderLayer):
    def __init__(self, config: Qwen2Config, layer_idx: int):
        super().__init__(config, layer_idx)
        self.hidden_size = config.hidden_size

        if config.sliding_window and config._attn_implementation != "flash_attention_2":
            logger.warning_once(
                f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
                "unexpected results may be encountered."
            )
        self.self_attn = AblatedQwen2Attention(config, layer_idx)

        self.mlp = Qwen2MLP(config)
        self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
    

class AblatedQwen2Model(Qwen2Model):
    """
    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`]

    Args:
        config: Qwen2Config
    """

    def __init__(self, config: Qwen2Config):
        super().__init__(config)
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
        self.layers = nn.ModuleList(
            [AblatedQwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
        )
        self._attn_implementation = config._attn_implementation
        self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.rotary_emb = Qwen2RotaryEmbedding(config=config)

        self.gradient_checkpointing = False
        # Initialize weights and apply final processing
        self.post_init()


class AblatedQwen2ForCausalLM(Qwen2ForCausalLM):
    def __init__(self, config):
        super().__init__(config)
        self.model = AblatedQwen2Model(config)
        self.vocab_size = config.vocab_size
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        self.ignore_start = None
        self.ignore_end = None
        self.ignore_step = 0
        self.tokenizer = None
        self.wait_list = [
            "wait", "Wait", " wait", " Wait", "wait ", "Wait ",
            "hmm", "Hmm", " hmm", " Hmm", "hmm ", "Hmm ",
            "but", "But", " but", " But", "but ", "But ",
            "however", "However", " however", " However", "however ", "However ",
            "alternative", "Alternative", " alternative", " Alternative", "alternative ", "Alternative ",
            "alternatively", "Alternatively", " alternatively", " Alternatively", "alternatively ", "Alternatively ",
            "another", "Another", " another", " Another", "another ", "Another ",
            "check", "Check", " check", " Check", "check ", "Check ",
            "oh", "Oh", " oh", " Oh", "oh ", "Oh ",
            "maybe", "Maybe", " maybe", " Maybe", "maybe ", "Maybe ",
            "verify", "Verify", " verify", " Verify", "verify ", "Verify ",
            "other", "Other", " other", " Other", "other ", "Other ",
            "again", "Again", " again", " Again", "again ", "Again ",
            "now", "Now", " now", " Now", "now ", "Now ",
            "ah", "Ah", " ah", " Ah", "ah ", "Ah ",
            "any", "Any", " any", " Any", "any ", "Any ",
            "user", "User", " user", " User", "user ", "User ",
            "so", "So", " so", " So", "so ", "So ",
            "i", "I", " i", " I", "i ", "I "
        ]
        self.wait_token = []
        self.max_ignore_step = 10

        self.before_token = []
        self.before_attention = None
        # Initialize weights and apply final processing
        self.post_init()
        self.sentence_sep = [[]]
        self.sep_list = [',', ';', ':', '?', '!', '.', "<think>"]
        self.attn_score_threshold = 0

    def update_model(self, ignore_start, ignore_end):
        for layer in self.model.layers:
            layer.self_attn.ignore_start = ignore_start
            layer.self_attn.ignore_end = ignore_end
    def insert_wait_list(self):
        for each in self.tokenizer(self.wait_list)['input_ids']:
            self.wait_token.append(each[1])
        self.wait_token = list(set(self.wait_token))

    def spearman(self, attention_matrix):
        row_means = np.mean(attention_matrix, axis=0)
        num_layers = attention_matrix.shape[0]
        correlations = []
        for i in range(num_layers):
            corr = spearmanr(attention_matrix[i, :], row_means)
            correlations.append(corr)
        avg_correlation = np.mean(correlations)
        return avg_correlation

    def spacing(self, attention_matrix, k_percent=0.25):
        layer_avg = np.mean(attention_matrix, axis=0)
        k = max(1, int(k_percent * len(layer_avg)))
        peak_indices = np.argsort(layer_avg)[-k:]
        peak_indices = np.sort(peak_indices)
        if len(peak_indices) < 2:
            return 0.0
        actual_spacing = np.diff(peak_indices)
        ideal_spacing = len(layer_avg) / k
        normalized_diff = np.abs(actual_spacing - ideal_spacing) / (ideal_spacing + 1e-10)
        mean_diff = np.mean(normalized_diff)
        spacing_score = 1 / (1 + mean_diff)
        return spacing_score
    
    def remove_column(self, a, percent=0.01):
        if a.size == 0:
            return a
        col_means = np.mean(a, axis=0)
        k = max(1, int(percent * a.shape[1]))
        top_cols = np.argsort(col_means)[-k:]
        remaining_cols = np.setdiff1d(np.arange(a.shape[1]), top_cols)
        return a[:, remaining_cols]

    def normalize(self, x, percentile=99):
        if np.all(x == 0):
            return x
        max_clip = np.percentile(x, percentile)
        x_clipped = np.minimum(x, max_clip)
        x_normalized = (x_clipped - x.min()) / (max_clip - x.min())
        return x_normalized

    def get_avg_attention(self, outputs):
        attentions = outputs.attentions 
        all_layers_att = []
        for layer_att in attentions:
            avg_att = layer_att.mean(dim=[0, 1]).detach().cpu().numpy()
            input_att = avg_att[:, :]
            all_layers_att.append(np.expand_dims(input_att, axis=0))
        concatenated = np.concatenate(all_layers_att, axis=0)
        return concatenated
    
    def check_sep(self, text):
        for sep in self.sep_list:
            if sep in text:
                return True
        return False

    def process_matrix(self, return_data):
        new_matrix = np.zeros((return_data.shape[0], len(self.sentence_sep)))
        for i, col_indices in enumerate(self.sentence_sep):
            if not all(0 <= idx < return_data.shape[1] for idx in col_indices):
                raise ValueError(f"列索引 {col_indices} 超出矩阵范围")
            selected_cols = return_data[:, col_indices]
            new_matrix[:, i] = np.mean(selected_cols, axis=1)
        return new_matrix
    def attn_score(self, outputs, input_ids):

        attentions = self.get_avg_attention(outputs) # (num_layers, 1, num_input_tokens) or (num_layers, num_input_tokens, num_input_tokens)
        if input_ids.shape[1] > 1:
            self.before_token = input_ids[0].detach().cpu().tolist()
        else:
            self.before_token.append(input_ids[0, 0].item())
        attentions = attentions[:, -1, :]
        return_data = np.apply_along_axis(self.normalize, 1, attentions)

        if input_ids.shape[1] > 1:
            self.sentence_sep = [[]]
            for i, each in enumerate(self.before_token):
                if self.check_sep(self.tokenizer.decode([each])):
                    self.sentence_sep.append([i])
                else:
                    self.sentence_sep[-1].append(i)
        else:
            if self.check_sep(self.tokenizer.decode([self.before_token[-1]])):
                self.sentence_sep.append([len(self.before_token)-1])
            else:
                self.sentence_sep[-1].append(len(self.before_token)-1)
        
        return_data = self.process_matrix(return_data)

        non_zero_cols = ~np.all(return_data == 0, axis=0)
        return_data = return_data[:, non_zero_cols]

        return_data = self.remove_column(return_data)
        spacing_data = self.spacing(return_data)
        spearman_data = self.spearman(return_data)
        spearman_data = (spearman_data+1)/2
        return spacing_data*spearman_data

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        num_logits_to_keep: int = 0,
        **loss_kwargs,
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        r"""
        Args:
            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

            num_logits_to_keep (`int`, *optional*):
                Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
                `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
                token can save memory, which becomes pretty significant for long sequences or large vocabulary size.

        Returns:

        Example:

        ```python
        >>> from transformers import AutoTokenizer, Qwen2ForCausalLM

        >>> model = Qwen2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
        >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)

        >>> prompt = "Hey, are you conscious? Can you talk to me?"
        >>> inputs = tokenizer(prompt, return_tensors="pt")

        >>> # Generate
        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
        "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
        ```"""

        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
        if self.ignore_step > 0:
            self.update_model(self.ignore_start, self.ignore_end)
            self.ignore_step -= 1
        else:
            self.update_model(None, None)
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=True,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            cache_position=cache_position,
        )

        hidden_states = outputs[0]
        # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
        logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])

        next_token_logits = logits[:, -1, :].detach().cpu().clone()
        probs = torch.softmax(next_token_logits, dim=-1)

        max_index = torch.argmax(probs, dim=-1).item()

        if self.ignore_start is not None:
            attn_score = self.attn_score(outputs, input_ids)
        else:
            attn_score = 0
        if attn_score <= self.attn_score_threshold:
            loss = None
            if labels is not None:
                loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
            
            if max_index in self.wait_token:
                self.ignore_step = self.max_ignore_step

            if not return_dict:
                output = (logits,) + outputs[1:]
                return (loss,) + output if loss is not None else output

            return CausalLMOutputWithPast(
                loss=loss,
                logits=logits,
                past_key_values=outputs.past_key_values,
                hidden_states=outputs.hidden_states,
                attentions=outputs.attentions,
            )
        
        else:
            past_key_values.crop(-input_ids.shape[-1])
            self.update_model(self.ignore_start, self.ignore_end)

            outputs = self.model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                position_ids=position_ids,
                past_key_values=past_key_values,
                inputs_embeds=inputs_embeds,
                use_cache=use_cache,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
                cache_position=cache_position,
            )

            hidden_states = outputs[0]
            # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
            logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])

            loss = None
            if labels is not None:
                loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)

            next_token_logits = logits[:, -1, :].detach().cpu().clone()
            probs = torch.softmax(next_token_logits, dim=-1)
            max_index = torch.argmax(probs, dim=-1).item()
            if max_index in self.wait_token:
                self.ignore_step = self.max_ignore_step

            if not return_dict:
                output = (logits,) + outputs[1:]
                return (loss,) + output if loss is not None else output

            return CausalLMOutputWithPast(
                loss=loss,
                logits=logits,
                past_key_values=outputs.past_key_values,
                hidden_states=outputs.hidden_states,
                attentions=outputs.attentions,
            )