import torch
import torch.nn as nn
import torch.nn.functional as F

# pylint: disable=no-member


class T5LayerNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """ Construct a layernorm module in the T5 style
            No bias and no substraction of mean.
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, x):
        # layer norm should always be calculated in float32
        variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True)
        x = x / torch.sqrt(variance + self.variance_epsilon)

        if self.weight.dtype == torch.float16:
            x = x.to(torch.float16)
        return self.weight * x


class ReZeroNorm(nn.Module):
    def __init__(self, hidden_size=1):
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(hidden_size).squeeze(0))

    def forward(self, x):
        return self.weight * x


class NoNorm(nn.Module):
    def __init__(self, hidden_size=1):
        super().__init__()

    def forward(self, x):
        return x


class ScaleNorm(nn.Module):
    def __init__(self, hidden_size=1, value=1.0, eps=1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size) * value)
        self.variance_epsilon = eps

    def forward(self, x):
        # layer norm should always be calculated in float32
        norms = torch.norm(x.to(torch.float32), dim=-1, keepdim=True)
        x = x / (norms + self.variance_epsilon)
        return self.weight * x


NORM2FN = {"rezero": ReZeroNorm, "layer_norm": nn.LayerNorm, "t5_layer_norm": T5LayerNorm, "scale_norm": ScaleNorm, "none": NoNorm}
