import torch
import torch.nn as nn


class CNN(nn.Module):
     def __init__(self, input_dim=1, output_dim=10):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(input_dim, 16, 3)
        self.pool = nn.MaxPool2d(2, 2)
        self.activation1 = nn.ReLU()
        # hack: hardcode for the two datasets
        if input_dim == 1:
            self.linear2 = nn.Linear(2704, output_dim)
        else:
            self.linear2 = nn.Linear(3600, output_dim)
     def forward(self, x):
        x = self.pool(self.activation1(self.conv1(x)))
        x = torch.flatten(x, 1)
        outputs = self.linear2(x)
        return outputs
