Faster Training of Neural ODEs Using Gauß–Legendre Quadrature

Published: 05 Sept 2023, Last Modified: 05 Sept 2023Accepted by TMLREveryoneRevisionsBibTeX
Abstract: Neural ODEs demonstrate strong performance in generative and time-series modelling. However, training them via the adjoint method is slow compared to discrete models due to the requirement of numerically solving ODEs. To speed neural ODEs up, a common approach is to regularise the solutions. However, this approach may affect the expressivity of the model; when the trajectory itself matters, this is particularly important. In this paper, we propose an alternative way to speed up the training of neural ODEs. The key idea is to speed up the adjoint method by using Gauß-Legendre quadrature to solve integrals faster than ODE-based methods while remaining memory efficient. We also extend the idea to training SDEs using the Wong-Zakai theorem, by training a corresponding ODE and transferring the parameters. Our approach leads to faster training of neural ODEs, especially for large models. It also presents a new way to train SDE-based models.
License: Creative Commons Attribution 4.0 International (CC BY 4.0)
Submission Length: Regular submission (no more than 12 pages of main content)
Changes Since Last Submission: We have: - Added further explanation around Figure 1, the analytical experiment - Improved Figure 1 - Explained the memory usage in Figure 2 in more depth - Included a description of how the SDE is trained as an approximate ODE, with functions transferred to an SDE solver at test time - Added computational and memory complexities about the direct vs adjoint methods, as well as explained their differences in more detail - Added information about error estimates in numerical integration - Added a description of how we can shift and scale the integration domain from a general $[a, b]$ to $[-1, 1]$ interval for use of Gauß-Legendre Quadrature - Included approximate calculations about the number of function evaluations we can expect from the GQ method - Added a new experiment looking at crossing trajectories - Rephrased paragraph about Seminorms to improve clarity - Added a clarifying sentence about equation (13) and the meaning of m in that equation - Added further citations in the related work - Improved writing around ambiguous points
Video: https://www.youtube.com/watch?v=pKbLwsqy8aM
Code: https://github.com/a-norcliffe/torch_gq_adjoint
Supplementary Material: zip
Assigned Action Editor: ~Kevin_Swersky1
Submission Number: 979
Loading