from typing import Any

import torch
from torchinfo import summary
import torch.nn as nn

import pytorch_lightning as pl


def build_model(padding_mode: str, output_classes=10):
    # padding mode just for compatibility reasons
    model = nn.Sequential(
        nn.Conv2d(1, 32, kernel_size=3),  # 28 - 3 + 1 = 26
        nn.ReLU(),
        nn.Conv2d(32, 32, kernel_size=3),  # 26 - 3 + 1 = 24
        nn.ReLU(),
        nn.Conv2d(32, 32, kernel_size=3),  # 24 - 3 + 1 = 22
        nn.ReLU(),
        nn.Conv2d(32, 32, kernel_size=3),  # 22 - 3 + 1 = 20
        nn.ReLU(),
        nn.Conv2d(32, 32, kernel_size=3),  # 20 - 3 + 1 = 18
        nn.ReLU(),
        nn.Conv2d(32, 32, kernel_size=5),  # 18 - 5 + 1 = 14
        nn.ReLU(),
        nn.Conv2d(32, 32, kernel_size=5),  # 14 - 5 + 1 = 10
        nn.ReLU(),
        nn.Conv2d(32, 32, kernel_size=5),  # 10 - 5 + 1 = 6
        nn.ReLU(),
        nn.Conv2d(32, 32, kernel_size=5),  # 6 - 5 + 1 = 2
        nn.ReLU(),
        nn.Conv2d(32, 64, kernel_size=2),
        nn.Flatten(),
        nn.Linear(64, 50),
        nn.ReLU(),
        nn.Linear(50, output_classes),
    )

    summary(model, (1, 28, 28), device='cpu')

    return model

