from .vgg import _vgg
from .resnet import _resnet, Bottleneck, BasicBlock
from torchvision.models import ResNet50_Weights, ResNet18_Weights, VGG16_Weights, VGG16_BN_Weights



def get_model(name):
    if name in ['vgg16_bn', 'vgg']:
        return _vgg('D', True, VGG16_BN_Weights.DEFAULT)
    elif name in ['vgg16']:
        return _vgg('D', False, VGG16_Weights.DEFAULT)
    elif name in ['resnet', 'resnet50']:
        return _resnet(Bottleneck, [3, 4, 6, 3], ResNet50_Weights.DEFAULT)
    elif name in ['resnet18']:
        return _resnet(BasicBlock, [2, 2, 2, 2], ResNet18_Weights.DEFAULT)
    else:
        raise ValueError('Unknown model: {}'.format(name))
