from .cifar_resnet import BasicBlock, Bottleneck, ResNet
from .feature_extractor import FeatureExtractor
from .wrn import WideResNet


def get_model(args, name, num_classes):
    if name == "resnet18":
        net = ResNet(
            BasicBlock,
            [2, 2, 2, 2],
            num_classes=num_classes,
            original=True,
        )
    elif name == "wrn_34_10":
        net = WideResNet(
            34,
            num_classes,
            10,
            dropRate=0.0,
            original=True,
        )
    elif name == "cifar_resnet18":
        net = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes)
    elif name == "cifar_wrn_34_10":
        net = WideResNet(
            34,
            num_classes,
            10,
            dropRate=0.0,
            original=False,
        )
    else:
        raise NotImplementedError
    return net


def load_model(
    args,
    name,
    num_classes,
    extract_layers=[],
    is_avg_pool=True,
    is_relu=True,
):
    net = get_model(args, name, num_classes)

    net = FeatureExtractor(
        net, extract_layers, is_avg_pool=is_avg_pool, is_relu=is_relu
    )
    return net
