import torch
import torch.nn as nn
from typing import Callable, Optional

def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


def eager_attention_forward(
    module: nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attention_mask: Optional[torch.Tensor],
    scaling: float,
    dropout: float = 0.0,
    **kwargs,
):
    key_states = repeat_kv(key, module.num_key_value_groups)
    value_states = repeat_kv(value, module.num_key_value_groups)

    attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
    if attention_mask is not None:
        causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
        attn_weights = attn_weights + causal_mask

    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
    attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
    attn_output = torch.matmul(attn_weights, value_states)
    attn_output = attn_output.transpose(1, 2).contiguous()

    return attn_output, attn_weights


def attn_forward(layer, hidden_states, position_embeddings, past_key_value, attention_mask):
    hidden_states = layer.input_layernorm(hidden_states)
    bsz, q_len, _ = hidden_states.size()

    query_states = layer.self_attn.q_proj(hidden_states)
    key_states = layer.self_attn.k_proj(hidden_states)
    value_states = layer.self_attn.v_proj(hidden_states)

    query_states = query_states.view(bsz, q_len, -1, layer.self_attn.head_dim).transpose(1, 2)
    key_states = key_states.view(bsz, q_len, -1, layer.self_attn.head_dim).transpose(1, 2)
    value_states = value_states.view(bsz, q_len, -1, layer.self_attn.head_dim).transpose(1, 2)

    cos, sin = position_embeddings
    query_states, key_states = apply_multimodal_rotary_pos_emb(
        query_states, key_states, cos, sin, layer.self_attn.rope_scaling["mrope_section"]
    )

    assert past_key_value is None
    # if past_key_value is not None:
    #     cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}  # Specific to RoPE models
    #     key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

    # attention_interface: Callable = eager_attention_forward
    # if layer.self_attn.config._attn_implementation != "eager":
    #     attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]

    attn_output, attn_weights = eager_attention_forward(
        layer.self_attn,
        query_states,
        key_states,
        value_states,
        attention_mask,
        dropout=0.0 if not layer.self_attn.training else layer.self_attn.attention_dropout,
        scaling=layer.self_attn.head_dim**-0.5,
        # sliding_window=layer.self_attn.sliding_window
        # **kwargs,
    )

    attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
    # attn_output = self.o_proj(attn_output)
    return attn_output, attn_weights
    


def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
    """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/).

    Explanation:
        Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding
        sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For
        vision embedding part, we apply rotary position embedding on temporal, height and width dimension seperately.
        Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding.
        For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal,
        height and width) of text embedding is always the same, so the text embedding rotary position embedding has no
        difference with modern LLMs.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`):
            The position indices of the tokens corresponding to the query and key tensors. For example, this can be
            used to pass offsetted position ids when working with a KV-cache.
        mrope_section(`List(int)`):
            Multimodal rope section is for channel dimension of temporal, height and width in rope calculation.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    mrope_section = mrope_section * 2
    cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
        unsqueeze_dim
    )
    sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
        unsqueeze_dim
    )

    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed


def get_key_query_value(hidden_states, decoder_layer, position_ids, position_embeddings, get_only_k=False) :

    hidden_states_normalized = decoder_layer.input_layernorm(hidden_states)
    current_k = decoder_layer.self_attn.k_proj(hidden_states_normalized) 
    bsz, q_len, _ = current_k.size()
    current_k_cosine = current_k.view(bsz, q_len, decoder_layer.self_attn.num_key_value_heads, decoder_layer.self_attn.head_dim).transpose(1, 2)
    cos, sin = position_embeddings
    if get_only_k :
        current_k_cosine, _=apply_multimodal_rotary_pos_emb(
            current_k_cosine, current_k_cosine, cos, sin, 
            decoder_layer.self_attn.rope_scaling["mrope_section"]
            )
        current_k_cosine=current_k_cosine.transpose(1, 2).flatten(2,3)
        return current_k,current_k_cosine

    current_q = decoder_layer.self_attn.q_proj(hidden_states_normalized) 
    current_q_cosine = current_q.view(bsz, q_len, decoder_layer.self_attn.num_heads, decoder_layer.self_attn.head_dim).transpose(1, 2)  
    current_k_cosine, current_q_cosine = apply_multimodal_rotary_pos_emb(
        current_k_cosine, current_q_cosine, cos, sin, 
        decoder_layer.self_attn.rope_scaling["mrope_section"]
        )
    current_k_cosine=current_k_cosine.transpose(1, 2).flatten(2,3)
    current_q_cosine=current_q_cosine.transpose(1, 2).flatten(2,3)

    return current_k, current_q, current_k_cosine, current_q_cosine