Automatic Functional Differentiation in JAX

Published: 16 Jan 2024, Last Modified: 05 Mar 2024ICLR 2024 posterEveryoneRevisionsBibTeX
Code Of Ethics: I acknowledge that I and all co-authors of this work have read and commit to adhering to the ICLR Code of Ethics.
Keywords: JAX, automatic differentiation, calculus of variation, higher-order functions, functionals, operators, JVP rules, transpose rules
Submission Guidelines: I certify that this submission complies with the submission instructions as described on
TL;DR: Describes a augmentation to JAX that enables automatic functional differentiation.
Abstract: We extend JAX with the capability to automatically differentiate higher-order functions (functionals and operators). By representing functions as infinite dimensional generalization of arrays, we seamlessly use JAX's existing primitive system to implement higher-order functions. We present a set of primitive operators that serve as foundational building blocks for constructing several key types of functionals. For every introduced primitive operator, we derive and implement both linearization and transposition rules, aligning with JAX's internal protocols for forward and reverse mode automatic differentiation. This enhancement allows for functional differentiation in the same syntax traditionally use for functions. The resulting functional gradients are themselves functions ready to be invoked in python. We showcase this tool's efficacy and simplicity through applications where functional derivatives are indispensable.
Anonymous Url: I certify that there is no URL (e.g., github page) that could be used to find authors' identity.
No Acknowledgement Section: I certify that there is no acknowledgement section in this submission for double blind review.
Primary Area: infrastructure, software libraries, hardware, etc.
Submission Number: 1974