#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
#           This file was automatically generated from src/transformers/models/qwen3_moe/modular_qwen3_moe.py.
#               Do NOT edit this file manually as any edits will be overwritten by the generation of
#             the file from the modular. If any change should be done, please apply the change to the
#                          modular_qwen3_moe.py file directly. One of our CI enforces this.
#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, Optional, Tuple
from einops import rearrange
import math

import torch
from torch import nn

from transformers.generation import GenerationMixin
from transformers.utils import logging
from transformers.models.qwen3_moe.modeling_qwen3_moe import (
    Qwen3MoeForCausalLM,
    Qwen3MoeModel,
    Qwen3MoeRotaryEmbedding,
    Qwen3MoeMLP,
    Qwen3MoeRMSNorm,
    Qwen3MoeSparseMoeBlock,
    Qwen3MoeDecoderLayer,
    apply_rotary_pos_emb,
    eager_attention_forward
)
from transformers.cache_utils import Cache
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.processing_utils import Unpack
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS

from .configuration_qwen3_mea_moe import Qwen3MEAMoeConfig

logger = logging.get_logger(__name__)


class Qwen3MEAMoEHeadLinearCombine(nn.Module):
    def __init__(self,
                 heads: int,
                 new_heads: int,
                 n_wise: int = -1,
                 expand_only: bool = False,
                 ):
        super().__init__()

        self.heads = heads
        self.new_heads = new_heads
        self.n_wise = n_wise
        self.exband_only = expand_only

        if n_wise != -1:
            assert heads % n_wise == 0, "heads must be even when pairwise is True"
            ratio = heads // new_heads
            self.weight = nn.Parameter(torch.randn(new_heads // ratio // n_wise, ratio, n_wise))
        else:
            self.weight = nn.Parameter(torch.randn(1, new_heads, heads))

        if expand_only:
            assert n_wise == -1, "expand_only and n_wise cannot be used together"
            assert new_heads % heads == 0, "new_heads must be divisible by heads"

    def forward(self, x):
        """
        x: (bs, seqlen, heads, head_dim)
        return: (bs, seqlen, new_heads, head_dim)
        """

        if self.exband_only:
            out = x.expand(-1, -1, self.new_heads // self.heads, -1)
            return out

        if self.n_wise != -1:
            rearranged_x = rearrange(x, "b s (g t) d -> b s g t d", t=2)
        else:
            rearranged_x = rearrange(x, "b s (g t) d -> b s g t d", g=1)

        out = torch.einsum("bsgtd,grt->bsgrd", rearranged_x, self.weight)
        out = rearrange(out, "b s g t d -> b s (g t) d")
        return out

    def __repr__(self):
        return f"DeepSeekV3MoEHeadLinearCombine(in_features={self.heads},out_features={self.new_heads})"


def yarn_get_mscale(scale=1, mscale=1):
    if scale <= 1:
        return 1.0
    return 0.1 * mscale * math.log(scale) + 1.0


class Qwen3MEAMoEAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(self, config: Qwen3MEAMoeConfig, layer_idx: int):
        super().__init__()
        self.config = config
        self.layer_idx = layer_idx
        self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
        self.scaling = self.head_dim ** -0.5
        self.attention_dropout = config.attention_dropout
        self.is_causal = True

        if self.layer_idx in config.non_mea_layers:
            self.pattern = ""
            self.num_k_base = 0
            self.num_v_base = 0
        else:
            self.pattern = config.pattern
            self.num_k_base = config.num_k_base
            self.num_v_base = config.num_v_base
        
        self.with_qknorm = config.with_qknorm
        if self.with_qknorm:
            self.q_norm = Qwen3MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps)  # unlike olmo, only on the head dim!
            self.k_norm = Qwen3MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps)  # thus post q_norm does not need reshape

        if self.pattern == "nMLA":
            self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
            self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim,
                                    bias=config.qkv_bias)
            self.k_proj = nn.Linear(config.hidden_size, config.num_naive_MLA_kv_heads * self.head_dim,
                                    bias=config.qkv_bias)
            self.v_proj = nn.Linear(config.hidden_size, config.num_naive_MLA_kv_heads * self.head_dim,
                                    bias=config.qkv_bias)
            self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.bias)
            self.c2k_proj = nn.Linear(2 * config.num_naive_MLA_kv_heads * self.head_dim,
                                      config.num_key_value_heads * self.head_dim, bias=config.bias)
            self.c2v_proj = nn.Linear(2 * config.num_naive_MLA_kv_heads * self.head_dim,
                                      config.num_key_value_heads * self.head_dim, bias=config.bias)
        elif self.pattern == "LC":
            self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
            self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim,
                                    bias=config.qkv_bias)
            self.k_proj = nn.Linear(config.hidden_size, config.num_k_base * self.head_dim, bias=config.qkv_bias)
            self.v_proj = nn.Linear(config.hidden_size, config.num_v_base * self.head_dim, bias=config.qkv_bias)
            self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.bias)
            if self.num_k_base:
                self.k_proj_comb = Qwen3MEAMoEHeadLinearCombine(config.num_k_base, config.num_key_value_heads,
                                                                n_wise=config.k_n_wise,
                                                                expand_only=config.k_expand_only)
            if self.num_v_base:
                self.v_proj_comb = Qwen3MEAMoEHeadLinearCombine(config.num_v_base, config.num_key_value_heads,
                                                                n_wise=config.v_n_wise,
                                                                expand_only=config.v_expand_only)
        else:
            self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
            self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim,
                                    bias=config.qkv_bias)
            self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim,
                                    bias=config.qkv_bias)
            self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim,
                                    bias=config.qkv_bias)
            self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.bias)

        self.apply_pre_attn_group_norm = False
        self.apply_post_attn_group_norm = False

        if config.pre_attn_group_norm:
            self.apply_pre_attn_group_norm = True
            self.pre_attn_group_RMS = Qwen3MoeRMSNorm(config.head_dim, eps=config.rms_norm_eps)

        if self.pattern == "DFM":
            assert config.num_key_value_heads % 2 == 0, "num_key_value_heads must be even in DFM pattern."
            self.lambda_init = 0.8

            self.lambda_q1 = nn.Parameter(torch.full((self.head_dim,), self.lambda_init), requires_grad=True)
            self.lambda_q2 = nn.Parameter(torch.full((self.head_dim,), self.lambda_init), requires_grad=True)

            self.lambda_k1 = nn.Parameter(torch.full((self.head_dim,), self.lambda_init), requires_grad=True)
            self.lambda_k2 = nn.Parameter(torch.full((self.head_dim,), self.lambda_init), requires_grad=True)

        if config.post_attn_group_norm:
            self.apply_post_attn_group_norm = True
            if self.pattern == "DFM":
                self.post_attn_group_RMS = Qwen3MoeRMSNorm(self.head_dim * 2, eps=config.rms_norm_eps)
            else:
                self.post_attn_group_RMS = Qwen3MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps)

        self.scaling = self.head_dim ** (-0.5)
        if self.config.rope_scaling is not None:
            mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
            scaling_factor = self.config.rope_scaling["factor"]
            if mscale_all_dim:
                mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
                self.scaling = self.scaling * mscale * mscale

    def forward(
            self,
            hidden_states: torch.Tensor,
            position_embeddings: Tuple[torch.Tensor, torch.Tensor],
            attention_mask: Optional[torch.Tensor],
            past_key_value: Optional[Cache] = None,
            cache_position: Optional[torch.LongTensor] = None,
            **kwargs: Unpack[FlashAttentionKwargs],
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        input_shape = hidden_states.shape[:-1]
        hidden_shape = (*input_shape, -1, self.head_dim)

        if self.with_qknorm:
            query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
            key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        else:
            query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
            key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

        if self.pattern == "nMLA":
            key_states = key_states.transpose(1, 2).view(*input_shape, -1)
            value_states = value_states.transpose(1, 2).view(*input_shape, -1)
            context_states = torch.concat([key_states, value_states], dim=-1)
            key_states = self.c2k_proj(context_states).view(*input_shape, -1, self.head_dim).transpose(1, 2)
            value_states = self.c2v_proj(context_states).view(*input_shape, -1, self.head_dim).transpose(1, 2)
        elif self.pattern == "LC":
            if self.num_k_base:
                key_states = self.k_proj_comb(key_states.transpose(1, 2)).transpose(1, 2)
            if self.num_v_base:
                value_states = self.v_proj_comb(value_states.transpose(1, 2)).transpose(1, 2)
        
        if self.with_qknorm:
            key_states = self.k_norm(key_states.transpose(1, 2)).transpose(1, 2)

        if self.apply_pre_attn_group_norm:
            value_states = self.pre_attn_group_RMS(value_states)

        cos, sin = position_embeddings
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        if past_key_value is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        attention_interface: Callable = eager_attention_forward
        if self.config._attn_implementation != "eager":
            if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
                logger.warning_once(
                    "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
                    'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
                )
            else:
                attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]

        if self.config._attn_implementation == "sdpa":
            raise NotImplementedError(
                "sdpa is not implemented yet. Please use eager or flash attention."
            )

        if self.pattern == "DFM":
            q1, q2 = torch.chunk(query_states, 2, dim=1)
            k1, k2 = torch.chunk(key_states, 2, dim=1)
            v1, v2 = torch.chunk(value_states, 2, dim=1)

            attn_output11, attn_weights11 = attention_interface(
                self,
                q1,
                k1,
                v1,
                attention_mask=attention_mask,
                dropout=0.0 if not self.training else self.attention_dropout,
                scaling=self.scaling,
                **kwargs,
            )

            attn_output12, _ = attention_interface(
                self,
                q1,
                k1,
                v2,
                attention_mask=attention_mask,
                dropout=0.0 if not self.training else self.attention_dropout,
                scaling=self.scaling,
                **kwargs,
            )

            attn_output1 = torch.concat([attn_output11, attn_output12], dim=-1)
            attn_weights1 = attn_weights11

            attn_output21, attn_weights21 = attention_interface(
                self,
                q2,
                k2,
                v1,
                attention_mask=attention_mask,
                dropout=0.0 if not self.training else self.attention_dropout,
                scaling=self.scaling,
                **kwargs,
            )

            attn_output22, _ = attention_interface(
                self,
                q2,
                k2,
                v2,
                attention_mask=attention_mask,
                dropout=0.0 if not self.training else self.attention_dropout,
                scaling=self.scaling,
                **kwargs,
            )

            attn_output2 = torch.concat([attn_output21, attn_output22], dim=-1)
            attn_weights2 = attn_weights21

            lambda_ = self.lambda_init

            lambda_ += torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1))
            lambda_ -= torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2))

            attn_output = attn_output1 - lambda_ * attn_output2
            attn_weights = attn_weights1 - lambda_ * attn_weights2

        else:
            attn_output, attn_weights = attention_interface(
                self,
                query_states,
                key_states,
                value_states,
                attention_mask,
                dropout=0.0 if not self.training else self.attention_dropout,
                scaling=self.scaling,
                **kwargs,
            )

        if self.apply_post_attn_group_norm:
            attn_output = self.post_attn_group_RMS(attn_output)

        attn_output = attn_output.reshape(*input_shape, -1).contiguous()
        attn_output = self.o_proj(attn_output)
        return attn_output, attn_weights


class Qwen3MEAMoeDecoderLayer(Qwen3MoeDecoderLayer):
    def __init__(self, config: Qwen3MEAMoeConfig, layer_idx: int):
        super(Qwen3MoeDecoderLayer, self).__init__()
        self.hidden_size = config.hidden_size

        self.self_attn = Qwen3MEAMoEAttention(config, layer_idx)
        self.mlp = Qwen3MoeMLP(config)

        if (layer_idx not in config.mlp_only_layers) and (
                config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0
        ):
            self.mlp = Qwen3MoeSparseMoeBlock(config)
        else:
            self.mlp = Qwen3MoeMLP(config, intermediate_size=config.intermediate_size)

        self.input_layernorm = Qwen3MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = Qwen3MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)


class Qwen3MEAMoeModel(Qwen3MoeModel):
    """
    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen3MoeDecoderLayer`]

    Args:
        config: Qwen3MoeConfig
    """
    def __init__(self, config: Qwen3MEAMoeConfig):
        super(Qwen3MoeModel, self).__init__(config)
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
        self.layers = nn.ModuleList(
            [Qwen3MEAMoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
        )
        self.norm = Qwen3MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.rotary_emb = Qwen3MoeRotaryEmbedding(config=config)
        self.gradient_checkpointing = False

        # Initialize weights and apply final processing
        self.post_init()


class Qwen3MEAMoeForCausalLM(Qwen3MoeForCausalLM):
    config_class = Qwen3MEAMoeConfig

    def __init__(self, config):
        super(Qwen3MoeForCausalLM, self).__init__(config)
        self.model = Qwen3MEAMoeModel(config)
        self.vocab_size = config.vocab_size
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        self.router_aux_loss_coef = config.router_aux_loss_coef
        self.num_experts = config.num_experts
        self.num_experts_per_tok = config.num_experts_per_tok

        # Initialize weights and apply final processing
        self.post_init()


__all__ = [
    "Qwen3MEAMoeForCausalLM",
    "Qwen3MEAMoeModel",
]
