import torch
import torch.backends.cudnn as cudnn

# import all archs
from torchvision.models.resnet import resnet18, resnet34, resnet50
from multiquery_randomized_smoothing.src.models.cifar_resnet import resnet as resnet_cifar
from multiquery_randomized_smoothing.src.models.model_arch_test import model_test
from multiquery_randomized_smoothing.src.models.modified_resnet_adam import modified_resnet_adam
from multiquery_randomized_smoothing.src.dataset_utils import get_normalize_layer, PreProcessLayer

ARCHITECTURES = ["resnet50", "cifar_resnet20", "cifar_resnet110", "simple_cnn", "STNSmoothed"]

def get_architecture(arch: str = "cifar_resnet110",
                     prepend_preprocess_layer: bool = False, 
                     prepend_normalize_layer: bool = False,
                     dataset: str = "cifar10",
                     input_size: int = 32, # required for when we adjust in_features of fc layer of cifar_resnet110
                     input_channels: int = 3,
                     num_classes: int = 10,
                     mask_output: str = "penalty") -> torch.nn.Module:
    """Return a neural network (with random weights)

    :param arch: the architecture - should be in the ARCHITECTURES list above
    :param dataset: the dataset - should be in the datasets.DATASETS list
    :return: a Pytorch module
    """
    if arch == "resnet18":
        model = resnet18()
    elif arch == "resnet34":
        model = resnet34()
    elif arch == "resnet50" and dataset == "imagenet":
        model = torch.nn.DataParallel(resnet50(pretrained=False))
        cudnn.benchmark = True
    elif arch == "cifar_resnet20":
        model = resnet_cifar(depth=20, num_classes=num_classes)
    elif arch == "cifar_resnet110":
        model = resnet_cifar(depth=110, input_size=input_size, num_classes=num_classes)
    elif arch == "modified_resnet_adam":
        model = modified_resnet_adam(mask_output=mask_output, depth=110)
    elif arch == "unet":
        model = model_test(idx=0)

    if prepend_normalize_layer:
        normalize_layer = get_normalize_layer(dataset)
        model = torch.nn.Sequential(normalize_layer, model)

    if prepend_preprocess_layer:
        preprocess_layer = PreProcessLayer(prob_flip=0.5)
        model = torch.nn.Sequential(preprocess_layer, model)

    return model
