import torch

def compute_grad_norm(params):
    grad_norm = 0.0

    for p in params:
        if p.grad is not None and p.requires_grad:
            grad_norm += torch.norm(p.grad.data)**2
    return torch.sqrt(grad_norm)
