"""
Modified from: https://github.com/pskugit/custom-conv2d/blob/master/models/customconv.py
"""
import torch.nn as nn
import torch

from ...layers import ReConvLayer

class ReLeNetFullConv(nn.Module):
    def __init__(
        self,
        in_dim: int = 1,
        in_size: int = 28,
        in_norm: bool = False,
        we_norm: bool = False,
        in_reg: bool = False,
        we_reg: bool = False
    ):
        raise NotImplementedError()
        super(ReLeNetFullConv, self).__init__()
        self.in_dim = in_dim
        self.in_size = in_size
        self.in_norm = in_norm
        self.we_norm = we_norm
        self.in_reg = in_reg
        self.we_reg = we_reg

        self.conv1 = ReConvLayer(in_dim, 8, 5, in_norm=in_norm, we_norm=we_norm)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = ReConvLayer(8, 32, 5, in_norm=in_norm, we_norm=we_norm)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(2)
        self.conv3 = ReConvLayer(32, 128, 3, padding=1, in_norm=in_norm, we_norm=we_norm)
        self.relu3 = nn.ReLU()
        self.conv4 = ReConvLayer(128, 64, 3, padding=1, in_norm=in_norm, we_norm=we_norm)
        self.relu4 = nn.ReLU()
        self.conv5 = ReConvLayer(64, 10, 3, padding=1, in_norm=in_norm, we_norm=we_norm)
        self.gpool = nn.AdaptiveAvgPool2d(output_size=1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.pool2(x)
        x = self.conv3(x)
        x = self.relu3(x)
        x = self.conv4(x)
        x = self.relu4(x)
        x = self.conv5(x)
        x = self.gpool(x)
        x = x.reshape(x.size(0), -1)
        return x



