from torch import nn
from .conv_delta_orthogonal import conv_delta_orthogonal_

def activation(state, event):
    return nn.Tanh()

def init_conv(state,event,m):
    if state["init"] == "kaiming":
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
    elif state["init"] == "delta":
        conv_delta_orthogonal_(m.weight)
    else:
        raise ValueError("Initialization not supported")

def register(mf):
    mf.register_event('activation_layer_cls', lambda:nn.Tanh, unique=True)
    mf.register_event('activation_layer', activation, unique=True)
    mf.register_event('init_conv', init_conv, unique=True)
    mf.register_defaults({
        "init": "kaiming", #kaiming or delta_
    })