from torch.nn import BatchNorm2d

def normalization_layer(state, event, num_channels):
    return BatchNorm2d(num_channels, 1e-5, 0.1, True, True)

def register(mf):
    mf.register_event('normalization_layer', normalization_layer, unique=True)
    mf.register_event('normalization_layer_cls', lambda:BatchNorm2d, unique=True)
