# SLIM-QN: A **S**tochastic, **Li**ght, **M**omentumized **Q**uasi-**N**ewton Optimizer for Deep Neural Networks

This is the official implementation of the Vision Transformer experiments in the anonymous NeurIPS 2021 submission titled *SLIM-QN: A **S**tochastic, **Li**ght, **M**omentumized **Q**uasi-**N**ewton Optimizer for Deep Neural Networks*.

## Requirements
The code has been tested with
- Python 3.6.9
- CUDA 11.0
- PyTorch 1.8.1
- PyTorch Lightning 1.2.5

The experiments have been performed on either 4 or 8 GPUs using PyTorch DDP with NCCL backend.

To install requirements:

```setup
pip install -r requirements.txt
```

To run the training scripts, you will need to download the [ImageNet](https://www.image-net.org/) and [Oxford Flowers102](https://www.robots.ox.ac.uk/~vgg/data/flowers/102/) datasets.

Flowers102 with directory structure supported by our code can be downloaded [here](https://s3.amazonaws.com/content.udacity-data.com/nd089/flower_data.tar.gz).

## Training

To train the models in the paper you may either run the `jupyter-notebook` files in the root directory, or alternatively simply run the python command included therein in the command line.

The name of the `.ipynb` file determines the experiment in the following format: `run_vit_DATASET_OPTIMIZER.ipynb`, where `DATASET` and `OPTIMIZER` are to be replaced by the desired dataset and optimizer. Make sure that you replace `--data_dir` with the path to the dataset in your machine.

**NOTE**: PyTorch DDP automatically divides the gradients based on the number of GPUs. To exactly match our results using different number of GPUs as indicated in the scripts, the learning rate (`--lr`) has to be multiplied by `num_gpus_desired / num_gpus_in_script`.

## Results

Our experimental results on the small Vision Transformer can be summarized as :

### SGD
| Dataset / batch size | Peak val. acc.  | Iterations to peak |
| ---------------------|-----------------|--------------------|
|   Cifar-10 / 256     |     81.9%       |        15071       |
|   Cifar-10 / 1024    |     81.1%       |         5719       |
|   Flowers102 / 256   |     71.2%       |         3171       |
|   ImageNet / 1024    |     51.9%       |        80600       |

### SLIM-QN
| Dataset / batch size | Peak val. acc.  | Iterations to peak |
| ---------------------|-----------------| -------------------|
|    Cifar-10 / 256    |     82.7%       |         9262       |
|    Cifar-10 / 1024   |     83.3%       |         3679       |
|    Flowers102        |     72.7%       |         3171       |
|    ImageNet / 1024   |     53.6%       |       107066       |
