import classifier
from config import opt

classifier_dict = {
    'alexnet': classifier.AlexNet,
    'resnet18': classifier.ResNet18,
    'resnet34': classifier.ResNet34,
    'resnet50': classifier.ResNet50,
    'inceptionv3': classifier.InceptionV3,
    'vgg': classifier.VGG,
    # 'newresnet34': classifier.NewResNet34,
    # 'newresnet50': classifier.NewResNet50,
    # 'disresnet50': classifier.DisResNet50,
    'densenet': classifier.DenseNet,
    'pyramidnet': classifier.PyramidNet,
    'resnext': classifier.ResNeXt,
    'wrn': classifier.WideResNet,
    'eweresnet50': classifier.EWEResNet50,
    'plainresnet50': classifier.PlainResNet50,
    'mobilenetv1': classifier.MobileNetV1,
    'mobilenetv2': classifier.MobileNetV2,
}

# load old model weights into new model (only for layers with common name)
def load_model_weights(old_model, new_model):
    if opt.use_gpu:
        new_model.cuda()
    pretrained_dict = old_model.state_dict()
    substitute_dict = new_model.state_dict()
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in substitute_dict}
    substitute_dict.update(pretrained_dict)
    new_model.load_state_dict(substitute_dict)
    return new_model