import math
from torch import Tensor
import transformers
from typing import Optional, Union
from ...ops.softmax import SoftmaxFunction
import torch
# transformers/src/transformers/models/qwen3
# /modeling_qwen3.py
from transformers.models.qwen3.modeling_qwen3 import repeat_kv,apply_rotary_pos_emb,eager_attention_forward
from transformers.integrations.sdpa_attention import sdpa_attention_forward , use_gqa_in_sdpa
from torch import nn
from collections.abc import Callable

def qwen3_selfattn_forward(
    self,
    hidden_states: torch.Tensor,
    position_embeddings: tuple[torch.Tensor, torch.Tensor],
    attention_mask: Optional[torch.Tensor],
    past_key_values = None,
    cache_position: Optional[torch.LongTensor] = None,
    compress_kwargs: dict | None = None,
    **kwargs,
):
    input_shape = hidden_states.shape[:-1]
    hidden_shape = (*input_shape, -1, self.head_dim)

    query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
    key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
    value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

    cos, sin = position_embeddings
    query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

    if past_key_values is not None:
        # sin and cos are specific to RoPE models; cache_position needed for the static cache
        cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
        key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)

    sdpa_kwargs = {}
    if hasattr(self, "num_key_value_groups"):
        if not use_gqa_in_sdpa(attention_mask, key_states):
            key_states = repeat_kv(key_states, self.num_key_value_groups)
            value_states = repeat_kv(value_states, self.num_key_value_groups)
        else:
            raise NotImplementedError(
                "Native GQA (Grouped Query Attention) is not supported in PyTorch's "
                "scaled_dot_product_attention as of PyTorch 2.4. "
                "Please ensure `use_gqa_in_sdpa` returns False, and rely on `repeat_kv` for GQA."
            )
            sdpa_kwargs = {"enable_gqa": True}

    is_causal = None
    is_causal = is_causal if is_causal is not None else getattr(self, "is_causal", True)
    is_causal = query_states.shape[2] > 1 and attention_mask is None and is_causal
    if torch.jit.is_tracing() and isinstance(is_causal, torch.Tensor):
        is_causal = is_causal.item()

    L, S = query_states.size(-2), key_states.size(-2)
    score_shape = query_states.shape[:-2] + (L, S)  # e.g., [B, H, L, S]
    scale_factor = 1 / math.sqrt(query_states.size(-1)) if self.scaling is None else self.scaling
    attn_bias = torch.zeros(score_shape, dtype=query_states.dtype, device=query_states.device)
    if is_causal:
        assert attention_mask is None
        temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0).to(attn_bias.device)
        attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))

    if attention_mask is not None:
        if attention_mask.dtype == torch.bool:
            attn_bias.masked_fill_(attention_mask.logical_not(), float("-inf"))
        else:
            attn_bias = attention_mask + attn_bias

    if sdpa_kwargs.get("enable_gqa", False):
        key_states = key_states.repeat_interleave(query_states.size(-3) // key_states.size(-3), -3)
        value_states = value_states.repeat_interleave(query_states.size(-3)// value_states.size(-3), -3)

    attn_weights = query_states @ key_states.transpose(-2, -1) * scale_factor
    attn_weights += attn_bias
    attn_weights = SoftmaxFunction.apply(
        attn_weights,
        -1,
        attn_weights.dtype,
        compress_kwargs if self.training else None,
    )


    dropout=0.0
    attn_weights = torch.dropout(attn_weights, dropout, train=True)
    attn_output = attn_weights @ value_states

    attn_output = attn_output.transpose(1, 2).contiguous()    

    attn_output = attn_output.reshape(*input_shape, -1).contiguous()
    attn_output = self.o_proj(attn_output)
    return attn_output, attn_weights


def llama3_selfattn_foward(
    self,
    hidden_states,
    position_embeddings,
    attention_mask,
    past_key_values,
    cache_position,
    compress_kwargs,
    **kwargs,
):
    input_shape = hidden_states.shape[:-1]
    hidden_shape = (*input_shape, -1, self.head_dim)

    query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
    key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
    value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

    cos, sin = position_embeddings
    query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

    if past_key_values is not None:
        # sin and cos are specific to RoPE models; cache_position needed for the static cache
        cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
        key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)


    # sdpa
    sdpa_kwargs = {}
    if hasattr(self, "num_key_value_groups"):
        if not use_gqa_in_sdpa(attention_mask, key_states):
            key_states = repeat_kv(key_states, self.num_key_value_groups)
            value_states = repeat_kv(value_states, self.num_key_value_groups)
        else:
            raise NotImplementedError(
                "Native GQA (Grouped Query Attention) is not supported in PyTorch's "
                "scaled_dot_product_attention as of PyTorch 2.4. "
                "Please ensure `use_gqa_in_sdpa` returns False, and rely on `repeat_kv` for GQA."
            )
            sdpa_kwargs = {"enable_gqa": True}

    is_causal = None
    is_causal = is_causal if is_causal is not None else getattr(self, "is_causal", True)
    is_causal = query_states.shape[2] > 1 and attention_mask is None and is_causal
    if torch.jit.is_tracing() and isinstance(is_causal, torch.Tensor):
        is_causal = is_causal.item()

    L, S = query_states.size(-2), key_states.size(-2)
    score_shape = query_states.shape[:-2] + (L, S)  # e.g., [B, H, L, S]
    scale_factor = 1 / math.sqrt(query_states.size(-1)) if self.scaling is None else self.scaling
    attn_bias = torch.zeros(score_shape, dtype=query_states.dtype, device=query_states.device)
    if is_causal:
        assert attention_mask is None
        temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0).to(attn_bias.device)
        attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))

    if attention_mask is not None:
        if attention_mask.dtype == torch.bool:
            attn_bias.masked_fill_(attention_mask.logical_not(), float("-inf"))
        else:
            attn_bias = attention_mask + attn_bias

    if sdpa_kwargs.get("enable_gqa", False):
        key_states = key_states.repeat_interleave(query_states.size(-3) // key_states.size(-3), -3)
        value_states = value_states.repeat_interleave(query_states.size(-3)// value_states.size(-3), -3)

    attn_weights = query_states @ key_states.transpose(-2, -1) * scale_factor
    attn_weights += attn_bias
    attn_weights = SoftmaxFunction.apply(
        attn_weights,
        -1,
        attn_weights.dtype,
        compress_kwargs if self.training else None,
    )
    dropout=0.0
    attn_weights = torch.dropout(attn_weights, dropout, train=True)
    attn_output = attn_weights @ value_states
    attn_output = attn_output.transpose(1, 2).contiguous()   
    
    attn_output = attn_output.reshape(*input_shape, -1).contiguous()
    attn_output = self.o_proj(attn_output)
    return attn_output, attn_weights