#! -*- coiding: utf-8
import typing
from collections import OrderedDict

import torch


def make_conv_block(in_channels: int = 3, dropout: float = 0.0, num_groups: int = 4,
                    activation=torch.nn.Tanh, pool_kernel_size: typing.Tuple = (2, 2)) -> torch.nn.Module:
    block = OrderedDict()

    channels = [in_channels, 6, 16]
    for i, (in_channels, out_channels) in enumerate(zip(channels, channels[1:])):
        block[f"conv-{i}"] = torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                                             kernel_size=5, stride=1)
        if num_groups > 0:
            block[f"norm-{i}"] = torch.nn.GroupNorm(num_groups, out_channels)
        if dropout > 0.0:
            block[f"dropout-{i}"] = torch.nn.Dropout2d(p=dropout)
        block[f"activate-{i}"] = activation()
        block[f"pool-{i}"] = torch.nn.MaxPool2d(pool_kernel_size)

    return torch.nn.Sequential(block)


def make_linear_block(in_features: int = 16*5*5, out_features: int = 10, features=[120, 84],
                      dropout: float = 0.0, num_groups: int = 4,
                      activation=torch.nn.Tanh) -> torch.nn.Module:
    block = OrderedDict()

    features = [in_features, *features]
    for i, (in_feature, out_feature) in enumerate(zip(features, features[1:])):
        block[f"linear-{i}"] = torch.nn.Linear(in_feature, out_feature)
        if num_groups > 0:
            block[f"norm-{i}"] = torch.nn.GroupNorm(num_groups, out_feature)
        if dropout > 0.0:
            block[f"dropout-{i}"] = torch.nn.Dropout2d(p=dropout)
        block[f"activate-{i}"] = activation()

    block["output"] = torch.nn.Linear(features[-1], out_features)

    return torch.nn.Sequential(block)


class LeNet5(torch.nn.Module):

    def __init__(self, in_features: int = 3, num_classes: int = 10, activation: typing.Callable = torch.nn.Tanh,
                 conv_dropout: float = 0.0, linear_dropout: float = 0.0,
                 conv_num_groups: int = 4, linear_num_groups: int = 0,):
        super().__init__()
        if isinstance(activation, str):
            activation = getattr(torch.nn, activation)

        self.conv_block = make_conv_block(in_channels=in_features, dropout=conv_dropout, num_groups=conv_num_groups,
                                          activation=activation, pool_kernel_size=(2, 2))

        self.linear_block = make_linear_block(in_features=16*5*5, out_features=num_classes, features=[120, 84],
                                              dropout=linear_dropout, num_groups=linear_num_groups, activation=activation)

    def forward(self, x):
        x = self.conv_block(x)
        x = torch.flatten(x, start_dim=1)
        x = self.linear_block(x)
        return x
