from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_rope_utils import rope_config_validation
from transformers.utils import logging
from transformers.models.phi3 import Phi3Config


logger = logging.get_logger(__name__)


class RNSAPhi3Config(Phi3Config):
    def __init__(
        self,
        forget_gate_bias_init=10.0,
        forget_weight=1.0,
        memory_size=1024,
        forget_gate='fg4',
        base_loss='ntp',
        attn_impl='rnsa_flex',
        compress_memory=False,
        compress_strategy='lw_knorm_alpha',
        buffer_size=1,
        trainable_params=None,
        max_seq_len=20480,
        forget_gate_intermediate_size=128,
        **kwargs,
    ):
        self.forget_gate_bias_init = forget_gate_bias_init
        self.forget_weight = forget_weight
        self.forget_gate_intermediate_size = forget_gate_intermediate_size
        self.memory_size = memory_size
        self.forget_gate = forget_gate
        self.attn_impl = attn_impl
        self.base_loss = base_loss
        self.trainable_params = trainable_params
        self.compress_memory = compress_memory
        self.compress_strategy = compress_strategy
        self.buffer_size = buffer_size # run compression every `buffer_size` tokens
        self.max_seq_len = max_seq_len
        super().__init__(
            **kwargs,
        )


__all__ = ["RNSAPhi3Config"]
