from torch import nn

def gradientnorm(state, event):
    grads = [w[1].grad.abs().mean() for w in state.all["net"].named_parameters() if w[1].grad is not None and "convs" in w[0]]
    if state.all["main.current_batch"] % state["plot.steps"] == 0:
        event.plot_scalar(sum(grads).item()/len(grads), title="Gradient Norm", ylabel="Weight Norm")
        event.plot_scalar2d(grads, title="Gradient Norm", ylabel="Weight No.")

def register(mf):
    mf.register_defaults({
        "plot.steps": 10
    })
    mf.register_event('after_step', gradientnorm)
