# Conflicting Biases at the Edge of Stability: Norm versus Sharpness Regularization

This repository contains source code for the ICML 2026 submission Conflicting Biases at the Edge of Stability: Norm versus Sharpness Regularization.

Parts of this code were adapted from the [repository](https://github.com/locuslab/edge-of-stability) accompanying the paper [Gradient Descent on Neural Networks Typically Occurs
at the Edge of Stability](https://openreview.net/forum?id=jh-rTtvkGeM) by Jeremy Cohen, Simran Kaur, Yuanzhi Li, Zico Kolter, and Ameet Talwalkar. If you use this code, please cite the original work. For more details see [Code attribution](#code-attribution).


The structure of this README is:
1. [Preliminaries](#preliminaries)
2. [Usage](#usage)
3. [Code attribution](#code-attribution)

## Preliminaries


### Required packages

The following package versions were used in our experiments (also available in `requirements.txt`)
- matplotlib==3.10.3
- numpy==2.2.6
- scipy==1.15.3
- torch==2.5.1
- torchvision==0.20.1
- jax==0.6.0

### Environment variable definitions

To run the code, you need to set two environment variables:
- Set the `DATASETS` environment variable to a directory where datasets will be stored. For example: `export DATASET="/my/directory/datasets"`.
- Set the `RESULTS` environment variable to a directory where results will be stored. For example: `export RESULTS="/my/directory/results"`.

## Usage

Our experimental workflow for each setup consisted of four steps:
1. [Run gradient flow](#running-gradient-flow),
2. [Create a learning rate schedule](#creating-a-learning-rate-schedule)
3. [Run gradient descent](#running-gradient-descent)
4. [Generate plots](#generating-plots)

### Running gradient flow

After defining environment variables,
```
export DATASET="/my/directory/datasets"
export RESULTS="/my/directory/results"
```
run
```
flow.py [dataset] [arch_id] [loss] [tick] [max_time] --seed [seed] --loss_goal [loss_goal]
```
For example:
```
python src/flow.py cifar10-5k fc-relu mse 1.0 1000 --seed 43 --loss_goal 0.0001 --neigs 1 --eig_freq 1 --save_freq 20 --norm_type all
```
### Creating a learning rate schedule

To generate a range of learning rates for GD based on the system described in our paper's Methodology section, after the flow computation has completed run
```
make_lr_schedule.py [dataset] [arch_id] [loss] [tick] --seed [seed] --loss_goal [loss_goal]
```
For example:
```
python src/make_lr_schedule.py cifar10-5k fc-relu mse 43 1.0 --loss_goal 0.0001
```
The output will include a list of coarse and fine values, such as
```
Full coarse grid:
[0.0089 0.0203 0.0317 0.043 0.0544 0.0658 0.0772 0.0885 0.0999 0.1113 0.1227 0.1341 0.1454]
(...)
Full fine grid:
[0.0007 0.0015 0.0022 0.003  0.0037 0.0044 0.0052 0.0059 0.0067 0.0074 0.0082 0.0089]
 ```
and the same information will be saved in a `lr_schedule.json` file, in the appropriate location within the `RESULTS` folder. For example: `results/cifar10-5k/fc-relu/seed_43/mse/gd/lr_schedule.json`.

### Running gradient descent

For each of the values in the learning rate schedule, run:
```
gd.py [dataset] [arch_id] [loss] [learning_rate] [max_steps] --seed [seed] --loss_goal [loss_goal]
```
For example:
```
learning_rates=(0.0089 0.0203 0.0317 0.043 0.0544 0.0658 0.0772 0.0885 0.0999 0.1113 0.1227 0.1341 0.1454 0.0007 0.0015 0.0022 0.003  0.0037 0.0044 0.0052 0.0059 0.0067 0.0074 0.0082)

for lr in "${learning_rates[@]}"
do
    python src/gd.py cifar10-5k fc-relu  mse  $lr 2000000 \
        --seed 43 --loss_goal 0.0001 --neigs 1 \
        --eig_freq 2500 --save_freq 10000 --norm_type all
done
```

### Generating plots

Finally, to generate plots of final GD values, run
```
plot.py [dataset] [arch_id] [loss] --seed [seed] --loss_goal [loss_goal] --load_lr [lr_schedule]
```
For example
```
src/plot.py cifar10-5k fc-relu mse --seed 43 --loss_goal 0.0001 --load_lr fine
```
To generate plots of iterates over training, run:
```
python src/iterate_plot.py [dataset] [arch_id] [loss] [index] --seed [seed] --norm_type 1 [--all] --es [loss_goal]
```
For example, for all (hard-coded) indices:
```
python src/iterate_plot.py cifar10-5k fc-relu mse -1 --seed 43 --norm_type 1 --all --es 0.0001
```
or for the fifth smallest learning rate:
```
python src/iterate_plot.py cifar10-5k fc-relu mse 5 --seed 43 --norm_type 1 ---es 0.0001
```
#### Plotting options

The required parameters of `plot.py` are:

- `dataset` [string]: the dataset to train on. Some possible values are:
    - `cifar10`: the full `CIFAR-10` dataset
    - `cifar10-5k`: the first 5,000 examples from the full `CIFAR-10` dataset.  
    - `mnist`: the full `MNIST` dataset
    - `mnist-5k`: the first 5,000 examples from the full `MNIST` dataset.  
- `arch_id` [string]: which network architectures to train.  See `load_architecture()` in `archs.py` for a full list of the permissible values.
- `loss` [string]: which loss function to use.  The possible values are:
    - `ce`: cross-entropy loss
    - `mse`: mean squared error (square loss).

The optional parameters of `plot.py` are:
- `seed` [int]: the random seed used when initializing the network weights.
- `loss_goal` [float]: each learning rate value will be included only if the trein loss for this learning rate reached this value.
- `no_flow` [boolean flag]: if included, the GD values of the smallest available learning rate are used in place of GF.
- `eta_min` [float]: lowest eta to consider in the plot.
- `eta_max` [float]: highest eta to consider in the plot.
- `es` [float]: early stopping loss, must equal loss_goal if applicable. If included, the script plots values attained upon first crossing the early stopping loss, rather than the final value over the whole training.
- `load_lr` [str: "fine","coarse","all","old","cf"]: which learning rate schedule to use for plot. 'all' loads all available; 'fine', 'coarse' and 'old' require existance of the appropriate lr_schedule.json file.
- `max_smoothing` [int]: how many values on each side of the sharpness maximum should be used for smoothing.
- `show` [boolean flag]: if included, the script shows each plot.
- `plot_es` [boolean flag]: if included, the script also recursively creates plots for al early stopping instances between 1 and the current loss goal.
- `do_max` [boolean flag]: if included, also displays values attained at maximum sharpness for all plots where possible.
- `no_legend` [boolean flag]: if included, does not display the legend in any plot (names of the saved files are altered to include "_no_legend").
- `enable_titles` [boolean flag]: if included, plots each figure with a title.
- `general_captions` [boolean flag]: if included, captions in each plot are made non-specific to the concrete plotted value.

Options of all other scripts are analogous.

### Diagonal model

To reproduce the figures for the diagonal model, use the Jupyter notebook `digaonal_model.ipynb`.

## Code attribution

Parts of this repository were taken or adapted from the [repository](https://github.com/locuslab/edge-of-stability) accompanying the paper [Gradient Descent on Neural Networks Typically Occurs
at the Edge of Stability](https://openreview.net/forum?id=jh-rTtvkGeM) by Jeremy Cohen, Simran Kaur, Yuanzhi Li, Zico Kolter, and Ameet Talwalkar. If you use this code, please cite the original work.

In addition to enabling the systematic learning rate selection and plotting procedure described in the Methodology section of our paper and in the [Usage](#usage) section, the altered and new code expands the range of supported experimental settings (in particular adding the `MNIST` and `MNIST-5k` dataset and the vision transformer architecture, as well as a diagonal linear neural network with weight sharing). Further changes include minor bug fixes and adaptation to preferences.

The following code was taken **without alteration**:[$^1$](#foot)
- `cifar.py`
- `resnet_cifar.py`
- `vgg.py` (not used in our experiments)

The following code was **adapted** (see beginning of each file for a list of modifications):[$^1$](#foot)
- `archs.py`
- `data.py`
- `flow.py`
- `gd.py`
- `synthetic.py`
- `utilities.py`


The following code is entirely **original**:
- `iterate_plot.py`
- `make_lr_schedule.py`
- `mnist.py`
- `plot.py`
- `vision_transformer.py`
- `diagonal_model.ipynb`

<sup><a id="foot"></a>1: For documentation of all of these files we refer to the [original repository](https://github.com/locuslab/edge-of-stability).</sup>

