from .resnet import resnet18, resnet20, resnet32, resnet34, resnet44, resnet56
from .vgg import VGG9, ConvNet4, ConvNet5, ConvNet6

def get_model(model_name, img_size):
    conv_channels = -1
    if "conv" in model_name:
        conv_channels = int(model_name.split("_")[-1])
        model_name = "_".join(model_name.split("_")[:-1])
    if model_name == "lenet":
        raise DeprecationWarning
    elif model_name == "conv4":
        return ConvNet4(conv_channels=conv_channels,
                        img_size=img_size)
    elif model_name == "conv5":
        return ConvNet5(conv_channels=conv_channels,
                        img_size=img_size)
    elif model_name == "conv6":
        return ConvNet6(conv_channels=conv_channels)
    elif model_name == "vgg9":
        return VGG9()
    elif model_name == "resnet20":
        return resnet20()
    elif model_name == "resnet32":
        return resnet32()
    elif model_name == "resnet44":
        return resnet44()
    elif model_name == "resnet56":
        return resnet56()
    elif model_name == "resnet18":
        return resnet18()
    elif model_name == "resnet34":
        return resnet34()
    else:
        raise NotImplementedError