#!-*- coding: utf-8
import typing

import torch

__all__ = ["CNN"]

# ONLY EMNIST!


class CNN(torch.nn.Module):
    def __init__(self, in_channels: int = 3, out_features: int = 62, ngroup: int=4):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(in_channels, 32, 3, 1)
        self.bn1 = torch.nn.GroupNorm(ngroup, 32)
        self.conv2 = torch.nn.Conv2d(32, 64, 3, 1)
        self.bn2 = torch.nn.GroupNorm(ngroup, 64)
        # self.dropout1 = torch.nn.Dropout2d(0.25)
        # self.dropout2 = torch.nn.Dropout2d(0.5)
        self.dropout1 = torch.nn.Dropout(0.25)
        self.dropout2 = torch.nn.Dropout(0.5)
        self.fc1 = torch.nn.Linear(9216, 128)
        self.fc2 = torch.nn.Linear(128, out_features)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if x.shape[1] == 1:
            x = x.repeat(1, 3, 1, 1)
        x = self.conv1(x)
        x = self.bn1(x)
        x = torch.nn.functional.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = torch.nn.functional.relu(x)
        x = torch.nn.functional.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = torch.nn.functional.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        return x
