# snn_quant

Post Training Quantization/Pruning with Second-order Hessian Information for Spiking Neural Network

Required packages: 
```bash
pip install spikingjelly torchvision numpy
```

Create folders:
```bash
mkdir datas/NMNIST/download datas/CIFAR10DVS/download datas/DVS128Gesture/download 
```
follow instructions to manually put NMNIST dataset in datas/NMNIST/download folder, and dvs128 gesture dataset in datas/DVS128Gesture/download folder

TRAINING

Training commands: 

N-MNIST:
```bash
python3 playgroundNMNIST.py -T 100 -device cuda -b 64 -epochs 200 -data-dir ./datas -opt adam -lr 1e-3 -tau 2.0
```
DVS128-Gesture:
```bash
python3 playgroundDVS128.py -T 20 -device cuda -b 16 -epochs 512 -opt adam -lr 1e-3 -amp -tau 2.0
```

CIFAR10-DVS:
```bash
python3 playgroundCIFAR10DVS.py -T 20 -device cuda -b 16 -epochs 512 -opt adam -lr 1e-3 -amp -tau 2.0
```

ASL-DVS:
```bash
CUDA_VISIBLE_DEVICES=0,1 torchrun --standalone --nnodes 1 --nproc_per_node 2 playgroundASLDVS.py --local-rank=0 --T 30 --tau 2.0 --data-path ../datasets/ASL_DVS --model asl_4conv_snn --mixup-alpha 0.0 --cutmix-alpha 0.0 --b 16 --lr 0.1 --lr-scheduler cosa --epochs 90 --wd 1e-4
```

CIFAR100:
```bash
CUDA_VISIBLE_DEVICES=0,1 torchrun --standalone --nnodes 1 --nproc_per_node 2 playgroundCIFAR100.py --model spiking_vgg16 --T 5 --data-path ./datas --batch-size 128 --lr 0.1 --lr-scheduler multistep --lr-milestone 150 225 --wd 1e-4 --epochs 300 --device cuda --pretrained --workers 8
```

ImageNet:

Download checkpoints from https://github.com/fangwei123456/Spike-Element-Wise-ResNet

for mps backend, replace ```-device cuda``` with ```-device mps```, and uncomment modelutils.py line 5, comment line 6


PRUNING

Pruning is configured in the compression_XXX.py files. At the top of each file, there is MODEL_SAVE_PATH, which is where you save the files to. TARGET_PRUNE is the final targeted sparsity. OSBS=True means the algorithm is in OSBC pruning mode; OSBS=False means the algorithm is in OBC pruning mode.

In main, there is model_path and data_path, which tells you where the model and datasets are stored. In osbc_prune.py, MBP controls whether the algorithm is in magnitude based pruning mode, which overrides OSBS if MPS=True. 

Pruning Commands:

N-MNIST:
```bash
python3 compression_nmnist.py
```
DVS128-Gesture:
```bash
python3 compression_dvs128.py
```

CIFAR10-DVS:
```bash
python3 compression_cifar10dvs.py
```

ASL-DVS:
```bash
python3 compression_asldvs.py
```

CIFAR100:
```bash
python3 compression_cifar100.py
```

ImageNet:
```bash
python3 compression_imagenet.py
```

FINETUNE

Finetuning is experimented on the VGG16SNN trained on CIFAR100, and SEW-ResNets trained on ImageNet. First identify how many GPUs do you have, and modify the following code to your cuda device number. The example gives 2 cuda devices. Note that the fine-tuning hours shown in paper is GPU hours, using multiple GPU just makes the process slightly quicker

CIFAR100:
```bash
CUDA_VISIBLE_DEVICES=0,1 torchrun --standalone --nnodes 1 --nproc_per_node 2 finetune_cifar100.py
```

CIFAR100:
```bash
CUDA_VISIBLE_DEVICES=0,1 torchrun --standalone --nnodes 1 --nproc_per_node 2 finetune_imagenet.py
```

QUANTIZATION

Quantization is very similar to pruning, with TARGET_BIT_WIDTH controlling the quantization bit width. In osbc_quant.py, RTN=True overwrites OSBC setting, and the algorithm rounds to nearest quantization grid

Quantization Commands:

N-MNIST:
```bash
python3 quantization_nmnist.py
```
DVS128-Gesture:
```bash
python3 quantization_dvs128.py
```
CIFAR10-DVS:
```bash
python3 quantization_cifar10dvs.py
```
