import math
from typing import List, Optional, Tuple, Union
import torch
from torch import nn
from transformers.cache_utils import Cache, StaticCache, SlidingWindowCache
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from transformers import logging
from transformers.models.mistral.modeling_mistral import apply_rotary_pos_emb, repeat_kv, MistralAttention
from  asym_kv.util.cache_utils import apply_rotary_pos_emb_single_withpos, CompressCache
from transformers.models.qwen2.modeling_qwen2 import Qwen2Model, Qwen2Attention
import types
import logging
logger = logging.getLogger(__name__)
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.utils import add_start_docstrings_to_model_forward
from transformers.models.qwen2.modeling_qwen2 import QWEN2_INPUTS_DOCSTRING

cumsum_pos = True

def Qwen2Attention_forward(
    self,
    hidden_states: torch.Tensor,
    attention_mask: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.LongTensor] = None,
    past_key_value: Optional[Cache] = None,
    output_attentions: bool = False,
    use_cache: bool = False,
    cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
    bsz, q_len, _ = hidden_states.size()

    query_states = self.q_proj(hidden_states)
    key_states = self.k_proj(hidden_states)
    value_states = self.v_proj(hidden_states)

    query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
    key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
    value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

    full_len = key_states.shape[-2]
    past_len = 0
    if past_key_value is not None and len(past_key_value.key_cache) > self.layer_idx:
        past_len = past_key_value.key_cache[self.layer_idx].shape[-2]
    else:
        past_len = 0
    full_len += past_len

    query_position_ids = torch.arange(past_len, full_len, device=position_ids.device, dtype=torch.long).unsqueeze(0).expand(bsz, -1)
    key_position_ids = torch.arange(full_len, device=position_ids.device).unsqueeze(0)
    
    cos, sin = self.rotary_emb(value_states, seq_len=full_len)
    

    if past_key_value is not None:
        key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx)


    query_states = apply_rotary_pos_emb_single_withpos(query_states, cos, sin, position_ids = query_position_ids)
    key_states = apply_rotary_pos_emb_single_withpos(key_states, cos, sin, position_ids = key_position_ids)

    # repeat k/v heads if n_kv_heads < n_heads
    key_states = repeat_kv(key_states, self.num_key_value_groups)
    value_states = repeat_kv(value_states, self.num_key_value_groups)
    

    attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

    if attention_mask is not None:  # no matter the length, we just slice it
        causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
        attn_weights = attn_weights + causal_mask

    # upcast attention to fp32
    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
    attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
    attn_output = torch.matmul(attn_weights, value_states)

    if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
        raise ValueError(
            f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
            f" {attn_output.size()}"
        )

    attn_output = attn_output.transpose(1, 2).contiguous()
    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

    attn_output = self.o_proj(attn_output)

    if not output_attentions:
        attn_weights = None

    return attn_output, attn_weights, past_key_value

def enable_qwen2_compress_attention_recursive(model):
    for name, module in reversed(model._modules.items()):
        if len(list(module.children())) > 0:
            enable_qwen2_compress_attention_recursive(
                module,
            )

        if isinstance(module, Qwen2Attention):
            model._modules[name].forward = types.MethodType(
                Qwen2Attention_forward, model._modules[name]
            )

def enable_qwen2_compress_attention(model):
    enable_qwen2_compress_attention_recursive(model)

