import torch
import torch.nn as nn
from typing import Tuple


def get_vgg3_5_modules(input_channels: int = 3, num_classes: int = 10, input_size: Tuple[int, int] = (32, 32)) -> nn.ModuleList:
    """
    Create VGG5 module list for PCN (experimental)
    Following the provided architecture with MaxPool moved to beginning of next module
    """
    modules = nn.ModuleList([
        nn.Sequential(nn.Conv2d(input_channels, 128, 3, 1, 1), nn.ReLU(inplace=True)),
        
        nn.Sequential(nn.MaxPool2d(2, 2), nn.Conv2d(128, 256, 3, 1, 1), nn.ReLU(inplace=True)),
        
        nn.Sequential(nn.MaxPool2d(2, 2), nn.Conv2d(256, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        
        nn.Sequential(nn.MaxPool2d(2, 2), nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        
        nn.Sequential(
            nn.MaxPool2d(2, 2),
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(512, num_classes)
        )
    ])
    return modules


def get_vgg3_7_modules(input_channels: int = 3, num_classes: int = 10, input_size: Tuple[int, int] = (32, 32)) -> nn.ModuleList:
    """
    Create VGG7 module list for PCN (experimental)
    Following the provided architecture with MaxPool moved to beginning of next module
    """
    modules = nn.ModuleList([
        nn.Sequential(nn.Conv2d(input_channels, 128, 3, 1, 1), nn.ReLU(inplace=True)),
        
        nn.Sequential(nn.MaxPool2d(2, 2), nn.Conv2d(128, 128, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(128, 256, 3, 1, 1), nn.ReLU(inplace=True)),

        nn.Sequential(nn.MaxPool2d(2, 2), nn.Conv2d(256, 256, 3, 1, 0), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(256, 512, 3, 1, 1), nn.ReLU(inplace=True)),

        nn.Sequential(nn.MaxPool2d(2, 2), nn.Conv2d(512, 512, 3, 1, 0), nn.ReLU(inplace=True)),
        
        nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(512, num_classes)
        )
    ])
    return modules


def get_vgg3_9_modules(input_channels: int = 3, num_classes: int = 10, input_size: Tuple[int, int] = (32, 32)) -> nn.ModuleList:
    """
    Create VGG9 module list for PCN (experimental)
    Following the provided architecture with MaxPool moved to beginning of next module
    """
    modules = nn.ModuleList([
        nn.Sequential(nn.Conv2d(input_channels, 128, 3, 1, 1), nn.ReLU(inplace=True)),
        
        nn.Sequential(nn.MaxPool2d(2, 2), nn.Conv2d(128, 128, 3, 1, 1), nn.ReLU(inplace=True)),
        
        nn.Sequential(nn.Conv2d(128, 256, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.MaxPool2d(2, 2), nn.Conv2d(256, 256, 3, 1, 1), nn.ReLU(inplace=True)),
        
        nn.Sequential(nn.Conv2d(256, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.MaxPool2d(2, 2), nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        
        nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.MaxPool2d(2, 2), nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        
        nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(512, num_classes)
        )
    ])
    return modules


def get_vgg3_11_modules(input_channels: int = 3, num_classes: int = 10, input_size: Tuple[int, int] = (32, 32)) -> nn.ModuleList:
    """
    Create VGG11 module list for PCN
    """
    modules = nn.ModuleList([
        nn.Sequential(nn.Conv2d(input_channels, 64, 3, 1, 1), nn.ReLU(inplace=True)),
        
        nn.Sequential(nn.MaxPool2d(2, 2), nn.Conv2d(64, 128, 3, 1, 1), nn.ReLU(inplace=True)),
        
        nn.Sequential(nn.MaxPool2d(2, 2), nn.Conv2d(128, 256, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(256, 256, 3, 1, 1), nn.ReLU(inplace=True)),
        
        nn.Sequential(nn.MaxPool2d(2, 2), nn.Conv2d(256, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        
        nn.Sequential(nn.MaxPool2d(2, 2), nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        
        nn.Sequential(
            nn.AdaptiveAvgPool2d((7, 7)),
            nn.Flatten(),
            nn.Linear(512*7*7, 4096),
            nn.ReLU(inplace=True)
        ),
        nn.Sequential(
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True)
        ),
        nn.Sequential(
            nn.Linear(4096, num_classes)
        )
    ])
    return modules


def get_vgg3_13_modules(input_channels: int = 3, num_classes: int = 10, input_size: Tuple[int, int] = (32, 32)) -> nn.ModuleList:
    """
    Create VGG13 module list for PCN
    """
    modules = nn.ModuleList([
        nn.Sequential(nn.Conv2d(input_channels, 64, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(64, 64, 3, 1, 1), nn.ReLU(inplace=True)),
        
        nn.Sequential(nn.MaxPool2d(2, 2), nn.Conv2d(64, 128, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(128, 128, 3, 1, 1), nn.ReLU(inplace=True)),
        
        nn.Sequential(nn.MaxPool2d(2, 2), nn.Conv2d(128, 256, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(256, 256, 3, 1, 1), nn.ReLU(inplace=True)),
        
        nn.Sequential(nn.MaxPool2d(2, 2), nn.Conv2d(256, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        
        nn.Sequential(nn.MaxPool2d(2, 2), nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        
        nn.Sequential(
            nn.AdaptiveAvgPool2d((7, 7)),
            nn.Flatten(),
            nn.Linear(512*7*7, 4096),
            nn.ReLU(inplace=True)
        ),
        nn.Sequential(
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True)
        ),
        nn.Sequential(
            nn.Linear(4096, num_classes)
        )
    ])
    return modules


def get_vgg3_16_modules(input_channels: int = 3, num_classes: int = 10, input_size: Tuple[int, int] = (32, 32)) -> nn.ModuleList:
    """
    Create VGG16 module list for PCN
    """
    modules = nn.ModuleList([
        nn.Sequential(nn.Conv2d(input_channels, 64, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(64, 64, 3, 1, 1), nn.ReLU(inplace=True)),
        
        nn.Sequential(nn.MaxPool2d(2, 2), nn.Conv2d(64, 128, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(128, 128, 3, 1, 1), nn.ReLU(inplace=True)),
        
        nn.Sequential(nn.MaxPool2d(2, 2), nn.Conv2d(128, 256, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(256, 256, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(256, 256, 3, 1, 1), nn.ReLU(inplace=True)),
        
        nn.Sequential(nn.MaxPool2d(2, 2), nn.Conv2d(256, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        
        nn.Sequential(nn.MaxPool2d(2, 2), nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        
        nn.Sequential(
            nn.AdaptiveAvgPool2d((7, 7)),
            nn.Flatten(),
            nn.Linear(512*7*7, 4096),
            nn.ReLU(inplace=True)
        ),
        nn.Sequential(
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True)
        ),
        nn.Sequential(
            nn.Linear(4096, num_classes)
        )
    ])
    return modules


def get_vgg3_19_modules(input_channels: int = 3, num_classes: int = 10, input_size: Tuple[int, int] = (32, 32)) -> nn.ModuleList:
    """
    Create VGG19 module list for PCN
    """
    modules = nn.ModuleList([
        nn.Sequential(nn.Conv2d(input_channels, 64, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(64, 64, 3, 1, 1), nn.ReLU(inplace=True)),
        
        nn.Sequential(nn.MaxPool2d(2, 2), nn.Conv2d(64, 128, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(128, 128, 3, 1, 1), nn.ReLU(inplace=True)),
        
        nn.Sequential(nn.MaxPool2d(2, 2), nn.Conv2d(128, 256, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(256, 256, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(256, 256, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(256, 256, 3, 1, 1), nn.ReLU(inplace=True)),
        
        nn.Sequential(nn.MaxPool2d(2, 2), nn.Conv2d(256, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        
        nn.Sequential(nn.MaxPool2d(2, 2), nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        
        nn.Sequential(
            nn.AdaptiveAvgPool2d((7, 7)),
            nn.Flatten(),
            nn.Linear(512*7*7, 4096),
            nn.ReLU(inplace=True)
        ),
        nn.Sequential(
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True)
        ),
        nn.Sequential(
            nn.Linear(4096, num_classes)
        )
    ])
    return modules
