Keywords: Physics-informed neural Networks, Compilers, XLA, JAX, TensorFlow, PyTorch
TL;DR: This paper introduces PINNs-JAX, a new implementation of Physics-Informed Neural Networks (PINNs) using the JAX framework.
Abstract: Physics-Informed Neural Networks (PINNs) have become a pivotal technology for adhering to physical laws and solving nonlinear partial differential equations (PDEs). Enhancing the performance of PINN implementations can significantly quicken the pace of simulations and foster the creation of innovative methodologies. This paper presents `PINNs-JAX', an innovative implementation that utilizes the JAX framework to leverage the distinct capabilities of XLA compilers. This approach aims to improve computational efficiency and flexibility within PINN applications. We conduct a comprehensive comparison of PINNs-JAX against traditional PINN implementations in widely-used frameworks such as TensorFlow V1, TensorFlow V2, and PyTorch, evaluating performance across a variety of six different examples. These include continuous, discrete, forward, and inverse problems. Our findings indicate that PINNs implemented with JAX outperform in simpler examples, yet TensorFlow V2 presents potential benefits for tackling large-scale challenges, as exemplified by the 3D-Navier Stokes case. To support collaborative development and further research, we have made the source code available to the public at: https://github.com/rezaakb/pinns-jax.
Submission Number: 93
Loading