# How Does Adaptive Optimization Impact Local Neural Network Geometry?

## Requirements

To install requirements:

```setup
pip install -r requirements.txt
```

>📋  Please install the required packaged listed in `requirements.txt`.

## Training

To fine-tune BERT-small on the sentence classification task on IMDB dataset (see Section 4.1 for more details), run:

```train
python train_BERT.py
```

To train a transformer for the translation task on Multi30k (see Section 4.1 for more details), run:

```train
python train_translation_task.py
```

>📋 The model and dataset will be automatically downloaded when running the scripts.
>
>📋 After training, there will be two new folders named `BERT_results` and `translation_results`. The training losses and models at some epochs will be stored in these folders.

## Calculating diagonal of loss Hessian

To calculate the diagonal of loss Hessian, run:

```eval
python hessian_BERT.py
python hessian_translation.py
```

>📋 The scripts will create a new folder named `diagHessian_adaptGrad_200` in `BERT_results` or `translation_results` containing the results.
>
>📋 In the new folder, the files about SGD+M have 3 columns, representing (from left to right) the diagonal of Hessian, the gradient and the momentum for 200 selected coordinates. The files about Adam have 4 columns, representing (from left to right) the diagonal of Hessian, the gradient, $m_t$ and $v_t$ for 200 selected coordinates.

## Calculating $R_{\text{med}}^{\text{OPT}}$

To calculate $R_{\text{med}}^{\text{OPT}}$ based on the results in `diagHessian_adaptGrad_200`, please use Matlab to run `R_med_OPT_BERT.m` and `R_med_OPT_translation.m`.

>📋  The Matlab scripts will create a file `R_med_OPT.txt` in the folder `BERT_results` or `translation_results`, which stores the $R_{\text{med}}^{\text{OPT}}$ in a table.

## Results

Sentence classification task on BERT-small

|         |             Step0              |             Step0              |            Step750             |            Step750             |                           Step750                            |            Step1250            |            Step1250            |                           Step1250                           |
| :-----: | :----------------------------: | :----------------------------: | :----------------------------: | :----------------------------: | :----------------------------------------------------------: | :----------------------------: | :----------------------------: | :----------------------------------------------------------: |
| Layer # | $R_{\text{med}}^{\text{SGDM}}$ | $R_{\text{med}}^{\text{Adam}}$ | $R_{\text{med}}^{\text{SGDM}}$ | $R_{\text{med}}^{\text{Adam}}$ | $\frac{R_{\text{med}}^{\text{SGDM}}}{R_{\text{med}}^{\text{Adam}}}$ | $R_{\text{med}}^{\text{SGDM}}$ | $R_{\text{med}}^{\text{Adam}}$ | $\frac{R_{\text{med}}^{\text{SGDM}}}{R_{\text{med}}^{\text{Adam}}}$ |
|    9    |    15.7      | 15.7 |   12.76 | 9.65 | 1.45 |	11.43 |	14.24 | 0.94 |
|   12 | 22.63 |	22.63 |	13.17 |	7.41 | 1.92 |	10.62 |	9.67 | 1.33 |
|   15 | 9.35 |	9.35 |	80.57 |	53.52 | 1.65 |	100.65 |	61.80 | 2.01 |
|   17 | 82.37 |	82.37 |	405.02 |	223.56 | 1.91 | 423.28 |	337.32 | 1.43|
|   18 | 31.32 |	31.32 |	17.07 |	13.24 | 1.43 |	18.15 |	15.63 | 1.21 |
|   22 | 47.13	| 47.13	| 233.72 |	72.67 | 3.54 |	158.38 |	93.13 | 2.28|
|   24    |             31.17 |	31.17 |	17.52 |	17.34 | 1.13 | 13.51 |	14.23 | 1.05   |

Translation task

|         |             Epoch0             |             Epoch0             |            Epoch30             |            Epoch30             |                           Epoch30                            |            Epoch55             |            Epoch55             |                           Epoch55                            |
| :-----: | :----------------------------: | :----------------------------: | :----------------------------: | :----------------------------: | :----------------------------------------------------------: | :----------------------------: | :----------------------------: | :----------------------------------------------------------: |
| Layer # | $R_{\text{med}}^{\text{SGDM}}$ | $R_{\text{med}}^{\text{Adam}}$ | $R_{\text{med}}^{\text{SGDM}}$ | $R_{\text{med}}^{\text{Adam}}$ | $\frac{R_{\text{med}}^{\text{SGDM}}}{R_{\text{med}}^{\text{Adam}}}$ | $R_{\text{med}}^{\text{SGDM}}$ | $R_{\text{med}}^{\text{Adam}}$ | $\frac{R_{\text{med}}^{\text{SGDM}}}{R_{\text{med}}^{\text{Adam}}}$ |
|  3	|  4.27	|  4.27	 | 4.95	|  2.48	|  2.00	 | 2.96	 | 2.06	 | 1.44|
|   5	|  7.09	 | 7.09	| 36.57	| 19.08	 | 1.92	| 71.72	 |17.63	 | 4.07|
|   7	|  5.79	 | 5.79	 | 5.89	 | 4.12	 | 1.43	  |7.06	 | 3.13	|  2.26|
|   9	| 18.11	 |18.11	 |30.61	 |22.22	|  1.38	| 38.83	 |16.21	|  2.39|
|   12	| 11.10	 |11.10	 | 7.63	  |6.46	|  1.18	|  8.03	|  7.77	|  1.03|
|   15	| 83.15	 |83.15	 |59.96	 | 8.29	 | 7.23	| 43.93	 | 8.19	 | 5.37|
|   18	| 14.99	 |14.99	|  3.84	|  2.73	 | 1.41	 | 2.95	  |3.69	|  0.80|
|   21	 |93.50	| 93.50	| 27.35	|  5.50	|  4.97	| 19.85	|  4.94	|  4.02|
|   24	| 36.63	 |36.63	 | 5.12	  |4.01	|  1.28	|  4.87	  |3.62	 | 1.34|
|  28	| 18.47	 |18.47	 | 2.97	 | 2.02	| 1.47	|  2.97	 | 1.69	  |1.76|

> 📋  The tables stored in `R_med_OPT.txt` have the same format but without the first two rows.