This repository contains code for the paper _Improving Adam using Second-Order
Perspectives_, submitted to NeurIPS 2023.


# Installation

Our complete development environment under Python 3.10 is specified in
`local_requirements.txt`, with a list of top-level requirements given in
`Pipfile`. In theory, `pipenv install` in a fresh virtual environment will set
everything up; in practice, JAX in particular may need manual intervention
depending on your local CUDA and cuDNN versions.

At the time of writing, we depend on a bugfix to the KFAC-JAX library, which is
specified in `kfac_jax.patch`. This can be applied from the project root with
``` shell
$ patch -p0 -i kfac_jax.patch
```

Datasets are not bundled with the repository, so before first use they will need
to be downloaded by calling the constructors with `download=True`.


# Running
Each dataset and algorithm is specified by a YAML configuration file in
`configs/`, where `AdamQLR_Damped_AdamDampedCurvature.yaml` is the _AdamQLR_
algorithm described in our paper. To perform a single training run, simply pass
the corresponding files to `train.py` with the `-c` flag, e.g.:
``` shell
$ python train.py -c ./configs/fashion_mnist.yaml ./configs/AdamQLR_Damped_AdamDampedCurvature.yaml
```

A complete hyperparameter optimisation routine, including 50 repetitions of the
best hyperparameters found, can be performed by calling
`hyperparameter_optimisation.py` with the corresponding configuration files:
``` shell
$ python hyperparameter_optimisation.py -c ./configs/fashion_mnist.yaml ./configs/AdamQLR_Damped_AdamDampedCurvature.yaml ./configs/ASHA.yaml
```
This same file also contains helper functions for running sensitivity studies.

To replicate all our experimental results, the various `run_*.sh*` scripts may
be useful.

# Analysis
Logs are produced by Tensorboard in a `runs/` directory by default; the paths
can be changed with command-line flags as shown in `run_hpo.sh`.

All our experimental plots are produced using `paper_plots.py`, though you may
need to update the paths to match your local configuration.
