from models.resnet import ResNet18
from functools import partial


model_factories = {
    'resnet18': ResNet18,
    'resnet18nobn': partial(ResNet18, use_batchnorm=False)
}

def get_available_models():
    return model_factories.keys()


def get_model(name, *args, **kwargs):
    return model_factories[name](*args, **kwargs)