import torch.nn as nn

from models.base_model import SequentialModel


class Conv_Net(SequentialModel):
    def __init__(
        self,
        width=64,
        dims=[(3, 32, 32), 10],
        bias=False,
    ):
        super(Conv_Net, self).__init__()

        self.dims = dims

        kernel_size = 3
        stride = 1
        padding = 1

        # determine the kernel size for the pooling layers for different datasets
        # to achieve a flat vector in the end
        if dims[0][-1] == 32:  # CIFAR10 case
            pooling_kernels = [1, 2, 2, 8]
        if dims[0][-1] == 28:  # MNIST case
            pooling_kernels = [1, 2, 2, 7]

        self.layers = nn.Sequential(
            nn.Conv2d(dims[0][0], width, kernel_size, stride=stride, padding=padding, bias=bias),
            nn.ReLU(),
            nn.MaxPool2d(pooling_kernels[0]),
            ###
            nn.Conv2d(width, width * 2, kernel_size, stride=stride, padding=padding, bias=bias),
            nn.ReLU(),
            nn.MaxPool2d(pooling_kernels[1]),
            ###
            nn.Conv2d(width * 2, width * 4, kernel_size, stride=stride, padding=padding, bias=bias),
            nn.ReLU(),
            nn.MaxPool2d(pooling_kernels[2]),
            ###
            nn.Conv2d(width * 4, width * 8, kernel_size, stride=stride, padding=padding, bias=bias),
            nn.ReLU(),
            nn.MaxPool2d(pooling_kernels[3]),
            ###
            nn.Flatten(),
            nn.Linear(width * 8, dims[-1], bias=bias),
        )

        self.layer_input_shapes = self.get_layer_input_shapes()
