from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_rope_utils import rope_config_validation
from transformers.utils import logging
from transformers.models.qwen3.configuration_qwen3 import Qwen3Config


logger = logging.get_logger(__name__)


class RNSAQwen3Config(Qwen3Config):
    def __init__(
        self,
        forget_gate_bias_init=10.0,
        forget_weight=1.0,
        memory_size=1024,
        forget_gate='fg4',
        base_loss='ntp',
        attn_impl='rnsa_flex',
        compress_memory=False,
        compress_strategy='lw_knorm_alpha',
        buffer_size=128,
        trainable_params=None,
        max_seq_len=131072,
        forget_gate_intermediate_size=128,
        logit_block_size=-1,
        **kwargs,
    ):
        self.forget_gate_bias_init = forget_gate_bias_init
        self.forget_weight = forget_weight
        self.forget_gate_intermediate_size = forget_gate_intermediate_size
        self.memory_size = memory_size
        self.forget_gate = forget_gate
        self.attn_impl = attn_impl
        self.base_loss = base_loss
        self.trainable_params = trainable_params
        self.compress_memory = compress_memory
        self.compress_strategy = compress_strategy
        self.buffer_size = buffer_size # run compression every `buffer_size` tokens
        self.max_seq_len = max_seq_len
        self.logit_block_size = logit_block_size
        super().__init__(
            **kwargs,
        )


__all__ = ["RNSAQwen3Config"]
