import torch.nn as nn


class CNN(nn.Module):

    def __init__(self):
        super(CNN, self).__init__()

        self.layers = nn.Sequential()

        self.layers.append(nn.Conv2d(1, 32, 3, 1, padding=1, bias=False))
        self.layers.append(nn.BatchNorm2d(32))
        self.layers.append(nn.ReLU(inplace=True))
        self.layers.append(nn.Conv2d(32, 32, 3, 1, padding=1, bias=False))
        self.layers.append(nn.BatchNorm2d(32))
        self.layers.append(nn.ReLU(inplace=True))
        self.layers.append(nn.MaxPool2d(2))

        self.layers.append(nn.Conv2d(32, 64, 3, 1, padding=1, bias=False))
        self.layers.append(nn.BatchNorm2d(64))
        self.layers.append(nn.ReLU(inplace=True))
        self.layers.append(nn.Conv2d(64, 64, 3, 1, padding=1, bias=False))
        self.layers.append(nn.BatchNorm2d(64))
        self.layers.append(nn.ReLU(inplace=True))
        self.layers.append(nn.MaxPool2d(2))

        self.layers.append(nn.Conv2d(64, 128, 3, 1, padding=1, bias=False))
        self.layers.append(nn.BatchNorm2d(128))
        self.layers.append(nn.ReLU(inplace=True))
        self.layers.append(nn.MaxPool2d(2))

        self.layers.append(nn.Flatten())

    def forward(self, input):
        return self.layers(input)