import torch
import torch.nn.functional as F
from pado.core import PadoModule
from pado.nn.parameter import ParameterModule, ParameterModuleWithOffset

__all__ = ["LayerNorm", "GroupNorm", "GroupLayerNorm"]


class LayerNorm(PadoModule):

    def __init__(self,
                 normalized_shape,
                 eps: float = 1e-5,
                 use_scale: bool = True,
                 use_offset: bool = True) -> None:
        super().__init__()
        if isinstance(normalized_shape, int):
            normalized_shape = (normalized_shape,)
        self.normalized_shape = list(normalized_shape)
        self.eps = eps
        self.use_scale = use_scale
        self.use_offset = use_offset

        if self.use_scale:
            self.weight = ParameterModuleWithOffset(torch.zeros(normalized_shape), offset=1.0)
        else:
            self.weight = None

        if self.use_offset:
            self.bias = ParameterModule(torch.zeros(normalized_shape))
        else:
            self.bias = None

    def extra_repr(self) -> str:
        s = f"{self.normalized_shape}, eps={self.eps}"
        if not self.use_scale:
            s += f", use_scale=False"
        if not self.use_offset:
            s += f", use_offset=False"
        return s

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        weight = self.weight() if (self.weight is not None) else None
        bias = self.bias() if (self.bias is not None) else None
        return F.layer_norm(x, self.normalized_shape, weight, bias, eps=self.eps)


class GroupNorm(PadoModule):

    def __init__(self,
                 num_groups: int,
                 num_channels: int,
                 eps: float = 1e-5,
                 use_scale: bool = True,
                 use_offset: bool = True) -> None:
        super().__init__()
        self.num_groups = num_groups
        self.num_channels = num_channels
        self.eps = eps
        self.use_scale = use_scale
        self.use_offset = use_offset

        if self.use_scale:
            self.weight = ParameterModuleWithOffset(torch.zeros(num_channels, ), offset=1.0)
        else:
            self.weight = None

        if self.use_offset:
            self.bias = ParameterModule(torch.zeros(num_channels, ))
        else:
            self.bias = None

    def extra_repr(self) -> str:
        s = f"{self.num_groups}, {self.num_channels}, eps={self.eps}"
        if not self.use_scale:
            s += f", use_scale=False"
        if not self.use_offset:
            s += f", use_offset=False"
        return s

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        weight = self.weight() if (self.weight is not None) else None
        bias = self.bias() if (self.bias is not None) else None
        return F.group_norm(x, self.num_groups, weight, bias, eps=self.eps)


class GroupLayerNorm(PadoModule):

    def __init__(self,
                 normalize_dim: int,  # different to LayerNorm, it only accepts 1-dim
                 num_groups: int,
                 eps: float = 1e-5,
                 use_scale: bool = True,
                 use_offset: bool = True) -> None:
        super().__init__()
        if normalize_dim % num_groups != 0:
            raise ValueError(f"GroupLayerNorm normalize_dim {normalize_dim} is not divisible by groups {num_groups}.")

        self.normalize_dim = normalize_dim
        self.num_groups = num_groups
        self.dim_per_group = normalize_dim // num_groups

        self.eps = eps
        self.use_scale = use_scale
        self.use_offset = use_offset

        if self.use_scale:
            self.weight = ParameterModuleWithOffset(torch.zeros(normalize_dim, ), offset=1.0)
        else:
            self.weight = None

        if self.use_offset:
            self.bias = ParameterModule(torch.zeros(normalize_dim, ))
        else:
            self.bias = None

    def extra_repr(self) -> str:
        s = f"{self.normalize_dim}, {self.num_groups}, eps={self.eps}"
        if not self.use_scale:
            s += f", use_scale=False"
        if not self.use_offset:
            s += f", use_offset=False"
        return s

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x_shape = tuple(x.shape)
        assert x_shape[-1] == self.normalize_dim

        x_view = x.view(x_shape[:-1] + (self.num_groups, -1))
        y = F.layer_norm(x_view, [self.dim_per_group, ], None, None, self.eps)
        y = y.view(x_shape)

        weight = self.weight() if (self.weight is not None) else None
        bias = self.bias() if (self.bias is not None) else None

        if (weight is not None) and (bias is not None):
            y = y * weight + bias
        elif weight is not None:
            y = y * weight
        elif bias is not None:
            y = y + bias

        return y
