# Does SGD really happen in tiny subspaces?

The experiments were conducted using Pytorch, and this repository is based on the GitHub repository at https://github.com/locuslab/edge-of-stability to replicate the experimental setup described in [Cohen et al. (2021)](https://openreview.net/forum?id=jh-rTtvkGeM).

### Preliminaries

To run the code, you need to set two environment variables:

1. Set the `DATASETS` environment variable to a directory where datasets will be stored. For example: `export DATASETS="datasets"`.
2. Set the `RESULTS` environment variable to a directory where results will be stored. For example: `export RESULTS="results"`.

### Quick Start

Terminal:
```
python src/train.py --help
```
Output:
```
usage: train.py [-h] [--max_steps MAX_STEPS] [--seed SEED] [--beta BETA] [--rho RHO] [--physical_batch_size PHYSICAL_BATCH_SIZE] [--batch_size BATCH_SIZE] [--acc_goal ACC_GOAL] [--loss_goal LOSS_GOAL] [--neigs NEIGS] [--neigs_dom NEIGS_DOM]
                [--eig_freq EIG_FREQ] [--save_freq SAVE_FREQ] [--save_model SAVE_MODEL] [--gpu_id GPU_ID]
                {cifar10-5k,mnist-5k,sst2-1k} arch_id {ce,mse,logtanh} {sgd,sam,adam} lr

Train a neural network.

positional arguments:
  {cifar10-5k,mnist-5k,sst2-1k}
                        which dataset to train
  arch_id               which network architectures to train
  {ce,mse,logtanh}      which loss function to use
  {sgd,sam,adam}        which optimization method to use
  lr                    the learning rate

options:
  -h, --help            show this help message and exit
  --max_steps MAX_STEPS
                        the maximum number of gradient steps to train for
  --seed SEED           the random seed used when initializing the network weights
  --beta BETA           momentum parameter
  --rho RHO             perturbation radius for SAM
  --physical_batch_size PHYSICAL_BATCH_SIZE
                        the maximum number of examples that we try to fit on the GPU at once
  --batch_size BATCH_SIZE
                        batch size of SGD
  --acc_goal ACC_GOAL   terminate training if the train accuracy ever crosses this value
  --loss_goal LOSS_GOAL
                        terminate training if the train loss ever crosses this value
  --neigs NEIGS         the number of top eigenvalues to compute
  --neigs_dom NEIGS_DOM
                        the number of dominant top eigenvalues
  --eig_freq EIG_FREQ   the frequency at which we compute the top Hessian eigenvalues (-1 means never)
  --save_freq SAVE_FREQ
                        the frequency at which we save results
  --save_model SAVE_MODEL
                        if 'true', save model weights at end of training
  --gpu_id GPU_ID       gpu (cuda device) id
```

Terminal:
```
python src/train_proj.py --help
```
Output:
```
usage: train_proj.py [-h] [--max_steps MAX_STEPS] [--start_step START_STEP] [--seed SEED] [--beta BETA] [--rho RHO] [--physical_batch_size PHYSICAL_BATCH_SIZE] [--batch_size BATCH_SIZE] [--acc_goal ACC_GOAL] [--loss_goal LOSS_GOAL] [--neigs NEIGS]
                     [--neigs_dom NEIGS_DOM] [--save_freq SAVE_FREQ] [--save_model SAVE_MODEL] [--gpu_id GPU_ID]
                     {dom,bulk} {cifar10-5k,mnist-5k,sst2-1k} arch_id {ce,mse,logtanh} {sgd,sam,adam} lr

Train a neural network using projected updates.

positional arguments:
  {dom,bulk}            which subspace to project
  {cifar10-5k,mnist-5k,sst2-1k}
                        which dataset to train
  arch_id               which network architectures to train
  {ce,mse,logtanh}      which loss function to use
  {sgd,sam,adam}        which optimization method to use
  lr                    the learning rate

options:
  -h, --help            show this help message and exit
  --max_steps MAX_STEPS
                        the maximum number of gradient steps to train for
  --start_step START_STEP
                        the step to start projected method
  --seed SEED           the random seed used when initializing the network weights
  --beta BETA           momentum parameter
  --rho RHO             perturbation radius for SAM
  --physical_batch_size PHYSICAL_BATCH_SIZE
                        the maximum number of examples that we try to fit on the GPU at once
  --batch_size BATCH_SIZE
                        batch size of SGD
  --acc_goal ACC_GOAL   terminate training if the train accuracy ever crosses this value
  --loss_goal LOSS_GOAL
                        terminate training if the train loss ever crosses this value
  --neigs NEIGS         the number of top eigenvalues to compute
  --neigs_dom NEIGS_DOM
                        the number of dominant top eigenvalues
  --save_freq SAVE_FREQ
                        the frequency at which we save results
  --save_model SAVE_MODEL
                        if 'true', save model weights at end of training
  --gpu_id GPU_ID       gpu (cuda device) id
```

### Example 1: Train MLP on MNIST-5k using SGD

Run terminal (SGD):
```
python src/train.py mnist-5k fc-tanh mse sgd 0.01 --max_steps 40000 --batch_size 50 --eig_freq 100 --neigs 20 --neigs_dom 10
```

Plot results:
```
import torch
import matplotlib.pyplot as plt
cmap_blue = plt.get_cmap("Blues_r", 20)
cmap_orange = plt.get_cmap("Oranges_r", 20)

dataset = "mnist-5k"
arch_id = "fc-tanh"
loss = "mse"
opt = "sgd"
lr = 0.01
beta = 0
batch_size = 50
seed = 0
neigs = 20
neigs_dom = 10
eig_freq = 100

def get_directory(dataset, arch_id, loss, seed, opt, lr, batch_size, beta):
    directory = f"{dataset}/{arch_id}/{loss}/seed_{seed}/{opt}"
    if beta == 0:
        directory = f"{directory}/lr_{lr}_batch_{batch_size}"
    else:
        directory = f"{directory}/lr_{lr}_batch_{batch_size}_beta_{beta}"
    return directory

directory = get_directory(dataset, arch_id, loss, seed, opt, lr, batch_size, beta)
results_directory = f"results/{directory}"

train_loss = torch.load(f"{results_directory}/train_loss_final")
train_acc = torch.load(f"{results_directory}/train_acc_final")
eigs = torch.load(f"{results_directory}/eigs_final")
evecs_grad_cos = torch.load(f"{results_directory}/evecs_grad_cos_final")

plt.figure(1, figsize=(8,12))
plt.subplot(311)
plt.plot(torch.log(train_loss))
plt.ylabel('Training loss (log-scale)')

plt.subplot(312)
for i in range(neigs_dom):
    plt.plot(torch.arange(eigs.shape[0]) * eig_freq, eigs[:,i], color = cmap_blue(i))
for i in range(neigs-neigs_dom):
    plt.plot(torch.arange(eigs.shape[0]) * eig_freq, eigs[:,neigs_dom+i], color = cmap_orange(i))
plt.ylabel('Top eigenvectors')

plt.subplot(313)
plt.plot(torch.arange(eigs.shape[0]) * eig_freq, torch.norm(evecs_grad_cos[:,:neigs_dom], dim=1))
plt.ylabel(r'$\chi_{10}(\nabla L)$')
plt.xlabel('Training Steps')
```
![plot](./demo1.png)

### Example 2: Train MLP on MNIST-5k using Dom-SGD and Bulk-SGD

Run terminals

- Dom-SGD:
```
python src/train_proj.py dom mnist-5k fc-tanh mse sgd 0.01 --max_steps 20000 --batch_size 50 --start_step 5600 --neigs 20 --neigs_dom 10
```
- Bulk-SGD:
```
python src/train_proj.py bulk mnist-5k fc-tanh mse sgd 0.01 --max_steps 20000 --batch_size 50 --start_step 5600 --neigs 20 --neigs_dom 10
```

Plot results:
```
dataset = "mnist-5k"
arch_id = "fc-tanh"
loss = "mse"
seed = 0
neigs = 20
neigs_dom = 10
opt = "sgd"
lr = 0.01
beta = 0
batch_size = 50
start_step = 5600

def get_directory(dataset, arch_id, loss, seed, opt, lr, batch_size, beta):
    directory = f"{dataset}/{arch_id}/{loss}/seed_{seed}/{opt}"
    if beta == 0:
        directory = f"{directory}/lr_{lr}_batch_{batch_size}"
    else:
        directory = f"{directory}/lr_{lr}_batch_{batch_size}_beta_{beta}"
    return directory

def get_proj_directory(proj:str, dataset: str, arch_id: float, loss: float, seed: int, opt: str, lr: float, batch_size: int, beta: float, start_step: int):
    directory = f"{dataset}/{arch_id}/{loss}/seed_{seed}/{opt}"
    if beta == 0:
        return f"{directory}/lr_{lr}_batch_{batch_size}/{proj}_{start_step}"
    else:
        return f"{directory}/lr_{lr}_batch_{batch_size}_beta_{beta}/{proj}_{start_step}"

results_dir_base = f"results/{get_directory(dataset, arch_id, loss, seed, opt, lr, batch_size, beta)}"
results_dir_dom = f"results/{get_proj_directory('dom', dataset, arch_id, loss, seed, opt, lr, batch_size, beta, start_step)}"
results_dir_bulk = f"results/{get_proj_directory('bulk', dataset, arch_id, loss, seed, opt, lr, batch_size, beta, start_step)}"

train_loss_base = torch.load(f"{results_dir_base}/train_loss_final")
train_loss_dom = torch.load(f"{results_dir_dom}/train_loss_final")
train_loss_bulk = torch.load(f"{results_dir_bulk}/train_loss_final")


T = 20000
fig, ax = plt.subplots(1, 1,figsize = (8, 4))
ax.plot(torch.log(train_loss_base[:T]), label = "SGD")
ax.plot(torch.arange(start_step, T), torch.log(train_loss_dom[start_step:T]), label = "Dom-SGD")
ax.plot(torch.arange(start_step, T), torch.log(train_loss_bulk[start_step:T]), label = "Bulk-SGD")
ax.axvline(x=start_step, ls='--', color = 'black')
ax.set_xlabel('Training Steps')
plt.ylabel('Training loss (log-scale)')
plt.legend()
plt.tight_layout()
```
![plot](./demo2.png)
