Keywords: implicit differentiation, bilevel optimization, autodiff, jax
Abstract: Automatic differentiation (autodiff) has revolutionized machine learning. It allows expressing complex computations by composing elementary ones in creative ways and removes the tedious burden of computing their derivatives by hand. More recently, differentiation of optimization problem solutions has attracted a great deal of research, with applications as a layer in a neural network, and in bi-level optimization, including hyper-parameter optimization. However, the formulae for these derivatives often involves a tedious manual derivation and implementation. In this paper, we propose a unified, efficient and modular approach for implicit differentiation of optimization problems. In our approach, the user defines directly in Python a function $F$ capturing the optimality conditions of the problem to be differentiated. Once this is done, we leverage autodiff of $F$ to automatically differentiate the optimization problem. This way, our approach combines the benefits of implicit differentiation and autodiff. We show that seemingly simple principles allow to recover all recently proposed implicit differentiation methods and create new ones easily. We describe in details a JAX implementation of our framework and demonstrate the ease of differentiating through optimization problems thanks to it on four diverse tasks: hyperparameter optimization of multiclass SVMs, dataset distillation, task-driven dictionary learning and sensitivity analysis of molecular dynamics.
One-sentence Summary: A unified, efficient and modular approach for implicit differentiation of optimization problems
Supplementary Material: zip
Community Implementations: [![CatalyzeX](/images/catalyzex_icon.svg) 2 code implementations](https://www.catalyzex.com/paper/arxiv:2105.15183/code)
14 Replies
Loading