# -*- coding: utf-8 -*-

from transformers import AutoConfig, AutoModel, AutoModelForCausalLM

from fla.models.deltatransformer.configuration_transformer import TransformerConfig as DeltaTransformerConfig
from fla.models.deltatransformer.modeling_transformer import TransformerForCausalLM as DeltaTransformerForCausalLM
from fla.models.deltatransformer.modeling_transformer import TransformerModel as DeltaTransformerModel


AutoConfig.register(DeltaTransformerConfig.model_type, DeltaTransformerConfig)
AutoModel.register(DeltaTransformerConfig, DeltaTransformerModel)
AutoModelForCausalLM.register(DeltaTransformerConfig, DeltaTransformerForCausalLM)


__all__ = ['DeltaTransformerConfig', 'DeltaTransformerForCausalLM', 'DeltaTransformerModel']
