import torch
import torch.nn as nn
import torch.nn.functional as F
import math
# pylint: disable=no-member


class T5LayerNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """ Construct a layernorm module in the T5 style
            No bias and no substraction of mean.
        """
        super().__init__()
        self.w = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, x):
        # layer norm should always be calculated in float32
        variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True)
        x = x / torch.sqrt(variance + self.variance_epsilon)

        if self.w.dtype == torch.float16:
            x = x.to(torch.float16)
        return self.w * x


class ReZeroNorm(nn.Module):
    def __init__(self, hidden_size=1):
        super().__init__()
        self.w = nn.Parameter(torch.zeros(hidden_size).squeeze(0))

    def forward(self, x):
        return self.w * x


class NoNorm(nn.Module):
    def __init__(self, hidden_size=1):
        super().__init__()

    def forward(self, x):
        return x


class ScaleNorm(nn.Module):
    """ScaleNorm"""
    def __init__(self, hidden_size=1, eps=1e-5):
        super(ScaleNorm, self).__init__()
        self.scale = nn.Parameter(torch.ones(1))
        self.eps = eps

    def forward(self, x):
        # layer norm should always be calculated in float32
        variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True)
        x = x / torch.sqrt(variance + self.variance_epsilon)

        if self.w.dtype == torch.float16:
            x = x.to(torch.float16)
        return self.w * x

NORM2FN = {
    "rezero": ReZeroNorm,
    "layer_norm": nn.LayerNorm,
    "t5_layer_norm": T5LayerNorm,
    "scale_norm": ScaleNorm,
    "none": NoNorm
}
