Keywords: neural ode, memory efficient, reverse accuracy, gradient estimation
Abstract: Neural ordinary differential equations (Neural ODEs) are a new family of deep-learning models with continuous depth. However, the numerical estimation of the gradient in the continuous case is not well solved: existing implementations of the adjoint method suffer from inaccuracy in reverse-time trajectory, while the naive method and the adaptive checkpoint adjoint method (ACA) have a memory cost that grows with integration time. In this project, based on the asynchronous leapfrog (ALF) solver, we propose the Memory-efficient ALF Integrator (MALI), which has a constant memory cost $w.r.t$ integration time similar to the adjoint method, and guarantees accuracy in reverse-time trajectory (hence accuracy in gradient estimation). We validate MALI in various tasks: on image recognition tasks, to our knowledge, MALI is the first to enable feasible training of a Neural ODE on ImageNet and outperform a well-tuned ResNet, while existing methods fail due to either heavy memory burden or inaccuracy; for time series modeling, MALI significantly outperforms the adjoint method; and for continuous generative models, MALI achieves new state-of-the-art performance. We provide a pypi package: https://jzkay12.github.io/TorchDiffEqPack
One-sentence Summary: A solver for ODE that guarantees accuracy in reverse-time trajectory at a constant memory cost
Code Of Ethics: I acknowledge that I and all co-authors of this work have read and commit to adhering to the ICLR Code of Ethics
Code: [![github](/images/github_icon.svg) juntang-zhuang/TorchDiffEqPack](https://github.com/juntang-zhuang/TorchDiffEqPack)
Data: [CIFAR-10](https://paperswithcode.com/dataset/cifar-10), [DeepMind Control Suite](https://paperswithcode.com/dataset/deepmind-control-suite)
Community Implementations: [![CatalyzeX](/images/catalyzex_icon.svg) 4 code implementations](https://www.catalyzex.com/paper/arxiv:2102.04668/code)
16 Replies
Loading