import torch
import torch.nn as nn
import torch.nn.functional as F

class ConvBlock(nn.Module):
    def __init__(self, in_planes, planes, kernel_size=3, stride=1, padding=1):
        super(ConvBlock, self).__init__()
        self.conv = nn.Conv2d(in_planes, planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)

    def forward(self, x):
        out = self.conv(x)
        out = F.relu(out)
        return out

class ConvOnlyNetLayersBase(nn.Module):
    def __init__(self, num_classes=10):
        super(ConvOnlyNetLayersBase, self).__init__()
        self.conv1 = ConvBlock(3, 64, kernel_size=3, stride=1, padding=1)   # Layer 1
        self.conv2 = ConvBlock(64, 64, kernel_size=3, stride=1, padding=1)  # Layer 2
        self.conv3 = ConvBlock(64, 128, kernel_size=3, stride=2, padding=1) # Layer 3 (stride 2 for downsampling)
        self.conv4 = ConvBlock(128, 128, kernel_size=3, stride=1, padding=1) # Layer 4
        self.conv5 = ConvBlock(128, 256, kernel_size=3, stride=2, padding=1) # Layer 5 (stride 2 for downsampling)
        self.conv6 = ConvBlock(256, 512, kernel_size=3, stride=1, padding=1) # Layer 6
        self.conv7 = ConvBlock(512, 1024, kernel_size=3, stride=2, padding=1) # Layer 7 (stride 2 for downsampling)
        self.conv8 = ConvBlock(1024, 2048, kernel_size=3, stride=1, padding=1) # Layer 8
        self.conv9 = ConvBlock(2048, 2048, kernel_size=3, stride=1, padding=1) # Layer 9
        self.conv10 = ConvBlock(2048, 2048, kernel_size=3, stride=1, padding=1) # Layer 10

        # Global Average Pooling (replaces fully connected layers)
        self.global_avg_pool = nn.AdaptiveAvgPool2d(1)

    def forward(self, x, return_intermediates=False):
        intermediates = []
        
        out = self.conv1(x)
        if return_intermediates:
            intermediates.append(out.clone())
        
        out = self.conv2(out)
        if return_intermediates:
            intermediates.append(out.clone())
        
        out = self.conv3(out)
        if return_intermediates:
            intermediates.append(out.clone())
        
        out = self.conv4(out)
        if return_intermediates:
            intermediates.append(out.clone())
        
        out = self.conv5(out)
        if return_intermediates:
            intermediates.append(out.clone())
        
        out = self.conv6(out)
        if return_intermediates:
            intermediates.append(out.clone())
        
        out = self.conv7(out)
        if return_intermediates:
            intermediates.append(out.clone())
        
        out = self.conv8(out)
        if return_intermediates:
            intermediates.append(out.clone())
        
        out = self.conv9(out)
        if return_intermediates:
            intermediates.append(out.clone())
        
        out = self.conv10(out)
        if return_intermediates:
            intermediates.append(out.clone())

        # Global Average Pooling to produce a 1x1 output for each channel
        out = self.global_avg_pool(out)
        out = out.view(out.size(0), -1)

        if return_intermediates:
            return out, intermediates
        return out
    
class ConvOnlyNetLayersSmall(nn.Module):
    def __init__(self, num_classes=10):
        super(ConvOnlyNetLayersSmall, self).__init__()
        self.conv1 = ConvBlock(3, 64, kernel_size=3, stride=1, padding=1)   # Layer 1
        self.conv2 = ConvBlock(64, 64, kernel_size=3, stride=1, padding=1)  # Layer 2
        self.conv3 = ConvBlock(64, 64, kernel_size=3, stride=2, padding=1)  # Layer 3
        self.conv4 = ConvBlock(64, 64, kernel_size=3, stride=1, padding=1)  # Layer 4
        self.conv5 = ConvBlock(64, 64, kernel_size=3, stride=2, padding=1)  # Layer 5
        self.conv6 = ConvBlock(64, 64, kernel_size=3, stride=1, padding=1)  # Layer 6
        self.conv7 = ConvBlock(64, 64, kernel_size=3, stride=2, padding=1)  # Layer 7
        self.conv8 = ConvBlock(64, 64, kernel_size=3, stride=1, padding=1)  # Layer 8
        self.conv9 = ConvBlock(64, 64, kernel_size=3, stride=1, padding=1)  # Layer 9
        self.conv10 = ConvBlock(64, 64, kernel_size=3, stride=1, padding=1) # Layer 10

        # Global Average Pooling (replaces fully connected layers)
        self.global_avg_pool = nn.AdaptiveAvgPool2d(1)

    def forward(self, x, return_intermediates=False):
        intermediates = []
        
        out = self.conv1(x)
        if return_intermediates:
            intermediates.append(out.clone())
        
        out = self.conv2(out)
        if return_intermediates:
            intermediates.append(out.clone())
        
        out = self.conv3(out)
        if return_intermediates:
            intermediates.append(out.clone())
        
        out = self.conv4(out)
        if return_intermediates:
            intermediates.append(out.clone())
        
        out = self.conv5(out)
        if return_intermediates:
            intermediates.append(out.clone())
        
        out = self.conv6(out)
        if return_intermediates:
            intermediates.append(out.clone())
        
        out = self.conv7(out)
        if return_intermediates:
            intermediates.append(out.clone())
        
        out = self.conv8(out)
        if return_intermediates:
            intermediates.append(out.clone())
        
        out = self.conv9(out)
        if return_intermediates:
            intermediates.append(out.clone())
        
        out = self.conv10(out)
        if return_intermediates:
            intermediates.append(out.clone())

        # Global Average Pooling to produce a 1x1 output for each channel
        out = self.global_avg_pool(out)
        out = out.view(out.size(0), -1)

        if return_intermediates:
            return out, intermediates
        return out