from .base_vae import BaseVAEModel


def register_vaes():
    if BaseVAEModel.already_registered():
        return

    ##### Default Models #####
    from .daps import DAPSModel
    from .gumbel import GRMCKModel, GumbelModel
    from .ppo import PPOModel
    from .vae import VAEModel
    from .vqvae import FSQModel, VQVAEModel

    DAPSModel.register()
    VQVAEModel.register()
    VAEModel.register()
    GumbelModel.register()
    GRMCKModel.register()
    PPOModel.register()
    FSQModel.register()


def create_vae(cfg):
    vae_cls = BaseVAEModel.registered[cfg.model.name]
    return vae_cls.create(cfg)
