### Official repository of oViT
---

The code is based on the [SparseML GitHub repository](https://github.com/neuralmagic/sparseml).

### Structure of the repository
---

The modified source code from `SparseML` is located in `/src` and subdirectories. 
Pruning algorithms are implemented in  `src/sparseml/pytorch/sparsification/pruning` directory:

- GlobalMagnitudePruningModifier: `modifier_pruning_magnitude.py`
- OBSPruningModifier: `modifier_pruning_obs.py`
- OBSXPruningModifier: `modifier_pruning_obsx.py`

*Note*: 
`OBSX` is an alias to the pruning method defined in the `oViT` paper.  I.e in case you one would like
to run oViT one should call `OBSXPruningModifier` in the pruning recipe. 

The code to launch experiments is located inside `/research` directory. 

- ```research/``` — root directory for experiments \
    ```├── sparse_training.py``` — main script for gradual pruning (based on [train.py](https://github.com/rwightman/pytorch-image-models/blob/master/train.py) from timm) \
    ```├── one_shot_pruning.py``` — script for running one-shot pruning experiments \
    ```├── run_gradual_pruning.sh``` — script to launch `sparse_training.py` \
    ```├── run_one_shot_pruning.sh``` — script to launch `one_shot_pruning.py` \
    ```├── utils/``` — additional utils used in training scripts \
    ```├── configs/``` — `.yaml` recipes with training hyperparameters \
    ```├── recipes/``` — SparseML recipes 


### Usage
---

**Installation**

The recommended way to run `oViT` it via conda enviroment.

**Configure enviroment**

One needs to installed torch with GPU support and timm library to run the code:

```bash
conda create --name ovit python==3.9
conda activate ovit
conda install scipy numpy scikit-learn pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch 
<unpack this archive>
cd oViT
pip install -r requirements.txt
```

To install `SparseML` type (in the root directory of the project):
```
python setup.py install
```

(*Optional*) We use [W&B](https://wandb.ai) for logging. Install it via `pip` in case you want to log data there:

```
pip install wandb
```

If logging to W&B  prior to launching script define W&B enviroment variables:
```bash
export WANDB_ENTITY=<your_entity>
export WANDB_PROJECT=<project_name>
export WANDB_NAME=<run_name>
```

**Potential issues**

One may face problems with the protobuf. 
In order to resolve the issue run 
```bash
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
```
Before execution of SparseML code.

**Workflow**

- Select config with training hyperparameters
- Select SparseML recipe (one from `/recipes` or your custom)
- Define other hyperparams in the launch script
- Enjoy!

*Note*: 
in most experiments we split `QKV` layer in ViT into `Q`, `K`, `V`.

Pass `--split-qkv`  to training script to perform this change in transformer blocks. This is important, since the training recipes assume split `QKV` layer.

**Example usage**

**One-shot pruning**

```bash
python one_shot_pruning.py \
    \
    --data-dir <data_dir> \
    \
    --sparseml-recipe <path_to_recipe> \
    \
    --model <model_name> \
    \
    --experiment <experiment_name> \
    \
    -b <obs_loader_batch_size> \
    -vb <validation_batch_size> \
    \
    --sparsities <list_of_sparsities> \
    \
   --split-qkv 
```

**One-shot+finetune/gradual pruning**

```bash
python -m torch.distributed.launch \
    --nproc_per_node=<num_proc> \
    --master_port=<master_port> \
    \
    --data-dir <data_dir> \
    \
    --sparseml-recipe <path_to_recipe> \
    \
    --model <model_name> \
    \
    --experiment <experiment_name> \
    \
    -b <obs_loader_batch_size> \
    -vb <validation_batch_size> \
    \
    --sparsities <list_of_sparsities> \
    \
   --split-qkv 
```

**Tweaking oViT hyperparameters**

There are sevelar hyperparameters in the oViT method that can be adjusted for better peformance.

```
    :param mask_type: String to define type of sparsity to apply. 'unstructured'
        'block4', 'N:M are supported. Default is 'unstructured'. For N:M provide
        two integers that will be parsed. 
    :param num_grads: number of gradients used to calculate the Fisher approximation
    :param damp: dampening factor, default is 1e-7
    :param fisher_block_size: size of blocks along the main diagonal of the Fisher
        approximation, default is 50
    :param grad_sampler_kwargs: kwargs to override default train dataloader config
        for pruner's gradient sampling.
    :param num_recomputations: number of EmpiricalFisher matrix recomputations
    :param blocks_in_parallel: amount of rows traversed simultaneously by OBSX pruning modifier
    :param fisher_inv_device: select specific device to store Fisher inverses.
    :param traces_backup_dir: str. If one would like to store pruning traces on disk, one can 
        specify temporary dir for storage. 
```