# coding=utf-8
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_rope_utils import rope_config_validation
from transformers.utils import logging
from transformers import LlamaConfig


logger = logging.get_logger(__name__)


class RNSALlamaConfig(LlamaConfig):
    def __init__(
        self,
        forget_gate_bias_init=10.0,
        forget_weight=1.0,
        memory_size=1024,
        forget_gate='fg4',
        base_loss='ntp',
        attn_impl='rnsa_flex',
        compress_memory=False,
        compress_strategy='lw_knorm_alpha',
        buffer_size=1,
        trainable_params=None,
        max_seq_len=20480,
        forget_gate_intermediate_size=512,
        fg_dropout=0.1,
        skip_layers=0,
        **kwargs,
    ):
        self.forget_gate_bias_init = forget_gate_bias_init
        self.forget_weight = forget_weight
        self.forget_gate_intermediate_size = forget_gate_intermediate_size
        self.memory_size = memory_size
        self.forget_gate = forget_gate
        self.attn_impl = attn_impl
        self.base_loss = base_loss
        self.trainable_params = trainable_params
        self.compress_memory = compress_memory
        self.compress_strategy = compress_strategy
        self.buffer_size = buffer_size # run compression every `buffer_size` tokens
        self.max_seq_len = max_seq_len
        self.skip_layers = skip_layers
        self.fg_dropout = fg_dropout

        super().__init__(
            **kwargs,
        )


__all__ = ["RNSALlamaConfig"]
