# Hessian-Free Natural Gradient Descent for Physics-Informed Computational Fluid Dynamics


## Requirements

The following Python packages are required to run the code:

- Python 3.10.10 or later
- [JAX](https://github.com/google/jax) 0.4.8 or later
- [JAXopt](https://github.com/google/jaxopt) 0.6 or later
- [Optax](https://github.com/deepmind/optax) 0.1.4 or later
- [jaxtyping](https://github.com/google/jaxtyping) 0.2.25 or later
- [matfree](https://github.com/pnkraemer/matfree) (for matrix-free operations)
- [Lineax](https://github.com/google/lineax)
- [jmp](https://github.com/deepmind/jmp)

If you plan to use GPU or TPU acceleration, ensure that your installed JAX version is compatible with your CUDA version and NVIDIA driver. For more information, refer to JAX's GPU/TPU support [here](https://github.com/google/jax#installation).

### Additional Notes
When running on personal laptops, we recommend using smaller network architectures, which still produce satisfying results within reasonable computational times.

## Installation

1. Set up a virtual environment:
    ```bash
    python -m venv env
    source env/bin/activate   # For Linux/macOS
    # On Windows:
    # .\env\Scripts\activate
    ```

2. Install the required dependencies:
    ```bash
    pip install jax jaxopt optax jaxtyping lineax jmp matfree
    ```

    If using GPU support, make sure to install the appropriate JAX version with CUDA. For example:
    ```bash
    pip install jax[cuda11_cudnn86] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
    ```

3. Install `matfree` from the official GitHub repository:
    ```bash
    git clone https://github.com/pnkraemer/matfree.git
    cd matfree
    pip install .
    ```



