# Specialization of Sub-Paths for Adaptive Depth Networks

This is the official implementation of “Specilaization of sub-paths for adaptive depth networks” that is under submission to ICLR 2023. 


## Requirements

We conducted experiments under

- python 3.8
- pytorch 1.12, torchvision 0.13, cuda11


## Training

To train ResNet50-Multi on ILSVRC2012, run this command:

```train
CUDA_VISIBLE_DEVICES=0,1 python train_ilsvrc.py --model=ResNet50_Multi --batch_size=256 --dataset_path=<path_to_magenet_dataset> --alpha 0.5
```

To train MobileNetV2-Multi, run this command:

```train
CUDA_VISIBLE_DEVICES=0,1 python train_ilsvrc.py --model=MobileNetV2_Multi --batch_size=256 --dataset_path=<path_to_magenet_dataset> --epoch 300 --lr-scheduler multisteplr --lr 0.1 --lr-multi-steps 150 225 285 --weight_decay 1e-5 --lr-gamma 0.1 --alpha 0.7
```


## Evaluation

*ResNet50_Multi*, *MobileNetV2_Multi* are available for evaluation. 
Change the name of the model and the pretrained file appropriately. 

To evaluate **super-net** of ResNet50-Multi, run:

```eval
python eval_ilsvrc.py --model=ResNet50_Multi --pretrained=./pretrained/ILSVRC-ResNet50-Multi.pth --dataset_path=<path_to_dataset> 
```

To evaluate **base-net** of ResNet50-Multi, add '*--skip*' option and run:

```eval
python eval_ilsvrc.py --model=ResNet50_Multi --pretrained=./pretrained/ILSVRC-ResNet50-Multi.pth --dataset_path=<path_to_dataset>  --skip
```


## Results and Pretrained models

Our adpative depth networks achieve the following performance on ILSVRC2012 validation set. 

### ILSVRC2012 Classifcation

| Model name                | Acc@1  | Acc@5 |  FLOPs   |          |
| ------------------------- |------------- | ----------- | -------- | ------- |
| ResNet50-Multi (super-net) |     77.6%   |   93.7%     |   4.11G  | [Donwload](https://drive.google.com/file/d/1QSjMYcpZvEN4Wtbxq5HGnSXEhalC8If8/view?usp=sharing)
| ResNet50-Multi (base-net)   |     76.1%   |   93.2%     |   2.58G  |    
| MobileNetV2-Multi (super-net) |     72.7%   |   90.8%     |   0.32G  | [Download](https://drive.google.com/file/d/1m2PZzX0lRAYIbRC0K9RzYMchuYRR3l-F/view?usp=sharing)
| MobileNetV2-Multi (base-net)   |     70.7%   |   89.8%     |   0.25G  |    





