from torch import nn
from moduleloader import outervar
from functools import partial

def plot_mean_std(state, event, net, shortcut=None, l=outervar, title="Act"):
    mean, std = net.mean(), net.std()
    event.plot_scalar2d(mean.item(), l, title=title+" Mean")
    event.plot_scalar2d(std.item(), l, title=title+" Std")

def register(mf):
    mf.register_defaults({
        "plot.steps": 10
    })
    mf.register_event('before_relu', partial(plot_mean_std,title="Pre-Act"))
    mf.register_event('after_relu', plot_mean_std)
