Keywords: implicit differentiation, bilevel optimization, autodiff, jax
TL;DR: We propose an approach for automatic implicit differentiation.
Abstract: Automatic differentiation (autodiff) has revolutionized machine learning. It
allows to express complex computations by composing elementary ones in creative
ways and removes the burden of computing their derivatives by hand. More
recently, differentiation of optimization problem solutions has attracted
widespread attention with applications such as optimization layers, and in
bi-level problems such as hyper-parameter optimization and meta-learning.
However, so far, implicit differentiation remained difficult to use for
practitioners, as it often required case-by-case tedious mathematical
derivations and implementations. In this paper, we propose
automatic implicit differentiation, an efficient
and modular approach for implicit differentiation of optimization problems. In
our approach, the user defines directly in Python a function $F$ capturing the
optimality conditions of the problem to be differentiated. Once this is done, we
leverage autodiff of $F$ and the implicit function theorem to automatically
differentiate the optimization problem. Our approach thus combines the benefits
of implicit differentiation and autodiff. It is efficient as it can be added on
top of any state-of-the-art solver and modular as the optimality condition
specification is decoupled from the implicit differentiation mechanism. We show
that seemingly simple principles allow to recover many existing implicit
differentiation methods and create new ones easily. We demonstrate the ease of
formulating and solving bi-level optimization problems using our framework. We
also showcase an application to the sensitivity analysis of molecular dynamics.
Supplementary Material: pdf
Community Implementations: [![CatalyzeX](/images/catalyzex_icon.svg) 2 code implementations](https://www.catalyzex.com/paper/efficient-and-modular-implicit/code)
14 Replies
Loading