import torch
import torch.nn as nn
from typing import Tuple


def get_shallow_cnn_modules(input_channels: int = 3, num_classes: int = 10, input_size: Tuple[int, int] = (32, 32)) -> nn.ModuleList:
    """
    Create shallow CNN module list for PCN (current implementation)
    
    Args:
        input_channels: Number of input channels (1 for MNIST, 3 for CIFAR/ImageNet)
        num_classes: Number of output classes
        input_size: Input image size (height, width)
    
    Returns:
        nn.ModuleList: List of sequential modules for PCN
    """
    modules = nn.ModuleList([
        nn.Sequential(nn.Conv2d(input_channels, 128, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(128, 128, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(128, 128, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(128, 128, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(
            nn.MaxPool2d(2, 2),
            nn.Flatten(),
            nn.Linear(128 * (input_size[0]//2) * (input_size[1]//2), num_classes)
        )
    ])
    
    return modules


def get_deep_cnn_modules(input_channels: int = 3, num_classes: int = 10, input_size: Tuple[int, int] = (32, 32)) -> nn.ModuleList:
    """
    Create deep CNN module list for PCN
    
    Args:
        input_channels: Number of input channels (1 for MNIST, 3 for CIFAR/ImageNet)
        num_classes: Number of output classes
        input_size: Input image size (height, width)
    
    Returns:
        nn.ModuleList: List of sequential modules for PCN
    """
    modules = nn.ModuleList([
        # Block 1
        nn.Sequential(nn.Conv2d(input_channels, 64, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(64, 64, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.MaxPool2d(2, 2)),
        
        # Block 2
        nn.Sequential(nn.Conv2d(64, 128, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(128, 128, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.MaxPool2d(2, 2)),
        
        # Block 3
        nn.Sequential(nn.Conv2d(128, 256, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(256, 256, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(256, 256, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.MaxPool2d(2, 2)),
        
        # Block 4
        nn.Sequential(nn.Conv2d(256, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.MaxPool2d(2, 2)),
        
        # Classifier
        nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(512, num_classes)
        )
    ])
    
    return modules


def get_wide_cnn_modules(input_channels: int = 3, num_classes: int = 10, input_size: Tuple[int, int] = (32, 32)) -> nn.ModuleList:
    """
    Create wide CNN module list for PCN (wider layers)
    
    Args:
        input_channels: Number of input channels (1 for MNIST, 3 for CIFAR/ImageNet)
        num_classes: Number of output classes
        input_size: Input image size (height, width)
    
    Returns:
        nn.ModuleList: List of sequential modules for PCN
    """
    modules = nn.ModuleList([
        # Block 1
        nn.Sequential(nn.Conv2d(input_channels, 128, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(128, 128, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.MaxPool2d(2, 2)),
        
        # Block 2
        nn.Sequential(nn.Conv2d(128, 256, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(256, 256, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.MaxPool2d(2, 2)),
        
        # Block 3
        nn.Sequential(nn.Conv2d(256, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.MaxPool2d(2, 2)),
        
        # Block 4
        nn.Sequential(nn.Conv2d(512, 1024, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.Conv2d(1024, 1024, 3, 1, 1), nn.ReLU(inplace=True)),
        nn.Sequential(nn.MaxPool2d(2, 2)),
        
        # Classifier
        nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(1024, num_classes)
        )
    ])
    
    return modules
