import torch
import torch.nn as nn

from pado.core import PadoModuleMixin, PadoModule

__all__ = ["MaxPool2d", "AvgPool2d", "GlobalAvgPool"]


class MaxPool2d(nn.MaxPool2d, PadoModuleMixin):

    def __init__(self,
                 kernel_size,
                 stride=None,
                 padding=0,
                 dilation=1,
                 ceil_mode: bool = False) -> None:
        nn.MaxPool2d.__init__(self, kernel_size, stride, padding, dilation,
                              ceil_mode=ceil_mode)
        PadoModuleMixin.__init__(self)


class AvgPool2d(nn.AvgPool2d, PadoModuleMixin):

    def __init__(self,
                 kernel_size,
                 stride=None,
                 padding=0,
                 ceil_mode: bool = False,
                 count_include_pad: bool = True) -> None:
        nn.AvgPool2d.__init__(self, kernel_size, stride, padding,
                              ceil_mode=ceil_mode, count_include_pad=count_include_pad)
        PadoModuleMixin.__init__(self)


class GlobalAvgPool(PadoModule):

    def __init__(self, keepdim: bool = False) -> None:
        super().__init__()
        self.keepdim = keepdim

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Global Avg Pooling forward
        :param x:       (batch_size, num_channels, ...)
        :return:
                y:      (batch_size, num_channels)          if keepdim = False
                        (batch_size, num_channels, 1, ...)  if keepdim = True
        """
        if x.ndim <= 2:
            return x
        pool_dims = list(range(2, x.ndim))
        y = torch.mean(x, dim=pool_dims, keepdim=self.keepdim)
        return y
