import torch.nn as nn
import torch.nn.functional as F


class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, do_maxpool=False):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels,
                              kernel_size=kernel_size, stride=stride, padding=padding)
        self.bn = nn.BatchNorm2d(out_channels)
        self.do_maxpool = do_maxpool

    def forward(self, x):
        out = self.bn(self.conv(x))
        out = F.relu(out, inplace=True)
        if self.do_maxpool:
            out = F.max_pool2d(out, 2)
        return out


class ConvNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_layers = nn.ModuleList()

    def forward(self, x):
        out = x
        for layer in self.conv_layers:
            out = layer(out)

        out = out.view(out.size(0), -1)
        #out = self.fc(out)
        return out


class ConvNet4(nn.Module):
    def __init__(self, conv_channels=32, img_size=32):
        super().__init__()
        sz = (((img_size // 2) // 2) // 2) // 2
        self.conv_layers = nn.ModuleList(
            [
                ConvBlock(3, conv_channels, do_maxpool=True),
                ConvBlock(conv_channels, conv_channels, do_maxpool=True),
                ConvBlock(conv_channels, conv_channels, do_maxpool=True),
                ConvBlock(conv_channels, conv_channels, do_maxpool=True),
            ]
        )
        #self.fc = nn.Linear(sz * sz * conv_channels, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(
                    m.weight, mode="fan_out", nonlinearity="relu")
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        out = x
        for i, layer in enumerate(self.conv_layers):
            out = layer(out)

        out = out.view(out.size(0), -1)
        #out = self.fc(out)
        return out

class ConvNet5(nn.Module):
    def __init__(self, conv_channels=32, img_size=32):
        super().__init__()
        self.conv_layers = nn.ModuleList(
            [
                ConvBlock(3, conv_channels, do_maxpool=True),
                ConvBlock(conv_channels, conv_channels, do_maxpool=True),
                ConvBlock(conv_channels, conv_channels, do_maxpool=True),
                ConvBlock(conv_channels, conv_channels, do_maxpool=True),
                ConvBlock(conv_channels, conv_channels, do_maxpool=True)
            ]
        )
        #self.fc = nn.Linear(sz * sz * conv_channels, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(
                    m.weight, mode="fan_out", nonlinearity="relu")
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        out = x
        for i, layer in enumerate(self.conv_layers):
            out = layer(out)

        out = out.view(out.size(0), -1)
        #out = self.fc(out)
        return out

class ConvNet6(nn.Module):
    def __init__(self, conv_channels=32, img_size=32):
        super().__init__()
        sz = (((img_size // 2) // 2) // 2) // 2
        self.conv_layers = nn.ModuleList(
            [
                ConvBlock(3, conv_channels, self.perturb,
                          do_maxpool=True, perturb_idx=0),
                ConvBlock(conv_channels, conv_channels,
                          self.perturb, do_maxpool=True),
                ConvBlock(conv_channels, conv_channels,
                          self.perturb, do_maxpool=False),
                ConvBlock(conv_channels, conv_channels,
                          self.perturb, do_maxpool=True),
                ConvBlock(conv_channels, conv_channels,
                          self.perturb, do_maxpool=False),
                ConvBlock(conv_channels, conv_channels,
                          self.perturb, do_maxpool=True),
            ]
        )
        #self.fc = nn.Linear(sz * sz * conv_channels, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(
                    m.weight, mode="fan_out", nonlinearity="relu")
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)


class VGG9(nn.Module):
    def __init__(self, img_size=32):
        super().__init__()
        sz = ((((img_size // 2) // 2) // 2) // 2) // 2
        self.conv_layers = nn.ModuleList(
            [
                ConvBlock(3, 64, self.perturb, do_maxpool=True),
                ConvBlock(64, 128, self.perturb, do_maxpool=True),
                ConvBlock(128, 256, self.perturb, do_maxpool=False),
                ConvBlock(256, 256, self.perturb, do_maxpool=True),
                ConvBlock(256, 512, self.perturb, do_maxpool=False),
                ConvBlock(512, 512, self.perturb, do_maxpool=True),
                ConvBlock(512, 512, self.perturb, do_maxpool=False),
                ConvBlock(512, 512, self.perturb, do_maxpool=True),
            ]
        )
        #self.fc = nn.Linear(sz * sz * 512, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(
                    m.weight, mode="fan_out", nonlinearity="relu")
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
