from torch import nn

def activation(state, event):
    lrelu = nn.LeakyReLU(negative_slope=state["alpha"])
    state.all["relus"].append(lrelu)
    return lrelu

def init_conv(state,event,m):
    nn.init.kaiming_normal_(m.weight, a=state["alpha"], mode='fan_out', nonlinearity='leaky_relu')

def register(mf):
    mf.register_event('activation_layer_cls', lambda:nn.LeakyReLU, unique=True)
    mf.register_event('activation_layer', activation, unique=True)
    mf.register_event('init_conv', init_conv, unique=True)
    mf.register_defaults({
        "alpha": 0.01,
    })

    mf.register_helpers({
        "relus": nn.ModuleList(),
    },parsefn=False,scope="")