## Requirements

To install requirements:

Create conda env
```
conda create -n eq-pde-jax python=3.11
```

Install dependencies
```
conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia
pip install --upgrade "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install optax flax orbax wandb matplotlib trimesh hydra-core tqdm netCDF4 py-pde
conda install -c conda-forge dedalus
```

## Data

All datasets are generated as training progresses. Code for datasets can be found under `/experiments/fitting/datasets/pdes.py`.

## Training

We specify commands per experiment. Seeds for the three runs are {0, 1, 2}.

**Planar diffusion**
```bash
# With se2 bi-invariants
export PYTHONPATH=. && python experiments/fitting/fit_diff_plane.py seed={0, 1, 2}
```

**Navier-Stokes**
```bash
# With periodic bi-invariants
export PYTHONPATH=. && python experiments/fitting/fit_navier_stokes.py seed={0, 1, 2}
# With absolute position bi-invariants
export PYTHONPATH=. && python experiments/fitting/fit_navier_stokes.py seed={0, 1, 2} nef.invariant_type=abs_pos
# With periodic bi-invariants and autodecoding instead of meta-sgd
export PYTHONPATH=. && python experiments/fitting/fit_navier_stokes_nonmaml.py seed={0, 1, 2}
```

**Diffusion on the sphere**
```bash
# With so3 bi-invariants
export PYTHONPATH=. && python experiments/fitting/fit_diff_sphere.py seed={0, 1, 2}
# With abs pos bi-invariants
export PYTHONPATH=. && python experiments/fitting/fit_diff_sphere.py seed={0, 1, 2} nef.invariant_type=abs_pos
```

**Global shallow water equations**
Super resolution is run in the same script, the last logged wandb value is the super resolution run.
```bash
export PYTHONPATH=. && python experiments/fitting/fit_shallow_water.py seed={0, 1, 2}
```

**Internally heated convection**
```bash
export PYTHONPATH=. && python experiments/fitting/fit_ihc.py seed={0, 1, 2}
```
