Source code for archai.networks

import torch

from torch import nn
from torch.nn import DataParallel
# from torchvision import models

from .resnet import ResNet
from .pyramidnet import PyramidNet
from .shakeshake.shake_resnet import ShakeResNet
from .wideresnet import WideResNet
from .shakeshake.shake_resnext import ShakeResNeXt


[docs]def get_model(conf, num_class=10): name = conf['type'] if name == 'resnet50': model = ResNet(dataset='imagenet', depth=50, n_classes=num_class, bottleneck=True) elif name == 'resnet200': model = ResNet(dataset='imagenet', depth=200, n_classes=num_class, bottleneck=True) elif name == 'wresnet40_2': model = WideResNet(40, 2, dropout_rate=0.0, n_classes=num_class) elif name == 'wresnet28_10': model = WideResNet(28, 10, dropout_rate=0.0, n_classes=num_class) elif name == 'shakeshake26_2x32d': model = ShakeResNet(26, 32, num_class) elif name == 'shakeshake26_2x64d': model = ShakeResNet(26, 64, num_class) elif name == 'shakeshake26_2x96d': model = ShakeResNet(26, 96, num_class) elif name == 'shakeshake26_2x112d': model = ShakeResNet(26, 112, num_class) elif name == 'shakeshake26_2x96d_next': model = ShakeResNeXt(26, 96, 4, num_class) elif name == 'pyramid': model = PyramidNet('cifar10', depth=conf['depth'], alpha=conf['alpha'], n_classes=num_class, bottleneck=conf['bottleneck']) else: raise NameError('no model named, %s' % name)
[docs]def num_class(dataset): return { 'cifar10': 10, 'reduced_cifar10': 10, 'cifar10.1': 10, 'cifar100': 100, 'svhn': 10, 'reduced_svhn': 10, 'imagenet': 1000, 'reduced_imagenet': 120, }[dataset]