from transformers import PretrainedConfig


class ChatGLMConfig(PretrainedConfig):
    model_type = "chatglm"

    def __init__(
        self,
        num_layers=28,
        padded_vocab_size=65024,
        hidden_size=4096,
        ffn_hidden_size=13696,
        mot_dim=[512],
        mot_hidden_size=[512],
        mot_ffn_hidden_size=[6848],
        mot_vocab_size=[4096],
        # mot_vocab_size=[2048],
        mot_window=[26, 52],
        mot_eos_token_id=[152352],
        mot_loss_weight=[0.0, 1.0],
        mot_cross_attn_mode="all",
        audio_offset=152353,
        kv_channels=128,
        num_attention_heads=32,
        seq_length=2048,
        hidden_dropout=0.0,
        classifier_dropout=None,
        attention_dropout=0.0,
        layernorm_epsilon=1e-5,
        output_conv=False,
        rmsnorm=True,
        apply_residual_connection_post_layernorm=False,
        post_layer_norm=True,
        add_bias_linear=False,
        add_qkv_bias=False,
        bias_dropout_fusion=True,
        multi_query_attention=False,
        multi_query_group_num=1,
        rope_ratio=1,
        apply_query_key_layer_scaling=True,
        attention_softmax_in_fp32=True,
        fp32_residual_connection=False,
        config_dforcing=None,
        **kwargs,
    ):
        self.num_layers = num_layers
        self.vocab_size = padded_vocab_size
        self.padded_vocab_size = padded_vocab_size
        self.hidden_size = hidden_size
        self.ffn_hidden_size = ffn_hidden_size
        self.kv_channels = kv_channels
        self.num_attention_heads = num_attention_heads
        self.mot_cross_attn_mode = mot_cross_attn_mode
        self.seq_length = seq_length
        self.mot_dim = mot_dim
        self.mot_hidden_size = mot_hidden_size
        self.mot_ffn_hidden_size = mot_ffn_hidden_size
        self.mot_vocab_size = mot_vocab_size
        self.mot_window = mot_window
        self.mot_loss_weight = mot_loss_weight
        self.mot_eos_token_id = mot_eos_token_id
        self.audio_offset = audio_offset
        self.hidden_dropout = hidden_dropout
        self.classifier_dropout = classifier_dropout
        self.attention_dropout = attention_dropout
        self.layernorm_epsilon = layernorm_epsilon
        self.rmsnorm = rmsnorm
        self.output_conv = output_conv
        self.apply_residual_connection_post_layernorm = (
            apply_residual_connection_post_layernorm
        )
        self.post_layer_norm = post_layer_norm
        self.add_bias_linear = add_bias_linear
        self.add_qkv_bias = add_qkv_bias
        self.bias_dropout_fusion = bias_dropout_fusion
        self.multi_query_attention = multi_query_attention
        self.multi_query_group_num = multi_query_group_num
        self.rope_ratio = rope_ratio
        self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
        self.attention_softmax_in_fp32 = attention_softmax_in_fp32
        self.fp32_residual_connection = fp32_residual_connection
        self.config_dforcing = config_dforcing
        super().__init__(**kwargs)
