import torch
import torch.nn as nn
from typing import Tuple


class BasicBlock(nn.Module):
    """Basic ResNet block for ResNet18/34"""
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = torch.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        return out


class BottleneckBlock(nn.Module):
    """Bottleneck ResNet block for ResNet50+"""
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(BottleneckBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion * planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = torch.relu(self.bn1(self.conv1(x)))
        out = torch.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        return out


def get_resnet18_modules(input_channels: int = 3, num_classes: int = 10, input_size: Tuple[int, int] = (32, 32)) -> nn.ModuleList:
    """
    Create ResNet18 module list for PCN
    
    Args:
        input_channels: Number of input channels (1 for MNIST, 3 for CIFAR/ImageNet)
        num_classes: Number of output classes
        input_size: Input image size (height, width)
    
    Returns:
        nn.ModuleList: List of sequential modules for PCN
    """
    modules = nn.ModuleList([
        # Initial conv layer
        nn.Sequential(
            nn.Conv2d(input_channels, 64, kernel_size=7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        ),
        
        # Layer 1: 2 blocks
        nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
            BasicBlock(64, 64)
        ),
        nn.Sequential(BasicBlock(64, 64)),
        
        # Layer 2: 2 blocks with stride 2
        nn.Sequential(BasicBlock(64, 128, stride=2)),
        nn.Sequential(BasicBlock(128, 128)),
        
        # Layer 3: 2 blocks with stride 2
        nn.Sequential(BasicBlock(128, 256, stride=2)),
        nn.Sequential(BasicBlock(256, 256)),
        
        # Layer 4: 2 blocks with stride 2
        nn.Sequential(BasicBlock(256, 512, stride=2)),
        nn.Sequential(BasicBlock(512, 512)),
        
        # Final classifier
        nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(512, num_classes)
        )
    ])
    
    return modules


def get_resnet34_modules(input_channels: int = 3, num_classes: int = 10, input_size: Tuple[int, int] = (32, 32)) -> nn.ModuleList:
    """
    Create ResNet34 module list for PCN
    
    Args:
        input_channels: Number of input channels (1 for MNIST, 3 for CIFAR/ImageNet)
        num_classes: Number of output classes
        input_size: Input image size (height, width)
    
    Returns:
        nn.ModuleList: List of sequential modules for PCN
    """
    modules = nn.ModuleList([
        # Initial conv layer
        nn.Sequential(
            nn.Conv2d(input_channels, 64, kernel_size=7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        ),
        
        # Layer 1: 3 blocks
        nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
            BasicBlock(64, 64)
        ),
        nn.Sequential(BasicBlock(64, 64)),
        nn.Sequential(BasicBlock(64, 64)),
        
        # Layer 2: 4 blocks with stride 2
        nn.Sequential(BasicBlock(64, 128, stride=2)),
        nn.Sequential(BasicBlock(128, 128)),
        nn.Sequential(BasicBlock(128, 128)),
        nn.Sequential(BasicBlock(128, 128)),
        
        # Layer 3: 6 blocks with stride 2
        nn.Sequential(BasicBlock(128, 256, stride=2)),
        nn.Sequential(BasicBlock(256, 256)),
        nn.Sequential(BasicBlock(256, 256)),
        nn.Sequential(BasicBlock(256, 256)),
        nn.Sequential(BasicBlock(256, 256)),
        nn.Sequential(BasicBlock(256, 256)),
        
        # Layer 4: 3 blocks with stride 2
        nn.Sequential(BasicBlock(256, 512, stride=2)),
        nn.Sequential(BasicBlock(512, 512)),
        nn.Sequential(BasicBlock(512, 512)),
        
        # Final classifier
        nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(512, num_classes)
        )
    ])
    
    return modules


def get_resnet50_modules(input_channels: int = 3, num_classes: int = 10, input_size: Tuple[int, int] = (32, 32)) -> nn.ModuleList:
    """
    Create ResNet50 module list for PCN
    
    Args:
        input_channels: Number of input channels (1 for MNIST, 3 for CIFAR/ImageNet)
        num_classes: Number of output classes
        input_size: Input image size (height, width)
    
    Returns:
        nn.ModuleList: List of sequential modules for PCN
    """
    modules = nn.ModuleList([
        # Initial conv layer
        nn.Sequential(
            nn.Conv2d(input_channels, 64, kernel_size=7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        ),
        
        # Layer 1: 3 blocks
        nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
            BottleneckBlock(64, 64)
        ),
        nn.Sequential(BottleneckBlock(256, 64)),
        nn.Sequential(BottleneckBlock(256, 64)),
        
        # Layer 2: 4 blocks with stride 2
        nn.Sequential(BottleneckBlock(256, 128, stride=2)),
        nn.Sequential(BottleneckBlock(512, 128)),
        nn.Sequential(BottleneckBlock(512, 128)),
        nn.Sequential(BottleneckBlock(512, 128)),
        
        # Layer 3: 6 blocks with stride 2
        nn.Sequential(BottleneckBlock(512, 256, stride=2)),
        nn.Sequential(BottleneckBlock(1024, 256)),
        nn.Sequential(BottleneckBlock(1024, 256)),
        nn.Sequential(BottleneckBlock(1024, 256)),
        nn.Sequential(BottleneckBlock(1024, 256)),
        nn.Sequential(BottleneckBlock(1024, 256)),
        
        # Layer 4: 3 blocks with stride 2
        nn.Sequential(BottleneckBlock(1024, 512, stride=2)),
        nn.Sequential(BottleneckBlock(2048, 512)),
        nn.Sequential(BottleneckBlock(2048, 512)),
        
        # Final classifier
        nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(2048, num_classes)
        )
    ])
    
    return modules
