# NxMTransformer

This repository contains the source for NxMTransformer, an extensible library for inducing models with semi-structured sparsity. This work is the companion to the NeurIPS submission.

## Installation

For ease of development, NxMTransformer is provided as an easily installable package. To avoid re-installing dependencies, however, the current wheel does not accurately reflect its dependencies, which must be installed separately. The only pre-requisite is a Python (>=3.6) environment and access to pip.

To install the components for NxMTransformer as well as the examples:
```
pip install -r requirements.txt
./install.sh
```

## Repeating Paper Experiments

Scripts to run the experiments reported in the submitted paper are provided. They do require some tuning in order to save model checkpoints appropriately. The scripts themselves may be found in `experiments/transformers/scripts`. Each of the scripts will require modification to add a directory prefix to the save location for each run. A short description of each script is provided here:

- `run_frequency_experiment`: Run an experimental sweep over batch size, rho, learning rate, and how frequently ADMM is performed for a single task. The ADMM frequency parameter controls how many ADMM iterations are performed in each training epoch. For example, if an ADMM frequency parameter is passed with value 4, each epoch will be divided into 4 sections in which the model is trained, the second ADMM subproblem is solved, and the auxiliary variable U is updated.
- `run_masked_finetune.sh`: Run an experimental sweep over batch size and learning rate for the masked finetuning workload (ASP). The model is pruned according to the NxM constraints and then trained with that mask. This was used to help build the baseline for Table 1.
- `run_per_epoch_experiment.sh`: Run an ADMM experiment in which after each epoch the model is fully pruned and evaluated using the validation set. Note that there is some non-determinism introduced into the RNG by this evaluation, so identical seeds using the per_epoch mode and the normal mode may provide different results. This corresponds to Figure 2.
- `run_per_epoch_masked_finetune.sh`: Run a masked finetune (ASP) for the specified sweep of hyperparameters. This corresponds to Figure 2.
- `run_raw_trainin.sh`: Dense model training pipeline. This was used to help build baseline for Table 2.
- `run_single_model_anlysis.sh`: Used to determine similarity for models trained with a debug configuration. This was used to obtain data for Figure 3.
- `run_value_decay_analysis.sh`: Used to measure value decay as measured by the frequency with which a value appears in the sparse mask. This was used to obtain data for Figure 4.
- `run_vanilla_experiment.sh`: Perform a hyperparamete sweep for either unstructured or NxM sparsity using ADMM. This script was used to obtain results for Table 1.

## Configuration Files

Three configuration files are provided. They can be found at `experiments/transformers/configs`. The two BERT configuration files were used with the bert-base-uncased model, whereas the DistilBert configuration file is compatible with distilbert-base-uncased. 

The debug variant of the Bert configuration will save detailed information about the ADMM state after each epoch for 4 evenly distributed Transformer layers in the model. Models trained with this configuration file can be passed to the two analysis scripts detailed above. Note that the data saved in this mode can require significant amounts of disk space.

The provided configuration files are only compatible with single GPU training instances. Distributing the models across multiple GPUs will change the names of the parameters as the model is wrapped in an encapsulating module. To utilize multiple GPUs in a system, we recommend changing the CUDA_VISIBLE_DEVICES variable in each training script. This will run the experiment sweep on the specified GPU. Note this variable does need to be set in all circumstances; Transformers will otherwise attempt to parallelize the model across all visible devices.