def compute_norm(model, p=2):
    p_norm = 0
    for param in model.parameters():
        if param.requires_grad:
            param_norm = param.grad.data.norm(p)
            p_norm += param_norm.item() ** p
    p_norm = p_norm ** (1. / p)

    return p_norm
