from transformers.utils import logging
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
from torch import nn

from dataclasses import dataclass
from torch.nn import CrossEntropyLoss
from torch.nn.utils import skip_init
from torch.nn.utils.rnn import pad_sequence
from torch.utils.checkpoint import checkpoint
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.generation.utils import (
    ModelOutput,
    StoppingCriteria,
    StoppingCriteriaList,
)
from torch.amp.autocast_mode import autocast
from .configuration_chatglm import ChatGLMConfig

from .modeling_chatglm import (
    ChatGLMPreTrainedModel,
    LayerNorm,
    RMSNorm,
    RotaryEmbedding,
    MLP,
    Embedding,
    SelfAttention,
    GLMBlock,
    ChatGLMForConditionalGeneration,
    _config_to_kwargs,
    apply_rotary_pos_emb,
    default_init,
    split_tensor_along_last_dim,
)

from modules.dforcing import (
    DiscreteDiffusion,
    MotDiscreteDiffusion,
    ContinuousDiffusion,
    StochasticTimeEmbedding,
)

ALL_CACHE_NAMES = [
    "past_key_values",  # default
    "cache_params",  # mamba-based models
    "state",  # rwkv
    "mems",  # xlnet
    "past_buckets_states",  # reformer
    "mot_input_ids",
    "mot_input_embs",
    "valid_pos",
    "type_ids",
    "position_ids",
]


@dataclass
class MOTBaseModelOutputWithPast(ModelOutput):
    last_hidden_state: Optional[torch.FloatTensor] = None
    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
    attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
    valid_pos: Optional[List[torch.BoolTensor]] = None


@dataclass
class MOTModelOutputWithPast(ModelOutput):
    loss: Optional[torch.FloatTensor] = None
    logits: Optional[torch.FloatTensor] = None
    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
    attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
    mot_input_ids: Optional[List[torch.Tensor]] = None
    mot_input_embs: Optional[List[torch.FloatTensor]] = None
    valid_pos: Optional[List[torch.Tensor]] = None
    type_ids: Optional[List[torch.Tensor]] = None
    position_ids: Optional[torch.Tensor] = None


@dataclass
class MOTGenerateDecoderOnlyOutput(ModelOutput):
    sequences: Optional[torch.LongTensor] = None
    scores: Optional[Tuple[torch.FloatTensor]] = None
    logits: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    mot_output_ids: Optional[List[torch.Tensor]] = None
    mot_output_embs: Optional[List[torch.Tensor]] = None
    past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None


IGNORE_INDEX = -100
logger = logging.get_logger(__name__)


class MOTMLP(MLP):
    """
    Add support for custom hidden_size & ffn_hidden_size
    """

    def __init__(
        self, hidden_size, ffn_hidden_size, add_bias_linear, config, device=None
    ):
        super().__init__(config, device)

        # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
        self.dense_h_to_4h = nn.Linear(
            hidden_size,
            ffn_hidden_size * 2,
            bias=add_bias_linear,
            device=device,
            **_config_to_kwargs(config),
        )

        # Project back to h.
        self.dense_4h_to_h = nn.Linear(
            ffn_hidden_size,
            hidden_size,
            bias=add_bias_linear,
            device=device,
            **_config_to_kwargs(config),
        )


class MOTEmbedding(Embedding):
    """
    Add support for custom hidden_size & ffn_hidden_size
    """

    def __init__(
        self, padded_vocab_size, hidden_size, config: ChatGLMConfig, device=None
    ):
        super().__init__(config)

        # Word embeddings (parallel).
        self.word_embeddings = nn.Embedding(
            padded_vocab_size,
            hidden_size,
            dtype=config.torch_dtype,
            device=device,
        )
        self.fp32_residual_connection = config.fp32_residual_connection


class MOTRotaryEmbedding(RotaryEmbedding):
    def forward_impl(
        self,
        *args,
        seq_idx: torch.Tensor,
        n_elem: int,
        dtype: torch.dtype,
        device: torch.device,
        base: int = 10000,
        **kwargs,
    ):
        """Enhanced Transformer with Rotary Position Embedding.

        Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
        transformers/rope/__init__.py. MIT License:
        https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
        """

        # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
        base = base * self.rope_ratio
        theta = 1.0 / (
            base
            ** (torch.arange(0, n_elem, 2, dtype=torch.float, device=device) / n_elem)
        )

        # Calculate the product of position index and $\theta_i$
        idx_theta = torch.outer(seq_idx, theta).float()
        # idx_theta = seq_idx.unsqueeze(-1) * theta

        cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)

        # this is to mimic the behaviour of complex32, else we will get different results
        if dtype in (torch.float16, torch.bfloat16, torch.int8):
            cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half()

        return cache

    def forward(
        self,
        *args,
        position_ids: torch.Tensor,
        **kwargs,
    ):
        rotary_pos_emb = []
        for seq_idx in position_ids:
            emb = self.forward_impl(
                seq_idx=seq_idx,
                n_elem=self.dim,
                dtype=self.inv_freq.dtype,
                device=self.inv_freq.device,
            )
            rotary_pos_emb.append(emb)

        rotary_pos_emb = torch.stack(rotary_pos_emb)

        return rotary_pos_emb


class MOTSelfAttention(SelfAttention):
    """Parallel self-attention layer abstract class.

    Self-attention layer takes input with size [s, b, h]
    and returns output of the same size.
    """

    def __init__(self, config: ChatGLMConfig, layer_number, device=None):
        super().__init__(config, layer_number, device)

        self.mot_hidden_size = config.mot_hidden_size

        num_layers = config.num_layers
        attention_mode = config.mot_cross_attn_mode
        if attention_mode == "all":
            share_attn = True
        elif attention_mode == "first":
            share_attn = layer_number == 1
        elif attention_mode == "last":
            share_attn = layer_number == num_layers
        elif attention_mode == "Ahalf":
            share_attn = layer_number <= num_layers // 2
        elif attention_mode == "halfB":
            share_attn = layer_number > num_layers // 2
        elif attention_mode == "odd":
            share_attn = layer_number % 2 == 1
        elif attention_mode == "even":
            share_attn = layer_number % 2 == 0
        elif attention_mode == "None":
            share_attn = False
        else:
            raise NotImplementedError(
                "Unsupported attention mode: {}".format(attention_mode)
            )

        self.share_attn = not share_attn

        # MoT for different attention heads.
        # if not self.share_attn:

        # MoT for different attention outputs.

        # else:
        #     self.mot_query_key_value = None

        self.mot_query_key_value = nn.ModuleList(
            [
                nn.Linear(
                    ch,
                    self.qkv_hidden_size,
                    bias=config.add_bias_linear or config.add_qkv_bias,
                    device=device,
                    **_config_to_kwargs(config),
                )
                for ch in self.mot_hidden_size
            ]
        )

        self.mot_dense = nn.ModuleList(
            [
                nn.Linear(
                    self.projection_size,
                    ch,
                    bias=config.add_bias_linear,
                    device=device,
                    **_config_to_kwargs(config),
                )
                for ch in self.mot_hidden_size
            ]
        )

    def apply_qkv_proj(
        self,
        hidden_states: List[torch.Tensor],
        valid_pos: List[torch.Tensor],
        module: nn.Module,
        mot_module: nn.ModuleList | List,
    ):
        # Text-Audio Part
        hidden_states_ = module(hidden_states[0])

        # MOT Part
        for i, _ in enumerate(self.mot_hidden_size):
            pos = valid_pos[i + 1]
            both_valid = valid_pos[i] & valid_pos[i + 1]  # shape (batch, length)
            pos = pos & ~both_valid  # Only keep positions not in both_valid
            input_tensor = hidden_states[i + 1][pos]
            output_tensor = mot_module[i](input_tensor)
            hidden_states_[pos] = output_tensor

        return hidden_states_

    def apply_output_proj(
        self,
        original_hidden_states: List[torch.Tensor],
        valid_pos: List[torch.Tensor],
        hidden_states: torch.Tensor,
        module: nn.Module,
        mot_module: nn.ModuleList | List,
    ):
        hidden_states_ = [
            torch.full_like(original_hidden_state, IGNORE_INDEX)
            for original_hidden_state in original_hidden_states
        ]

        hidden_states_[0][valid_pos[0]] = module(hidden_states[valid_pos[0]])
        for i in range(1, len(hidden_states_)):
            hidden_states_[i][valid_pos[i]] = mot_module[i - 1](
                hidden_states[valid_pos[i]]
            )

        return hidden_states_

    def independent_attn_mask(self, attention_mask, valid_pos):
        attention_mask = attention_mask.clone()
        attention_mask.squeeze_(1)
        base_local_mask = valid_pos[0]
        mot_local_mask = valid_pos[1]
        base_local_mask = base_local_mask & ~mot_local_mask
        local_mask = base_local_mask.float()
        attention_mask = attention_mask * local_mask.unsqueeze(1)
        attention_mask = (attention_mask < 0.5).bool()
        attention_mask.unsqueeze_(1)  # multi-head attention
        return attention_mask

    def forward(
        self,
        hidden_states: List[torch.Tensor],
        valid_pos: List[torch.Tensor],
        attention_mask,
        rotary_pos_emb,
        kv_cache=None,
        use_cache=True,
    ):
        # hidden_states: [b, sq, h]

        # =================================================
        # Pre-allocate memory for key-values for inference.
        # =================================================
        # =====================
        # Query, Key, and Value
        # =====================

        # Attention heads [b, sq, h] --> [b, sq, (np * 3 * hn)]
        mixed_x_layer = self.apply_qkv_proj(
            hidden_states, valid_pos, self.query_key_value, self.mot_query_key_value
        )

        if self.share_attn:
            attention_mask = self.independent_attn_mask(attention_mask, valid_pos)

        if self.multi_query_attention:
            (query_layer, key_layer, value_layer) = mixed_x_layer.split(
                [
                    self.num_attention_heads_per_partition
                    * self.hidden_size_per_attention_head,
                    self.num_multi_query_groups_per_partition
                    * self.hidden_size_per_attention_head,
                    self.num_multi_query_groups_per_partition
                    * self.hidden_size_per_attention_head,
                ],
                dim=-1,
            )
            query_layer = query_layer.view(
                query_layer.size()[:-1]
                + (
                    self.num_attention_heads_per_partition,
                    self.hidden_size_per_attention_head,
                )
            )
            key_layer = key_layer.view(
                key_layer.size()[:-1]
                + (
                    self.num_multi_query_groups_per_partition,
                    self.hidden_size_per_attention_head,
                )
            )
            value_layer = value_layer.view(
                value_layer.size()[:-1]
                + (
                    self.num_multi_query_groups_per_partition,
                    self.hidden_size_per_attention_head,
                )
            )
        else:
            new_tensor_shape = mixed_x_layer.size()[:-1] + (
                self.num_attention_heads_per_partition,
                3 * self.hidden_size_per_attention_head,
            )
            mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

            # [b, sq, np, 3 * hn] --> 3 [b, sq, np, hn]
            (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(
                mixed_x_layer, 3
            )

        # [b, sq, np, hn] -> [b, np, sq, hn]
        query_layer, key_layer, value_layer = [
            k.transpose(1, 2) for k in [query_layer, key_layer, value_layer]
        ]

        # apply relative positional encoding (rotary embedding)
        if rotary_pos_emb is not None:
            query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)
            key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)

        # adjust key and value for inference
        if kv_cache is not None:
            cache_k, cache_v = kv_cache
            key_layer = torch.cat((cache_k, key_layer), dim=2)
            value_layer = torch.cat((cache_v, value_layer), dim=2)

        if use_cache:
            if kv_cache is None:
                kv_cache = torch.cat(
                    (
                        key_layer.unsqueeze(0).unsqueeze(0),
                        value_layer.unsqueeze(0).unsqueeze(0),
                    ),
                    dim=1,
                )
            else:
                kv_cache = (key_layer, value_layer)
        else:
            kv_cache = None

        if self.multi_query_attention:
            key_layer = key_layer.unsqueeze(2)
            key_layer = key_layer.expand(
                -1,
                -1,
                self.num_attention_heads_per_partition
                // self.num_multi_query_groups_per_partition,
                -1,
                -1,
            )
            key_layer = key_layer.contiguous().view(
                key_layer.size()[:1]
                + (self.num_attention_heads_per_partition,)
                + key_layer.size()[3:]
            )
            value_layer = value_layer.unsqueeze(2)
            value_layer = value_layer.expand(
                -1,
                -1,
                self.num_attention_heads_per_partition
                // self.num_multi_query_groups_per_partition,
                -1,
                -1,
            )
            value_layer = value_layer.contiguous().view(
                value_layer.size()[:1]
                + (self.num_attention_heads_per_partition,)
                + value_layer.size()[3:]
            )

        # ==================================
        # core attention computation
        # ==================================

        context_layer = self.core_attention(
            query_layer, key_layer, value_layer, attention_mask
        )

        # =================
        # Output. [sq, b, h]
        # =================

        output = self.apply_output_proj(
            hidden_states, valid_pos, context_layer, self.dense, self.mot_dense
        )

        return output, kv_cache


class MOTGLMBlock(GLMBlock):
    """A single transformer layer.

    Transformer layer takes input with size [s, b, h] and returns an
    output of the same size.
    """

    def __init__(self, config: ChatGLMConfig, layer_number, device=None):
        super().__init__(config, layer_number, device)

        self.mot_hidden_size = config.mot_hidden_size
        self.mot_ffn_hidden_size = config.mot_ffn_hidden_size

        LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm

        self.self_attention = MOTSelfAttention(config, layer_number, device=device)

        # num_layers = config.num_layers
        # attention_mode = config.mot_cross_attn_mode
        # self.adapter = nn.Identity()
        # use_adapter = False
        # if attention_mode == "all":
        #     share_attn = True
        # elif attention_mode == "first":
        #     share_attn = layer_number == 1
        #     if layer_number == 1:
        #         use_adapter = "up"
        #     if layer_number == 2:
        #         use_adapter = "down"
        # elif attention_mode == "last":
        #     share_attn = layer_number == num_layers
        #     if layer_number == num_layers:
        #         use_adapter = True
        # elif attention_mode == "Ahalf":
        #     share_attn = layer_number <= num_layers // 2
        #     if layer_number == num_layers // 2:
        #         use_adapter = True
        # elif attention_mode == "halfB":
        #     share_attn = layer_number > num_layers // 2
        #     if layer_number == num_layers // 2 + 1:
        #         use_adapter = True
        # elif attention_mode == "odd":
        #     share_attn = layer_number % 2 == 1
        #     if layer_number % 2 == 1:
        #         use_adapter = True
        # elif attention_mode == "even":
        #     share_attn = layer_number % 2 == 0
        #     if layer_number % 2 == 1:
        #         use_adapter = True
        # elif attention_mode == "None":
        #     share_attn = False
        # else:
        #     raise NotImplementedError(
        #         "Unsupported attention mode: {}".format(attention_mode)
        #     )

        # self.share_attn = not share_attn

        self.mot_input_layernorm = nn.ModuleList(
            [
                LayerNormFunc(
                    ch,
                    eps=config.layernorm_epsilon,
                    device=device,
                    dtype=config.torch_dtype,
                )
                for ch in self.mot_hidden_size
            ]
        )

        self.mot_post_attention_layernorm = nn.ModuleList(
            [
                LayerNormFunc(
                    ch,
                    eps=config.layernorm_epsilon,
                    device=device,
                    dtype=config.torch_dtype,
                )
                for ch in self.mot_hidden_size
            ]
        )

        self.mot_mlp = nn.ModuleList(
            [
                MOTMLP(
                    ch,
                    self.mot_ffn_hidden_size[i],
                    config.add_bias_linear,
                    config,
                    device=device,
                )
                for i, ch in enumerate(self.mot_hidden_size)
            ]
        )
        # else:
        #     self.mot_input_layernorm = None
        #     self.mot_post_attention_layernorm = None
        #     self.mot_mlp = None

    def apply_residual(self, hidden_states, residual):
        output = []
        for i in range(len(hidden_states)):
            dropout_output = nn.functional.dropout(
                hidden_states[i], p=self.hidden_dropout, training=self.training
            )
            output.append(residual[i] + dropout_output)
        return output

    def forward(
        self,
        hidden_states: List[torch.Tensor],
        valid_pos: List[torch.Tensor],
        attention_mask,
        rotary_pos_emb,
        kv_cache=None,
        use_cache=True,
    ):
        # hidden_states: [s, b, h]

        # Layer norm at the beginning of the transformer layer.
        layernorm_output = [self.input_layernorm(hidden_states[0])]

        layernorm_output.extend(
            [
                self.mot_input_layernorm[i](hidden_state)
                for i, hidden_state in enumerate(hidden_states[1:])
            ]
        )

        # Self attention.
        attention_output, kv_cache = self.self_attention(
            layernorm_output,
            valid_pos,
            attention_mask,
            rotary_pos_emb,
            kv_cache=kv_cache,
            use_cache=use_cache,
        )

        # Residual connection.
        if self.apply_residual_connection_post_layernorm:
            residual = layernorm_output
        else:
            residual = hidden_states

        layernorm_input = self.apply_residual(attention_output, residual)

        # Layer norm post the self attention.
        layernorm_output = [self.post_attention_layernorm(layernorm_input[0])]

        layernorm_output.extend(
            [
                self.mot_post_attention_layernorm[i](hidden_state)
                for i, hidden_state in enumerate(layernorm_input[1:])
            ]
        )

        # MLP.
        mlp_output = [self.mlp(layernorm_output[0])]
        mlp_output.extend(
            [
                self.mot_mlp[i](layernorm_output[i + 1])
                for i, _ in enumerate(self.mot_hidden_size)
            ]
        )

        # Second residual connection.
        if self.apply_residual_connection_post_layernorm:
            residual = layernorm_output
        else:
            residual = layernorm_input

        output = self.apply_residual(mlp_output, residual)

        return output, kv_cache


class MOTGLMTransformer(nn.Module):
    """Transformer class."""

    def __init__(self, config: ChatGLMConfig, device=None):
        super(MOTGLMTransformer, self).__init__()

        self.fp32_residual_connection = config.fp32_residual_connection
        self.post_layer_norm = config.post_layer_norm

        self.mot_hidden_size = config.mot_hidden_size

        # Number of layers.
        self.num_layers = config.num_layers

        # Transformer layers.
        def build_layer(layer_number):
            return MOTGLMBlock(config, layer_number, device=device)

        self.layers = nn.ModuleList(
            [build_layer(i + 1) for i in range(self.num_layers)]
        )

        if self.post_layer_norm:
            LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
            # Final layer norm before output.
            self.final_layernorm = LayerNormFunc(
                config.hidden_size,
                eps=config.layernorm_epsilon,
                device=device,
                dtype=config.torch_dtype,
            )

            self.mot_final_layernorm = nn.ModuleList(
                [
                    LayerNormFunc(
                        ch,
                        eps=config.layernorm_epsilon,
                        device=device,
                        dtype=config.torch_dtype,
                    )
                    for ch in self.mot_hidden_size
                ]
            )

        self.gradient_checkpointing = False

    def _get_layer(self, layer_number):
        return self.layers[layer_number]

    def forward(
        self,
        hidden_states,
        valid_pos: List[torch.Tensor],
        attention_mask,
        rotary_pos_emb,
        kv_caches=None,
        use_cache: Optional[bool] = True,
        output_hidden_states: Optional[bool] = False,
    ):
        if not kv_caches:
            kv_caches = [None for _ in range(self.num_layers)]
        presents = () if use_cache else None

        if self.gradient_checkpointing and self.training:
            if use_cache:
                logger.warning_once(
                    "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
                )
                use_cache = False

        all_self_attentions = None
        all_hidden_states = ()
        for index in range(self.num_layers):
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)
            layer = self._get_layer(index)
            if self.gradient_checkpointing and self.training:
                layer_ret = checkpoint(
                    layer,
                    hidden_states,
                    valid_pos,
                    attention_mask,
                    rotary_pos_emb,
                    kv_caches[index],
                    use_cache,
                    use_reentrant=False,
                )
            else:
                layer_ret = layer(
                    hidden_states,
                    valid_pos,
                    attention_mask,
                    rotary_pos_emb,
                    kv_cache=kv_caches[index],
                    use_cache=use_cache,
                )

            hidden_states, kv_cache = layer_ret

            if use_cache:
                # token by token decoding, use tuple format
                if kv_caches[0] is not None:
                    presents = presents + (kv_cache,)
                # prefilling in decoding, use tensor format to save cuda memory
                else:
                    if len(presents) == 0:
                        presents = kv_cache
                    else:
                        presents = torch.cat(
                            (presents, kv_cache.to(presents.device)), dim=0
                        )

        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        # Final layer normodule.
        if self.post_layer_norm:
            hidden_states[0] = self.final_layernorm(hidden_states[0])
            for i in range(1, len(hidden_states)):
                hidden_states[i] = self.mot_final_layernorm[i - 1](hidden_states[i])

        return hidden_states, presents, all_hidden_states, all_self_attentions


class MOTChatGLMModel(ChatGLMPreTrainedModel):
    def __init__(
        self, config: ChatGLMConfig, config_dforcing, device=None, empty_init=True
    ):
        super().__init__(config)
        self.mot_hidden_size = config.mot_hidden_size
        self.mot_vocab_size = config.mot_vocab_size
        self.mot_window = config.mot_window

        if empty_init:
            init_method = skip_init
        else:
            init_method = default_init
        init_kwargs = {}
        if device is not None:
            init_kwargs["device"] = device

        self.embedding = init_method(
            MOTEmbedding,
            config.padded_vocab_size,
            config.hidden_size,
            config,
            **init_kwargs,
        )
        self.mot_embedding = nn.ModuleList(
            [
                MOTEmbedding(self.mot_vocab_size[i], ch, config, **init_kwargs)
                for i, ch in enumerate(config.mot_hidden_size)
            ]
        )

        self.mot_mapping = nn.ModuleList(
            [
                nn.Linear(
                    config.mot_dim[i],
                    ch,
                    bias=False,
                    dtype=config.torch_dtype,
                    **init_kwargs,
                )
                for i, ch in enumerate(self.mot_hidden_size)
            ]
        )

        self.num_layers = config.num_layers
        self.multi_query_group_num = config.multi_query_group_num
        self.kv_channels = config.kv_channels

        # Rotary positional embeddings
        self.seq_length = config.seq_length
        rotary_dim = (
            config.hidden_size // config.num_attention_heads
            if config.kv_channels is None
            else config.kv_channels
        )

        self.rotary_pos_emb = MOTRotaryEmbedding(
            rotary_dim // 2,
            rope_ratio=config.rope_ratio,
            original_impl=config.original_rope,
            device=device,
            dtype=config.torch_dtype,
        )

        self.encoder = init_method(MOTGLMTransformer, config, **init_kwargs)
        self.output_layer = init_method(
            nn.Linear,
            config.hidden_size,
            config.padded_vocab_size,
            bias=False,
            dtype=config.torch_dtype,
            **init_kwargs,
        )
        # self.mot_output_layer = nn.ModuleList(
        #     [
        #         nn.Linear(
        #             ch,
        #             self.mot_vocab_size[i],
        #             bias=False,
        #             dtype=config.torch_dtype,
        #             **init_kwargs,
        #         )
        #         for i, ch in enumerate(self.mot_hidden_size)
        #     ]
        # )
        self.mot_emb_output_layer = nn.ModuleList(
            [
                nn.Linear(
                    ch,
                    config.mot_dim[i],
                    bias=False,
                    dtype=config.torch_dtype,
                    **init_kwargs,
                )
                for i, ch in enumerate(self.mot_hidden_size)
            ]
        )
        if self.config.output_conv:
            self.mot_emb_output_conv = nn.ModuleList(
                [
                    nn.Conv1d(
                        ch,
                        ch,
                        kernel_size=3,
                        stride=1,
                        padding=1,
                    )
                    for i, ch in enumerate(self.mot_hidden_size)
                ]
            )

        self.get_time_embedding_for_diffusion(config_dforcing)

    def get_time_embedding_for_diffusion(self, config_dforcing):
        time_embed_dim = config_dforcing.x_shape[0]
        use_fourier_noise_embedding = (
            config_dforcing.backbone.use_fourier_noise_embedding
        )
        self.noise_level_pos_embedding = StochasticTimeEmbedding(
            # dim=self.noise_level_dim,
            time_embed_dim=time_embed_dim,
            use_fourier=use_fourier_noise_embedding,
        )

        # Initialize noise level embedding and external condition embedding MLPs:
        def _mlp_init(module: nn.Module) -> None:
            if isinstance(module, nn.Linear):
                nn.init.normal_(module.weight, std=0.02)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)

        # init params
        self.noise_level_pos_embedding.apply(_mlp_init)

    def get_input_embeddings(self):  # type: ignore
        embeddings = [self.embedding.word_embeddings]
        embeddings.extend([x.word_embeddings for x in self.mot_embedding])
        return embeddings

    def set_input_embeddings(self, value: List[torch.embedding]):  # type: ignore
        self.embedding.word_embeddings = value[0]
        for i, x in enumerate(self.mot_embedding):
            x.word_embeddings = value[i + 1]

    def get_mot_attn_mask(
        self,
        batch_size: int,
        seq_length: int,
        valid_pos: List[torch.Tensor],
        past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
        padding_mask: Optional[torch.Tensor] = None,
        type_ids: Optional[torch.Tensor] = None,
        device: Optional[torch.device] = None,
    ) -> torch.BoolTensor:
        # TODO: support input attention mask
        padding_mask = None

        # Calculate total length including past context
        past_length = 0 if past_key_values is None else past_key_values[0][0].shape[2]
        total_length = past_length + seq_length

        # Create base causal mask
        full_attention_mask = torch.ones(
            batch_size, total_length, total_length, device=device, dtype=torch.bool
        ).tril_()

        # Mask out mot for text/audio queries
        base_local_mask = valid_pos[0]
        mot_local_mask = valid_pos[1]
        base_local_mask = base_local_mask & ~mot_local_mask
        mot_local_mask = base_local_mask | mot_local_mask
        if type_ids is not None:
            is_mot = type_ids > 0.5
            local_mask = torch.where(is_mot, mot_local_mask, base_local_mask).float()
        else:
            local_mask = base_local_mask.float()

        full_attention_mask = full_attention_mask * local_mask.unsqueeze(1)
        if not past_length and local_mask is not None:
            full_attention_mask -= local_mask.unsqueeze(-1) - 1

        if padding_mask is not None:
            full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1)
        if not past_length and padding_mask is not None:
            full_attention_mask -= padding_mask.unsqueeze(-1) - 1

        full_attention_mask = full_attention_mask[:, -seq_length:]
        full_attention_mask = (full_attention_mask < 0.5).bool()
        full_attention_mask.unsqueeze_(1)  # multi-head attention

        return full_attention_mask  # type: ignore

    def forward(
        self,
        input_ids: torch.Tensor,
        mot_input_ids: List[torch.Tensor],
        mot_input_embs: List[torch.Tensor],
        noise_levels: List[torch.Tensor],
        position_ids: torch.Tensor,
        valid_pos: List[torch.Tensor],
        attention_mask: Optional[torch.BoolTensor] = None,
        type_ids: Optional[torch.Tensor] = None,
        past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
        inputs_embeds: Optional[List[torch.Tensor]] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        output_hidden_states = (
            output_hidden_states
            if output_hidden_states is not None
            else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = (
            return_dict if return_dict is not None else self.config.use_return_dict
        )

        batch_size, seq_length = input_ids.shape

        # if past_key_values is not None:
        current_valid_pos = [valid_pos[0][..., -seq_length:]]
        current_valid_pos.extend(
            [x[..., -seq_length:] for i, x in enumerate(valid_pos[1:])]
        )

        input_ids = input_ids.masked_fill(~current_valid_pos[0], 0)
        for i, x in enumerate(mot_input_ids):
            mot_input_ids[i] = x.masked_fill(~current_valid_pos[i + 1], 0)
        for i, x in enumerate(mot_input_embs):
            mot_input_embs[i] = x.masked_fill(
                ~current_valid_pos[i + 1].unsqueeze(-1), 0
            )

        if inputs_embeds is None:
            # add timestep positional embedding
            time_emb = [self.noise_level_pos_embedding(noise_levels[0])]
            mot_input_embs = [x + k for x, k in zip(mot_input_embs, time_emb)]

            inputs_embeds = [self.embedding(input_ids)]
            # put all the inputs_embeds to bfloat16
            # inputs_embeds = [self.embedding(input_ids).to(mot_input_embs[0].dtype)]
            inputs_embeds.extend(
                # Diffusion forcing, mot_embedding to motion_embeding
                [x(mot_input_embs[i]) for i, x in enumerate(self.mot_mapping)]
            )

            inputs_embeds = [x.to(torch.bfloat16) for x in inputs_embeds]

            # if self.config.fp32_residual_connection:
            # inputs_embeds = [
            #     x.float() for x in inputs_embeds
            # ]

        # MOT Attention mask
        full_attention_mask = self.get_mot_attn_mask(
            batch_size,
            seq_length,
            valid_pos,
            past_key_values,
            type_ids=type_ids,
            device=input_ids.device,
        )

        if past_key_values:
            position_ids = position_ids[..., -seq_length:]
        full_rotary_pos_emb = self.rotary_pos_emb(position_ids=position_ids)

        # Run encoder.
        hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
            inputs_embeds,
            current_valid_pos,
            full_attention_mask,
            rotary_pos_emb=full_rotary_pos_emb,
            kv_caches=past_key_values,
            use_cache=use_cache,
            output_hidden_states=output_hidden_states,
        )

        if presents is not None and type(presents) is torch.Tensor:
            presents = presents.split(1, dim=0)
            presents = list(presents)
            presents = [list(x.squeeze(0).split(1, dim=0)) for x in presents]
            presents = [tuple([x.squeeze(0) for x in y]) for y in presents]
            presents = tuple(presents)

        if not return_dict:
            return tuple(
                v
                for v in [
                    hidden_states,
                    presents,
                    all_hidden_states,
                    all_self_attentions,
                ]
                if v is not None
            )

        return MOTBaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=presents,  # type: ignore
            hidden_states=all_hidden_states,
            attentions=all_self_attentions,
            valid_pos=valid_pos,  # type: ignore
        )


def interpolate_position_ids(
    position_ids, window_ratio=(3, 6), pad_val=-100, device=None
):
    """
    Interpolate position ids to fill in the masked area.
    eg.
    [-100,1,2,3,4,-100,-100,-100,-100,-100,-100,5]
    ->
    [-100,1.0, 2.0, 3.0, 4.0, 2.33, 2.67, 3.33, 3.67, 4.33, 4.67, 5.0]
    """
    # TODO: support more than one modalities
    result = []
    left_win, right_win = window_ratio

    for row in position_ids:
        output = []
        i = 0
        while i < len(row):
            if row[i] != pad_val:
                output.append(row[i].item())
                i += 1
            else:
                # mask area
                start = i
                while i < len(row) and row[i] == pad_val:
                    i += 1
                end = i
                mask_len = end - start
                left_len = int(mask_len * left_win / right_win)

                # Calculate the step size
                step = 1 / (right_win / left_win + 1)

                # Fill in the mask area
                base = row[start - left_len].item() if start > 0 else 0
                for j in range(1, mask_len + 1):
                    if j > 1 and j % (right_win / left_win) == 1:
                        base += step
                    output.append(base + step * j)

        result.append(output)

    return torch.tensor(result, device=device, dtype=torch.float32)


class MOTChatGLMForConditionalGeneration(ChatGLMForConditionalGeneration):
    _keys_to_ignore_on_load_missing = [r".*\.mot_.*"]

    def __init__(self, config: ChatGLMConfig, empty_init=False, device=None):
        super().__init__(config)

        # init transformer
        config_dforcing = config.config_dforcing
        self.transformer = MOTChatGLMModel(
            config, config_dforcing, empty_init=empty_init, device=device
        )
        self.mot_window = config.mot_window
        self.mot_vocab_size = config.mot_vocab_size
        self.eos_token_id = config.eos_token_id
        self.mot_eos_token_id = config.mot_eos_token_id
        self.mot_loss_weight = config.mot_loss_weight

        # For generation
        self.type_ids = None
        self.attention_mask = None
        self.input_ids = None
        self.mot_input_ids = None
        self.mot_input_embs = None
        self.router = MOTGenerationRouter(
            config.audio_offset,
            config.mot_window[0],
            config.mot_window[1],
            self.eos_token_id,
            self.mot_eos_token_id,
        )
        self.df_router = DFRouter(config.mot_window)

        self.load_with_diffusion_backbone(config_dforcing)

    def load_with_diffusion_backbone(self, dforcing_config):
        self.dforcing_config = dforcing_config

        if dforcing_config.diffusion.is_continuous:
            raise NotImplementedError("Continuous diffusion is not supported yet.")

        diffusion_cls = (
            # MotDiscreteDiffusion
            ContinuousDiffusion
            if dforcing_config.diffusion.is_continuous  # False
            else MotDiscreteDiffusion
        )

        self.diffusion_model = diffusion_cls(
            cfg=dforcing_config,
            model=self.transformer,
            x_shape=dforcing_config.x_shape,
            max_tokens=dforcing_config.backbone.max_tokens_training,
            external_cond_dim=dforcing_config.external_cond_dim,  # None
            # device=self.transformer.device,
        )
        # .to(self.transformer.device, dtype=torch.float32)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                nn.init.constant_(module.bias, 0)
        elif isinstance(module, nn.LayerNorm):
            if module.bias is not None:
                nn.init.constant_(module.bias, 0)
            if module.weight is not None:
                nn.init.constant_(module.weight, 1.0)
        elif isinstance(module, RMSNorm):
            nn.init.constant_(module.weight, 1.0)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.padding_idx is not None:
                nn.init.constant_(module.weight[module.padding_idx], 0)

    def _update_model_kwargs_for_generation(
        self,
        outputs: ModelOutput,
        model_kwargs: Dict[str, Any],
        is_encoder_decoder: bool = False,
    ) -> Dict[str, Any]:
        # Get the cache name
        for possible_cache_name in ALL_CACHE_NAMES:
            if possible_cache_name in outputs:
                model_kwargs[possible_cache_name] = getattr(
                    outputs, possible_cache_name
                )

        # update attention mask
        if "attention_mask" in model_kwargs:
            attention_mask = model_kwargs["attention_mask"]
            model_kwargs["attention_mask"] = torch.cat(
                [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))],
                dim=-1,
            )

        # update position ids
        if "position_ids" in model_kwargs:
            position_ids = model_kwargs["position_ids"]
            new_position_id = position_ids[..., -1:].clone()
            new_position_id += 1
            model_kwargs["position_ids"] = torch.cat(
                [position_ids, new_position_id], dim=-1
            )

        # update valid_pos
        if "valid_pos" in model_kwargs:
            valid_pos = model_kwargs["valid_pos"]
            current_valid_pos = [
                torch.cat([x, x.new_ones((x.shape[0], 1), dtype=torch.bool)], dim=-1)
                for x in valid_pos
            ]
            model_kwargs["valid_pos"] = current_valid_pos

        # generation modality flag
        if "type_ids" in model_kwargs:
            # update for motional tokens
            type_ids = model_kwargs["type_ids"]
            current_type_ids = torch.cat(
                [
                    self.type_ids
                    if type_ids.shape[1]
                    < self.type_ids.shape[1]  # keep longer sequence
                    else type_ids,
                    torch.zeros(
                        (type_ids.shape[0], 1), dtype=torch.long, device=type_ids.device
                    ),
                ],
                dim=-1,
            )

            model_kwargs["type_ids"] = current_type_ids
            self.type_ids = current_type_ids

        model_kwargs["is_first_forward"] = False
        return model_kwargs

    def get_position_ids(self, input_ids: torch.Tensor, device=None, pad_token_id=None):
        # (batch_size, seq_length)
        valid_mask = input_ids != IGNORE_INDEX
        position_ids = valid_mask.cumsum(dim=1) - 1
        # Mask out the padding positions
        position_ids = torch.where(
            valid_mask,
            position_ids,
            torch.full_like(input_ids, IGNORE_INDEX, dtype=torch.float),
        )
        return position_ids

    def prepare_inputs_for_generation(
        self,
        input_ids: torch.Tensor,
        mot_input_ids: List[torch.Tensor],
        mot_input_embs: Optional[List[torch.Tensor]] = None,
        past_key_values: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        use_cache: Optional[bool] = None,
        is_first_forward: bool = True,
        **kwargs,
    ) -> dict:
        valid_pos = None
        type_ids = None
        is_training = False
        is_diffusion = False

        # print("before input_ids", input_ids[:, -1])

        if not is_first_forward:
            valid_pos = kwargs["valid_pos"]
            type_ids = kwargs["type_ids"]

            self.input_ids = torch.cat([self.input_ids, input_ids[:, -1:]], dim=-1)
            (
                self.input_ids,
                mot_input_ids,
                mot_input_embs,
                position_ids,
                valid_pos,
                type_ids,
            ) = self.router(
                self.input_ids,
                self.mot_input_ids,
                self.mot_input_embs,
                position_ids,
                valid_pos,
                self.type_ids,
            )

            # if torch.any(valid_pos[0] == valid_pos[1]):
            #     raise ValueError(
            #         "valid_pos[0] and valid_pos[1] should not be the same."
            #     )

            # diffusion part
            if type_ids[0, -1] > 0.5:
                # chunk size
                # TODO audio_position_count[0] only for batch_size = 0
                window_size = (
                    self.router.audio_position_count[0]
                    * self.router.rate_audio_to_motion
                )

                is_diffusion = True
                (
                    self.input_ids,
                    mot_input_ids,
                    mot_input_embs,
                    position_ids,
                    valid_pos,
                    type_ids,
                ) = self.df_router(
                    window_size,
                    # self.mot_window[1],  # 52 for motion
                    self.input_ids,
                    mot_input_ids,
                    mot_input_embs,
                    position_ids,
                    valid_pos,
                    type_ids,
                )
                self.router.fix_for_motion_chunk(window_size)
                input_ids = self.input_ids

            # llm part
            else:
                if past_key_values is not None:
                    input_ids = self.input_ids[:, -1:]
                    mot_input_ids = [x[:, -1:] for x in mot_input_ids]
                    mot_input_embs = [x[:, -1:, ...] for x in mot_input_embs]
                    type_ids = type_ids[:, -1:]

        else:
            # only last token for input_ids if past is not None
            if position_ids is None:
                position_ids = self.get_position_ids(input_ids, device=input_ids.device)
                position_ids = interpolate_position_ids(
                    position_ids, self.mot_window, IGNORE_INDEX, input_ids.device
                )
            # init mot_input_embs
            if mot_input_embs is None:
                mot_input_embs = [
                    torch.zeros(
                        input_ids.shape[0],
                        input_ids.shape[1],
                        self.dforcing_config.x_shape[0],
                    ).to(input_ids.device)
                ]

            if type_ids is None:
                type_ids = torch.zeros(
                    input_ids.shape,
                    device=input_ids.device,
                    dtype=torch.long,
                )

            input_ids, position_ids, type_ids = self.router.process_input_ids(
                input_ids, position_ids, type_ids
            )
            self.input_ids = input_ids

        # print("aftere input_ids", self.input_ids[:, -1])

        return {
            "input_ids": input_ids,
            "mot_input_ids": mot_input_ids,
            "mot_input_embs": mot_input_embs,
            "past_key_values": past_key_values,
            "position_ids": position_ids,
            "valid_pos": valid_pos,
            "type_ids": type_ids,
            "attention_mask": attention_mask,
            "return_last_logit": True,
            "use_cache": use_cache,
            "is_training": is_training,
            "is_diffusion": is_diffusion,
        }

    def forward(
        self,
        input_ids: torch.Tensor,
        mot_input_ids: List[torch.Tensor],
        mot_input_embs: Optional[List[torch.Tensor]],
        position_ids: Optional[torch.Tensor] = None,
        valid_pos: Optional[List[torch.Tensor]] = None,
        attention_mask: Optional[torch.Tensor] = None,
        type_ids: Optional[torch.Tensor] = None,
        past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        return_last_logit: Optional[bool] = False,
        is_training: Optional[bool] = True,
        is_diffusion: Optional[bool] = False,
    ):  # type: ignore
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = (
            return_dict if return_dict is not None else self.config.use_return_dict
        )

        # For forward
        if position_ids is None:
            position_ids = self.get_position_ids(input_ids, device=input_ids.device)
            position_ids = interpolate_position_ids(
                position_ids, self.mot_window, IGNORE_INDEX, input_ids.device
            )

        if valid_pos is None:
            valid_pos = [torch.ne(input_ids, IGNORE_INDEX)]
            valid_pos.extend(
                [torch.ne(input_id, IGNORE_INDEX) for input_id in mot_input_ids]
            )

        if type_ids is None:
            type_ids = torch.where(
                valid_pos[0],
                torch.zeros_like(valid_pos[0], dtype=torch.long),
                torch.ones_like(valid_pos[0], dtype=torch.long),
            )
            type_ids = torch.cat(
                [type_ids[..., 1:], torch.zeros_like(type_ids[..., :1])], dim=-1
            )
        if self.input_ids is None:
            self.input_ids = input_ids

        if self.type_ids is None:
            self.type_ids = type_ids

        if self.attention_mask is None:
            self.attention_mask = attention_mask

        if self.mot_input_ids is None:
            self.mot_input_ids = mot_input_ids

        if self.mot_input_embs is None:
            self.mot_input_embs = mot_input_embs

        ###################################################################################
        if is_training:
            ###################################################################################
            # mot_input_embs[0] for face part
            masks = mot_input_ids[0] != -100
            noise_levels, masks = self.diffusion_model._get_training_noise_levels(
                self.dforcing_config, mot_input_embs[0], masks
            )
            external_cond = None
            training_inputs = {
                "input_ids": input_ids,
                "mot_input_ids": mot_input_ids,
                "mot_input_embs": mot_input_embs,
                "position_ids": position_ids,
                "valid_pos": valid_pos,
                "attention_mask": attention_mask,
                "type_ids": type_ids,
                # "past_key_values": past_key_values,
                "inputs_embeds": inputs_embeds,
                "use_cache": use_cache,
                "output_hidden_states": output_hidden_states,
                "return_dict": return_dict,
            }
            mot_output_embs, loss, transformer_outputs = self.diffusion_model(
                x=mot_input_embs[0],
                k=noise_levels,
                external_cond=external_cond,
                training_inputs=training_inputs,
            )
            mot_output_embs = [mot_output_embs]
            loss = self.diffusion_model._reweight_loss(loss, masks.float())

            # TODO loss for text/audio ids, glm_voice SFT training
            lm_logits = None
            hidden_states = transformer_outputs.hidden_states

        ####################################################################################
        # testing
        else:
            ####################################################################################
            # diffusion part
            # 13+16 => start 52 diffusion
            # is motion part
            if is_diffusion:
                # build sample_inputs with all params
                conditions = None
                sample_inputs = {
                    # "do_sample": True,
                    "input_ids": input_ids,
                    "mot_input_ids": mot_input_ids,
                    # "mot_input_embs": mot_input_embs,
                    "position_ids": position_ids,
                    "valid_pos": valid_pos,
                    "attention_mask": attention_mask,
                    "type_ids": type_ids,
                    # "past_key_values": past_key_values,
                    # no past_key_values  no past_key_values  no past_key_values  no past_key_values
                    "inputs_embeds": inputs_embeds,
                    "use_cache": False,
                    "output_hidden_states": output_hidden_states,
                    "return_dict": return_dict,
                }

                # predict 0~52 in one chunk
                # chunk 0 [text/audio padding]+[motion chunk]
                # chunk 1 [text/audio padding]+[motion chunk]+[text/audio padding]+[motion chunk]
                # TODO batch size 0 only
                chunk_size = self.router.motion_count[0]

                mot_input_embs_pred, transformer_outputs = (
                    self.diffusion_model._predict_sequence_step(
                        xs=mot_input_embs[0],
                        conditions=conditions,
                        sample_inputs=sample_inputs,
                        chunk_size=chunk_size,
                    )
                )

                mot_input_embs_pred = [mot_input_embs_pred]

                # cat results only in last chunk
                self.mot_input_ids = [
                    torch.cat([x, y[:, -chunk_size:]], dim=-1)
                    for x, y in zip(self.mot_input_ids, mot_input_ids)
                ]
                self.mot_input_embs = [
                    torch.cat([x, y[:, -chunk_size:]], dim=-2)
                    for x, y in zip(self.mot_input_embs, mot_input_embs_pred)
                ]
                # self.mot_input_ids = mot_input_ids
                # self.mot_input_embs = mot_input_embs_pred

                # fake lm_logits with chunk size
                hidden_states = transformer_outputs[0]
                self.transformer.output_layer = self.transformer.output_layer.to(
                    hidden_states[0].device
                )
                lm_logits = self.transformer.output_layer(hidden_states[0][:, -1:, :])

                # prepare for AR step
                self.router.switch_to_text(0, self.input_ids, position_ids, valid_pos)

                # evaluate first chunk only
                if self.eval_first_chunk:
                    self.router.end_flags[0] = True

                # for kv cache
                input_ids = self.input_ids[:, -(chunk_size + 1) :]
                mot_input_ids = [x[:, -(chunk_size + 1) :] for x in self.mot_input_ids]
                mot_input_embs = [
                    x[:, -(chunk_size + 1) :, :] for x in self.mot_input_embs
                ]
                # type_ids = type_ids[:, -53:]

            ####################################################################################
            # auto regressive part
            # call prepare_inputs_for_generation function

            # fake noise level 0 for LLM part inference
            max_noise_level = self.dforcing_config.diffusion.timesteps - 1
            noise_levels = [
                torch.ones_like(x[:, :, 0]) * max_noise_level for x in mot_input_embs
            ]

            transformer_outputs = self.transformer(
                input_ids=input_ids,
                mot_input_ids=[mot_input_id.clone() for mot_input_id in mot_input_ids],
                mot_input_embs=mot_input_embs,
                noise_levels=noise_levels,
                position_ids=position_ids,
                valid_pos=valid_pos,
                attention_mask=attention_mask,
                type_ids=type_ids,
                past_key_values=past_key_values,
                inputs_embeds=inputs_embeds,
                use_cache=use_cache,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )

            hidden_states = transformer_outputs[0]
            if return_last_logit:
                # hidden_states[0] = hidden_states[0][:, -1:]
                hidden_states = [x[:, -1:] for x in hidden_states]

            lm_logits = self.transformer.output_layer(hidden_states[0])

            # cat mot_input_embs
            if self.config.output_conv:
                for i, x in enumerate(self.transformer.mot_emb_output_conv):
                    hidden_states[i + 1] = x(
                        hidden_states[i + 1].permute(0, 2, 1)
                    ).permute(0, 2, 1)
            mot_embs_noise_pred = [
                x(hidden_states[i + 1])
                # TODO ablation remove motion layer
                # for directly diffusion forcing Inference
                for i, x in enumerate(self.transformer.mot_emb_output_layer)
            ]
            mot_embs_noise_pred = [x.to(torch.float32) for x in mot_embs_noise_pred]
            # No need to recover x_pred from noise_pred
            # Not used in generation. Will be masked in generation
            mot_input_embs_pred = [torch.zeros_like(x) for x in mot_embs_noise_pred]
            mot_ids = [torch.full_like(lm_logits[:, :, 0], -100, dtype=torch.long)]

            if is_diffusion:
                # have already cat 52 tokens, add one more padding for ar part
                mot_ids = [x[:, -1:] for x in mot_ids]
                mot_embs_noise_pred = [x[:, -1:, ...] for x in mot_input_embs_pred]

            # prepare inputs for next step
            self.mot_input_ids = [
                torch.cat([x, y], dim=-1) for x, y in zip(self.mot_input_ids, mot_ids)
            ]
            self.mot_input_embs = [
                torch.cat([x, y], dim=-2)
                for x, y in zip(self.mot_input_embs, mot_input_embs_pred)
            ]

            # Sample for text and audio ids
            loss = None
            mot_loss = None

            if not return_dict:
                raise ValueError(
                    "Diffusion forcing, not implemented for generation only"
                )
                output = (lm_logits,) + transformer_outputs[1:] + mot_loss
                return ((loss,) + output) if loss is not None else output
        ####################################################################################

        return MOTModelOutputWithPast(
            loss=loss,  # type: ignore
            logits=lm_logits,
            mot_input_ids=mot_input_ids,
            mot_input_embs=mot_input_embs,
            # past_key_values=past_key_values,
            past_key_values=transformer_outputs.past_key_values,
            hidden_states=transformer_outputs.hidden_states,
            attentions=transformer_outputs.attentions,
            valid_pos=transformer_outputs.valid_pos,
            position_ids=position_ids,
            type_ids=type_ids,
        )

    def generate(self, *args, **kwargs):
        # Get the input ids
        if len(args) > 0:
            input_ids = args[0]
        elif "input_ids" in kwargs:
            input_ids = kwargs["input_ids"]
        else:
            raise ValueError(
                "You have to specify either `input_ids` or pass them as the first argument."
            )

        input_ids[input_ids == kwargs["pad_token_id"]] = IGNORE_INDEX

        # Set the router and stopping criteria
        self.router.reset(len(input_ids))
        position_ids = self.get_position_ids(input_ids, device=input_ids.device).float()
        kwargs["position_ids"] = position_ids
        kwargs["eos_token_id"] = (
            None  # stop using original eos_token_id for stopping criteria
        )
        kwargs["stopping_criteria"] = StoppingCriteriaList(
            [
                MOTStoppingCriteria(
                    kwargs["max_new_tokens"], self.eos_token_id, self.router
                )
            ]
        )

        # for evaluating training
        self.gt_mot_input_embs = None
        if "labels" in kwargs:
            self.router.set_generation_label(labels=kwargs["labels"])
            if "motion_embs" in kwargs:
                motion_embs = kwargs["motion_embs"]
                self.gt_mot_input_embs = motion_embs
                # Remove the "motion_embs" key from the kwargs dictionary, as it's no longer needed.
                kwargs.pop("motion_embs")

        # for evaluating only the first chunk, faster evaluation
        self.eval_first_chunk = False
        if "eval_first_chunk" in kwargs:
            self.eval_first_chunk = kwargs["eval_first_chunk"]
            kwargs.pop("eval_first_chunk")

        # Initialize instance variables to None. These variables will be used to store various tensors during generation.
        self.type_ids = None
        self.attention_mask = None
        self.input_ids = None
        self.mot_input_ids = None
        self.mot_input_embs = None

        # Initialize the diffusion model for video prediction during generation.
        self.diffusion_model.init_predict_videos()

        # Check if the chunk size in the sampling configuration matches the second element of the mot_window.
        # This ensures that the denoising process uses the correct chunk size.
        assert (
            self.dforcing_config.sampling.chunk_size == self.mot_window[1]
        )  # denoise with chunk_size

        # Call the generate method of the superclass with the provided arguments and keyword arguments.
        # This will perform the actual generation process.
        output = super().generate(*args, **kwargs)

        # restore flag for next chunk prepreation
        assert input_ids.shape[0] == 1, "Only support batch size 1"
        # recover last token, as it is a padding token for next chunk
        self.input_ids[:, -1] = IGNORE_INDEX

        # remove last motion padding
        self.mot_input_ids = [x[:, :-1] for x in self.mot_input_ids]
        self.mot_input_embs = [x[:, :-1, ...] for x in self.mot_input_embs]
        # if self.gt_mot_input_embs is not None:
        #     bs = 0
        #     # -100 or 151357
        #     # valid_indices = (self.input_ids[bs] == -100) | (self.input_ids[bs] == 151357)
        #     valid_indices = self.input_ids[bs] == -100
        #     # assert mask[i].sum() == b.shape[1], f"Mask true count does not match b.shape[1] at batch {i}"
        #     self.mot_input_embs[0][bs][valid_indices] = self.gt_mot_input_embs[
        #         bs, : valid_indices.sum()
        #     ]

        return MOTGenerateDecoderOnlyOutput(
            sequences=self.input_ids,
            scores=output.scores,
            attentions=output.attentions,
            past_key_values=output.past_key_values,
            hidden_states=output.hidden_states,
            mot_output_ids=self.mot_input_ids,
            mot_output_embs=self.mot_input_embs,
        )


class MOTStoppingCriteria(StoppingCriteria):
    """
    Generation stopped when all samples of mot tokens are sampled.
    """

    def __init__(self, max_new_token, eos_token_id, router):
        self.max_new_token = max_new_token
        self.eos_token_id = eos_token_id
        self.router = router

    def __call__(self, input_ids, scores, **kwargs) -> torch.BoolTensor:
        # Check if all sample of the generated tokens are in the eos token list
        all_motion_eos = all(end for end in self.router.end_flags)

        # Check if router.audio_count list is all 0
        all_audio_count_zero = all(count == 0 for count in self.router.audio_count)
        should_stop = all_motion_eos & all_audio_count_zero

        print("end flag", self.router.end_flags)
        print("audio co", self.router.audio_count)
        return torch.tensor(should_stop, dtype=torch.bool)  # type: ignore


def apply_chat_template(
    instructions,
    targets=None,
    audio_input=False,
    behaviour_output=False,
    generation_prefix=False,
) -> list[str]:
    """
    Applies a chat template to the given instructions and targets.
    """
    labels = []
    if targets is None:
        targets = [""] * len(instructions)

    system_prompt_base = "User will provide you with a {} instruction. Do it step by step. First, think about the instruction and respond in a interleaved manner, with 13 text token followed by 26 audio tokens."

    # Batch processing
    for user_input, target_output in zip(instructions, targets):
        label = ""
        # System prompt based on audio input
        system_prompt = system_prompt_base.format("speech" if audio_input else "text")

        # Apply chat template
        if "<|system|>" not in user_input:
            label += f"<|system|>\n{system_prompt}\n"
        label += f"<|user|>\n{user_input}<|assistant|>streaming_transcription\n"
        # if behaviour_output and self.behaviour_special_token:
        #     label += self.m_special
        label += target_output
        labels.append(label)

    return labels


if __name__ == "__main__":
    model_path = "/mnt/deps/glm-4-voice-9b/"
    tokenizer = AutoTokenizer.from_pretrained(
        model_path, legacy=True, trust_remote_code=True
    )
    tokenizer.padding_side = "left"

    generation_config = dict(
        # max_new_tokens=128,
        # do_sample=False,
        # # top_k=1,
        # # top_p=0.0001,
        # num_return_sequences=1,
        # use_cache=False,
        # # no_repeat_ngram_size=4
        pad_token_id=tokenizer.pad_token_id,
        do_sample=False,
        max_new_tokens=456,
        output_hidden_states=True,
        return_dict_in_generate=True,
        # use_cache=False,
    )

    # for name, p in model2.named_parameters():
    #     if "mot" not in name:
    #         p.requires_grad = False

    inputs = tokenizer(
        apply_chat_template(
            # ["What is love, baby don't hurt me, ", "能给我介绍下上海嘛"]
            ["What is love, baby don't hurt me, "]
        ),
        return_tensors="pt",
        padding=True,
    )
    tokenized_attention_mask = inputs.attention_mask
    input_ids = inputs["input_ids"].to("cuda")
    attention_mask = tokenized_attention_mask.to("cuda")

    torch.manual_seed(0)
    model = (
        AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
        .eval()
        .to("cuda")
    )
    outputs = model.generate(input_ids.clone(), **generation_config)
    print(outputs.sequences)
    print(
        tokenizer.batch_decode(
            outputs.sequences,
            spaces_between_special_tokens=False,
            skip_special_tokens=True,
        )
    )

    insert_pos = 3
    bs = input_ids.shape[0]
    mot_input_ids = [
        torch.ones_like(input_ids, dtype=torch.long).to("cuda") * IGNORE_INDEX
    ]
    mot_input_embs = [torch.zeros_like(input_ids, dtype=torch.float16).to("cuda")]
    mask_size = 5
    # input_ids = torch.cat(
    #     [
    #         input_ids[:, :insert_pos],
    #         torch.ones((bs, mask_size), dtype=input_ids.dtype, device=input_ids.device)
    #         * IGNORE_INDEX,
    #         input_ids[:, insert_pos:],
    #     ],
    #     dim=1,
    # )
    # mot_input_ids = [
    #     torch.cat(
    #         [
    #             x[:, :insert_pos],
    #             torch.ones((bs, mask_size), dtype=x.dtype, device=x.device),
    #             x[:, insert_pos:],
    #         ],
    #         dim=1,
    #     )
    #     for x in mot_input_ids
    # ]

    model_config2 = ChatGLMConfig.from_pretrained(model_path)
    with autocast(device_type="cuda", enabled=True, dtype=torch.bfloat16):
        model2 = (
            MOTChatGLMForConditionalGeneration.from_pretrained(
                model_path, config=model_config2
            )
            .eval()
            .to("cuda")
        )
        torch.manual_seed(0)
        outputs2 = model2.generate(
            input_ids,
            mot_input_ids=mot_input_ids,
            mot_input_embs=mot_input_embs,
            **generation_config,
        )

    print(outputs2.sequences)
    # print(outputs2.mot_input_ids)
    print(
        tokenizer.batch_decode(
            outputs2.sequences,
            spaces_between_special_tokens=False,
            skip_special_tokens=True,
        )
    )
    # print(
    #     tokenizer.batch_decode(
    #         outputs.sequences,
    #         attention_mask=attention_mask,
    #         spaces_between_special_tokens=False,
    #         skip_special_tokens=True,
    #     )
    # )
