from typing import Literal
import torch.nn as nn
from torchinfo import summary


def build_model(padding_mode: Literal['zeros', 'reflect', 'replicate', 'circular'],
                output_classes: int = 26):
    assert padding_mode in ['zeros', 'reflect', 'replicate', 'circular']
    model = nn.Sequential(
        nn.Conv2d(1, 32, kernel_size=3, padding=1, padding_mode=padding_mode),
        nn.ReLU(),
        nn.Conv2d(32, 32, kernel_size=3, padding=1, padding_mode=padding_mode, stride=2),
        nn.ReLU(),
        # nn.MaxPool2d(2),  # 28 / 2 = 14
        nn.Conv2d(32, 32, kernel_size=5, padding=5, padding_mode=padding_mode),
        nn.ReLU(),
        nn.Conv2d(32, 32, kernel_size=5, padding=2, padding_mode=padding_mode, stride=2),
        nn.ReLU(),
        # nn.MaxPool2d(2),  # 14 / 2 = 7
        nn.Conv2d(32, 32, kernel_size=5, padding=2, padding_mode=padding_mode),
        nn.ReLU(),
        nn.Conv2d(32, 64, kernel_size=5, padding=2, padding_mode=padding_mode, stride=2),
        nn.ReLU(),

        nn.AdaptiveAvgPool2d((1,1)),
        # nn.MaxPool2d(2),  # (7 + 1) / 2 = 4
        nn.Flatten(),
        nn.Linear(64, 50),
        nn.ReLU(),
        nn.Linear(50, output_classes),
    )

    # print(summary(model, input_size=(16, 1, 28, 28), verbose=1))

    return model

if __name__ == '__main__':
    build_model('zeros')