import torch


def hook_fn_decorator(name, param, print_val=False):
    def hook_fn(grad):
        if grad.isnan().any():
            print(f"grad of {name} contains nan, exiting")
            exit()
        if print_val:
            print(f"{name}: {grad.norm()}")

    return hook_fn


def summary(model):
    total_params = sum(p.numel() for p in model.parameters())
    return total_params
