from .lenet import LeNet
from .resnet import ResNet18, ResNet34, ResNet50, ResNet101, ResNet152
from .densenet import DenseNet121, DenseNet161, DenseNet169, DenseNet201
from .efficientnet import EfficientNet_B0, EfficientNet_B1, EfficientNet_B2, EfficientNet_B3, EfficientNet_B4, EfficientNet_B5, EfficientNet_B6, EfficientNet_B7, EfficientNet_V2_S, EfficientNet_V2_M, EfficientNet_V2_L
from .alexnet import AlexNet
from .convnext import ConvNeXt_Tiny, ConvNeXt_Small, ConvNeXt_Base, ConvNeXt_Large
from .googlenet import GoogLeNet
from .vgg import VGG11, VGG13, VGG16, VGG19
from .mobilenet import MobileNet_V2, MobileNet_V3_Small, MobileNet_V3_Large


MODEL_DICT = {
    "LeNet": LeNet,
    "ResNet18": ResNet18,
    "ResNet34": ResNet34,
    "ResNet50": ResNet50,
    "ResNet101": ResNet101,
    "ResNet152": ResNet152,
    "DenseNet121": DenseNet121,
    "DenseNet161": DenseNet161,
    "DenseNet169": DenseNet169,
    "DenseNet201": DenseNet201,
    "EfficientNet_B0": EfficientNet_B0,
    "EfficientNet_B1": EfficientNet_B1,
    "EfficientNet_B2": EfficientNet_B2,
    "EfficientNet_B3": EfficientNet_B3,
    "EfficientNet_B4": EfficientNet_B4,
    "EfficientNet_B5": EfficientNet_B5,
    "EfficientNet_B6": EfficientNet_B6,
    "EfficientNet_B7": EfficientNet_B7,
    "EfficientNet_V2_S": EfficientNet_V2_S,
    "EfficientNet_V2_M": EfficientNet_V2_M,
    "EfficientNet_V2_L": EfficientNet_V2_L,
    "AlexNet": AlexNet,
    "ConvNeXt_Tiny": ConvNeXt_Tiny,
    "ConvNeXt_Small": ConvNeXt_Small,
    "ConvNeXt_Base": ConvNeXt_Base,
    "ConvNeXt_Large": ConvNeXt_Large,
    "GoogLeNet": GoogLeNet,
    "VGG11": VGG11,
    "VGG13": VGG13,
    "VGG16": VGG16,
    "VGG19": VGG19,
    "MobileNet_V2": MobileNet_V2,
    "MobileNet_V3_Small": MobileNet_V3_Small,
    "MobileNet_V3_Large": MobileNet_V3_Large,
}


def get_model(model, input_shape, num_classes, device="cuda", verbose=True):
    if verbose:
        print(f"Constructing {model} with input shape {'x'.join(map(str, input_shape))} and {num_classes} classes ...")
    if model in MODEL_DICT:
        cls = MODEL_DICT[model]
    else:
        raise NotImplementedError

    return cls(input_shape, num_classes).to(device)