# -*- coding: utf-8 -*-

from transformers import AutoConfig, AutoModel, AutoModelForCausalLM

from fla.models.bma_heads.configuration_bma_heads import BMAHeadsConfig
from fla.models.bma_heads.modeling_bma_heads import BMAHeadsBlock, BMAHeadsForCausalLM, BMAHeadsModel

AutoConfig.register(BMAHeadsConfig.model_type, BMAHeadsConfig, exist_ok=True)
AutoModel.register(BMAHeadsConfig, BMAHeadsModel, exist_ok=True)
AutoModelForCausalLM.register(BMAHeadsConfig, BMAHeadsForCausalLM, exist_ok=True)


__all__ = ['BMAHeadsConfig', 'BMAHeadsForCausalLM', 'BMAHeadsModel', 'BMAHeadsBlock']
