#! /bin/bash

CUTOUT=0
GC=5
DATASET=TBD
SD=0
ALP=1
MMP=0
LM=0
NCH=32
MM=K1
TRAIN_PORTION=1
LR=0.0005
WD=0.0005
TRAIN_PORTION=1
EPOCH=100
BATCH=128
K=2
GPU=0

ROOT=<PATH_TO_DATASET>

# cifar10
function run_cifar10 {
    DATASET=cifar10
    MODEL=cifar_resnet18
    MMP=<PATH_TO_PRETRAINED_MODEL>
}

function run_cifar100 {
    DATASET=cifar100
    MODEL=preactresnet18
    if [ ${MODEL} = "resnet18" ]; then
        MMP=<PATH_TO_PRETRAINED_MODEL>
    elif [ ${MODEL} = "preactresnet18" ]; then
        MMP=<PATH_TO_PRETRAINED_MODEL>
    elif [ ${MODEL} = "wresnet28_10" ]; then
        MMP=<PATH_TO_PRETRAINED_MODEL>
    fi
}

function run_tiny_imagenet {
    DATASET=tiny_imagenet
    MODEL=preactresnet18
    if [ ${MODEL} = "resnet18" ]; then
        MMP=<PATH_TO_PRETRAINED_MODEL>
    elif [ ${MODEL} = "preactresnet18" ]; then
        MMP=<PATH_TO_PRETRAINED_MODEL>
    fi
}

if [ $1 = "cifar10" ]; then
    run_cifar10
elif [ $1 = "cifar100" ]; then
    run_cifar100
elif [ $1 = "tiny_imagenet" ]; then
    run_tiny_imagenet
fi

if [ $2 = "train" ]; then
    SAVE=mixngaugmentor_${DATASET}_${MODEL}_${BATCH}_${EPOCH}_cutout_${CUTOUT}_lr${LR}_wd${WD}_alpha${ALP}_layermix${LM}_${MM}@${NCH}_K${K}
    python ./train_kmixaugmentor.py  --dataroot ${ROOT} --k ${K} --masknet_model ${MM} --mask_n_channel ${NCH} --layer_mix ${LM} --mix_alpha ${ALP} --report_freq 50 --num_workers 4  --epochs ${EPOCH} --batch_size ${BATCH} --learning_rate ${LR} --dataset ${DATASET} --model_name ${MODEL} --save ${SAVE} --gpu ${GPU} --weight_decay ${WD}  --train_ratio ${TRAIN_PORTION} --pretrained_model_path ${MMP}
fi
