from torch.nn import GroupNorm

def normalization_layer(state, event, num_channels):
    return GroupNorm(state["num_groups"], num_channels)

def register(mf):
    mf.register_defaults({
        "num_groups": 2
    })
    mf.register_event('normalization_layer', normalization_layer, unique=True)
    mf.register_event('normalization_layer_cls', lambda:GroupNorm, unique=True)
