import torch
import torch.nn as nn
from typing import Optional, Tuple

from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
from transformers.models.qwen2.modeling_qwen2 import (
    Qwen2PreTrainedModel,
    Qwen2Attention,
    Qwen2MLP,
    Qwen2RMSNorm,
    Qwen2RotaryEmbedding,
    repeat_kv,
    apply_rotary_pos_emb,
)
from transformers.integrations.sdpa_attention import sdpa_attention_forward
from torch.nn import functional as F


class MultiHeadRMSNorm(nn.Module):
    def __init__(self, dim, heads=1):
        super().__init__()
        self.scale = dim**0.5
        self.gamma = nn.Parameter(torch.ones(heads, 1, dim))

    def forward(self, x):
        return F.normalize(x, dim=-1) * self.gamma * self.scale


class Qwen2BidirectionalSdpaAttention(Qwen2Attention):
    """
    An SDPA-based attention that does NOT apply causal masking.
    Inherits from Qwen2Attention, but sets self.is_causal = False.
    """

    def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None):
        super().__init__(config, layer_idx)
        self.is_causal = False
        self.qk_norm = config.qk_norm
        if self.qk_norm:
            self.q_norm = MultiHeadRMSNorm(
                config.hidden_size // config.num_attention_heads,
                config.num_attention_heads,
            )
            self.k_norm = MultiHeadRMSNorm(
                config.hidden_size // config.num_attention_heads,
                config.num_key_value_heads,
            )

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
    ):
        input_shape = hidden_states.shape[:-1]
        hidden_shape = (*input_shape, -1, self.head_dim)

        query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

        if position_embeddings is not None:
            cos, sin = position_embeddings
            query_states, key_states = apply_rotary_pos_emb(
                query_states, key_states, cos, sin
            )

        if self.qk_norm:
            query_states = self.q_norm(query_states)
            key_states = self.k_norm(key_states)

        attn_output, attn_weights = sdpa_attention_forward(
            self,
            query_states,
            key_states,
            value_states,
            attention_mask=None,
            dropout=0.0 if not self.training else self.attention_dropout,
            scaling=self.scaling,
            is_causal=False,
        )

        attn_output = attn_output.reshape(*input_shape, -1).contiguous()
        attn_output = self.o_proj(attn_output)
        return attn_output


class Qwen2EncoderLayer(nn.Module):
    def __init__(self, config: Qwen2Config, layer_idx: int):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.self_attn = Qwen2BidirectionalSdpaAttention(config, layer_idx)
        self.mlp = Qwen2MLP(config)

        self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = Qwen2RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
    ):
        # Norm + Self-Attn
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)

        hidden_states = self.self_attn(
            hidden_states=hidden_states,
            position_embeddings=position_embeddings,
        )
        hidden_states = residual + hidden_states

        # Norm + MLP
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        return hidden_states


class Qwen2Encoder(Qwen2PreTrainedModel):
    supports_gradient_checkpointing = True

    def __init__(self, config: Qwen2Config):
        super().__init__(config)
        self.layers = nn.ModuleList(
            [Qwen2EncoderLayer(config, i) for i in range(self.config.num_hidden_layers)]
        )
        if config.rope:
            self.rotary_emb = Qwen2RotaryEmbedding(config=config)
        else:
            self.rotary_emb = None
        if hasattr(config, "norm") and config.norm:
            self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        else:
            self.norm = None
        self.gradient_checkpointing = True
        self.post_init()

    def _init_weights(self, module):
        std = self.config.initializer_range
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.bias is not None:
                module.bias.data.zero_()

    def forward(self, hidden_states):
        bsz, seq_len, _ = hidden_states.size()
        position_ids = torch.arange(seq_len, device=hidden_states.device).unsqueeze(0)

        # Compute RoPE embeddings once, shared across layers
        if self.rotary_emb is not None:
            position_embeddings = self.rotary_emb(hidden_states, position_ids)
        else:
            position_embeddings = None

        for layer in self.layers:
            if self.gradient_checkpointing and self.training:
                hidden_states = self._gradient_checkpointing_func(
                    layer.__call__,
                    hidden_states,
                    position_embeddings,
                )
            else:
                hidden_states = layer(
                    hidden_states,
                    position_embeddings=position_embeddings,
                )
        if self.norm:
            hidden_states = self.norm(hidden_states)
        return hidden_states
