# Receding Neuron Importances for Structured Pruning
XXXX-1

## Abstract
Structured pruning efficiently compresses networks by identifying and removing unimportant neurons. While this can be elegantly achieved by applying sparsity-inducing regularisation on BatchNorm parameters, an L1 penalty would shrink all scaling factors rather than just those of superfluous neurons. To tackle this issue, we introduce a simple BatchNorm variation with bounded scaling parameters, based on which we design a novel regularisation term that suppresses only neurons with low importance. Under our method, the weights of unnecessary neurons effectively recede, producing a polarised bimodal distribution of importances. We show that neural networks trained this way can be pruned to a larger extent and with less deterioration. We one-shot prune VGG and ResNet architectures at different ratios on CIFAR and ImagenNet datasets. In the case of VGG-style networks, our method significantly outperforms existing approaches particularly under a severe pruning regime.

## Experiments and Usage
- Datasets: CIFAR10, CIFAR100
- Methods: RNI (ours), UCS (baseline), L1 (Slimming)
- Architectures: VGG-16, ResNet-56
- In the hope to make things more user-friendly all training, pruning and fine-tuning can be run from a Ipython Notebook using only two functions.
- Both Training and pruning-finetuning return the best model and training logs.
- Sample usage over CIFAR scenarios can be found in Results.ipynb

### Training
- Trains a new full/baseline model from scratch, which can be pruned later.
```
model, logs = train(
    dataset="cifar10",  # "cifar10", "cifar100"
    arch="vgg16",       # "vgg16", "resnet56"
    method="rni",       # "rni", "ucs", "l1"
    reg_weight=1e-4,    # strength of sparsity regularisation, will be turned off for ucs
    b=3,                # "shift" hyper-parameter for rni, won't do anything for l1 or ucs
    epochs=160,
    lr=0.1,             # scheduled to be divided by 10 at 50% and 75% of epochs.
    batch_size=64,
    seed=0,
    device="cuda:0",
)
```

### Pruning and Finetuning
- Prunes and finetunes a given model.
- Pruning strategy is dictated by method used. local: ucs, global: l1, rni

```
model, logs = prune_finetune(
    model,
    dataset="cifar10",  
    method="rni",       
    prune_pc=0.5,       # percentage of filters to prune
    epochs=160,
    lr=0.1,             
    batch_size=64,
    seed=0,
    device="cuda:0",
)
```


# Implementation Details
- Written in Pytorch
- Very simple implementation using a single underlying trainer function for all methods, architectures and datasets.
- Pruning logic is defined in the model modules, as it depends on the architecture.
- Sigmoid BatchNorm implementation is based on Pytorch's BatchNorm code.