import torch
import torch.nn.functional as F

from pado.core import PadoModule
from pado.nn.functional import sync_batch_norm_func, masked_batch_norm_func, masked_sync_batch_norm_func
from pado.nn.parameter import ParameterModule, ParameterModuleWithOffset, BufferModule

__all__ = ["BatchNorm", "MaskedBatchNorm"]


class _BatchNorm(PadoModule):
    """Simplified BN implementation.
    (1) no difference between 1D, 2D, and 3D.
    (2) always track stats
    (3) always use exponential moving average (EMA) with momentum.
    (4) optionally call SyncBN (only with DDP, not for DP)
    (5) Without GPU, Sync/Masked will be **DISABLED**. -> Should be fixed...

    We don't inherit from nn.BatchNormNd, so be careful if you use 3rd party BN-related ops.
    """

    def __init__(self,
                 num_features: int,
                 eps: float = 1e-5,
                 momentum: float = 0.1,
                 use_scale: bool = True,
                 use_offset: bool = True, *,
                 sync_bn: bool = False) -> None:
        super().__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        self.use_scale = use_scale
        self.use_offset = use_offset

        if self.use_scale:
            self.weight = ParameterModuleWithOffset(torch.zeros(num_features), offset=1.0)
        else:
            self.weight = None

        if self.use_offset:
            self.bias = ParameterModule(torch.zeros(num_features))
        else:
            self.bias = None

        self.running_mean = BufferModule(torch.zeros(num_features))
        self.running_var = BufferModule(torch.ones(num_features))
        self.sync_bn = sync_bn

    def forward(self, *args) -> torch.Tensor:
        raise NotImplementedError

    def extra_repr(self) -> str:
        s = f"{self.num_features}, eps={self.eps}, momentum={self.momentum}"
        if not self.use_scale:
            s += ", use_scale=False"
        if not self.use_offset:
            s += ", use_offset=False"
        if self.sync_bn:
            s += ", sync_bn=True"
        return s


class BatchNorm(_BatchNorm):

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        BatchNorm forward.
        :param x:       (batch_size, num_channels, ...)
        :return:
                y:      (batch_size, num_channels, ...)
        """
        x = x.float()  # BN should always run in FP32.

        # ------------------------------------------------------------------------------------------------ #
        # training flag
        bn_training = self.training

        # ------------------------------------------------------------------------------------------------ #
        # if not track_running_stats, use batch stat for both train/eval.
        # buffers are only updated when they are to be tracked and in training mode.
        # so the buffers should be passed only when update occur (training & tracked), or used for normalization (eval)
        running_mean = self.running_mean()
        running_var = self.running_var().clamp_min_(1e-6)

        if (not bn_training) or (not self.sync_bn) or (x.device == torch.device("cpu")):
            return F.batch_norm(
                x,
                running_mean,
                running_var,
                self.weight() if (self.weight is not None) else None,
                self.bias() if (self.bias is not None) else None,
                bn_training,
                self.momentum,
                self.eps
            )
        else:  # sync_bn and self.training and GPU
            return sync_batch_norm_func(
                x,
                self.weight() if (self.weight is not None) else None,
                self.bias() if (self.bias is not None) else None,
                running_mean,
                running_var,
                self.momentum,
                self.eps
            )


class MaskedBatchNorm(_BatchNorm):

    def forward(self,
                x: torch.Tensor,
                mask: torch.Tensor) -> torch.Tensor:
        """
        MaskedBN forward.
        Masking is only required during training (=training stat update stage).

        :param x:           (batch_size, num_features, ...)
        :param mask:        (batch_size, 1, ...)
        :return:
                y:          (batch_size, num_features, ...)
        """
        x = x.float()  # BN should always run in FP32.

        # ------------------------------------------------------------------------------------------------ #
        # training flag
        bn_training = self.training

        if x.ndim == 3:
            b, c, s = x.shape
            # fix mask
            if mask.ndim == 2:
                mask = mask.unsqueeze(1)  # (b, s) -> (b, 1, s)
            if mask.shape != (b, 1, s):
                raise ValueError(f"MaskedBN1D require same mask as input {tuple(x.shape)}, got {tuple(mask.shape)}.")
        elif x.ndim == 4:
            b, c, h, w = x.shape
            # we assume that time dimension is at "h" dimension (3rd dim.)
            # fix mask
            if mask.ndim == 2:
                mask = mask.unsqueeze(1).unsqueeze(-1).expand(b, 1, h, w)  # (b, h) -> (b, 1, h, 1) -> (b, 1, h, w)
            if mask.shape != (b, 1, h, w):
                raise ValueError(f"MaskedBN2D require same mask as input {tuple(x.shape)}, got {tuple(mask.shape)}.")
        else:
            raise NotImplementedError(f"MaskedBN only takes 3D or 4D input, but got {tuple(x.shape)}.")

        # ------------------------------------------------------------------------------------------------ #
        # if not track_running_stats, use batch stat for both train/eval.
        # buffers are only updated when they are to be tracked and in training mode.
        # so the buffers should be passed only when update occur (training & tracked), or used for normalization (eval)
        running_mean = self.running_mean()
        running_var = self.running_var().clamp_min_(1e-6)

        if (not bn_training) or (x.device == torch.device("cpu")):
            return F.batch_norm(
                x,
                running_mean,
                running_var,
                self.weight() if (self.weight is not None) else None,
                self.bias() if (self.bias is not None) else None,
                bn_training,
                self.momentum,
                self.eps
            )
        elif not self.sync_bn:  # self.training and GPU
            return masked_batch_norm_func(
                x,
                self.weight() if (self.weight is not None) else None,
                self.bias() if (self.bias is not None) else None,
                mask,
                running_mean,
                running_var,
                self.momentum,
                self.eps
            )
        else:  # sync_bn and self.training and GPU
            return masked_sync_batch_norm_func(
                x,
                self.weight() if (self.weight is not None) else None,
                self.bias() if (self.bias is not None) else None,
                mask,
                running_mean,
                running_var,
                self.momentum,
                self.eps
            )
