














import math 
from typing import List ,Optional ,Tuple ,Union 

import torch 
import torch .utils .checkpoint 
from torch import nn 

from transformers .activations import ACT2FN 
from transformers .cache_utils import Cache ,DynamicCache ,StaticCache 
from transformers .generation import GenerationMixin 
from transformers .modeling_attn_mask_utils import AttentionMaskConverter 
from transformers .modeling_flash_attention_utils import (
FlashAttentionKwargs ,
_flash_attention_forward ,
)
from transformers .modeling_outputs import (
BaseModelOutputWithPast ,
CausalLMOutputWithPast ,
QuestionAnsweringModelOutput ,
SequenceClassifierOutputWithPast ,
TokenClassifierOutput ,
)
from transformers .modeling_rope_utils import ROPE_INIT_FUNCTIONS 
from transformers .modeling_utils import PreTrainedModel 
from transformers .processing_utils import Unpack 
from transformers .pytorch_utils import ALL_LAYERNORM_LAYERS 
from transformers .utils import (
LossKwargs ,
add_code_sample_docstrings ,
add_start_docstrings ,
add_start_docstrings_to_model_forward ,
is_flash_attn_greater_or_equal_2_10 ,
logging ,
replace_return_docstrings ,
)
from .configuration_internlm3 import InternLM3Config 


logger =logging .get_logger (__name__ )

_CONFIG_FOR_DOC ="InternLM3Config"


class InternLM3RMSNorm (nn .Module ):
    def __init__ (self ,hidden_size ,eps =1e-6 ):
        """
        InternLM3RMSNorm is equivalent to T5LayerNorm
        """
        super ().__init__ ()
        self .weight =nn .Parameter (torch .ones (hidden_size ))
        self .variance_epsilon =eps 

    def forward (self ,hidden_states ):
        input_dtype =hidden_states .dtype 
        hidden_states =hidden_states .to (torch .float32 )
        variance =hidden_states .pow (2 ).mean (-1 ,keepdim =True )
        hidden_states =hidden_states *torch .rsqrt (variance +self .variance_epsilon )
        return self .weight *hidden_states .to (input_dtype )

    def extra_repr (self ):
        return f"{tuple (self .weight .shape )}, eps={self .variance_epsilon }"


ALL_LAYERNORM_LAYERS .append (InternLM3RMSNorm )


class InternLM3RotaryEmbedding (nn .Module ):
    def __init__ (
    self ,
    dim =None ,
    max_position_embeddings =2048 ,
    base =10000 ,
    device =None ,
    scaling_factor =1.0 ,
    rope_type ="default",
    config :Optional [InternLM3Config ]=None ,
    ):
        super ().__init__ ()

        self .rope_kwargs ={}
        if config is None :
            logger .warning_once (
            "`InternLM3RotaryEmbedding` can now be fully parameterized by passing the model config through the "
            "`config` argument. All other arguments will be removed in v4.46"
            )
            self .rope_kwargs ={
            "rope_type":rope_type ,
            "factor":scaling_factor ,
            "dim":dim ,
            "base":base ,
            "max_position_embeddings":max_position_embeddings ,
            }
            self .rope_type =rope_type 
            self .max_seq_len_cached =max_position_embeddings 
            self .original_max_seq_len =max_position_embeddings 
        else :

            if config .rope_scaling is not None :
                self .rope_type =config .rope_scaling .get (
                "rope_type",config .rope_scaling .get ("type")
                )
            else :
                self .rope_type ="default"
            self .max_seq_len_cached =config .max_position_embeddings 
            self .original_max_seq_len =config .max_position_embeddings 

        self .config =config 
        self .rope_init_fn =ROPE_INIT_FUNCTIONS [self .rope_type ]

        inv_freq ,self .attention_scaling =self .rope_init_fn (
        self .config ,device ,**self .rope_kwargs 
        )
        self .register_buffer ("inv_freq",inv_freq ,persistent =False )
        self .original_inv_freq =self .inv_freq 

    def _dynamic_frequency_update (self ,position_ids ,device ):
        """
        dynamic RoPE layers should recompute `inv_freq` in the following situations:
        1 - growing beyond the cached sequence length (allow scaling)
        2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
        """
        seq_len =torch .max (position_ids )+1 
        if seq_len >self .max_seq_len_cached :
            inv_freq ,self .attention_scaling =self .rope_init_fn (
            self .config ,device ,seq_len =seq_len ,**self .rope_kwargs 
            )
            self .register_buffer (
            "inv_freq",inv_freq ,persistent =False 
            )
            self .max_seq_len_cached =seq_len 

        if (
        seq_len <self .original_max_seq_len 
        and self .max_seq_len_cached >self .original_max_seq_len 
        ):
            self .register_buffer ("inv_freq",self .original_inv_freq ,persistent =False )
            self .max_seq_len_cached =self .original_max_seq_len 

    @torch .no_grad ()
    def forward (self ,x ,position_ids ):
        if "dynamic"in self .rope_type :
            self ._dynamic_frequency_update (position_ids ,device =x .device )


        inv_freq_expanded =(
        self .inv_freq [None ,:,None ].float ().expand (position_ids .shape [0 ],-1 ,1 )
        )
        position_ids_expanded =position_ids [:,None ,:].float ()

        device_type =x .device .type 
        device_type =(
        device_type 
        if isinstance (device_type ,str )and device_type !="mps"
        else "cpu"
        )
        with torch .autocast (device_type =device_type ,enabled =False ):
            freqs =(
            inv_freq_expanded .float ()@position_ids_expanded .float ()
            ).transpose (1 ,2 )
            emb =torch .cat ((freqs ,freqs ),dim =-1 )
            cos =emb .cos ()
            sin =emb .sin ()


        cos =cos *self .attention_scaling 
        sin =sin *self .attention_scaling 

        return cos .to (dtype =x .dtype ),sin .to (dtype =x .dtype )


class InternLM3LinearScalingRotaryEmbedding (InternLM3RotaryEmbedding ):
    """InternLM3RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""

    def __init__ (self ,*args ,**kwargs ):
        logger .warning_once (
        "`InternLM3LinearScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
        "`InternLM3RotaryEmbedding`, which now also does linear scaling (simply pass the model config to __init__)."
        )
        kwargs ["rope_type"]="linear"
        super ().__init__ (*args ,**kwargs )


class InternLM3DynamicNTKScalingRotaryEmbedding (InternLM3RotaryEmbedding ):
    """InternLM3RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""

    def __init__ (self ,*args ,**kwargs ):
        logger .warning_once (
        "`InternLM3DynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
        "`InternLM3RotaryEmbedding`, which now also does dynamic ntk scaling (simply pass the model config to "
        "__init__)."
        )
        kwargs ["rope_type"]="dynamic"
        super ().__init__ (*args ,**kwargs )


def rotate_half (x ):
    """Rotates half the hidden dims of the input."""
    x1 =x [...,:x .shape [-1 ]//2 ]
    x2 =x [...,x .shape [-1 ]//2 :]
    return torch .cat ((-x2 ,x1 ),dim =-1 )


def apply_rotary_pos_emb (q ,k ,cos ,sin ,position_ids =None ,unsqueeze_dim =1 ):
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`, *optional*):
            Deprecated and unused.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    cos =cos .unsqueeze (unsqueeze_dim )
    sin =sin .unsqueeze (unsqueeze_dim )
    q_embed =(q *cos )+(rotate_half (q )*sin )
    k_embed =(k *cos )+(rotate_half (k )*sin )
    return q_embed ,k_embed 


class InternLM3MLP (nn .Module ):
    def __init__ (self ,config ):
        super ().__init__ ()
        self .config =config 
        self .hidden_size =config .hidden_size 
        self .intermediate_size =config .intermediate_size 
        self .gate_proj =nn .Linear (
        self .hidden_size ,self .intermediate_size ,bias =config .bias 
        )
        self .up_proj =nn .Linear (
        self .hidden_size ,self .intermediate_size ,bias =config .bias 
        )
        self .down_proj =nn .Linear (
        self .intermediate_size ,self .hidden_size ,bias =config .bias 
        )
        self .act_fn =ACT2FN [config .hidden_act ]

    def forward (self ,x ):
        down_proj =self .down_proj (self .act_fn (self .gate_proj (x ))*self .up_proj (x ))
        return down_proj 


def repeat_kv (hidden_states :torch .Tensor ,n_rep :int )->torch .Tensor :
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch ,num_key_value_heads ,slen ,head_dim =hidden_states .shape 
    if n_rep ==1 :
        return hidden_states 
    hidden_states =hidden_states [:,:,None ,:,:].expand (
    batch ,num_key_value_heads ,n_rep ,slen ,head_dim 
    )
    return hidden_states .reshape (batch ,num_key_value_heads *n_rep ,slen ,head_dim )


class InternLM3Attention (nn .Module ):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__ (self ,config :InternLM3Config ,layer_idx :Optional [int ]=None ):
        super ().__init__ ()
        self .config =config 
        self .layer_idx =layer_idx 
        if layer_idx is None :
            logger .warning_once (
            f"Instantiating {self .__class__ .__name__ } without passing a `layer_idx` is not recommended and will "
            "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
            "when creating this class."
            )

        self .attention_dropout =config .attention_dropout 
        self .hidden_size =config .hidden_size 
        self .num_heads =config .num_attention_heads 
        self .head_dim =getattr (config ,"head_dim",self .hidden_size //self .num_heads )
        self .num_key_value_heads =config .num_key_value_heads 
        self .num_key_value_groups =self .num_heads //self .num_key_value_heads 
        self .max_position_embeddings =config .max_position_embeddings 
        self .rope_theta =config .rope_theta 
        self .is_causal =True 

        self .q_proj =nn .Linear (
        self .hidden_size ,self .num_heads *self .head_dim ,bias =config .qkv_bias 
        )
        self .k_proj =nn .Linear (
        self .hidden_size ,
        self .num_key_value_heads *self .head_dim ,
        bias =config .qkv_bias ,
        )
        self .v_proj =nn .Linear (
        self .hidden_size ,
        self .num_key_value_heads *self .head_dim ,
        bias =config .qkv_bias ,
        )
        self .o_proj =nn .Linear (
        self .num_heads *self .head_dim ,self .hidden_size ,bias =config .bias 
        )


        self .rotary_emb =InternLM3RotaryEmbedding (config =self .config )

    def forward (
    self ,
    hidden_states :torch .Tensor ,
    attention_mask :Optional [torch .Tensor ]=None ,
    position_ids :Optional [torch .LongTensor ]=None ,
    past_key_value :Optional [Cache ]=None ,
    output_attentions :bool =False ,
    use_cache :bool =False ,
    cache_position :Optional [torch .LongTensor ]=None ,
    position_embeddings :Optional [
    Tuple [torch .Tensor ,torch .Tensor ]
    ]=None ,
    **kwargs ,
    )->Tuple [torch .Tensor ,Optional [torch .Tensor ],Optional [Tuple [torch .Tensor ]]]:
        bsz ,q_len ,_ =hidden_states .size ()

        query_states =self .q_proj (hidden_states )
        key_states =self .k_proj (hidden_states )
        value_states =self .v_proj (hidden_states )


        query_states =query_states .view (bsz ,q_len ,-1 ,self .head_dim ).transpose (1 ,2 )
        key_states =key_states .view (bsz ,q_len ,-1 ,self .head_dim ).transpose (1 ,2 )
        value_states =value_states .view (bsz ,q_len ,-1 ,self .head_dim ).transpose (1 ,2 )

        if position_embeddings is None :
            logger .warning_once (
            "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
            "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
            "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
            "removed and `position_embeddings` will be mandatory."
            )
            cos ,sin =self .rotary_emb (value_states ,position_ids )
        else :
            cos ,sin =position_embeddings 
        query_states ,key_states =apply_rotary_pos_emb (
        query_states ,key_states ,cos ,sin 
        )

        if past_key_value is not None :

            cache_kwargs ={"sin":sin ,"cos":cos ,"cache_position":cache_position }
            key_states ,value_states =past_key_value .update (
            key_states ,value_states ,self .layer_idx ,cache_kwargs 
            )

        key_states =repeat_kv (key_states ,self .num_key_value_groups )
        value_states =repeat_kv (value_states ,self .num_key_value_groups )
        attn_weights =torch .matmul (
        query_states ,key_states .transpose (2 ,3 )
        )/math .sqrt (self .head_dim )

        if attention_mask is not None :
            causal_mask =attention_mask [:,:,:,:key_states .shape [-2 ]]
            attn_weights =attn_weights +causal_mask 


        attn_weights =nn .functional .softmax (
        attn_weights ,dim =-1 ,dtype =torch .float32 
        ).to (query_states .dtype )
        attn_weights =nn .functional .dropout (
        attn_weights ,p =self .attention_dropout ,training =self .training 
        )
        attn_output =torch .matmul (attn_weights ,value_states )

        if attn_output .size ()!=(bsz ,self .num_heads ,q_len ,self .head_dim ):
            raise ValueError (
            f"`attn_output` should be of size {(bsz ,self .num_heads ,q_len ,self .head_dim )}, but is"
            f" {attn_output .size ()}"
            )

        attn_output =attn_output .transpose (1 ,2 ).contiguous ()

        attn_output =attn_output .reshape (bsz ,q_len ,-1 )

        attn_output =self .o_proj (attn_output )

        if not output_attentions :
            attn_weights =None 

        return attn_output ,attn_weights ,past_key_value 


class InternLM3FlashAttention2 (InternLM3Attention ):
    """
    InternLM3 flash attention module. This module inherits from `InternLM3Attention` as the weights of the module stays
    untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
    flash attention and deal with padding tokens in case the input contains any of them.
    """

    def __init__ (self ,*args ,**kwargs ):
        super ().__init__ (*args ,**kwargs )




        self ._flash_attn_uses_top_left_mask =not is_flash_attn_greater_or_equal_2_10 ()

    def forward (
    self ,
    hidden_states :torch .Tensor ,
    attention_mask :Optional [torch .LongTensor ]=None ,
    position_ids :Optional [torch .LongTensor ]=None ,
    past_key_value :Optional [Cache ]=None ,
    output_attentions :bool =False ,
    use_cache :bool =False ,
    cache_position :Optional [torch .LongTensor ]=None ,
    position_embeddings :Optional [
    Tuple [torch .Tensor ,torch .Tensor ]
    ]=None ,
    **kwargs :Unpack [FlashAttentionKwargs ],
    )->Tuple [torch .Tensor ,Optional [torch .Tensor ],Optional [Tuple [torch .Tensor ]]]:
        if isinstance (past_key_value ,StaticCache ):
            raise ValueError (
            "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
            "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
            )

        output_attentions =False 

        bsz ,q_len ,_ =hidden_states .size ()

        query_states =self .q_proj (hidden_states )
        key_states =self .k_proj (hidden_states )
        value_states =self .v_proj (hidden_states )




        query_states =query_states .view (
        bsz ,q_len ,self .num_heads ,self .head_dim 
        ).transpose (1 ,2 )
        key_states =key_states .view (
        bsz ,q_len ,self .num_key_value_heads ,self .head_dim 
        ).transpose (1 ,2 )
        value_states =value_states .view (
        bsz ,q_len ,self .num_key_value_heads ,self .head_dim 
        ).transpose (1 ,2 )

        if position_embeddings is None :
            logger .warning_once (
            "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
            "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
            "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
            "removed and `position_embeddings` will be mandatory."
            )
            cos ,sin =self .rotary_emb (value_states ,position_ids )
        else :
            cos ,sin =position_embeddings 
        query_states ,key_states =apply_rotary_pos_emb (
        query_states ,key_states ,cos ,sin 
        )

        if past_key_value is not None :

            cache_kwargs ={"sin":sin ,"cos":cos ,"cache_position":cache_position }
            key_states ,value_states =past_key_value .update (
            key_states ,value_states ,self .layer_idx ,cache_kwargs 
            )



        query_states =query_states .transpose (1 ,2 )
        key_states =key_states .transpose (1 ,2 )
        value_states =value_states .transpose (1 ,2 )

        dropout_rate =self .attention_dropout if self .training else 0.0 







        input_dtype =query_states .dtype 
        if input_dtype ==torch .float32 :
            if torch .is_autocast_enabled ():
                target_dtype =torch .get_autocast_gpu_dtype ()

            elif hasattr (self .config ,"_pre_quantization_dtype"):
                target_dtype =self .config ._pre_quantization_dtype 
            else :
                target_dtype =self .q_proj .weight .dtype 

            logger .warning_once (
            f"The input hidden states seems to be silently casted in float32, this might be related to"
            f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
            f" {target_dtype }."
            )

            query_states =query_states .to (target_dtype )
            key_states =key_states .to (target_dtype )
            value_states =value_states .to (target_dtype )

        attn_output =_flash_attention_forward (
        query_states ,
        key_states ,
        value_states ,
        attention_mask ,
        q_len ,
        position_ids =position_ids ,
        dropout =dropout_rate ,
        sliding_window =getattr (self ,"sliding_window",None ),
        use_top_left_mask =self ._flash_attn_uses_top_left_mask ,
        is_causal =self .is_causal ,
        **kwargs ,
        )

        attn_output =attn_output .reshape (bsz ,q_len ,-1 ).contiguous ()
        attn_output =self .o_proj (attn_output )

        if not output_attentions :
            attn_weights =None 

        return attn_output ,attn_weights ,past_key_value 


class InternLM3SdpaAttention (InternLM3Attention ):
    """
    InternLM3 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
    `InternLM3Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
    SDPA API.
    """


    def forward (
    self ,
    hidden_states :torch .Tensor ,
    attention_mask :Optional [torch .Tensor ]=None ,
    position_ids :Optional [torch .LongTensor ]=None ,
    past_key_value :Optional [Cache ]=None ,
    output_attentions :bool =False ,
    use_cache :bool =False ,
    cache_position :Optional [torch .LongTensor ]=None ,
    position_embeddings :Optional [
    Tuple [torch .Tensor ,torch .Tensor ]
    ]=None ,
    **kwargs ,
    )->Tuple [torch .Tensor ,Optional [torch .Tensor ],Optional [Tuple [torch .Tensor ]]]:
        if output_attentions :

            logger .warning_once (
            "InternLM3Model is using InternLM3SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
            'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
            )
            return super ().forward (
            hidden_states =hidden_states ,
            attention_mask =attention_mask ,
            position_ids =position_ids ,
            past_key_value =past_key_value ,
            output_attentions =output_attentions ,
            use_cache =use_cache ,
            cache_position =cache_position ,
            position_embeddings =position_embeddings ,
            )

        bsz ,q_len ,_ =hidden_states .size ()

        query_states =self .q_proj (hidden_states )
        key_states =self .k_proj (hidden_states )
        value_states =self .v_proj (hidden_states )


        query_states =query_states .view (bsz ,q_len ,-1 ,self .head_dim ).transpose (1 ,2 )
        key_states =key_states .view (bsz ,q_len ,-1 ,self .head_dim ).transpose (1 ,2 )
        value_states =value_states .view (bsz ,q_len ,-1 ,self .head_dim ).transpose (1 ,2 )

        if position_embeddings is None :
            logger .warning_once (
            "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
            "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
            "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
            "removed and `position_embeddings` will be mandatory."
            )
            cos ,sin =self .rotary_emb (value_states ,position_ids )
        else :
            cos ,sin =position_embeddings 
        query_states ,key_states =apply_rotary_pos_emb (
        query_states ,key_states ,cos ,sin 
        )

        if past_key_value is not None :

            cache_kwargs ={"sin":sin ,"cos":cos ,"cache_position":cache_position }
            key_states ,value_states =past_key_value .update (
            key_states ,value_states ,self .layer_idx ,cache_kwargs 
            )

        key_states =repeat_kv (key_states ,self .num_key_value_groups )
        value_states =repeat_kv (value_states ,self .num_key_value_groups )

        causal_mask =attention_mask 
        if attention_mask is not None :
            causal_mask =causal_mask [:,:,:,:key_states .shape [-2 ]]



        if query_states .device .type =="cuda"and causal_mask is not None :
            query_states =query_states .contiguous ()
            key_states =key_states .contiguous ()
            value_states =value_states .contiguous ()



        is_causal =True if causal_mask is None and q_len >1 else False 

        attn_output =torch .nn .functional .scaled_dot_product_attention (
        query_states ,
        key_states ,
        value_states ,
        attn_mask =causal_mask ,
        dropout_p =self .attention_dropout if self .training else 0.0 ,
        is_causal =is_causal ,
        )

        attn_output =attn_output .transpose (1 ,2 ).contiguous ()
        attn_output =attn_output .view (bsz ,q_len ,-1 )

        attn_output =self .o_proj (attn_output )

        return attn_output ,None ,past_key_value 


InternLM3_ATTENTION_CLASSES ={
"eager":InternLM3Attention ,
"flash_attention_2":InternLM3FlashAttention2 ,
"sdpa":InternLM3SdpaAttention ,
}


class InternLM3DecoderLayer (nn .Module ):
    def __init__ (self ,config :InternLM3Config ,layer_idx :int ):
        super ().__init__ ()
        self .hidden_size =config .hidden_size 

        self .self_attn =InternLM3_ATTENTION_CLASSES [config ._attn_implementation ](
        config =config ,layer_idx =layer_idx 
        )

        self .mlp =InternLM3MLP (config )
        self .input_layernorm =InternLM3RMSNorm (
        config .hidden_size ,eps =config .rms_norm_eps 
        )
        self .post_attention_layernorm =InternLM3RMSNorm (
        config .hidden_size ,eps =config .rms_norm_eps 
        )

    def forward (
    self ,
    hidden_states :torch .Tensor ,
    attention_mask :Optional [torch .Tensor ]=None ,
    position_ids :Optional [torch .LongTensor ]=None ,
    past_key_value :Optional [Cache ]=None ,
    output_attentions :Optional [bool ]=False ,
    use_cache :Optional [bool ]=False ,
    cache_position :Optional [torch .LongTensor ]=None ,
    position_embeddings :Optional [
    Tuple [torch .Tensor ,torch .Tensor ]
    ]=None ,
    **kwargs ,
    )->Tuple [
    torch .FloatTensor ,Optional [Tuple [torch .FloatTensor ,torch .FloatTensor ]]
    ]:
        """
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`, *optional*):
                attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
                query_sequence_length, key_sequence_length)` if default attention is used.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            use_cache (`bool`, *optional*):
                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
                (see `past_key_values`).
            past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
            cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
                Indices depicting the position of the input sequence tokens in the sequence
            position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
                Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
                with `head_dim` being the embedding dimension of each attention head.
            kwargs (`dict`, *optional*):
                Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
                into the model
        """
        residual =hidden_states 

        hidden_states =self .input_layernorm (hidden_states )


        hidden_states ,self_attn_weights ,present_key_value =self .self_attn (
        hidden_states =hidden_states ,
        attention_mask =attention_mask ,
        position_ids =position_ids ,
        past_key_value =past_key_value ,
        output_attentions =output_attentions ,
        use_cache =use_cache ,
        cache_position =cache_position ,
        position_embeddings =position_embeddings ,
        **kwargs ,
        )
        hidden_states =residual +hidden_states 


        residual =hidden_states 
        hidden_states =self .post_attention_layernorm (hidden_states )
        hidden_states =self .mlp (hidden_states )
        hidden_states =residual +hidden_states 

        outputs =(hidden_states ,)

        if output_attentions :
            outputs +=(self_attn_weights ,)

        if use_cache :
            outputs +=(present_key_value ,)

        return outputs 


InternLM3_START_DOCSTRING =r"""
    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)

    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
    and behavior.

    Parameters:
        config ([`InternLM3Config`]):
            Model configuration class with all the parameters of the model. Initializing with a config file does not
            load the weights associated with the model, only the configuration. Check out the
            [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""


@add_start_docstrings (
"The bare InternLM3 Model outputting raw hidden-states without any specific head on top.",
InternLM3_START_DOCSTRING ,
)
class InternLM3PreTrainedModel (PreTrainedModel ):
    config_class =InternLM3Config 
    base_model_prefix ="model"
    supports_gradient_checkpointing =True 
    _no_split_modules =["InternLM3DecoderLayer"]
    _skip_keys_device_placement =["past_key_values"]
    _supports_flash_attn_2 =True 
    _supports_sdpa =True 
    _supports_cache_class =True 
    _supports_quantized_cache =True 
    _supports_static_cache =True 

    def _init_weights (self ,module ):
        std =self .config .initializer_range 
        if isinstance (module ,nn .Linear ):
            module .weight .data .normal_ (mean =0.0 ,std =std )
            if module .bias is not None :
                module .bias .data .zero_ ()
        elif isinstance (module ,nn .Embedding ):
            module .weight .data .normal_ (mean =0.0 ,std =std )
            if module .padding_idx is not None :
                module .weight .data [module .padding_idx ].zero_ ()


INTERNLM3_INPUTS_DOCSTRING =r"""
    Args:
        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
            it.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            [What are attention masks?](../glossary#attention-mask)

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
            `past_key_values`).

            If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
            and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
            information on the default strategy.

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.
        position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
            config.n_positions - 1]`.

            [What are position IDs?](../glossary#position-ids)
        past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
            Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
            blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
            returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.

            Two formats are allowed:
            - a [`~cache_utils.Cache`] instance, see our
            [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
            - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
            shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
            cache format.

            The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
            legacy cache format will be returned.

            If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
            have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
            of shape `(batch_size, sequence_length)`.
        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
            model's internal embedding lookup matrix.
        use_cache (`bool`, *optional*):
            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
            `past_key_values`).
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
        cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
            Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
            this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
            the complete sequence length.
"""


@add_start_docstrings (
"The bare InternLM3 Model outputting raw hidden-states without any specific head on top.",
InternLM3_START_DOCSTRING ,
)
class InternLM3Model (InternLM3PreTrainedModel ):
    """
    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`InternLM3DecoderLayer`]

    Args:
        config: InternLM3Config
    """

    def __init__ (self ,config :InternLM3Config ):
        super ().__init__ (config )
        self .padding_idx =config .pad_token_id 
        self .vocab_size =config .vocab_size 

        self .embed_tokens =nn .Embedding (
        config .vocab_size ,config .hidden_size ,self .padding_idx 
        )
        self .layers =nn .ModuleList (
        [
        InternLM3DecoderLayer (config ,layer_idx )
        for layer_idx in range (config .num_hidden_layers )
        ]
        )
        self .norm =InternLM3RMSNorm (config .hidden_size ,eps =config .rms_norm_eps )
        self .rotary_emb =InternLM3RotaryEmbedding (config =config )

        self .gradient_checkpointing =False 
        if getattr (config ,"pretraining_tp",1 )!=1 :
            logger .warn (
            "`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead."
            )


        self .post_init ()

    def get_input_embeddings (self ):
        return self .embed_tokens 

    def set_input_embeddings (self ,value ):
        self .embed_tokens =value 

    @add_start_docstrings_to_model_forward (INTERNLM3_INPUTS_DOCSTRING )
    def forward (
    self ,
    input_ids :torch .LongTensor =None ,
    attention_mask :Optional [torch .Tensor ]=None ,
    position_ids :Optional [torch .LongTensor ]=None ,
    past_key_values :Optional [Union [Cache ,List [torch .FloatTensor ]]]=None ,
    inputs_embeds :Optional [torch .FloatTensor ]=None ,
    use_cache :Optional [bool ]=None ,
    output_attentions :Optional [bool ]=None ,
    output_hidden_states :Optional [bool ]=None ,
    return_dict :Optional [bool ]=None ,
    cache_position :Optional [torch .LongTensor ]=None ,
    **flash_attn_kwargs :Unpack [FlashAttentionKwargs ],
    )->Union [Tuple ,BaseModelOutputWithPast ]:
        output_attentions =(
        output_attentions 
        if output_attentions is not None 
        else self .config .output_attentions 
        )
        output_hidden_states =(
        output_hidden_states 
        if output_hidden_states is not None 
        else self .config .output_hidden_states 
        )
        use_cache =use_cache if use_cache is not None else self .config .use_cache 
        return_dict =(
        return_dict if return_dict is not None else self .config .use_return_dict 
        )

        if (input_ids is None )^(inputs_embeds is not None ):
            raise ValueError (
            "You must specify exactly one of input_ids or inputs_embeds"
            )

        if self .gradient_checkpointing and self .training and use_cache :
            logger .warning_once (
            "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
            )
            use_cache =False 

        if inputs_embeds is None :
            inputs_embeds =self .embed_tokens (input_ids )


        return_legacy_cache =False 
        if use_cache and not isinstance (past_key_values ,Cache ):
            return_legacy_cache =True 
            if past_key_values is None :
                past_key_values =DynamicCache ()
            else :
                past_key_values =DynamicCache .from_legacy_cache (past_key_values )
                logger .warning_once (
                "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
                "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
                "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
                )

        if cache_position is None :
            past_seen_tokens =(
            past_key_values .get_seq_length ()if past_key_values is not None else 0 
            )
            cache_position =torch .arange (
            past_seen_tokens ,
            past_seen_tokens +inputs_embeds .shape [1 ],
            device =inputs_embeds .device ,
            )
        if position_ids is None :
            position_ids =cache_position .unsqueeze (0 )

        causal_mask =self ._update_causal_mask (
        attention_mask ,
        inputs_embeds ,
        cache_position ,
        past_key_values ,
        output_attentions ,
        )
        hidden_states =inputs_embeds 


        position_embeddings =self .rotary_emb (hidden_states ,position_ids )


        all_hidden_states =()if output_hidden_states else None 
        all_self_attns =()if output_attentions else None 
        next_decoder_cache =None 

        for decoder_layer in self .layers [:self .config .num_hidden_layers ]:
            if output_hidden_states :
                all_hidden_states +=(hidden_states ,)

            if self .gradient_checkpointing and self .training :
                layer_outputs =self ._gradient_checkpointing_func (
                decoder_layer .__call__ ,
                hidden_states ,
                causal_mask ,
                position_ids ,
                past_key_values ,
                output_attentions ,
                use_cache ,
                cache_position ,
                position_embeddings ,
                )
            else :
                layer_outputs =decoder_layer (
                hidden_states ,
                attention_mask =causal_mask ,
                position_ids =position_ids ,
                past_key_value =past_key_values ,
                output_attentions =output_attentions ,
                use_cache =use_cache ,
                cache_position =cache_position ,
                position_embeddings =position_embeddings ,
                **flash_attn_kwargs ,
                )

            hidden_states =layer_outputs [0 ]

            if use_cache :
                next_decoder_cache =layer_outputs [2 if output_attentions else 1 ]

            if output_attentions :
                all_self_attns +=(layer_outputs [1 ],)

        hidden_states =self .norm (hidden_states )


        if output_hidden_states :
            all_hidden_states +=(hidden_states ,)

        next_cache =next_decoder_cache if use_cache else None 
        if return_legacy_cache :
            next_cache =next_cache .to_legacy_cache ()

        if not return_dict :
            return tuple (
            v 
            for v in [hidden_states ,next_cache ,all_hidden_states ,all_self_attns ]
            if v is not None 
            )
        return BaseModelOutputWithPast (
        last_hidden_state =hidden_states ,
        past_key_values =next_cache ,
        hidden_states =all_hidden_states ,
        attentions =all_self_attns ,
        )

    def _update_causal_mask (
    self ,
    attention_mask :torch .Tensor ,
    input_tensor :torch .Tensor ,
    cache_position :torch .Tensor ,
    past_key_values :Cache ,
    output_attentions :bool ,
    ):
        if self .config ._attn_implementation =="flash_attention_2":
            if attention_mask is not None and 0.0 in attention_mask :
                return attention_mask 
            return None 




        past_seen_tokens =(
        past_key_values .get_seq_length ()if past_key_values is not None else 0 
        )
        using_static_cache =isinstance (past_key_values ,StaticCache )


        if (
        self .config ._attn_implementation =="sdpa"
        and not using_static_cache 
        and not output_attentions 
        ):
            if AttentionMaskConverter ._ignore_causal_mask_sdpa (
            attention_mask ,
            inputs_embeds =input_tensor ,
            past_key_values_length =past_seen_tokens ,
            is_training =self .training ,
            ):
                return None 

        dtype ,device =input_tensor .dtype ,input_tensor .device 
        sequence_length =input_tensor .shape [1 ]
        if using_static_cache :
            target_length =past_key_values .get_max_cache_shape ()
        else :
            target_length =(
            attention_mask .shape [-1 ]
            if isinstance (attention_mask ,torch .Tensor )
            else past_seen_tokens +sequence_length +1 
            )


        causal_mask =self ._prepare_4d_causal_attention_mask_with_cache_position (
        attention_mask ,
        sequence_length =sequence_length ,
        target_length =target_length ,
        dtype =dtype ,
        device =device ,
        cache_position =cache_position ,
        batch_size =input_tensor .shape [0 ],
        )

        if (
        self .config ._attn_implementation =="sdpa"
        and attention_mask is not None 
        and attention_mask .device .type =="cuda"
        and not output_attentions 
        ):



            min_dtype =torch .finfo (dtype ).min 
            causal_mask =AttentionMaskConverter ._unmask_unattended (
            causal_mask ,min_dtype 
            )

        return causal_mask 

    @staticmethod 
    def _prepare_4d_causal_attention_mask_with_cache_position (
    attention_mask :torch .Tensor ,
    sequence_length :int ,
    target_length :int ,
    dtype :torch .dtype ,
    device :torch .device ,
    cache_position :torch .Tensor ,
    batch_size :int ,
    **kwargs ,
    ):
        """
        Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
        `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.

        Args:
            attention_mask (`torch.Tensor`):
                A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
                `(batch_size, 1, query_length, key_value_length)`.
            sequence_length (`int`):
                The sequence length being processed.
            target_length (`int`):
                The target length: when generating with static cache, the mask should be as long as the static cache,
                to account for the 0 padding, the part of the cache that is not filled yet.
            dtype (`torch.dtype`):
                The dtype to use for the 4D attention mask.
            device (`torch.device`):
                The device to plcae the 4D attention mask on.
            cache_position (`torch.Tensor`):
                Indices depicting the position of the input sequence tokens in the sequence.
            batch_size (`torch.Tensor`):
                Batch size.
        """
        if attention_mask is not None and attention_mask .dim ()==4 :

            causal_mask =attention_mask 
        else :
            min_dtype =torch .finfo (dtype ).min 
            causal_mask =torch .full (
            (sequence_length ,target_length ),
            fill_value =min_dtype ,
            dtype =dtype ,
            device =device ,
            )
            if sequence_length !=1 :
                causal_mask =torch .triu (causal_mask ,diagonal =1 )
            causal_mask *=torch .arange (
            target_length ,device =device 
            )>cache_position .reshape (-1 ,1 )
            causal_mask =causal_mask [None ,None ,:,:].expand (batch_size ,1 ,-1 ,-1 )
            if attention_mask is not None :
                causal_mask =(
                causal_mask .clone ()
                )
                mask_length =attention_mask .shape [-1 ]
                padding_mask =(
                causal_mask [:,:,:,:mask_length ]
                +attention_mask [:,None ,None ,:]
                )
                padding_mask =padding_mask ==0 
                causal_mask [:,:,:,:mask_length ]=causal_mask [
                :,:,:,:mask_length 
                ].masked_fill (padding_mask ,min_dtype )

        return causal_mask 


class KwargsForCausalLM (FlashAttentionKwargs ,LossKwargs ):...


class InternLM3ForCausalLM (InternLM3PreTrainedModel ,GenerationMixin ):
    _tied_weights_keys =["lm_head.weight"]
    _tp_plan ={"lm_head":"colwise_rep"}

    def __init__ (self ,config ):
        super ().__init__ (config )
        self .model =InternLM3Model (config )
        self .vocab_size =config .vocab_size 
        self .lm_head =nn .Linear (config .hidden_size ,config .vocab_size ,bias =False )


        self .post_init ()

    def get_input_embeddings (self ):
        return self .model .embed_tokens 

    def set_input_embeddings (self ,value ):
        self .model .embed_tokens =value 

    def get_output_embeddings (self ):
        return self .lm_head 

    def set_output_embeddings (self ,new_embeddings ):
        self .lm_head =new_embeddings 

    def set_decoder (self ,decoder ):
        self .model =decoder 

    def get_decoder (self ):
        return self .model 

    @add_start_docstrings_to_model_forward (INTERNLM3_INPUTS_DOCSTRING )
    @replace_return_docstrings (
    output_type =CausalLMOutputWithPast ,config_class =_CONFIG_FOR_DOC 
    )
    def forward (
    self ,
    input_ids :torch .LongTensor =None ,
    attention_mask :Optional [torch .Tensor ]=None ,
    position_ids :Optional [torch .LongTensor ]=None ,
    past_key_values :Optional [Union [Cache ,List [torch .FloatTensor ]]]=None ,
    inputs_embeds :Optional [torch .FloatTensor ]=None ,
    labels :Optional [torch .LongTensor ]=None ,
    use_cache :Optional [bool ]=None ,
    output_attentions :Optional [bool ]=None ,
    output_hidden_states :Optional [bool ]=None ,
    return_dict :Optional [bool ]=None ,
    cache_position :Optional [torch .LongTensor ]=None ,
    num_logits_to_keep :int =0 ,
    **kwargs :Unpack [KwargsForCausalLM ],
    )->Union [Tuple ,CausalLMOutputWithPast ]:
        r"""
        Args:
            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

            num_logits_to_keep (`int`, *optional*):
                Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
                `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
                token can save memory, which becomes pretty significant for long sequences or large vocabulary size.

        Returns:

        Example:

        ```python
        >>> from transformers import AutoTokenizer, InternLM3ForCausalLM

        >>> model = InternLM3ForCausalLM.from_pretrained("internlm/InternLM3-8b-hf")
        >>> tokenizer = AutoTokenizer.from_pretrained("internlm/InternLM3-8b-hf")

        >>> prompt = "Hey, are you conscious? Can you talk to me?"
        >>> inputs = tokenizer(prompt, return_tensors="pt")

        >>> # Generate
        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
        "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
        ```"""
        output_attentions =(
        output_attentions 
        if output_attentions is not None 
        else self .config .output_attentions 
        )
        output_hidden_states =(
        output_hidden_states 
        if output_hidden_states is not None 
        else self .config .output_hidden_states 
        )
        return_dict =(
        return_dict if return_dict is not None else self .config .use_return_dict 
        )


        outputs =self .model (
        input_ids =input_ids ,
        attention_mask =attention_mask ,
        position_ids =position_ids ,
        past_key_values =past_key_values ,
        inputs_embeds =inputs_embeds ,
        use_cache =use_cache ,
        output_attentions =output_attentions ,
        output_hidden_states =output_hidden_states ,
        return_dict =return_dict ,
        cache_position =cache_position ,
        **kwargs ,
        )

        hidden_states =outputs [0 ]

        logits =self .lm_head (hidden_states [:,-num_logits_to_keep :,:])

        loss =None 
        if labels is not None :
            loss =self .loss_function (
            logits =logits ,
            labels =labels ,
            vocab_size =self .config .vocab_size ,
            **kwargs ,
            )

        if not return_dict :
            output =(logits ,)+outputs [1 :]
            return (loss ,)+output if loss is not None else output 

        return CausalLMOutputWithPast (
        loss =loss ,
        logits =logits ,
        past_key_values =outputs .past_key_values ,
        hidden_states =outputs .hidden_states ,
        attentions =outputs .attentions ,
        )
