import torch
import torch.nn as nn
from typing import Tuple


def get_vgg5_modules(input_channels: int = 3, num_classes: int = 10, input_size: Tuple[int, int] = (32, 32)) -> nn.ModuleList:
    """
    Create VGG5 module list for PCN based on reference architecture
    Channel sizes: [128, 256, 512, 512]
    """
    modules = nn.ModuleList([
        # Block 1: 128 channels
        nn.Sequential(nn.Conv2d(input_channels, 128, 3, 1, 1), nn.ReLU(inplace=True)),
        
        # Block 2: 256 channels
        nn.Sequential(nn.MaxPool2d(2, 2), nn.Conv2d(128, 256, 3, 1, 1), nn.ReLU(inplace=True)),
        
        # Block 3: 512 channels
        nn.Sequential(nn.MaxPool2d(2, 2), nn.Conv2d(256, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(512, 512, 3, 1, 0), nn.ReLU(inplace=True)),
        
        # Classifier: 1 linear layer
        nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten(), nn.Linear(512, num_classes))
    ])
    return modules


def get_vgg7_modules(input_channels: int = 3, num_classes: int = 10, input_size: Tuple[int, int] = (32, 32)) -> nn.ModuleList:
    """
    Create VGG7 module list for PCN based on reference architecture
    Channel sizes: [128, 128, 256, 256, 512, 512]
    """
    modules = nn.ModuleList([
        # Block 1: 128 channels
        nn.Sequential(nn.Conv2d(input_channels, 128, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(128, 128, 3, 1, 1), nn.ReLU(inplace=True)),
        
        # Block 2: 256 channels
        nn.Sequential(nn.MaxPool2d(2, 2), nn.Conv2d(128, 256, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(256, 256, 3, 1, 0), nn.ReLU(inplace=True)),
        
        # Block 3: 512 channels
        nn.Sequential(nn.MaxPool2d(2, 2), nn.Conv2d(256, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(512, 512, 3, 1, 0), nn.ReLU(inplace=True)),
        
        # Classifier: 1 linear layer
        nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten(), nn.Linear(512, num_classes))
    ])
    return modules


def get_vgg9_modules(input_channels: int = 3, num_classes: int = 10, input_size: Tuple[int, int] = (32, 32)) -> nn.ModuleList:
    """
    Create VGG9 module list for PCN based on reference architecture
    Channel sizes: [64, 128, 256, 256, 512, 512]
    """
    modules = nn.ModuleList([
        # Block 1: 64 channels
        nn.Sequential(nn.Conv2d(input_channels, 64, 3, 1, 1), nn.ReLU(inplace=True)),
        
        # Block 2: 128 channels
        nn.Sequential(nn.MaxPool2d(2, 2), nn.Conv2d(64, 128, 3, 1, 1), nn.ReLU(inplace=True)),
        
        # Block 3: 256 channels
        nn.Sequential(nn.MaxPool2d(2, 2), nn.Conv2d(128, 256, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(256, 256, 3, 1, 1), nn.ReLU(inplace=True)),
        
        # Block 4: 512 channels
        nn.Sequential(nn.MaxPool2d(2, 2), nn.Conv2d(256, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        
        # Classifier: 3 linear layers
        nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten(), nn.Linear(512, 512), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Linear(512, 512), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Linear(512, num_classes))
    ])
    return modules


def get_vgg11_modules(input_channels: int = 3, num_classes: int = 10, input_size: Tuple[int, int] = (32, 32)) -> nn.ModuleList:
    """
    Create VGG11 module list for PCN based on reference architecture
    Channel sizes: [64, 128, 256, 256, 512, 512, 512, 512]
    """
    modules = nn.ModuleList([
        # Block 1: 64 channels
        nn.Sequential(nn.Conv2d(input_channels, 64, 3, 1, 1), nn.ReLU(inplace=True)),
        
        # Block 2: 128 channels
        nn.Sequential(nn.MaxPool2d(2, 2), nn.Conv2d(64, 128, 3, 1, 1), nn.ReLU(inplace=True)),
        
        # Block 3: 256 channels
        nn.Sequential(nn.MaxPool2d(2, 2), nn.Conv2d(128, 256, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(256, 256, 3, 1, 1), nn.ReLU(inplace=True)),
        
        # Block 4: 512 channels
        nn.Sequential(nn.MaxPool2d(2, 2), nn.Conv2d(256, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        
        # Classifier: 3 linear layers
        nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten(), nn.Linear(512, 512), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Linear(512, 256), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Linear(256, num_classes))
    ])
    return modules


def get_vgg13_modules(input_channels: int = 3, num_classes: int = 10, input_size: Tuple[int, int] = (32, 32)) -> nn.ModuleList:
    """
    Create VGG13 module list for PCN based on reference architecture
    Channel sizes: [128, 128, 128, 128, 256, 256, 256, 256, 512, 512, 512, 512]
    """
    modules = nn.ModuleList([
        # Block 1: 128 channels
        nn.Sequential(nn.Conv2d(input_channels, 128, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(128, 128, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(128, 128, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(128, 128, 3, 1, 1), nn.ReLU(inplace=True)),
        
        # Block 2: 256 channels
        nn.Sequential(nn.MaxPool2d(2, 2), nn.Conv2d(128, 256, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(256, 256, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(256, 256, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(256, 256, 3, 1, 1), nn.ReLU(inplace=True)),
        
        # Block 3: 512 channels
        nn.Sequential(nn.MaxPool2d(2, 2), nn.Conv2d(256, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        
        # Classifier: 1 linear layer
        nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten(), nn.Linear(512, num_classes))
    ])
    return modules


def get_vgg16_modules(input_channels: int = 3, num_classes: int = 10, input_size: Tuple[int, int] = (32, 32)) -> nn.ModuleList:
    """
    Create VGG16 module list for PCN based on reference architecture style
    Channel sizes: [64, 64, 128, 128, 256, 256, 256, 512, 512, 512, 512, 512, 512]
    """
    modules = nn.ModuleList([
        # Block 1: 64 channels
        nn.Sequential(nn.Conv2d(input_channels, 64, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(64, 64, 3, 1, 1), nn.ReLU(inplace=True)),
        
        # Block 2: 128 channels
        nn.Sequential(nn.MaxPool2d(2, 2), nn.Conv2d(64, 128, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(128, 128, 3, 1, 1), nn.ReLU(inplace=True)),
        
        # Block 3: 256 channels
        nn.Sequential(nn.MaxPool2d(2, 2), nn.Conv2d(128, 256, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(256, 256, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(256, 256, 3, 1, 1), nn.ReLU(inplace=True)),
        
        # Block 4: 512 channels
        nn.Sequential(nn.MaxPool2d(2, 2), nn.Conv2d(256, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        
        # Block 5: 512 channels
        nn.Sequential(nn.MaxPool2d(2, 2), nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        
        # Classifier: 3 linear layers
        nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten(), nn.Linear(512, 512), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Linear(512, 512), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Linear(512, num_classes))
    ])
    return modules


def get_vgg19_modules(input_channels: int = 3, num_classes: int = 10, input_size: Tuple[int, int] = (32, 32)) -> nn.ModuleList:
    """
    Create VGG19 module list for PCN based on reference architecture style
    Channel sizes: [64, 64, 128, 128, 256, 256, 256, 256, 512, 512, 512, 512, 512, 512, 512, 512]
    """
    modules = nn.ModuleList([
        # Block 1: 64 channels
        nn.Sequential(nn.Conv2d(input_channels, 64, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(64, 64, 3, 1, 1), nn.ReLU(inplace=True)),
        
        # Block 2: 128 channels
        nn.Sequential(nn.MaxPool2d(2, 2), nn.Conv2d(64, 128, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(128, 128, 3, 1, 1), nn.ReLU(inplace=True)),
        
        # Block 3: 256 channels
        nn.Sequential(nn.MaxPool2d(2, 2), nn.Conv2d(128, 256, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(256, 256, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(256, 256, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(256, 256, 3, 1, 1), nn.ReLU(inplace=True)),
        
        # Block 4: 512 channels
        nn.Sequential(nn.MaxPool2d(2, 2), nn.Conv2d(256, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        
        # Block 5: 512 channels
        nn.Sequential(nn.MaxPool2d(2, 2), nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        
        # Classifier: 3 linear layers
        nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten(), nn.Linear(512, 512), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Linear(512, 512), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Linear(512, num_classes))
    ])
    return modules
