from typing import Any
from transformers import Qwen2ForCausalLM

import torch
import math
from torch import nn
from transformers.models.qwen2.modeling_qwen2 import (
    Qwen2Attention,
    repeat_kv,
    rotate_half,
    Qwen2DecoderLayer,
    Qwen2Model,
    Qwen2ForCausalLM,
    Qwen2Config,
    Cache
)
from transformers.models.qwen2.modeling_qwen2 import Qwen2RotaryEmbedding  # Import the rotary embedding class
from typing import Optional, Tuple, Union

def apply_single_rotary_pos_emb(inputs, cos, sin, position_ids=None, unsqueeze_dim=1):
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    embed = (inputs * cos) + (rotate_half(inputs) * sin)
    return embed

class Qwen2ModifiedAttention(Qwen2Attention):
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        **kwargs,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        if "padding_mask" in kwargs:
            warnings.warn(
                "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
            )
        bsz, q_len, _ = hidden_states.size()

        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        kv_seq_len = key_states.shape[-2]
        if past_key_value is not None:
            if self.layer_idx is None:
                raise ValueError(
                    f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
                    "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
                    "with a layer index."
                )
            kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)

        # Apply rotary embeddings only to query_states
        query_states = apply_single_rotary_pos_emb(query_states, cos, sin, position_ids)

        # Update cache (without applying rotary embeddings to key_states)
        if past_key_value is not None:
            cache_kwargs = {"sin": sin, "cos": cos}  # Specific to RoPE models
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

            # Prepare full_position_ids for the keys (from the cache)
            full_position_ids = torch.arange(
                    0, past_key_value.seen_tokens, dtype=torch.long, device=query_states.device
                )
            full_position_ids = full_position_ids.unsqueeze(0)
        else:
            full_position_ids = position_ids

        # Apply rotary embeddings to key_states (after possible cache retrieval)
        key_states = apply_single_rotary_pos_emb(key_states, cos, sin, full_position_ids)

        # repeat k/v heads if n_kv_heads < n_heads
        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

        if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
            raise ValueError(
                f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
                f" {attn_weights.size()}"
            )

        if attention_mask is not None:
            if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
                raise ValueError(
                    f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
                )

            attn_weights = attn_weights + attention_mask

        # upcast attention to fp32
        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)

        if hasattr(self.config, 'use_sel_attn') and self.config.use_sel_attn \
            and (q_len == 1 or q_len != kv_seq_len):
            group_starts = torch.cumsum(torch.tensor([0]+ self.config.group_size, dtype=torch.float32), dim=0)[:-1].to(query_states.dtype)
            num_groups = len(self.config.group_size)

            # 각 그룹 점수 계산
            group_scores = []
            for i in range(num_groups):
                 start = int(group_starts[i])
                 end = int(group_starts[i] + self.config.group_size[i])
                 attn_output = attn_weights[..., start:end] # [batch_size, head, q_len, group_len]
                 score = torch.topk(attn_output, k=5, dim=-1).values.sum(dim=-1) # top-5 합
                 group_scores.append(score.unsqueeze(-1)) # [batch_size, head, query_len, 1]

            top_k = 2  # Example value, can be adjusted
            group_scores = torch.cat(group_scores, dim=-1) # [batch_size, head, query_len, num_groups]

            # aggregation
            if "head" in self.config.aggr_mode:
                group_scores = group_scores.sum(dim=1, keepdim=True)  # 헤드 차원 집계
            if "query" in self.config.aggr_mode:
                group_scores = group_scores.sum(dim=2, keepdim=True) # 쿼리 차원 집계

            _, topk_indices = torch.topk(group_scores, k=top_k, dim=-1) # [batch_size, head, query_len, top_k]

            # Create a mask for the selected groups
            mask = torch.zeros_like(attn_weights)
            for i in range(bsz):
                for j in range(self.num_heads):
                    for k in range(q_len):
                        jj = 0 if "head" in self.config.aggr_mode else j
                        kk = 0 if "query" in self.config.aggr_mode else k
                        for selected_group_idx in topk_indices[i, jj, kk]:
                            start_idx = int(group_starts[selected_group_idx])
                            end_idx = start_idx + self.config.group_size[selected_group_idx]
                            mask[i, j, k, start_idx:end_idx] = 1
                        kv_size = sum(self.config.group_size)
                        mask[i, j, k, -kv_size:] = 1 # query part

            # Apply the mask to the attention weights
            attn_weights = attn_weights * mask

            # Renormalize the attention weights
            attn_weights = attn_weights / attn_weights.sum(dim=-1, keepdim=True)

        attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
        attn_output = torch.matmul(attn_weights, value_states)

        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
            raise ValueError(
                f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
                f" {attn_output.size()}"
            )

        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

        attn_output = self.o_proj(attn_output)

        if not output_attentions:
            attn_weights = None

        return attn_output, attn_weights, past_key_value

class Qwen2ModifiedSdpaAttention(Qwen2Attention):

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # will become mandatory in v4.46
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        bsz, q_len, _ = hidden_states.size()

        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        if position_embeddings is None:
            logger.warning_once(
                "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
                "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
                "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
                "removed and `position_embeddings` will be mandatory."
            )
            cos, sin = self.rotary_emb(value_states, position_ids)
        else:
            cos, sin = position_embeddings

        # Apply rotary embeddings only to query_states
        query_states = apply_single_rotary_pos_emb(query_states, cos, sin)

        # Update cache (without applying rotary embeddings to key_states)
        if past_key_value is not None:
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}  # Specific to RoPE models
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

            # Prepare full_position_ids for the keys (from the cache)
            full_position_ids = torch.arange(
                    0, past_key_value.seen_tokens, dtype=torch.long, device=query_states.device
                )
            full_position_ids = full_position_ids.unsqueeze(0)
            cos, sin = self.rotary_emb(value_states, full_position_ids)

        # Apply rotary embeddings to key_states (after possible cache retrieval)
        key_states = apply_single_rotary_pos_emb(key_states, cos, sin)



        # repeat k/v heads if n_kv_heads < n_heads
        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        causal_mask = attention_mask
        if attention_mask is not None:  # no matter the length, we just slice it
            causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]

        # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
        # Reference: https://github.com/pytorch/pytorch/issues/112577.
        if query_states.device.type == "cuda" and attention_mask is not None:
            query_states = query_states.contiguous()
            key_states = key_states.contiguous()
            value_states = value_states.contiguous()

        # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
        # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
        # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
        is_causal = True if causal_mask is None and q_len > 1 else False

        attn_output = torch.nn.functional.scaled_dot_product_attention(
            query_states,
            key_states,
            value_states,
            attn_mask=causal_mask,
            dropout_p=self.attention_dropout if self.training else 0.0,
            is_causal=is_causal,
        )

        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(bsz, q_len, self.hidden_size)

        attn_output = self.o_proj(attn_output)

        return attn_output, None, past_key_value


class Qwen2ModifiedDecoderLayer(Qwen2DecoderLayer):
    def __init__(self, config: Qwen2Config, layer_idx: int):
        super().__init__(config, layer_idx)
        self.self_attn = Qwen2ModifiedSdpaAttention(config, layer_idx)

class Qwen2ModifiedModel(Qwen2Model):
    def __init__(self, config: Qwen2Config):
        super().__init__(config)
        self.layers = nn.ModuleList(
            [Qwen2ModifiedDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
        )

class Qwen2ModifiedForCausalLM(Qwen2ForCausalLM):
    def __init__(self, config: Qwen2Config):
        super().__init__(config)
        self.model = Qwen2ModifiedModel(config)