# coding=utf-8
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Qwen2 model configuration"""

from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_rope_utils import rope_config_validation
from transformers.utils import logging
import math
from math import sqrt

logger = logging.get_logger(__name__)


def _xavier_gain_to_std(gain, dim0, dim1):
    return gain * sqrt(2.0 / float(dim1 + dim0))


def _get_deepscale_value_std(dims, num_layers, layer_idx):
    def attn_block(r):
        r_out = 1 - p
        sigma_w1 = math.sqrt(math.sqrt((1 - p) / r) / dims)
        return sigma_w1, r_out

    def ffn_block(r):
        r_out = (1 - p) * (r + ((1 - r**2) ** 0.5 - r * math.acos(r)) / math.pi)
        return r_out

    p = 0.0  # no dropout
    lambda_sq = 1 - 2 / num_layers
    beta_sq = 2 / num_layers
    r0 = 0.221  # deus ex machina from Zipf
    sigma_new = 1.0
    sigma_attn_out_v_list = []
    r_in = r0 * (1 - p)

    r_list = []
    r_list.append(r_in)
    for _ in range(num_layers):
        sigma_attn_out_v, r_out = attn_block(r_in)
        r_in = (lambda_sq * r_in * sigma_new + beta_sq * r_out * 1.0) / (lambda_sq * sigma_new + beta_sq * 1.0)
        sigma_new = lambda_sq * sigma_new + beta_sq * 1.0
        sigma_attn_out_v_list.append(sigma_attn_out_v)
        r_out = ffn_block(r_in)
        r_in = (lambda_sq * r_in * sigma_new + beta_sq * r_out * 1.0) / (lambda_sq * sigma_new + beta_sq * 1.0)
        sigma_new = lambda_sq * sigma_new + beta_sq * 1.0
        r_list.append(r_in)

    return sigma_attn_out_v_list[layer_idx]


def get_factor_table(dim, intermed_dim, attn_head_dim, layer_idx=0, num_layers=16):
    """A bit weird to have this as a fn that just defines the dict, but I like the compact summary that this setup provides."""
    layer_idx = layer_idx % num_layers
    lookup = {
        "mitchell": {
            "embedding": 1 / sqrt(dim),  # or 1
            "head": 1 / sqrt(dim),
            "in_proj": 1 / sqrt(dim),
            "out_proj": 1 / sqrt(dim) / sqrt(2 * (layer_idx + 1)),
        },  # from Zhang-Titov-Sennrich
        "normal": {"std": 1 / sqrt(dim)},
        "llama": {
            "embedding": 1.0,  # even without truncation, just a normal normal_
            "in_proj": 0.02,
            "out_proj": 0.02 / sqrt(2 * (layer_idx + 1)),
            "head": 1 / sqrt(dim),
        },  # apply in_proj definitely per q,k,v # small variation of zhang-titov-sennrich
        "llama-by-dim": {
            "embedding": 1.0,  # even without truncation, just a normal normal_
            "in_proj": 1 / sqrt(dim),
            "out_proj": 1 / sqrt(dim) / sqrt(2 * (layer_idx + 1)),
            "head": 1 / sqrt(dim),
        },  # apply in_proj definitely per q,k,v # small variation of zhang-titov-sennrich
        "llama-by-dim-ls": {
            "embedding": 1.0,  # even without truncation in the original, just a normal normal_
            "in_proj": 1 / sqrt(dim),
            "out_proj": 1 / sqrt(dim) / sqrt(2 * (layer_idx + 1)),
            "head": 1.0,
            "logit_scale": 1 / sqrt(dim),
        },
        # the ffn gate (w2) also counts as out_proj (but not in "olmo"/mitchell)
        "kaiming": {"std": sqrt(2.0 / dim)},  # need to account for intermed_dim on ffn out
        "bert": {"std": 0.02},
        "megatron": {"std": 0.02, "out_proj": 0.02 / sqrt(dim), "embedding": 0.02, "head": 1 / sqrt(dim)},
        "megatron2": {"std": sqrt(1 / (3 * dim))},
        "small": {"std": sqrt(2 / (5 * dim))},  # nguyen & salazar
        "scaled": {
            "std": sqrt(2 / (5 * dim)),
            "out_proj": sqrt(2 / (5 * dim)) / sqrt(2 * num_layers),
        },  # Le Scao, Biderman,
        "scaled-stuck": {
            "std": sqrt(2 / (5 * dim)),
            "out_proj": sqrt(2 / (5 * dim)) / sqrt(2 * 16),
        },  # Le Scao, Biderman,
        "takase": {
            "std": sqrt(2 / (5 * dim)),
            "out_proj": sqrt(2 / (5 * dim)) / sqrt(2 * num_layers),
            "embedding": sqrt(2 / (5 * dim)),
            "embed_scale": sqrt(dim),
            # "logit_scale": sqrt(2 / 5) / sqrt(dim),  # if weight-tied
        },  # spike-no-more, Takase et al.
        "takase-scaled": {
            "std": sqrt(2 / (5 * dim)),
            "out_proj": sqrt(2 / (5 * dim)) / sqrt(2 * num_layers),
            "embedding": sqrt(2 / (5 * dim)),
            "embed_scale": sqrt(dim),
            "logit_scale": sqrt(2 / 5) / sqrt(dim),  # if weight-tied
        },
        "wang": {"std": 2 / num_layers / sqrt(dim)},  # Wang& Komatsuzaki
        "deepnorm-straight": {
            "embedding": 0.02,  # undef in original, taken from megatron
            "gain": pow(8 * num_layers, -0.25),
            "skip": pow(2 * num_layers, 0.25),
            "mlp": _xavier_gain_to_std(pow(8 * num_layers, -0.25), dim, intermed_dim),
            "out_proj": _xavier_gain_to_std(pow(8 * num_layers, -0.25), dim, intermed_dim),
            "v": _xavier_gain_to_std(pow(8 * num_layers, -0.25), dim, dim),
            "out_attn": _xavier_gain_to_std(pow(8 * num_layers, -0.25), dim, dim),
            "q": _xavier_gain_to_std(1.0, dim, attn_head_dim),
            "k": _xavier_gain_to_std(1.0, dim, attn_head_dim),
            "head": 1 / sqrt(dim),  # undef in original, taken from megatron
        },
        "deepnorm-subln": {
            "embedding": 0.02,  # undef in original, taken from megatron
            "gain": sqrt(math.log(2 * num_layers)),
            "skip": pow(2 * num_layers, 0.25),
            "mlp": _xavier_gain_to_std(sqrt(math.log(2 * num_layers)), dim, intermed_dim),
            "v": _xavier_gain_to_std(sqrt(math.log(2 * num_layers)), dim, dim),
            "out_attn": _xavier_gain_to_std(sqrt(math.log(2 * num_layers)), dim, dim),
            "q": _xavier_gain_to_std(1.0, dim, attn_head_dim),
            "k": _xavier_gain_to_std(1.0, dim, attn_head_dim),
            "head": 1 / sqrt(dim),  # undef in original, taken from megatron
        },
        "noci-anagnostidis": {
            "residual": 1 / sqrt(num_layers),
            "mlp": _xavier_gain_to_std(1.0, dim, intermed_dim),
            "v": _xavier_gain_to_std(1.0, dim, dim),
            "out_attn": _xavier_gain_to_std(1.0, dim, dim),
            "q": _xavier_gain_to_std(1.0, dim, attn_head_dim),
            "k": _xavier_gain_to_std(1.0, dim, attn_head_dim),
            "std": 1.0,
            "logit_scale": 1 / sqrt(dim),
        },
        "shaped": {
            "residual": 0.2,  # gamma=0.1 from appendix, or around 0.2 from main?
            "skip": sqrt(1 - 0.2**2),  # needs to fulfill residual**2 + skip**2 = 1
            "in_proj": 1 / sqrt(dim),
            "out_proj": 1 / sqrt(dim),
            "std": 1.0,
            "q": 1 / sqrt(dim),
            "k": 1 / sqrt(dim),
            "logit_scale": 1 / sqrt(dim),
        },  # should go with identity-shaped activation functions
        "deep-scale-simple": {
            "residual": sqrt(2 / num_layers),
            "skip": sqrt(1 - 2 / num_layers),
            "std": sqrt(1 / dim * sqrt(1 / 2)),
            "q": sqrt(1 / dim),
            "k": sqrt(1 / dim),
            "embedding": sqrt(1 / 3),
            "head": 1.0,
            "logit_scale": 1 / sqrt(dim),  # mentioned for BERT in 5.1
        },
        "deep-scale-full": {
            "residual": sqrt(2 / num_layers),
            "skip": sqrt(1 - 2 / num_layers),
            "std": sqrt(1 / dim * sqrt(1 / 2)),
            "q": sqrt(1 / dim),
            "k": sqrt(1 / dim),
            "v": _get_deepscale_value_std(dim, num_layers, layer_idx),  # compare to sqrt(1/d * sqrt(sigma)) ?
            "out_attn": _get_deepscale_value_std(dim, num_layers, layer_idx),
            "embedding": sqrt(1 / 3),  # technically just one because we have only one embedding type
            "head": 1.0,
            "logit_scale": 1 / sqrt(dim),  # mentioned for BERT in 5.1
        },
        "scaled-and-logit-scale": {
            "std": sqrt(2 / (5 * dim)),
            "out_proj": sqrt(2 / (5 * dim)) / sqrt(2 * num_layers),
            "embedding": 1.0,
            "head": 1 / sqrt(dim),  # not used due to weight tying
            "logit_scale": 1 / sqrt(dim),
        },  # Le Scao, Biderman,
        "bernstein": {"std": 1.0},  # handled elsewhere  # a special in the twitter to code pipeline
        "illiterate": {
            "embedding": 1.0,
            "std": sqrt(1 / dim),
            "out_proj": 0.0,
            "head": 0.0,
        },
        "scaled-large-embed": {
            "embedding": 1.0,
            "std": sqrt(1 / dim),
            "out_proj": sqrt(1 / dim) / sqrt(2 * num_layers),
        },
    }
    return lookup


class Qwen2Config(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of a [`Qwen2Model`]. It is used to instantiate a
    Qwen2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
    with the defaults will yield a similar configuration to that of
    Qwen2-7B-beta [Qwen/Qwen2-7B-beta](https://huggingface.co/Qwen/Qwen2-7B-beta).

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.


    Args:
        vocab_size (`int`, *optional*, defaults to 151936):
            Vocabulary size of the Qwen2 model. Defines the number of different tokens that can be represented by the
            `inputs_ids` passed when calling [`Qwen2Model`]
        hidden_size (`int`, *optional*, defaults to 4096):
            Dimension of the hidden representations.
        intermediate_size (`int`, *optional*, defaults to 22016):
            Dimension of the MLP representations.
        num_hidden_layers (`int`, *optional*, defaults to 32):
            Number of hidden layers in the Transformer encoder.
        num_attention_heads (`int`, *optional*, defaults to 32):
            Number of attention heads for each attention layer in the Transformer encoder.
        num_key_value_heads (`int`, *optional*, defaults to 32):
            This is the number of key_value heads that should be used to implement Grouped Query Attention. If
            `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
            `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
            converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
            by meanpooling all the original heads within that group. For more details checkout [this
            paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
        hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
            The non-linear activation function (function or string) in the decoder.
        max_position_embeddings (`int`, *optional*, defaults to 32768):
            The maximum sequence length that this model might ever be used with.
        initializer_range (`float`, *optional*, defaults to 0.02):
            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
        rms_norm_eps (`float`, *optional*, defaults to 1e-06):
            The epsilon used by the rms normalization layers.
        use_cache (`bool`, *optional*, defaults to `True`):
            Whether or not the model should return the last key/values attentions (not used by all models). Only
            relevant if `config.is_decoder=True`.
        tie_word_embeddings (`bool`, *optional*, defaults to `False`):
            Whether the model's input and output word embeddings should be tied.
        rope_theta (`float`, *optional*, defaults to 10000.0):
            The base period of the RoPE embeddings.
        rope_scaling (`Dict`, *optional*):
            Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
            and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
            accordingly.
            Expected contents:
                `rope_type` (`str`):
                    The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
                    'llama3'], with 'default' being the original RoPE implementation.
                `factor` (`float`, *optional*):
                    Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
                    most scaling types, a `factor` of x will enable the model to handle sequences of length x *
                    original maximum pre-trained length.
                `original_max_position_embeddings` (`int`, *optional*):
                    Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
                    pretraining.
                `attention_factor` (`float`, *optional*):
                    Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
                    computation. If unspecified, it defaults to value recommended by the implementation, using the
                    `factor` field to infer the suggested value.
                `beta_fast` (`float`, *optional*):
                    Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
                    ramp function. If unspecified, it defaults to 32.
                `beta_slow` (`float`, *optional*):
                    Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
                    ramp function. If unspecified, it defaults to 1.
                `short_factor` (`List[float]`, *optional*):
                    Only used with 'longrope'. The scaling factor to be applied to short contexts (<
                    `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
                    size divided by the number of attention heads divided by 2
                `long_factor` (`List[float]`, *optional*):
                    Only used with 'longrope'. The scaling factor to be applied to long contexts (<
                    `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
                    size divided by the number of attention heads divided by 2
                `low_freq_factor` (`float`, *optional*):
                    Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
                `high_freq_factor` (`float`, *optional*):
                    Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
        use_sliding_window (`bool`, *optional*, defaults to `False`):
            Whether to use sliding window attention.
        sliding_window (`int`, *optional*, defaults to 4096):
            Sliding window attention (SWA) window size. If not specified, will default to `4096`.
        max_window_layers (`int`, *optional*, defaults to 28):
            The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
        attention_dropout (`float`, *optional*, defaults to 0.0):
            The dropout ratio for the attention probabilities.

    ```python
    >>> from transformers import Qwen2Model, Qwen2Config

    >>> # Initializing a Qwen2 style configuration
    >>> configuration = Qwen2Config()

    >>> # Initializing a model from the Qwen2-7B style configuration
    >>> model = Qwen2Model(configuration)

    >>> # Accessing the model configuration
    >>> configuration = model.config
    ```"""

    model_type = "qwen2"
    keys_to_ignore_at_inference = ["past_key_values"]

    # Default tensor parallel plan for base model `Qwen2`
    base_model_tp_plan = {
        "layers.*.self_attn.q_proj": "colwise",
        "layers.*.self_attn.k_proj": "colwise",
        "layers.*.self_attn.v_proj": "colwise",
        "layers.*.self_attn.o_proj": "rowwise",
        "layers.*.mlp.gate_proj": "colwise",
        "layers.*.mlp.up_proj": "colwise",
        "layers.*.mlp.down_proj": "rowwise",
    }

    def __init__(
        self,
        vocab_size=151936,
        hidden_size=4096,
        intermediate_size=22016,
        num_hidden_layers=32,
        num_attention_heads=32,
        num_key_value_heads=32,
        hidden_act="silu",
        max_position_embeddings=32768,
        initializer_range=0.02,
        rms_norm_eps=1e-6,
        use_cache=True,
        tie_word_embeddings=False,
        rope_theta=10000.0,
        rope_scaling=None,
        use_sliding_window=False,
        sliding_window=4096,
        max_window_layers=28,
        attention_dropout=0.0,
        recur_times=1,
        num_prelude_layers=4,
        num_coda_layers=4,
        recur_stategy="blockwise",
        input_injection_type="None",
        state_init_strategy="None",
        init_std="takase",
        attn_to_recur_key_values=False,
        ln_after_recur=False,
        **kwargs,
    ):
        self.vocab_size = vocab_size
        self.max_position_embeddings = max_position_embeddings
        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.use_sliding_window = use_sliding_window
        self.sliding_window = sliding_window if use_sliding_window else None
        self.max_window_layers = max_window_layers

        # for backward compatibility
        if num_key_value_heads is None:
            num_key_value_heads = num_attention_heads

        self.num_key_value_heads = num_key_value_heads
        self.hidden_act = hidden_act
        self.initializer_range = initializer_range
        self.rms_norm_eps = rms_norm_eps
        self.use_cache = use_cache
        self.rope_theta = rope_theta
        self.rope_scaling = rope_scaling
        self.attention_dropout = attention_dropout
        self.recur_times = recur_times
        self.num_prelude_layers = num_prelude_layers
        self.num_coda_layers = num_coda_layers
        self.recur_stategy = recur_stategy
        self.input_injection_type = input_injection_type
        self.state_init_strategy = state_init_strategy
        self.init_std = init_std
        self.attn_to_recur_key_values = attn_to_recur_key_values
        self.ln_after_recur = ln_after_recur
        # Validate the correctness of rotary position embeddings parameters
        # BC: if there is a 'type' field, move it to 'rope_type'.
        if self.rope_scaling is not None and "type" in self.rope_scaling:
            self.rope_scaling["rope_type"] = self.rope_scaling["type"]
        rope_config_validation(self)

        super().__init__(
            tie_word_embeddings=tie_word_embeddings,
            **kwargs,
        )
    
    def get_std(self, name_of_layer: str, layer_idx: int = -1):
        init_table = get_factor_table(self.hidden_size, self.intermediate_size, layer_idx, self.num_hidden_layers)[self.init_std]
        if name_of_layer in init_table:
            std = init_table[name_of_layer]
        elif "out_proj" in init_table and name_of_layer in ["out_attn", "w2", "w3"]:
            std = init_table["out_proj"]
        elif "in_proj" in init_table and name_of_layer in ["q", "k", "v", "w1"]:  # v is debated
            std = init_table["in_proj"]
        elif "mlp" in init_table and name_of_layer in ["w1", "w2", "w3", "mlp"]:
            std = init_table["mlp"]
        else:
            std = init_table["std"]
        return std