import torch.nn as nn


class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels,
                               3, stride, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 1, stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = self.relu(out)
        return out


class ResNet9(nn.Module):
    def __init__(self, num_classes=10, in_channels=1):
        super().__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, 64, 3, 1, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(64, 128, 3, 1, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2)
        )
        self.res1 = BasicBlock(128, 128)
        self.conv3 = nn.Sequential(
            nn.Conv2d(128, 256, 3, 1, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2)
        )
        self.conv4 = nn.Sequential(
            nn.Conv2d(256, 512, 3, 1, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2)
        )
        self.res2 = BasicBlock(512, 512)
        
        self.in_channels = in_channels
        if in_channels == 1:
            self.classifier = nn.Sequential(
                nn.AdaptiveAvgPool2d(1),
                nn.Flatten(),
                nn.Linear(512, num_classes)
            )
        else:
            self.classifier = nn.Sequential(
                nn.MaxPool2d(4),
                nn.Flatten(),
                nn.Linear(512, num_classes)
            )

    def forward(self, x, return_activations=False):
        activations = []
        
        x = self.conv1[0](x)
        x = self.conv1[1](x)
        x = self.conv1[2](x)
        actv1 = x
        activations.append(actv1)
        
        x = self.conv2[0](x)
        x = self.conv2[1](x)
        x = self.conv2[2](x)
        x = self.conv2[3](x)
        actv2 = x
        activations.append(actv2)
        
        x = self.res1(x)
        
        x = self.conv3[0](x)
        x = self.conv3[1](x)
        x = self.conv3[2](x)
        x = self.conv3[3](x)
        actv3 = x
        activations.append(actv3)
        
        x = self.conv4[0](x)
        x = self.conv4[1](x)
        x = self.conv4[2](x)
        x = self.conv4[3](x)
        actv4 = x
        activations.append(actv4)
        
        x = self.res2(x)
        
        x = self.classifier(x)
        
        if return_activations:
            return (x, *activations)
        return x


class ResNet18(nn.Module):
    def __init__(self, num_classes=10, in_channels=3):
        super().__init__()
        self.in_channels = 64

        self.conv1 = nn.Conv2d(
            in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = self._make_layer(64, 2)
        self.layer2 = self._make_layer(128, 2, stride=2)
        self.layer3 = self._make_layer(256, 2, stride=2)
        self.layer4 = self._make_layer(512, 2, stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, num_classes)

    def _make_layer(self, out_channels, blocks, stride=1):
        layers = []

        layers.append(BasicBlock(self.in_channels, out_channels, stride))
        self.in_channels = out_channels

        for _ in range(1, blocks):
            layers.append(BasicBlock(out_channels, out_channels))

        return nn.Sequential(*layers)

    def forward(self, x, return_activations=False):
        activations = []
        
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        actv1 = x
        activations.append(actv1)
        
        x = self.maxpool(x)
        
        x = self.layer1(x)
        actv2 = x
        activations.append(actv2)
        
        x = self.layer2(x)
        actv3 = x
        activations.append(actv3)
        
        x = self.layer3(x)
        actv4 = x
        activations.append(actv4)
        
        x = self.layer4(x)
        actv5 = x
        activations.append(actv5)
        
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        
        if return_activations:
            return (x, *activations)
        return x


class CifarResNet18(nn.Module):
    """ResNet18 optimized for CIFAR datasets"""
    def __init__(self, num_classes=100, in_channels=3):
        super().__init__()
        self.in_channels = 64

        self.conv1 = nn.Conv2d(
            in_channels, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)

        self.layer1 = self._make_layer(64, 2)
        self.layer2 = self._make_layer(128, 2, stride=2)
        self.layer3 = self._make_layer(256, 2, stride=2)
        self.layer4 = self._make_layer(512, 2, stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.dropout = nn.Dropout(0.2)
        self.fc = nn.Linear(512, num_classes)

    def _make_layer(self, out_channels, blocks, stride=1):
        layers = []

        layers.append(BasicBlock(self.in_channels, out_channels, stride))
        self.in_channels = out_channels

        for _ in range(1, blocks):
            layers.append(BasicBlock(out_channels, out_channels))

        return nn.Sequential(*layers)

    def forward(self, x, return_activations=False):
        activations = []
        
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        actv1 = x
        activations.append(actv1)
        
        x = self.layer1(x)
        actv2 = x
        activations.append(actv2)
        
        x = self.layer2(x)
        actv3 = x
        activations.append(actv3)
        
        x = self.layer3(x)
        actv4 = x
        activations.append(actv4)
        
        x = self.layer4(x)
        actv5 = x
        activations.append(actv5)
        
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.dropout(x)
        x = self.fc(x)
        
        if return_activations:
            return (x, *activations)
        return x