import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers import LlamaConfig
from transformers.cache_utils import (
    Cache,
    DynamicCache,
    StaticCache,
)
from transformers.generation import GenerationMixin
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.processing_utils import Unpack
from transformers.utils import logging
from typing import (
    List,
    Optional,
    Tuple,
    Union,
)

from .modeling_llama import (
    ACT2FN,
    KwargsForCausalLM,
    LLAMA_ATTENTION_CLASSES,
    LlamaRMSNorm,
    LlamaPreTrainedModel,
    LlamaRotaryEmbedding,
)
from .modeling_outputs import (
    MoeModelOutputWithPast,
    MoLoSCausalLMOutputWithPast,
)
from .MoLoSConfig import MoLoSConfig


logger = logging.get_logger(__name__)


class MoLoSLlamaMLP(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
        ratio: float = 1.0,
    ):
        super().__init__()

        rank = int(hidden_size * intermediate_size * ratio /
                   (hidden_size + intermediate_size))

        self.down_u_proj = nn.Linear(
            in_features=rank,
            out_features=hidden_size,
            bias=False,
        )
        self.down_v_proj = nn.Linear(
            in_features=intermediate_size,
            out_features=rank,
            bias=False,
        )

        self.gate_u_proj = nn.Linear(
            in_features=rank,
            out_features=intermediate_size,
            bias=False,
        )
        self.gate_v_proj = nn.Linear(
            in_features=hidden_size,
            out_features=rank,
            bias=False,
        )

        self.up_u_proj = nn.Linear(
            in_features=rank,
            out_features=intermediate_size,
            bias=False,
        )
        self.up_v_proj = nn.Linear(
            in_features=hidden_size,
            out_features=rank,
            bias=False,
        )

        self.act_fn = ACT2FN[hidden_act]

    def forward(
        self,
        x: torch.Tensor,
    ) -> torch.Tensor:
        gate = self.gate_u_proj(self.gate_v_proj(x))
        up = self.up_u_proj(self.up_v_proj(x))
        down = self.down_u_proj(self.down_v_proj(self.act_fn(gate) * up))

        return down


class MoLoSMoE(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
        molos_config: MoLoSConfig,
    ):
        super().__init__()

        self.molos_config = molos_config
        self.ex_num = self.molos_config.ex_num
        self.jitter_noise = self.molos_config.jitter_noise
        self.selected_ex_num = self.molos_config.selected_ex_num
        self.scaling_factor = self.ex_num / self.selected_ex_num

        self.router = nn.Linear(
            in_features=hidden_size,
            out_features=self.molos_config.ex_num,
            bias=False,
        )

        self.experts = nn.ModuleList(modules=[
            MoLoSLlamaMLP(
                hidden_size=hidden_size,
                intermediate_size=intermediate_size,
                hidden_act=hidden_act,
                ratio=self.molos_config.ex_params_ratio,
            ) for _ in range(self.molos_config.ex_num)
        ])

    def forward(
        self,
        x: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        batch_size, sequence_length, hidden_size = x.shape

        # For debugging.
        # Use self.molos_config.chosen_ex_idx to select the expert.
        if self.molos_config.chosen_ex_idx != -1:
            return self.experts[self.molos_config.chosen_ex_idx](x=x), None

        # Add jitter noise.
        if self.training and self.jitter_noise > 0:
            x = x * torch.empty_like(input=x).uniform_(
                (1.0 - self.jitter_noise),
                (1.0 + self.jitter_noise),
            )

        y = None
        routing_logits = None

        match self.molos_config.select_type:
            case 'sequence':
                # As per config description, selected_ex_num must be 1 for sequence routing.
                if self.selected_ex_num != 1:
                    raise ValueError(
                        'For sequence-level routing (select_type = \"sequence\"), selected_ex_num must be 1.'
                    )

                # Use the mean pooling of the sequence for routing.
                # Shape of sequence_representation: (batch_size, hidden_size)
                sequence_representation = x.mean(dim=1)

                # Shape of routing_logits: (batch_size, ex_num)
                routing_logits = self.router(input=sequence_representation)

                # Do softmax on ex_num dimension.
                routing_weights = F.softmax(
                    input=routing_logits,
                    dim=1,
                    dtype=torch.float,
                )

                # Select the single best expert for each sequence.
                # Shape of selected_experts: (batch_size)
                selected_experts = torch.argmax(
                    input=routing_weights,
                    dim=-1,
                )

                routing_weights = routing_weights.to(dtype=x.dtype)

                # Shape of y: (batch_size, sequence_length, hidden_size)
                y = torch.zeros(
                    size=x.shape,
                    dtype=x.dtype,
                    device=x.device,
                )

                # Process each expert
                for ex_idx in range(self.ex_num):
                    # Find which sequences in the batch selected this expert.
                    batch_indices = torch.where(selected_experts == ex_idx)[0]

                    if batch_indices.numel() > 0:
                        # Get the input sequences for this expert.
                        # Shape of current_x: (selected_sequences_num, sequence_length, hidden_size)
                        current_x = x[batch_indices]

                        # Pass sequences through the selected expert.
                        # Shape of output: (selected_sequences_num, sequence_length, hidden_size)
                        output = self.experts[ex_idx](x=current_x)

                        # Assign the output to the correct batch indices in the final output tensor.
                        y[batch_indices] = output.to(dtype=x.dtype)
            case 'token':
                # New shape of x: (batch_size * sequence_length, hidden_size)
                x = x.view(-1, hidden_size)

                # Shape of routing_logits: (batch_size * sequence_length, ex_num)
                routing_logits = self.router(input=x)

                # Do softmax on ex_num dimension.
                routing_weights = F.softmax(
                    input=routing_logits,
                    dim=1,
                    dtype=torch.float,
                )

                # Shape of y: (batch_size * sequence_length, hidden_size)
                y = torch.zeros_like(
                    input=x,
                    dtype=x.dtype,
                    device=x.device,
                )

                # Original routing algorithm.
                # -----
                # # Select the top-k experts.
                # # Shape of new routing_weights: (batch_size * sequence_length, selected_ex_num)
                # # Shape of selected_experts: (batch_size * sequence_length, selected_ex_num)
                # routing_weights, selected_experts = torch.topk(
                #     input=routing_weights,
                #     k=self.selected_ex_num,
                #     dim=-1,
                # )

                # # Do softmax on selected_ex_num dimension.
                # routing_weights = routing_weights / routing_weights.sum(
                #     dim=-1,
                #     keepdim=True,
                # )

                # routing_weights = routing_weights.to(dtype=x.dtype)

                # # Do not know why the parameter name is `input`, not `tensor` in F.one_hot.
                # # Shape of experts_mask: (batch_size * sequence_length, selected_ex_num, ex_num)
                # experts_mask = F.one_hot(
                #     input=selected_experts,
                #     num_classes=self.ex_num,
                # )

                # # Shape of new experts_mask: (ex_num, selected_ex_num, batch_size * sequence_length)
                # # The meaning of new experts_mask (i, j, k) is that the i-th expert is the k-th token's j-th selected expert.
                # experts_mask = experts_mask.permute(2, 1, 0)

                # # There is an example for the following code.
                # # Suppose:
                # # - batch_size * sequence_length is 5.
                # # - selected_ex_num is 2.
                # # And:
                # # - experts_mask[100] is: tensor([[1, 0, 0, 0, 0], [0, 0, 0, 0, 1]])
                # # - top_ex_idx and token_idx will be tensor([0, 1]) and tensor([0, 4]).
                # # This means the 0-th token's 0-th selected expert and the 4-th token's 1-th selected expert are the 100-th expert.
                # for ex_idx in range(self.ex_num):
                #     top_ex_idx, token_idx = torch.where(experts_mask[ex_idx])

                #     if token_idx.numel() > 0:
                #         # current_x = x[None, token_idx].reshape(-1, hidden_size)
                #         current_x = x[token_idx]

                #         output = self.experts[ex_idx](x=current_x)

                #         # Multiply the routing weights.
                #         output = output * routing_weights[token_idx, top_ex_idx, None]

                #         y.index_add_(
                #             dim=0,
                #             index=token_idx,
                #             source=output.to(dtype=x.dtype),
                #         )
                # -----

                # New routing algorithm.
                # -----
                # Process first token of each sequence to avoid NCCK timeout issue.
                # IMPORTANT: NCCL timeout issue occurs when there is any expert that is not selected by any token in a batch.
                # -----
                # Shape of first_token_indices: (batch_size)
                first_token_indices = torch.arange(
                    start=0,
                    end=(batch_size * sequence_length),
                    step=sequence_length,
                    device=x.device,
                )

                # Shape of first_token_x: (batch_size, hidden_size)
                first_token_x = x[first_token_indices]

                # Shape of first_token_routing_weights: (batch_size, ex_num)
                first_token_routing_weights = routing_weights[first_token_indices].to(dtype=x.dtype)

                first_token_outputs = torch.zeros_like(
                    input=first_token_x,
                    dtype=x.dtype,
                    device=x.device,
                )
                for ex_idx in range(self.ex_num):
                    first_token_outputs += (self.experts[ex_idx](x=first_token_x) * first_token_routing_weights[:, ex_idx, None]).to(dtype=x.dtype)

                y[first_token_indices] = first_token_outputs
                # -----

                # Process all other tokens.
                # -----
                if (batch_size * sequence_length) > 1:
                    other_token_indices = torch.arange(
                        start=0,
                        end=(batch_size * sequence_length),
                        step=1,
                        device=x.device,
                    )

                    keep_mask = torch.ones_like(
                        input=other_token_indices,
                        dtype=torch.bool,
                        device=x.device,
                    )
                    keep_mask[first_token_indices] = False

                    other_token_indices = other_token_indices[keep_mask]

                    # Select the top-k experts.
                    # Shape of other_routing_weights: (batch_size * sequence_length - batch_size, selected_ex_num)
                    # Shape of selected_experts: (batch_size * sequence_length - batch_size, selected_ex_num)
                    other_routing_weights, selected_experts = torch.topk(
                        input=routing_weights[other_token_indices],
                        k=self.selected_ex_num,
                        dim=-1,
                    )

                    # Do softmax on selected_ex_num dimension.
                    other_routing_weights = (other_routing_weights / other_routing_weights.sum(
                        dim=-1,
                        keepdim=True,
                    )).to(dtype=x.dtype)

                    # Do not know why the parameter name is `input`, not `tensor` in F.one_hot.
                    # Shape of experts_mask: (batch_size * sequence_length - batch_size, selected_ex_num, ex_num)
                    experts_mask = F.one_hot(
                        input=selected_experts,
                        num_classes=self.ex_num,
                    )

                    # Shape of new experts_mask: (ex_num, selected_ex_num, batch_size * sequence_length - batch_size)
                    # The meaning of new experts_mask (i, j, k) is that the i-th expert is the k-th token's j-th selected expert.
                    experts_mask = experts_mask.permute(2, 1, 0)

                    # There is an example for the following code.
                    # Suppose:
                    # - batch_size * sequence_length is 5.
                    # - selected_ex_num is 2.
                    # And:
                    # - experts_mask[100] is: tensor([[1, 0, 0, 0, 0], [0, 0, 0, 0, 1]])
                    # - top_ex_idx and token_idx will be tensor([0, 1]) and tensor([0, 4]).
                    # This means the 0-th token's 0-th selected expert and the 4-th token's 1-th selected expert are the 100-th expert.
                    for ex_idx in range(self.ex_num):
                        top_ex_indices, token_indices = torch.where(experts_mask[ex_idx])

                        if token_indices.numel() > 0:
                            actual_indices = other_token_indices[token_indices]

                            # current_x = x[None, token_idx].reshape(-1, hidden_size)
                            current_x = x[actual_indices]

                            output = (self.experts[ex_idx](x=current_x) * other_routing_weights[token_indices, top_ex_indices, None]).to(dtype=x.dtype)

                            y.index_add_(
                                dim=0,
                                index=actual_indices,
                                source=output,
                            )
                # -----
                # -----

                # Reshape y to the original shape and scale it.
                y = y.reshape(shape=(batch_size, sequence_length, hidden_size)) * self.scaling_factor
            case _:
                raise ValueError(
                    f'Invalid select_type: {self.molos_config.select_type}. '
                    'Valid options are: \"sequence\" or \"token\".'
                )

        return (
            y,
            routing_logits,
        )


class MoLoSLlamaDecoderLayer(nn.Module):

    def __init__(
        self,
        config: LlamaConfig,
        molos_config: MoLoSConfig,
        layer_idx: int,
    ):
        super().__init__()

        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size

        self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](
            config=config,
            layer_idx=layer_idx,
        )

        self.moe = MoLoSMoE(
            hidden_size=self.hidden_size,
            intermediate_size=self.intermediate_size,
            hidden_act=config.hidden_act,
            molos_config=molos_config,
        )

        self.input_layernorm = LlamaRMSNorm(
            hidden_size=self.hidden_size,
            eps=config.rms_norm_eps,
        )

        self.post_attention_layernorm = LlamaRMSNorm(
            hidden_size=self.hidden_size,
            eps=config.rms_norm_eps,
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: Optional[bool] = False,
        output_router_logits: Optional[bool] = False,
        use_cache: Optional[bool] = False,
        cache_position: Optional[torch.LongTensor] = None,
        # position_embeddings will become mandatory in v4.46.
        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
        **kwargs,
    ) -> Tuple[
        torch.FloatTensor,
        Optional[Tuple[
            torch.FloatTensor,
            torch.FloatTensor,
            torch.FloatTensor,
        ]],
    ]:
        """ The forward method for MoLoSLlamaDecoderLayer class.

        Args:
            hidden_states (torch.Tensor): Input to the layer of shape `(batch, seq_len, embed_dim)`.
            attention_mask (Optional[torch.Tensor], optional): Attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, query_sequence_length, key_sequence_length)` if default attention is used. Defaults to None.
            position_ids (Optional[torch.LongTensor], optional): The position of the input tokens in the sequence. Defaults to None.
            past_key_value (Optional[Cache], optional): Cached past key and value projection states. Defaults to None.
            output_attentions (Optional[bool], optional): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. Defaults to False.
            output_router_logits (Optional[bool], optional): Whether or not to return the router logits. Defaults to False.
            use_cache (Optional[bool], optional): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). Defaults to False.
            cache_position (Optional[torch.LongTensor], optional): Indices depicting the position of the input sequence tokens in the sequence. Defaults to None.
            position_embeddings (Optional[Tuple[torch.Tensor, torch.Tensor]], optional): Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, with `head_dim` being the embedding dimension of each attention head. Defaults to None.
            kwargs (Optional[Dict], optional): Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code into the model. Defaults to None.

        Returns:
            Tuple[ torch.FloatTensor, Optional[Tuple[ torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, ]], ]: The outputs.
        """

        residual = hidden_states

        hidden_states = self.input_layernorm(hidden_states)

        # Compute the self-attention.
        hidden_states, self_attn_weights, present_key_value = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_value=past_key_value,
            output_attentions=output_attentions,
            use_cache=use_cache,
            cache_position=cache_position,
            position_embeddings=position_embeddings,
            **kwargs,
        )
        hidden_states = residual + hidden_states

        # Fully Connected
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states, routing_logits = self.moe(hidden_states)
        hidden_states = residual + hidden_states

        outputs = (hidden_states, )

        if output_attentions:
            outputs += (self_attn_weights, )

        if use_cache:
            outputs += (present_key_value, )

        if output_router_logits and routing_logits is not None:
            outputs += (routing_logits, )

        return outputs


class MoLoSMixin:
    """ A mixin class for the MoLoS model.
    """

    def get_trainable_parameters_number(
        self,
        print_param_name: bool = False,
    ) -> Tuple[int, int]:
        """ Get the number of trainable parameters and all parameters in the model.

        Args:
            print_param_name (bool, optional): Whether to print the names of the trainable parameters. Defaults to False.

        Returns:
            Tuple[int, int]: The number of trainable parameters and all parameters in the model.
        """

        trainable_params = 0
        all_params = 0

        for param_name, param in self.named_parameters():
            num_params = param.numel()

            # Handle this case if using DeepSpeed Zero-3 and the weights are initialized empty.
            if num_params == 0 and hasattr(param, 'ds_numel'):
                num_params = param.ds_numel

            # Due to the design of 4-bit linear layers from the bitsandbytes library, the number of parameters needs to be multiplied by 2 to obtain the correct count.
            if param.__class__.__name__ == 'Params4bit':
                num_params *= 2

            all_params += num_params
            if param.requires_grad:
                if print_param_name:
                    if trainable_params == 0:
                        print(f'The name of the trainable parameters:')

                    print(f'- {param_name}')

                trainable_params += num_params

        return trainable_params, all_params

    def print_trainable_parameters(
        self,
        print_info: bool = False,
    ) -> str:
        """ Print the number of trainable parameters and all parameters in the model.

        Args:
            print_info (bool, optional): Whether to print the information. Defaults to False.

        Returns:
            str: The message containing the number of trainable parameters and all parameters in the model.
        """

        trainable_params, all_params = self.get_trainable_parameters_number(print_param_name=print_info)

        message = f'Trainable Params Number: {trainable_params:,d} || ' + \
            f'All params Number: {all_params:,d} || ' + \
                f'Trainable Rate (%): {100 * trainable_params / all_params:.5f}'

        if print_info:
            print(message)

        return message


class MoLoSLlamaModel(
    LlamaPreTrainedModel,
    MoLoSMixin,
):

    def __init__(
        self,
        config: LlamaConfig,
        molos_config: MoLoSConfig,
    ):
        super().__init__(config=config)

        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size
        self.molos_config = molos_config

        self.embed_tokens = nn.Embedding(
            num_embeddings=config.vocab_size,
            embedding_dim=config.hidden_size,
            padding_idx=self.padding_idx,
        )

        self.layers = nn.ModuleList(modules=[
            MoLoSLlamaDecoderLayer(
                config=config,
                molos_config=molos_config,
                layer_idx=layer_idx,
            ) for layer_idx in range(config.num_hidden_layers)
        ])

        self.norm = LlamaRMSNorm(
            hidden_size=config.hidden_size,
            eps=config.rms_norm_eps,
        )

        self.rotary_emb = LlamaRotaryEmbedding(config=config)

        self.gradient_checkpointing = False

        if getattr(config, 'pretraining_tp', 1) != 1:
            logger.warn('`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.')

        # Initialize weights and apply final processing.
        self.post_init()

    def get_input_embeddings(self):
        return self.embed_tokens

    def set_input_embeddings(self, value):
        self.embed_tokens = value

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Union[
            Cache,
            List[torch.FloatTensor],
        ]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_router_logits: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
    ) -> Union[Tuple, MoeModelOutputWithPast]:
        output_attentions = output_attentions \
            if output_attentions is not None \
                else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states \
                if output_hidden_states is not None \
                    else self.config.output_hidden_states
        )
        output_router_logits = (
            output_router_logits \
                if output_router_logits is not None \
                    else self.molos_config.output_router_logits
        )

        use_cache = use_cache \
            if use_cache is not None \
                else self.config.use_cache

        return_dict = return_dict \
            if return_dict is not None \
                else self.config.use_return_dict

        if (input_ids is None) ^ (inputs_embeds is not None):
            raise ValueError(
                'You must specify exactly one of input_ids or inputs_embeds'
            )

        if self.gradient_checkpointing and self.training and use_cache:
            logger.warning_once(
                '`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.'
            )

            use_cache = False

        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        # Kept for BC (non `Cache`, `past_key_values` inputs).
        return_legacy_cache = False
        if use_cache and not isinstance(past_key_values, Cache):
            return_legacy_cache = True

            if past_key_values is None:
                past_key_values = DynamicCache()
            else:
                past_key_values = DynamicCache.from_legacy_cache(past_key_values=past_key_values)

                logger.warning_once(
                    'We detected that you are passing `past_key_values` as a tuple of tuples. '
                    'This is deprecated and will be removed in v4.47. '
                    'Please convert your cache or use an appropriate `Cache` class (https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format).'
                )

        if cache_position is None:
            past_seen_tokens = past_key_values.get_seq_length() \
                if past_key_values is not None else 0

            cache_position = torch.arange(
                past_seen_tokens,
                past_seen_tokens + inputs_embeds.shape[1],
                device=inputs_embeds.device,
            )

        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)

        causal_mask = self._update_causal_mask(
            attention_mask,
            inputs_embeds,
            cache_position,
            past_key_values,
            output_attentions,
        )
        hidden_states = inputs_embeds

        # Create position embeddings to be shared across the decoder layers.
        position_embeddings = self.rotary_emb(
            hidden_states,
            position_ids,
        )

        # Decoder layers.
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None
        all_router_logits = () if output_router_logits else None
        next_decoder_cache = None

        for decoder_layer in self.layers[:self.config.num_hidden_layers]:
            if output_hidden_states:
                all_hidden_states += (hidden_states, )

            if self.gradient_checkpointing and self.training:
                layer_outputs = self._gradient_checkpointing_func(
                    decoder_layer.__call__,
                    hidden_states,
                    causal_mask,
                    position_ids,
                    past_key_values,
                    output_attentions,
                    output_router_logits,
                    use_cache,
                    cache_position,
                    position_embeddings,
                )
            else:
                layer_outputs = decoder_layer(
                    hidden_states,
                    attention_mask=causal_mask,
                    position_ids=position_ids,
                    past_key_value=past_key_values,
                    output_attentions=output_attentions,
                    output_router_logits=output_router_logits,
                    use_cache=use_cache,
                    cache_position=cache_position,
                    position_embeddings=position_embeddings,
                    **flash_attn_kwargs,
                )

            hidden_states = layer_outputs[0]

            if use_cache:
                next_decoder_cache = layer_outputs[2 if output_attentions else 1]

            if output_attentions:
                all_self_attns += (layer_outputs[1], )

            if output_router_logits and \
                    isinstance(decoder_layer, MoLoSLlamaDecoderLayer):
                all_router_logits += (layer_outputs[-1], )

        hidden_states = self.norm(hidden_states)

        # Add hidden states from the last decoder layer.
        if output_hidden_states:
            all_hidden_states += (hidden_states, )

        next_cache = next_decoder_cache if use_cache else None
        if return_legacy_cache:
            next_cache = next_cache.to_legacy_cache()

        if not return_dict:
            return tuple(v for v in [
                hidden_states,
                next_cache,
                all_hidden_states,
                all_self_attns,
            ] if v is not None)

        return MoeModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=next_cache,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
            router_logits=all_router_logits,
        )

    def _update_causal_mask(
        self,
        attention_mask: torch.Tensor,
        input_tensor: torch.Tensor,
        cache_position: torch.Tensor,
        past_key_values: Cache,
        output_attentions: bool,
    ):
        if self.config._attn_implementation == "flash_attention_2":
            if attention_mask is not None and 0.0 in attention_mask:
                return attention_mask
            return None

        # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in order to dispatch on Flash Attention 2.
        # This feature is not compatible with static cache, as SDPA will fail to infer the attention mask.
        past_seen_tokens = past_key_values.get_seq_length() \
            if past_key_values is not None else 0
        using_static_cache = isinstance(past_key_values, StaticCache)

        # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward.
        if self.config._attn_implementation == 'sdpa' and \
                not using_static_cache and \
                    not output_attentions:
            if AttentionMaskConverter._ignore_causal_mask_sdpa(
                attention_mask,
                inputs_embeds=input_tensor,
                past_key_values_length=past_seen_tokens,
                is_training=self.training,
            ):
                return None

        device = input_tensor.device
        dtype = input_tensor.dtype
        sequence_length = input_tensor.shape[1]
        if using_static_cache:
            target_length = past_key_values.get_max_cache_shape()
        else:
            target_length = (attention_mask.shape[-1] \
                if isinstance(attention_mask, torch.Tensor) \
                    else past_seen_tokens + sequence_length + 1
            )

        # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
        causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
            attention_mask,
            sequence_length=sequence_length,
            target_length=target_length,
            dtype=dtype,
            device=device,
            cache_position=cache_position,
            batch_size=input_tensor.shape[0],
        )

        if (self.config._attn_implementation == 'sdpa' \
            and attention_mask is not None \
                and attention_mask.device.type == 'cuda' \
                    and not output_attentions
        ):
            # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when using left padding.
            # This is required by F.scaled_dot_product_attention memory-efficient attention path.
            # Details: https://github.com/pytorch/pytorch/issues/110213
            min_dtype = torch.finfo(dtype).min
            causal_mask = AttentionMaskConverter._unmask_unattended(
                causal_mask,
                min_dtype,
            )

        return causal_mask

    @staticmethod
    def _prepare_4d_causal_attention_mask_with_cache_position(
        attention_mask: torch.Tensor,
        sequence_length: int,
        target_length: int,
        dtype: torch.dtype,
        device: torch.device,
        cache_position: torch.Tensor,
        batch_size: int,
        **kwargs,
    ):
        """
        Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.

        Args:
            attention_mask (torch.Tensor): A 2D attention mask with shape `(batch_size, key_value_length)` or a 4D attention mask with shape `(batch_size, 1, query_length, key_value_length)`.
            sequence_length (int): The length of the sequence being processed.
            target_length (int): The target length. When generating with a static cache, the mask should be as long as the static cache to account for zero-padding in the unfilled portions of the cache.
            dtype (torch.dtype): The dtype for the 4D attention mask.
            device (torch.device): The device on which to place the 4D attention mask.
            cache_position (torch.Tensor): Indices indicating the positions of tokens in the input sequence.
            batch_size (torch.Tensor): The batch size.
        """

        if attention_mask is not None and attention_mask.dim() == 4:
            # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
            causal_mask = attention_mask
        else:
            min_dtype = torch.finfo(dtype).min
            causal_mask = torch.full(
                (sequence_length, target_length),
                fill_value=min_dtype,
                dtype=dtype,
                device=device,
            )

            if sequence_length != 1:
                causal_mask = torch.triu(causal_mask, diagonal=1)

            causal_mask *= (torch.arange(
                target_length,
                device=device,
            ) > cache_position.reshape(-1, 1))
            causal_mask = causal_mask[None, None, :, :].expand(
                batch_size, 1, -1, -1)

            if attention_mask is not None:
                # Copy to contiguous memory for in-place edit.
                causal_mask = causal_mask.clone()
                mask_length = attention_mask.shape[-1]

                padding_mask = \
                    causal_mask[:, :, :, :mask_length] + \
                        attention_mask[:, None, None, :]
                padding_mask = (padding_mask == 0)

                causal_mask[:, :, :, :mask_length] = \
                    causal_mask[:, :, :, :mask_length].masked_fill(
                        padding_mask,
                        min_dtype,
                    )

        return causal_mask

    @staticmethod
    def calculate_load_balance_loss(
        routing_logits: Union[torch.Tensor, Tuple[torch.Tensor], None],
        ex_num: Optional[int] = None,
        selected_ex_num: int = 2,
        attention_mask: Optional[torch.Tensor] = None,
        eps: float = 1e-6,
    ) -> Union[torch.Tensor, int]:
        if routing_logits is None:
            return 0

        all_loss = 0
        routing_logits_list = routing_logits \
            if isinstance(routing_logits, tuple) else [routing_logits]

        compute_device = routing_logits[0].device

        for one_routing_logits in routing_logits_list:
            routing_weights = F.softmax(
                input=one_routing_logits,
                dim=-1,
            )

            _, selected_experts = torch.topk(
                input=routing_weights,
                k=selected_ex_num,
                dim=-1,
            )

            # Do not know why the parameter name is `input`, not `tensor` in F.one_hot.
            expert_mask = F.one_hot(
                input=selected_experts,
                num_classes=ex_num,
            )
            expert_mask_summed = torch.sum(
                input=expert_mask.float(),
                dim=1,
            )

            if attention_mask is None:
                # Compute the percentage of tokens routed to each experts.
                tokens_per_expert = torch.mean(
                    input=expert_mask_summed,
                    dim=0,
                )

                # Compute the average probability of routing to these experts.
                router_prob_per_expert = torch.mean(
                    input=routing_weights,
                    dim=0,
                )
            else:
                batch_size, sequence_length = attention_mask.shape

                # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask.
                expert_attention_mask = (
                    attention_mask[:, :, None, None].expand(
                        size=(
                            batch_size,
                            sequence_length,
                            selected_ex_num,
                            ex_num,
                        )
                    ).reshape(
                        -1,
                        selected_ex_num,
                        ex_num,
                    ).to(device=compute_device)
                )
                expert_attention_mask_summed = torch.sum(
                    input=expert_attention_mask.float(),
                    dim=1,
                )

                # Compute the percentage of tokens routed to each experts.
                tokens_per_expert = \
                    torch.sum(
                        input=(expert_mask_summed * expert_attention_mask_summed),
                        dim=0,
                    ) / (
                        torch.sum(
                            input=expert_attention_mask,
                            dim=0,
                        ) + eps
                    )

                # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert.
                router_per_expert_attention_mask = (
                    attention_mask[:, :, None].expand(
                        size=(
                            batch_size,
                            sequence_length,
                            ex_num,
                        )
                    ).reshape(-1, ex_num).to(device=compute_device)
                )

                # Compute the average probability of routing to these experts.
                router_prob_per_expert = \
                    torch.sum(
                        input=routing_weights * router_per_expert_attention_mask,
                        dim=0,
                    ) / (
                        torch.sum(
                            input=router_per_expert_attention_mask,
                            dim=0,
                        ) + eps
                    )

            loss = torch.sum(
                input=(tokens_per_expert * router_prob_per_expert)
            )
            all_loss += loss

        return all_loss * ex_num / len(routing_logits_list)

    # @staticmethod
    # def calculate_load_balance_loss(
    #     routing_logits: Union[torch.Tensor, Tuple[torch.Tensor], None],
    #     ex_num: Optional[int] = None,
    #     selected_ex_num: int = 2,
    #     attention_mask: Optional[torch.Tensor] = None,
    # ) -> Union[torch.Tensor, int]:
    #     if routing_logits is None:
    #         return 0

    #     concatenated_routing_logits = None

    #     if isinstance(routing_logits, tuple):
    #         compute_device = routing_logits[0].device

    #         concatenated_routing_logits = torch.cat(
    #             tensors=[
    #                 layer_routing_logits.to(device=compute_device) \
    #                     for layer_routing_logits in routing_logits
    #             ],
    #             dim=0,
    #         )

    #     routing_weights = F.softmax(
    #         input=concatenated_routing_logits \
    #             if concatenated_routing_logits is not None \
    #                 else routing_logits,
    #         dim=-1,
    #     )

    #     _, selected_experts = torch.topk(
    #         input=routing_weights,
    #         k=selected_ex_num,
    #         dim=-1,
    #     )

    #     # Do not know why the parameter name is `input`, not `tensor` in F.one_hot.
    #     expert_mask = F.one_hot(
    #         input=selected_experts,
    #         num_classes=ex_num,
    #     )
    #     expert_mask_summed = torch.sum(
    #         input=expert_mask.float(),
    #         dim=1,
    #     )

    #     if attention_mask is None:
    #         # Compute the percentage of tokens routed to each experts.
    #         tokens_per_expert = torch.mean(
    #             input=expert_mask_summed,
    #             dim=0,
    #         )

    #         # Compute the average probability of routing to these experts.
    #         router_prob_per_expert = torch.mean(
    #             input=routing_weights,
    #             dim=0,
    #         )
    #     else:
    #         batch_size, sequence_length = attention_mask.shape
    #         num_hidden_layers = concatenated_routing_logits.shape[0] // \
    #             (batch_size * sequence_length)

    #         # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask.
    #         expert_attention_mask = (
    #             attention_mask[None, :, :, None, None].expand(
    #                 size=(
    #                     num_hidden_layers,
    #                     batch_size,
    #                     sequence_length,
    #                     selected_ex_num,
    #                     ex_num,
    #                 )
    #             ).reshape(
    #                 -1,
    #                 selected_ex_num,
    #                 ex_num,
    #             ).to(device=compute_device)
    #         )
    #         expert_attention_mask_summed = torch.sum(
    #             input=expert_attention_mask.float(),
    #             dim=1,
    #         )

    #         # Compute the percentage of tokens routed to each experts.
    #         tokens_per_expert = \
    #             torch.sum(
    #                 input=(expert_mask_summed * expert_attention_mask_summed),
    #                 dim=0,
    #             ) / torch.sum(
    #                 input=expert_attention_mask,
    #                 dim=0,
    #             )

    #         # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert.
    #         router_per_expert_attention_mask = (
    #             attention_mask[None, :, :, None].expand(
    #                 size=(
    #                     num_hidden_layers,
    #                     batch_size,
    #                     sequence_length,
    #                     ex_num,
    #                 )
    #             ).reshape(-1, ex_num).to(device=compute_device)
    #         )

    #         # Compute the average probability of routing to these experts.
    #         router_prob_per_expert = \
    #             torch.sum(
    #                 input=routing_weights * router_per_expert_attention_mask,
    #                 dim=0,
    #             ) / torch.sum(
    #                 input=router_per_expert_attention_mask,
    #                 dim=0,
    #             )

    #     overall_loss = torch.sum(
    #         input=(tokens_per_expert * router_prob_per_expert)
    #     )

    #     return overall_loss * ex_num

    @staticmethod
    def calculate_z_loss(
        routing_logits: Union[torch.Tensor, Tuple[torch.Tensor], None],
    ) -> Union[torch.Tensor, int]:
        if routing_logits is None:
            return 0

        concatenated_routing_logits = None

        if isinstance(routing_logits, tuple):
            compute_device = routing_logits[0].device

            concatenated_routing_logits = torch.cat(
                tensors=[
                    layer_routing_logits.to(device=compute_device) \
                        for layer_routing_logits in routing_logits
                ],
                dim=0,
            )

        log_sum_exp = torch.logsumexp(
            input=concatenated_routing_logits \
                if concatenated_routing_logits is not None \
                    else routing_logits,
            dim=-1,
        )

        return torch.mean(input=log_sum_exp**2)


class MoLoSLlamaForCausalLM(
    LlamaPreTrainedModel,
    GenerationMixin,
    MoLoSMixin,
):
    _tied_weights_keys = ['lm_head.weight']
    _tp_plan = {'lm_head': 'colwise_rep'}

    def __init__(
        self,
        config: LlamaConfig,
        molos_config: MoLoSConfig,
    ):
        super().__init__(config=config)

        self.vocab_size = config.vocab_size
        self.molos_config = molos_config

        self.model = MoLoSLlamaModel(
            config=config,
            molos_config=molos_config,
        )

        self.lm_head = nn.Linear(
            in_features=config.hidden_size,
            out_features=config.vocab_size,
            bias=False,
        )

        # Initialize weights and apply final processing.
        self.post_init()

    def get_input_embeddings(self):
        return self.model.embed_tokens

    def set_input_embeddings(self, value):
        self.model.embed_tokens = value

    def get_output_embeddings(self):
        return self.lm_head

    def set_output_embeddings(self, new_embeddings):
        self.lm_head = new_embeddings

    def set_decoder(self, decoder):
        self.model = decoder

    def get_decoder(self):
        return self.model

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_router_logits: Optional[bool] = None,
        output_similarities: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        num_logits_to_keep: int = 0,
        **kwargs: Unpack[KwargsForCausalLM],
    ) -> Union[Tuple, MoLoSCausalLMOutputWithPast]:
        output_attentions = output_attentions \
            if output_attentions is not None \
                else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states \
                if output_hidden_states is not None else \
                    self.config.output_hidden_states
        )
        output_router_logits = (
            output_router_logits \
                if output_router_logits is not None \
                    else self.molos_config.output_router_logits
        )

        return_dict = return_dict \
            if return_dict is not None \
                else self.config.use_return_dict

        # Decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn).
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            output_router_logits=output_router_logits,
            return_dict=return_dict,
            cache_position=cache_position,
            **kwargs,
        )

        hidden_states = outputs[0]

        # Only compute necessary logits, and do not upcast them to float if we are not computing the loss.
        logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])

        loss = None
        if labels is not None:
            loss = self.loss_function(
                logits=logits,
                labels=labels,
                vocab_size=self.config.vocab_size,
                **kwargs,
            )

        routing_logits = None
        router_aux_loss = None
        router_z_loss = None
        if output_router_logits:
            routing_logits = outputs.router_logits

            # 'sequence' is not supported for the router_aux_loss.
            if self.molos_config.select_type != 'sequence' and \
                    labels is not None:
                if self.molos_config.router_aux_loss_coef > 0:
                    router_aux_loss = self.model.calculate_load_balance_loss(
                        routing_logits=routing_logits,
                        ex_num=self.molos_config.ex_num,
                        selected_ex_num=self.molos_config.selected_ex_num,
                        attention_mask=attention_mask,
                    )

                    loss = loss + self.molos_config.router_aux_loss_coef * \
                        router_aux_loss.to(device=loss.device)

                if self.molos_config.router_z_loss_coef > 0:
                    router_z_loss = self.model.calculate_z_loss(
                        routing_logits=routing_logits
                    )

                    loss = loss + self.molos_config.router_z_loss_coef * \
                        router_z_loss.to(device=loss.device)

        if not return_dict:
            output = (logits, ) + outputs[1:]

            if output_router_logits:
                output = (router_aux_loss, router_z_loss, ) + output

            return (loss, ) + output if loss is not None else output

        return MoLoSCausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            router_aux_loss=router_aux_loss,
            router_z_loss=router_z_loss,
            router_logits=routing_logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
