## datasets(Follow [TAPS](https://github.com/MattWallingford/TAPS/tree/master))

**ImageNet-to-Sketch**
The 5 datasets comprising ImagetNet-to-Sketch can be download from the [PiggyBack repository](https://github.com/arunmallya/piggyback) at this link: [https://uofi.box.com/s/ixncr3d85guosajywhf7yridszzg5zsq](https://uofi.box.com/s/ixncr3d85guosajywhf7yridszzg5zsq)

**DomainNet**
The 6 DomainNet datasets can be downloaded from the [original website](http://ai.bu.edu/M3SDA/). A formatted version can be downloaded [here](https://drive.google.com/file/d/1Eowq0kHzS0MKgo1oglqJIRAC_wDlqWP9/view?usp=sharing). The structure of the folder should be the following:
```
├── DomainNet
    ├── sketch
        ├── train
        ├── test
    ├── infograph
        ├── train
        ├── test
    ...
    ├── clipart
        ├── train
        ├── test
```

Place the datasets in the datasets folder. If you choose to place them elsewhere use the --dataset flag to point towards the dataset you would like to fine-tune on.

## Train

Training enviroment:
```
torch==2.0.0+cu118
torchaudio==2.0.1+cu118
torchvision==0.15.1+cu118
timm==0.6.13
tqdm
Pillow==9.0.1
opencv-python==4.7.0.72
```


Training options:
```
optional arguments:
  -h, --help            show this help message and exit
  --gpuNums GPUNUMS     number of gpus
  --nEpochs NEPOCHS     number of epochs to train for
  --warmup WARMUP       the epochs for warmup
  --lr LR               Learning Rate. Default=0.1
  --mask_lr MASK_LR     Mask Learning Rate. Default=0.2
  --optim {ADAM,SGD,ADAMW}
                        optimizer. Default=ADAM
  --wd WD               weight decay. Default=0.0
  --momentum MOMENTUM   momentum. Default=0.9
  --threads THREADS     number of threads for data loader to use
  --backbone {vit_small_patch16_224,vit_base_patch16_224,resnet18,resnet34,resnet50,wide_resnet,timm_resnet18,timm_resnet26,timm_resnet34,timm_resnet50,densenet121,timm_densenet121}
                        backbone of the model
  --batchSize BATCHSIZE
                        training batch size
  --dataset_name DATASET_NAME
                        which dataset to train
  --resume_from RESUME_FROM
                        iteration to resume from
  --save_path SAVE_PATH
                        path to save the model
  --visual_file VISUAL_FILE
                        path to save the visual_data
  --logname LOGNAME     name of the logging file
  --chkname CHKNAME     name of the checkpoints folder
  --p P                 end p. Default=0.5
  --p_T P_T             the update T of p. Default=10 epochs
  --cropped CROPPED     crop the pic or not
  --num_iterations NUM_ITERATIONS
                        the iteration times of the all tasks
  --nprocs NPROCS       number of gpus
  --local_rank LOCAL_RANK
                        node rank for distributed training
  --seed SEED           seed for initializing training.
  --ip IP
  --port PORT
```

For example, we can training ResNet50 with torch pre-trained backbone save in chk/torch/resnet50-19c8e357.pth on ImageNet_to_sketch benchmark:
```
python -u -m torch.distributed.launch --nproc_per_node=1 params_share_joint.py --dataset_name=datasets/ImageNet_sketch --lr=0.02 --nEpochs=50 --p_T=10 --optim=SGD --cropped=True --backbone=resnet50 --chkname=chk/torch/resnet50-19c8e357.pth
```

If the model is trained on incremental learning mode, set num_iterations=1 to confine the algorithm visit datasets once. Otherwise, our algorithmn will get the final supernet iteratively. Its output supernet will be saved in `save_path`.

## Visualization


Follow `visualization.sh` to visualize the first five layers of ResNet18 after 5 iterations. Change `chkname` and `backbone` for other pretrained model.
```
other optional arguments:
  --num_iterations NUM_ITERATIONS
                        the iteration times of the all tasks
  --visual_layers VISUAL_LAYERS
                        num of visualized layers
  --visual_start VISUAL_START
                        the id of the start visualized layers
  --resume_from RESUME_FROM
                        iteration to resume from
```

## TODO

- more efficient save method.