
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM

from fla.models.gated_deltastack.configuration_gated_deltastack import GatedDeltaStackConfig
from fla.models.gated_deltastack.modeling_gated_deltastack import GatedDeltaStackForCausalLM, GatedDeltaStackModel

AutoConfig.register(GatedDeltaStackConfig.model_type, GatedDeltaStackConfig, exist_ok=True)
AutoModel.register(GatedDeltaStackConfig, GatedDeltaStackModel, exist_ok=True)
AutoModelForCausalLM.register(GatedDeltaStackConfig, GatedDeltaStackForCausalLM, exist_ok=True)

__all__ = ['GatedDeltaStackConfig', 'GatedDeltaStackForCausalLM', 'GatedDeltaStackModel']
