from torch.nn import BatchNorm1d

def normalization_layer(state, event, num_channels):
    return BatchNorm1d(num_channels, 1e-5, 0.1, True, True)

def register(mf):
    mf.register_event('normalization_layer', normalization_layer, unique=True)
    mf.register_event('normalization_layer_cls', lambda:BatchNorm1d, unique=True)
