from typing import Tuple, Union
import torch

from pado.core import PadoModule
from pado.nn.modules.activation import get_activation_cls
from pado.nn.modules.conv import Conv2d
from pado.nn.modules.batchnorm import BatchNorm

__all__ = ["ConvBatchNorm2d", "ConvBatchNormAct2d"]


class ConvBatchNorm2d(PadoModule):

    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 kernel_size: Union[int, Tuple[int, ...]],
                 stride: Union[int, Tuple[int, ...]] = 1,
                 padding: Union[int, Tuple[int, ...]] = 0,
                 dilation: Union[int, Tuple[int, ...]] = 1,
                 groups: int = 1,
                 eps: float = 1e-5,
                 momentum: float = 0.1,
                 use_scale: bool = True,
                 use_offset: bool = True, *,
                 padding_mode: str = "zeros",
                 sync_bn: bool = False,
                 partial: bool = False) -> None:
        super().__init__()

        self.conv = Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation,
                           groups=groups, bias=False, padding_mode=padding_mode, partial=partial)
        self.bn = BatchNorm(out_channels, eps=eps, momentum=momentum,
                            use_scale=use_scale, use_offset=use_offset, sync_bn=sync_bn)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        :param x:           (batch_size, in_channels, h, w)
        :return:
                result:     (batch_size, out_channels, out_h, out_w)
        """
        # we do not add mask in arguments
        x = self.conv(x)
        x = self.bn(x)
        return x


class ConvBatchNormAct2d(ConvBatchNorm2d):

    def __init__(self, *args, act_type: str = "relu", **kwargs):
        super().__init__(*args, **kwargs)
        self.act = get_activation_cls(act_type, inplace=True)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        :param x:           (batch_size, in_channels, h, w)
        :return:
                result:     (batch_size, out_channels, out_h, out_w)
        """
        # we do not add mask in arguments
        x = self.conv(x)
        x = self.bn(x)
        x = self.act(x)
        return x
