from torch import nn


def get_cnn_architecture(model_name, dataset, activation):
    if dataset == "CIFAR10":
        num_classes = 10
        img_size = 32
    elif dataset == "CIFAR100":
        num_classes = 100
        img_size = 32
    elif dataset == "tiny-imagenet":
        num_classes = 200
        img_size = 64
    else:
        raise ValueError("Unsupported dataset. Only CIFAR10, cifar100 and tiny-imagenet are supported.")

    if activation == "relu":
        activation = nn.ReLU
    elif activation == "gelu":
        activation = nn.GELU
    elif activation == "tanh":
        activation = nn.Tanh
    elif activation == "leaky_relu":
        activation = nn.LeakyReLU
    else:
        raise ValueError("Unsupported activation function. Only relu, gelu, tanh and leaky_relu are supported.")

    if model_name == "SmallVGG":
        architecture = [
            nn.Sequential(nn.Conv2d(3, 32, 3, 2, 1), activation()),
            nn.Sequential(nn.Conv2d(32, 64, 3, 2, 1), activation()),
            nn.Sequential(nn.Conv2d(64, 128, 3, 2, 1), activation()),
            nn.Sequential(nn.Conv2d(128, 256, 3, 2, 1), activation()),
            nn.Sequential(nn.Conv2d(256, 512, 3, 2, 1), activation()),
            nn.Sequential(nn.Flatten(), nn.Linear(512 * (img_size // 2**5)**2 , num_classes, bias=True)),
        ]
    elif model_name == "SmallVGGOutput":
        architecture = [
            nn.Sequential(nn.Conv2d(3, 32, 3, 2, 1), activation(),
                          nn.Conv2d(32, 64, 3, 2, 1), activation(),
                          nn.Conv2d(64, 128, 3, 2, 1), activation(),
                          nn.Conv2d(128, 256, 3, 2, 1), activation(),
                          nn.Conv2d(256, 512, 3, 2, 1), activation(), nn.Flatten()),
            nn.Sequential(nn.Linear(512 * (img_size // 2**5)**2 , num_classes, bias=True)),
        ]
    elif model_name == "SmallVGGInput":
        architecture = [
            nn.Sequential(nn.Conv2d(3, 32, 3, 2, 1), activation()),
            nn.Sequential(nn.Conv2d(32, 64, 3, 2, 1), activation(),
                          nn.Conv2d(64, 128, 3, 2, 1), activation(),
                          nn.Conv2d(128, 256, 3, 2, 1), activation(),
                          nn.Conv2d(256, 512, 3, 2, 1), activation(),
                          nn.Flatten(), nn.Linear(512 * (img_size // 2**5)**2 , num_classes, bias=True)),
        ]
    elif model_name == "SmallVGGMiddle":
        architecture = [
            nn.Sequential(nn.Conv2d(3, 32, 3, 2, 1), activation(),
                          nn.Conv2d(32, 64, 3, 2, 1), activation(),
                          nn.Conv2d(64, 128, 3, 2, 1), activation()),
            nn.Sequential(nn.Conv2d(128, 256, 3, 2, 1), activation(),
                          nn.Conv2d(256, 512, 3, 2, 1), activation(),
                          nn.Flatten(), nn.Linear(512 * (img_size // 2**5)**2 , num_classes, bias=True)),
        ]
    elif model_name == "VGG5":
        architecture = [
            nn.Sequential(nn.Conv2d(3, 128, 3, 1, 1), activation(), nn.MaxPool2d(2, 2)),
            nn.Sequential(nn.Conv2d(128, 256, 3, 1, 1), activation(), nn.MaxPool2d(2, 2)),
            nn.Sequential(nn.Conv2d(256, 512, 3, 1, 1), activation(), nn.MaxPool2d(2, 2)),
            nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), activation(), nn.MaxPool2d(2, 2)),
            nn.Sequential(nn.Flatten(), nn.Linear(512 * (img_size // 2**4)**2 , num_classes, bias=True)),
        ]
    elif model_name == "VGG5Output":
        architecture = [
            nn.Sequential(nn.Conv2d(3, 128, 3, 1, 1), activation(), nn.MaxPool2d(2, 2),
                          nn.Conv2d(128, 256, 3, 1, 1), activation(), nn.MaxPool2d(2, 2),
                          nn.Conv2d(256, 512, 3, 1, 1), activation(), nn.MaxPool2d(2, 2),
                          nn.Conv2d(512, 512, 3, 1, 1), activation(), nn.MaxPool2d(2, 2), nn.Flatten()),
            nn.Sequential(nn.Linear(512 * (img_size // 2**4)**2 , num_classes, bias=True)),
        ]
    elif model_name == "VGG7":
        architecture = [
            nn.Sequential(nn.Conv2d(3, 128, 3, 1, 1), activation(), nn.MaxPool2d(2, 2)),
            nn.Sequential(nn.Conv2d(128, 128, 3, 1, 1), activation()),
            nn.Sequential(nn.Conv2d(128, 256, 3, 1, 1), activation(), nn.MaxPool2d(2, 2)),
            nn.Sequential(nn.Conv2d(256, 256, 3, 1, 0), activation()),
            nn.Sequential(nn.Conv2d(256, 512, 3, 1, 1), activation(), nn.MaxPool2d(2, 2)),
            nn.Sequential(nn.Conv2d(512, 512, 3, 1, 0), activation()),
            nn.Sequential(nn.Flatten(), nn.Linear(512 * (img_size // 2**5)**2, num_classes, bias=True)),
        ]
    elif model_name == "VGG9": 
        architecture = [
            nn.Sequential(nn.Conv2d(3, 128, 3, 1, 1), nn.BatchNorm2d(128), activation(), nn.MaxPool2d(2, 2)),
            nn.Sequential(nn.Conv2d(128, 128, 3, 1, 1), nn.BatchNorm2d(128), activation()),
            nn.Sequential(nn.Conv2d(128, 256, 3, 1, 1), nn.BatchNorm2d(256), activation(), nn.MaxPool2d(2, 2)),
            nn.Sequential(nn.Conv2d(256, 256, 3, 1, 1), nn.BatchNorm2d(256), activation()),
            nn.Sequential(nn.Conv2d(256, 512, 3, 1, 1), nn.BatchNorm2d(512), activation(), nn.MaxPool2d(2, 2)),
            nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), nn.BatchNorm2d(512), activation()),
            nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), nn.BatchNorm2d(512), activation(), nn.MaxPool2d(2, 2)),
            nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), nn.BatchNorm2d(512), activation()),
            nn.Sequential(nn.Flatten(), nn.Linear(512 * (img_size // 2**4)**2, num_classes, bias=True)),
        ]
    elif model_name == "VGG11":
        architecture = [
            nn.Sequential(nn.Conv2d(3, 64, 3, 1, 1), activation(), nn.MaxPool2d(2, 2)),
            nn.Sequential(nn.Conv2d(64, 128, 3, 1, 1), activation(), nn.MaxPool2d(2, 2)),
            nn.Sequential(nn.Conv2d(128, 256, 3, 1, 1), activation()),
            nn.Sequential(nn.Conv2d(256, 256, 3, 1, 1), activation(), nn.MaxPool2d(2, 2)),
            nn.Sequential(nn.Conv2d(256, 512, 3, 1, 1), activation()),
            nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), activation(), nn.MaxPool2d(2, 2)),
            nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), activation()),
            nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), activation(), nn.MaxPool2d(2, 2)),
            nn.Sequential(nn.Flatten(), nn.Linear(512 * (img_size // 2**5)**2, num_classes)),
        ]

    elif model_name == "VGG13":
        architecture = [
            nn.Sequential(nn.Conv2d(3, 64, 3, 1, 1), activation(), nn.MaxPool2d(2, 2)),
            nn.Sequential(nn.Conv2d(64, 64, 3, 1, 1), activation()),
            nn.Sequential(nn.Conv2d(64, 128, 3, 1, 1), activation(), nn.MaxPool2d(2, 2)),
            nn.Sequential(nn.Conv2d(128, 128, 3, 1, 1), activation()),
            nn.Sequential(nn.Conv2d(128, 256, 3, 1, 1), activation(), nn.MaxPool2d(2, 2)),
            nn.Sequential(nn.Conv2d(256, 256, 3, 1, 1), activation()),
            nn.Sequential(nn.Conv2d(256, 512, 3, 1, 1), activation(), nn.MaxPool2d(2, 2)),
            nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), activation()),
            nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), activation(), nn.MaxPool2d(2, 2)),
            nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), activation()),
            nn.Sequential(nn.Flatten(), nn.Linear(512 * (img_size // 2**5)**2, 512 * (img_size // 2**5)**2), activation()),
            nn.Sequential(nn.Linear(512 * (img_size // 2**5)**2, 1000), activation()),
            nn.Sequential(nn.Linear(1000, num_classes)),
        ]

    elif model_name == "VGG16":
        architecture = [
            nn.Sequential(nn.Conv2d(3, 64, 3, 1, 1), nn.BatchNorm2d(64), activation()),
            nn.Sequential(nn.Conv2d(64, 64, 3, 1, 1), nn.BatchNorm2d(64), activation(), nn.MaxPool2d(2, 2)),
            nn.Sequential(nn.Conv2d(64, 128, 3, 1, 1), nn.BatchNorm2d(128), activation()),
            nn.Sequential(nn.Conv2d(128, 128, 3, 1, 1), nn.BatchNorm2d(128), activation(), nn.MaxPool2d(2, 2)),
            nn.Sequential(nn.Conv2d(128, 256, 3, 1, 1), nn.BatchNorm2d(256), activation()),
            nn.Sequential(nn.Conv2d(256, 256, 3, 1, 1), nn.BatchNorm2d(256), activation()),
            nn.Sequential(nn.Conv2d(256, 256, 3, 1, 1), nn.BatchNorm2d(256), activation(), nn.MaxPool2d(2, 2)),
            nn.Sequential(nn.Conv2d(256, 512, 3, 1, 1), nn.BatchNorm2d(512), activation()),
            nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), nn.BatchNorm2d(512), activation()),
            nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), nn.BatchNorm2d(512), activation(), nn.MaxPool2d(2, 2)),
            nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), nn.BatchNorm2d(512), activation()),
            nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), nn.BatchNorm2d(512), activation()),
            nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), nn.BatchNorm2d(512), activation(), nn.MaxPool2d(2, 2)),
            nn.Sequential(nn.Flatten(), nn.Linear(512 * (img_size // 2**5)**2 , num_classes)),
        ]


    else:
        raise ValueError(f"Unknown model name: {model_name}")
    
    return architecture

