The file abel.py implements ABEL and can be used easily in any flax code. The experiments of table 2 which do not refer to ABEL are run directly from the flax baselines : github.com/google/flax/tree/master/examples, https://github.com/google-research/google-research/tree/master/flax_models/cifar 

We include the code to run ABEL for the ImageNet and CIFAR/SVHN experiments which is a small modification of the standard flax baselines.

Before starting the training loop, we initiate ABEL: 
```
from flax_models.abel import ABELScheduler
scheduler = ABELScheduler(num_epochs, base_learning_rate, steps_per_epoch = steps_per_epoch, decay_factor=decay_factor, train_fn = p_train_step_fn)
learning_rate_fn = scheduler.learning_rate_fn
```

We then add the ABEL update rule at the end of each epoch which takes the current `train_step` function and mean weight norm and returns a (possibly updated) `train_step` function. ABEL will update the optimizer if the learning rate has to be decayed. 
```
p_train_step = scheduler.update(p_train_step, summary['param_norm'])
```

The only other modifications on top of the base code are: addition of different schedules, label_smoothing of 0.1 to the ImageNet experiment and add the weight norm as a measurement.

Some example commands to run ABEL would be:

- For CIFAR/SVHN code on WRN28-10 with the standard parameters: 

```
  python3 -m flax_models.cifar.train --dataset=cifar100 --lr_schedule=ABEL --output_dir=cifar100_bs128_ABEL --learning_rate=0.1 --lr_schedule=ABEL --weight_decay=0.0005 --num_epochs=200
```

- For the ImageNet Resnet-50 experiments with standard parameters:

```
  python3 -m flax_models.imagenet.imagenet_main --batch_size=2048 --cache=True --model_dir=imagenet_fp16_bs2048_ABEL_lr0.8 --learning_rate=0.1 --half_precision=True --lr_schedule=ABEL --num_epochs=90
```

The argument `--lr_schedule` can take values  `ABEL`, `cosine`, `decay`, `simple_decay`. Note that the imagenet code uses `--learning_rate` for the learning rate per device, on our v3-8 TPU configuration the learning rate is 8 times that. 
As described in https://github.com/google-research/google-research/tree/master/flax_models/cifar , we can use `flax_models/cifar` with different datasets, architectures and data augementations.