import torch
import torch.nn as nn

def get_grad_norm(model):
	gn = 0
	for p in model.parameters():
		if p.grad is None:
			continue
		g_flat = p.grad.data.view(-1)
		gn += torch.dot(g_flat, g_flat)

	return torch.sqrt(gn)

def get_param_norm(model):
	pn = 0
	for p in model.parameters():
		if p.grad is None:
			continue
		p_flat = p.data.view(-1)
		pn += torch.dot(p_flat, p_flat)

	return torch.sqrt(pn)