import torch
import torch.nn as nn
from typing import Tuple


def get_vgg2_5_modules(input_channels: int = 3, num_classes: int = 10, input_size: Tuple[int, int] = (32, 32)) -> nn.ModuleList:
    """
    Create VGG5 module list for PCN (experimental)
    MaxPool moved to end of current module, adaptive FC layer size
    """
    img_size = input_size[0]  # Assuming square images
    modules = nn.ModuleList([
        nn.Sequential(nn.Conv2d(input_channels, 128, 3, 1, 1), nn.ReLU(inplace=True), nn.MaxPool2d(2, 2)),
        
        nn.Sequential(nn.Conv2d(128, 256, 3, 1, 1), nn.ReLU(inplace=True), nn.MaxPool2d(2, 2)),
        
        nn.Sequential(nn.Conv2d(256, 512, 3, 1, 1), nn.ReLU(inplace=True), nn.MaxPool2d(2, 2)),
        
        nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(inplace=True), nn.MaxPool2d(2, 2)),
        
        nn.Sequential(
            nn.Flatten(),
            nn.Linear(512 * (img_size // 2**4)**2, num_classes, bias=True)
        )
    ])
    return modules


def get_vgg2_7_modules(input_channels: int = 3, num_classes: int = 10, input_size: Tuple[int, int] = (32, 32)) -> nn.ModuleList:
    """
    Create VGG7 module list for PCN (experimental)
    MaxPool moved to end of current module, adaptive FC layer size
    """
    img_size = input_size[0]  # Assuming square images
    modules = nn.ModuleList([
        nn.Sequential(nn.Conv2d(input_channels, 128, 3, 1, 1), nn.ReLU(inplace=True), nn.MaxPool2d(2, 2)),
        
        nn.Sequential(nn.Conv2d(128, 128, 3, 1, 1), nn.ReLU(inplace=True)),
        
        nn.Sequential(nn.Conv2d(128, 256, 3, 1, 1), nn.ReLU(inplace=True), nn.MaxPool2d(2, 2)),
        
        nn.Sequential(nn.Conv2d(256, 256, 3, 1, 0), nn.ReLU(inplace=True)),
        
        nn.Sequential(nn.Conv2d(256, 512, 3, 1, 1), nn.ReLU(inplace=True), nn.MaxPool2d(2, 2)),
        
        nn.Sequential(nn.Conv2d(512, 512, 3, 1, 0), nn.ReLU(inplace=True)),
        
        nn.Sequential(
            nn.Flatten(),
            nn.Linear(512 * (img_size // 2**5)**2, num_classes, bias=True)
        )
    ])
    return modules


def get_vgg2_9_modules(input_channels: int = 3, num_classes: int = 10, input_size: Tuple[int, int] = (32, 32)) -> nn.ModuleList:
    """
    Create VGG9 module list for PCN (experimental)
    MaxPool moved to end of current module, adaptive FC layer size
    """
    img_size = input_size[0]  # Assuming square images
    modules = nn.ModuleList([
        nn.Sequential(nn.Conv2d(input_channels, 128, 3, 1, 1), nn.ReLU(inplace=True), nn.MaxPool2d(2, 2)),
        
        nn.Sequential(nn.Conv2d(128, 128, 3, 1, 1), nn.ReLU(inplace=True)),
        
        nn.Sequential(nn.Conv2d(128, 256, 3, 1, 1), nn.ReLU(inplace=True), nn.MaxPool2d(2, 2)),
        
        nn.Sequential(nn.Conv2d(256, 256, 3, 1, 1), nn.ReLU(inplace=True)),
        
        nn.Sequential(nn.Conv2d(256, 512, 3, 1, 1), nn.ReLU(inplace=True), nn.MaxPool2d(2, 2)),
        
        nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        
        nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(inplace=True), nn.MaxPool2d(2, 2)),
        
        nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        
        nn.Sequential(
            nn.Flatten(),
            nn.Linear(512 * (img_size // 2**4)**2, num_classes, bias=True)
        )
    ])
    return modules


def get_vgg2_11_modules(input_channels: int = 3, num_classes: int = 10, input_size: Tuple[int, int] = (32, 32)) -> nn.ModuleList:
    """
    Create VGG11 module list for PCN (paper version with MaxPool at end)
    """
    img_size = input_size[0]  # Assuming square images
    modules = nn.ModuleList([
        nn.Sequential(nn.Conv2d(input_channels, 64, 3, 1, 1), nn.ReLU(inplace=True), nn.MaxPool2d(2, 2)),
        
        nn.Sequential(nn.Conv2d(64, 128, 3, 1, 1), nn.ReLU(inplace=True), nn.MaxPool2d(2, 2)),
        
        nn.Sequential(nn.Conv2d(128, 256, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(256, 256, 3, 1, 1), nn.ReLU(inplace=True), nn.MaxPool2d(2, 2)),
        
        nn.Sequential(nn.Conv2d(256, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(inplace=True), nn.MaxPool2d(2, 2)),
        
        nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        
        nn.Sequential(
            nn.AdaptiveAvgPool2d((7, 7)),
            nn.Flatten(),
            nn.Linear(512*7*7, 4096),
            nn.ReLU(inplace=True)
        ),
        nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True)
        ),
        nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(4096, num_classes)
        )
    ])
    return modules


def get_vgg2_13_modules(input_channels: int = 3, num_classes: int = 10, input_size: Tuple[int, int] = (32, 32)) -> nn.ModuleList:
    """
    Create VGG13 module list for PCN (paper version with MaxPool at end)
    """
    img_size = input_size[0]  # Assuming square images
    modules = nn.ModuleList([
        nn.Sequential(nn.Conv2d(input_channels, 64, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(64, 64, 3, 1, 1), nn.ReLU(inplace=True), nn.MaxPool2d(2, 2)),
        
        nn.Sequential(nn.Conv2d(64, 128, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(128, 128, 3, 1, 1), nn.ReLU(inplace=True), nn.MaxPool2d(2, 2)),
        
        nn.Sequential(nn.Conv2d(128, 256, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(256, 256, 3, 1, 1), nn.ReLU(inplace=True), nn.MaxPool2d(2, 2)),
        
        nn.Sequential(nn.Conv2d(256, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(inplace=True), nn.MaxPool2d(2, 2)),
        
        nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        
        nn.Sequential(
            nn.AdaptiveAvgPool2d((7, 7)),
            nn.Flatten(),
            nn.Linear(512*7*7, 4096),
            nn.ReLU(inplace=True)
        ),
        nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True)
        ),
        nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(4096, num_classes)
        )
    ])
    return modules


def get_vgg2_16_modules(input_channels: int = 3, num_classes: int = 10, input_size: Tuple[int, int] = (32, 32)) -> nn.ModuleList:
    """
    Create VGG16 module list for PCN (paper version with MaxPool at end)
    """
    img_size = input_size[0]  # Assuming square images
    modules = nn.ModuleList([
        nn.Sequential(nn.Conv2d(input_channels, 64, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(64, 64, 3, 1, 1), nn.ReLU(inplace=True), nn.MaxPool2d(2, 2)),
        
        nn.Sequential(nn.Conv2d(64, 128, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(128, 128, 3, 1, 1), nn.ReLU(inplace=True), nn.MaxPool2d(2, 2)),
        
        nn.Sequential(nn.Conv2d(128, 256, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(256, 256, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(256, 256, 3, 1, 1), nn.ReLU(inplace=True), nn.MaxPool2d(2, 2)),
        
        nn.Sequential(nn.Conv2d(256, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(inplace=True), nn.MaxPool2d(2, 2)),
        
        nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        
        nn.Sequential(
            nn.AdaptiveAvgPool2d((7, 7)),
            nn.Flatten(),
            nn.Linear(512*7*7, 4096),
            nn.ReLU(inplace=True)
        ),
        nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True)
        ),
        nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(4096, num_classes)
        )
    ])
    return modules


def get_vgg2_19_modules(input_channels: int = 3, num_classes: int = 10, input_size: Tuple[int, int] = (32, 32)) -> nn.ModuleList:
    """
    Create VGG19 module list for PCN (paper version with MaxPool at end)
    """
    img_size = input_size[0]  # Assuming square images
    modules = nn.ModuleList([
        nn.Sequential(nn.Conv2d(input_channels, 64, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(64, 64, 3, 1, 1), nn.ReLU(inplace=True), nn.MaxPool2d(2, 2)),
        
        nn.Sequential(nn.Conv2d(64, 128, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(128, 128, 3, 1, 1), nn.ReLU(inplace=True), nn.MaxPool2d(2, 2)),
        
        nn.Sequential(nn.Conv2d(128, 256, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(256, 256, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(256, 256, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(256, 256, 3, 1, 1), nn.ReLU(inplace=True), nn.MaxPool2d(2, 2)),
        
        nn.Sequential(nn.Conv2d(256, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(inplace=True), nn.MaxPool2d(2, 2)),
        
        nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        
        nn.Sequential(
            nn.AdaptiveAvgPool2d((7, 7)),
            nn.Flatten(),
            nn.Linear(512*7*7, 4096),
            nn.ReLU(inplace=True)
        ),
        nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True)
        ),
        nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(4096, num_classes)
        )
    ])
    return modules
