
from .VQVAE import vqvae, vqvae2, vqvae3, vqvae4

model_dict = {
    'vqvae': vqvae,
    'vqvae2': vqvae2,
    'vqvae3': vqvae3,
    'vqvae4': vqvae4
}

def get_model_class(model_name):
    if model_name in model_dict:
        return model_dict[model_name]
    else:
        raise NotImplementedError