import torch


def print_grads(net):
    for name, param in net.named_parameters():
        if param.grad is not None:
            print(name, "grad", param.grad.shape, torch.sum(param.grad).item())
        else:
            print(name, "grad", None)
