from . import resnet
from . import wideresnet

def build_model(model_name, num_classes=10, n_channel=3, use_mps=True, device='mps', dataset='afhq'):
    if model_name in ['wideresnet-28-10', 'wrn-28-10']:
        model = wideresnet.WideResNet(28, 10, 0, num_classes)

    elif model_name in ['wideresnet-40-2', 'wrn-40-2']:
        model = wideresnet.WideResNet(40, 2, 0, num_classes)

    # elif model_name in ['shakeshake26_2x32d', 'ss32']:
    #     model = shake_resnet.ShakeResNet(26, 32, num_classes)

    # elif model_name in ['shakeshake26_2x96d', 'ss96']:
    #     model = shake_resnet.ShakeResNet(26, 96, num_classes)

    # elif model_name in ['shakeshake26_2x112d', 'ss112']:
    #     model = shake_resnet.ShakeResNet(26, 112, num_classes)

    # elif model_name == 'pyramidnet':
    #     model = pyramidnet.PyramidNet('cifar10', depth=272, alpha=200, num_classes=num_classes, bottleneck=True)

    elif model_name == 'resnet200':
        model = resnet.ResNet(dataset='imagenet', n_channel=n_channel, depth=200, num_classes=num_classes, bottleneck=True)

    elif model_name == 'resnet50':
        model = resnet.ResNet(dataset=dataset, n_channel=n_channel, depth=50, num_classes=num_classes, bottleneck=True)

    if use_mps:
        model = model.to(device)
    return model
