from .plot_scalar import plot_scalar
from .plot_scalar2d import plot_scalar2d

import numpy as np

def plot_gradnorm(state, net, plot_suffix):
    vars = list(filter(lambda p: p.grad is not None, net.parameters()))
    if S("visdom.summary.gradnorm"):
        grad_sum = np.mean([x.grad.norm(2).item() for x in vars])
        plot_scalar(grad_sum, state["total_i"], title="gradnorm_"+plot_suffix)
    if S("visdom.summary.gradnorm_per_conv"):
        vars_conv = list(filter(lambda p: p.dim() == 4, vars))
        for i,v in enumerate(vars_conv):
            grad_sum = v.grad.norm(2).item()
            plot_scalar2d(grad_sum, str(i), state.all["total_i"], title=("gradnorm_per_conv_"+plot_suffix))
    #num_filters = np.sum([np.sum(x.shape) for x in vars])
    #plot_scalar(grad_sum/num_filters, state["total_i"], title="rel_gradnorm_"+plot_suffix)


