import torchvision

def get_model(model_name, weights):
    model = None
    if model_name == 'vit_b_16':
        model = torchvision.models.vit_b_16(weights=weights)
    else:
        print('Model not registered!')
    return model