How to compute Hessian-vector products?

Published: 16 Feb 2024, Last Modified: 28 Mar 2024BT@ICLR2024EveryoneRevisionsBibTeXCC BY 4.0
Keywords: hessian-vector products, automatic differentiation, jax, pytorch, bilevel optimization
Blogpost Url: https://iclr-blogposts.github.io/2024/blog/bench-hvp/
Abstract: The products between the Hessian of a function and a vector, so-called Hessian-vector product (HVPs) is a quantity that appears in optimization and machine learning. However, the computation of HVPs is often considered prohibitive, preventing practitioners from using algorithms that rely on these quantities. Standard automatic differentiation theory predicts that computing a HVP has a cost of the same order of magnitude as computing a gradient. The goal of this blog post is to provide a practical counterpart to this theoretical result, showing that modern automatic differentiation frameworks, Jax and Pytorch, allow for efficient computation of these HVPs in standard deep learning cost functions.
Ref Papers: https://arxiv.org/abs/2010.07962, https://arxiv.org/abs/2010.01412
Id Of The Authors Of The Papers: ~Kaiyi_Ji1, ~Behnam_Neyshabur1
Conflict Of Interest: N/A
Submission Number: 18
Loading