# -*- coding: utf-8 -*-

from transformers import AutoConfig, AutoModel, AutoModelForCausalLM

from nsa_lib.configuration_nsa import NSAConfig
from nsa_lib.modeling_nsa import NSAForCausalLM, NSAModel
from nsa_lib.ops.parallel import parallel_nsa

AutoConfig.register(NSAConfig.model_type, NSAConfig)
AutoModel.register(NSAConfig, NSAModel)
AutoModelForCausalLM.register(NSAConfig, NSAForCausalLM)


__all__ = [
    'NSAConfig', 'NSAModel', 'NSAForCausalLM',
    'parallel_nsa',
]


__version__ = '0.1'
