from typing import Literal, Optional

import torch
import torch.nn as nn


class Block(nn.Module):

    def __init__(self, in_channels: int, out_channels: int, stride: int, momentum: float, dropout: float):
        super(Block, self).__init__()
        self.layers_in, self.layers_out = nn.Sequential(), nn.Sequential()
        self.layers_in.append(nn.BatchNorm2d(in_channels, momentum=momentum))
        self.layers_in.append(nn.Mish())
        self.layers_out.append(nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False))
        self.layers_out.append(nn.BatchNorm2d(out_channels, momentum=momentum))
        self.layers_out.append(nn.Mish())
        if dropout > 0.0:
            self.layers_out.append(torch.nn.Dropout(p=dropout))
        self.layers_out.append(nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False))
        self.shortcut = (in_channels != out_channels) and nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, padding=0, bias=False) or None

    def forward(self, input):
        residual = input
        output = self.layers_in(input)
        if (self.shortcut):
            residual = self.shortcut(output)
        output = self.layers_out(output)
        return torch.add(residual, output)


class Group(nn.Module):

    def __init__(self, num_blocks: int, in_channels: int, out_channels: int, stride: int, momentum: float, dropout: float):
        super(Group, self).__init__()
        self.blocks = nn.Sequential(*[Block(i == 0 and in_channels or out_channels, out_channels, i == 0 and stride or 1, momentum, dropout) for i in range(num_blocks)])

    def forward(self, input):
        return self.blocks(input)


class WideResNet(nn.Module):

    def __init__(self, num_groups: int, num_blocks: int, factor_base: int, factor_widen: int, image_channels: int, momentum: float = 0.1, dropout: float = 0.0, global_pool: Optional[Literal['avg', 'max']] = None):
        super(WideResNet, self).__init__()

        self.channels = [factor_base if i == 0 else (factor_base * pow(2, i - 1) * factor_widen) for i in range(num_groups + 1)]
        self.networks = nn.Sequential(nn.Conv2d(image_channels, self.channels[0], kernel_size=3, stride=1, padding=1, bias=False))

        for i in range(num_groups):
            self.networks.append(Group(num_blocks, self.channels[i], self.channels[i + 1], 1 if i == 0 else 2, momentum, dropout))

        self.networks.append(nn.BatchNorm2d(self.channels[-1], momentum=momentum))
        self.networks.append(nn.Mish())
        if (global_pool == 'avg'):
            self.networks.append(nn.AdaptiveAvgPool2d(1))
        elif (global_pool == 'max'):
            self.networks.append(nn.AdaptiveMaxPool2d(1))
        self.networks.append(nn.Flatten())

    def forward(self, input):
        return self.networks(input)
