from typing import Literal
import torch.nn as nn
from torchinfo import summary


def build_model(padding_mode: Literal['zeros', 'reflect', 'replicate', 'circular'], output_classes=10):
    assert padding_mode in ['zeros', 'reflect', 'replicate', 'circular']

    model = nn.Sequential(
        nn.Conv2d(1, 32, kernel_size=3, padding=3, padding_mode=padding_mode),  #28 + 6 - 2 = 32
        nn.ReLU(),
        nn.Conv2d(32, 32, kernel_size=3, padding=3, padding_mode=padding_mode, stride=2),  #(32 + 6 - 2 - 1)/2 + 1 = 18
        nn.ReLU(),
        # nn.MaxPool2d(2),  # 32 / 2 = 16
        nn.Conv2d(32, 32, kernel_size=3, padding=3, padding_mode=padding_mode),  # 18 + 4 = 22
        nn.ReLU(),
        nn.Conv2d(32, 32, kernel_size=3, padding=3, padding_mode=padding_mode, stride=2),  # (22 + 4 - 1)/2 + 1 = 13
        nn.ReLU(),
        # nn.MaxPool2d(2),  # 20 / 2 = 10
        nn.Conv2d(32, 32, kernel_size=3, padding=3, padding_mode=padding_mode),  # 13 + 4 = 17
        nn.ReLU(),
        nn.Conv2d(32, 32, kernel_size=3, padding=3, padding_mode=padding_mode, stride=2),  # (17 + 4 - 1)/2 + 1 = 11
        nn.ReLU(),
        # nn.MaxPool2d(2),  # 14 / 2 = 7
        nn.Conv2d(32, 32, kernel_size=3, padding=3, padding_mode=padding_mode),  # 11  + 4 = 15
        nn.ReLU(),
        nn.Conv2d(32, 32, kernel_size=5, padding=5, padding_mode=padding_mode, stride=4),  # (15 + 10 - 4 - 1) / 4 + 1 = 6
        nn.ReLU(),
        # nn.MaxPool2d(4),  # 15 / 4 = 4
        nn.Conv2d(32, 64, kernel_size=6),
        nn.Flatten(),
        nn.Linear(64, 50),
        nn.ReLU(),
        nn.Linear(50, output_classes),
    )
    summary(model, (1, 28, 28), device='cpu')
    return model


if __name__ == '__main__':
    build_model('zeros')