import torch
import torch.nn as nn
import torch.nn.functional as F
import sys


class FemnistCNN_bis(torch.nn.Module):

    def __init__(self):
        super().__init__()
        # Build parameters
        self._c1 = torch.nn.Conv2d(1, 64, 5, 1)
        self._c2 = torch.nn.Conv2d(64, 128, 5, 1)
        self._f1 = torch.nn.Linear(128 * 4 * 4, 1024)
        self._f2 = torch.nn.Linear(1024, 62)

    def forward(self, x):
        x = torch.nn.functional.relu(self._c1(x))
        x = torch.nn.functional.max_pool2d(x, 2, 2)
        x = torch.nn.functional.relu(self._c2(x))
        x = torch.nn.functional.max_pool2d(x, 2, 2)
        x = self._f1(x.view(-1, 128 * 4 * 4))
        x = torch.nn.functional.relu(x)
        x = torch.nn.functional.log_softmax(self._f2(x), dim=1)
        return x