import math
import types
import warnings
from typing import Optional, Tuple

import torch
from transformers.cache_utils import Cache
from torch import nn
from transformers.utils import TransformersKwargs
from transformers.processing_utils import Unpack
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv

def eager_attention_forward(
    module: nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attention_mask: Optional[torch.Tensor],
    scaling: float,
    dropout: float = 0.0,
    **kwargs: Unpack[TransformersKwargs],
):
    key_states = repeat_kv(key, module.num_key_value_groups)
    value_states = repeat_kv(value, module.num_key_value_groups)

    attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
    if attention_mask is not None:
        causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
        attn_weights = attn_weights + causal_mask

    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
    attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
    attn_output = torch.matmul(attn_weights, value_states)
    attn_output = attn_output.transpose(1, 2).contiguous()

    return attn_output, attn_weights

def attn_forward(
        self,
        hidden_states: torch.Tensor,
        position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
        attention_mask: Optional[torch.Tensor] = None,
        past_key_values: Optional[Cache] = None,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> tuple[torch.Tensor, torch.Tensor]:

  input_shape = hidden_states.shape[:-1]
  hidden_shape = (*input_shape, -1, self.head_dim)

  query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

  cos, sin = position_embeddings
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

  if past_key_values is not None:
    # sin and cos are specific to RoPE models; cache_position needed for the static cache
    cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
    key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)


  attn_output, attn_weights = eager_attention_forward(
      self,
      query_states,
      key_states,
      value_states,
      attention_mask,
      dropout=0.0 if not self.training else self.attention_dropout,
      scaling=self.scaling,
      **kwargs,
        )

  attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  attn_output = self.o_proj(attn_output)
  return attn_output, attn_weights



def enable_llama_custom_attention(layer, layer_id):
    """
    replace the forward function of LlamaAttention with a custom forward function `llama_custom_attention_forward`
    """
    modified_module = layer.self_attn
    modified_module.layer_id = layer_id 
    modified_module.forward = types.MethodType(attn_forward, modified_module)

    return modified_module