import torch


def hessian_step(model, lr, lamb=1e3, epsilon=1e-5):
    if not hasattr(model, "fb") or model.fb:
        model.n_param = 0
        for p in model.parameters():
            p.grad_ = p.grad
            model.n_param += p.numel()
        model.fb = False
    else:
        for p in model.parameters():
            const = lamb / model.n_param / lr
            hessian = const * torch.sign(p.grad) * (1 - p.grad / (p.grad_ + epsilon))
            p.grad += hessian
            p.grad_ = p.grad
