from torch import nn
import torch.nn.functional as F
import torch

# EMNIST CNN (2 layers + 1 dense)
class CNN21(nn.Module):
    def __init__(self, num_classes=10):
        super(CNN21, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, 3)
        self.conv2 = nn.Conv2d(10, 20, 3)
        self.fc1 = nn.Linear(20*12*12, num_classes)
    
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        return x

# From original FedAvg paper
class CNN22(nn.Module):
    def __init__(self, num_classes=10, n=28, hidden_size=128):
        super(CNN22, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 5)
        self.conv2 = nn.Conv2d(32, 64, 5)

        def conv_pool_output_size(size, kernel_size=5, pool_size=2):
            size = (size - (kernel_size - 1) - 1) + 1
            size = (size - (kernel_size - 1) - 1) // pool_size + 1
            return size
        
        conv_pool_n = conv_pool_output_size(n, 5)

        self.fc1 = nn.Linear(64 * conv_pool_n * conv_pool_n, hidden_size)
        self.fc2 = nn.Linear(hidden_size, num_classes)

        num_params = sum(p.numel() for p in self.parameters())
        print(f"Number of parameters in the model: {num_params}")

    
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(F.relu(self.conv2(x)), (2, 2))
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x