#! /bin/bash
GPU=0
KOPS=3
TEMP=3
DATASET=TBD

function run_reduced_svhn {
    DATASET=reduced_svhn
    MODEL=wresnet28_10
    EPOCH=160
    BATCH=128
    LR=0.05
    WD=0.01
    TRAIN_PORTION=1
    PP=POLICY_PATH
}

# svhn
function run_svhn {
    DATASET=svhn
    MODEL=wresnet28_10
    EPOCH=160
    BATCH=256
    LR=0.005
    WD=0.001
    TRAIN_PORTION=1
    CUTOUT=16
    PP=POLICY_PATH
}

# cifar10
function run_reduced_cifar10 {
    DATASET=reduced_cifar10
    MODEL=wresnet40_2
    EPOCH=200
    BATCH=128
    LR=0.1
    WD=0.0005
    CUTOUT=16
    TRAIN_PORTION=1
    PP=POLICY_PATH
}

function run_reduced_cifar100 {
    DATASET=reduced_cifar100
    MODEL=wresnet40_2
    EPOCH=200
    BATCH=256
    LR=0.1
    WD=0.0005
    CUTOUT=16
    TRAIN_PORTION=1
    PP=POLICY_PATH
}

# cifar100
function run_cifar100 {
    DATASET=cifar100
    MODEL=wresnet40_2
    EPOCH=200
    BATCH=256
    LR=0.1
    WD=0.0005
    CUTOUT=16
    TRAIN_PORTION=1
    PP=POLICY_PATH
}

# cifar10
function run_cifar10 {
    DATASET=cifar10
    MODEL=wresnet40_2
    EPOCH=200
    BATCH=256
    LR=0.1
    WD=0.0005
    CUTOUT=16
    TRAIN_PORTION=1
    PP=POLICY_PATH
}

function run_transfer {
    MODEL=resnet50
    EPOCH=180
    BATCH=64
    LR=0.1
    WD=0.0001
    CUTOUT=16
    TRAIN_PORTION=1
    PP=PP=POLICY_PATH
}

if [ $1 = "reduced_cifar10" ]; then
    run_reduced_cifar10
elif [ $1 = "cifar10" ]; then
    run_cifar10
elif [ $1 = "reduced_cifar100" ]; then
    run_reduced_cifar100
elif [ $1 = "cifar100" ]; then
    run_cifar100
elif [ $1 = "reduced_svhn" ]; then
    run_reduced_svhn
elif [ $1 = "svhn" ]; then
    run_svhn
elif [ $1 = "pet" ]; then
    run_transfer
    DATASET="pet"
    GPU=3
elif [ $1 = "flower" ]; then
    run_transfer
    DATASET="flower"
elif [ $1 = "car" ]; then
    run_transfer
    DATASET="car"
elif [ $1 = "aircraft" ]; then
    run_transfer
    DATASET="aircraft"
fi

SAVE=$2_${DATASET}_${MODEL}_${BATCH}_${EPOCH}_cutout_${CUTOUT}_lr${LR}_wd${WD}_kops_${KOPS}_TEMP_${TEMP}
python ada_aug/train.py --k_ops ${KOPS} --report_freq 10 --num_workers 8 --add_aug --aug_mode $2 --temperature ${TEMP} --epochs ${EPOCH} --batch_size ${BATCH} --learning_rate ${LR} --dataset ${DATASET} --model_name ${MODEL} --save ${SAVE} --gpu ${GPU} --weight_decay ${WD} --policy_path ${PP} --train_portion 1
