import torch.nn as nn
import torch

class LayerNorm2d(nn.Module):
    def __init__(self, nchan):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(nchan))
        self.bias = nn.Parameter(torch.zeros(nchan))

    def forward(self, x):
        x = x - x.mean(1, keepdim=True)
        x = x / x.std(1, keepdim=True, unbiased=False)
        x = x * self.weight.view(1, -1, 1, 1)
        x = x + self.bias.reshape(1, -1, 1, 1)
        return x
 
class MyIdentity(nn.Module):
    def __init__(self):
        super(MyIdentity, self).__init__()

    def forward(self, x, *args, **kwargs):
        return x
       
def get_norm_layer(width, norm):
    if norm == None:
        return MyIdentity()
    elif norm == "ln":
        return LayerNorm2d(width)
    elif norm == "bn":
        return nn.BatchNorm2d(width)
    else:
        raise ValueError("Wrong value for normalization layer")