# search.sh d_set latent projection

CUTOUT=0
KOPS=1
TEMP=1
AF=10
GPU=0

function run_reduced_svhn {
    DATASET=reduced_svhn
    MODEL=wresnet40_2
    EPOCH=160
    BATCH=128
    LR=0.05
    WD=0.01
    ALR=0.001
}

# svhn
function run_svhn {
    DATASET=svhn
    MODEL=wresnet40_2
    EPOCH=160
    BATCH=128
    LR=0.005
    WD=0.001
    ALR=0.001
}

# cifar10
function run_reduced_cifar10 {
    DATASET=reduced_cifar10
    MODEL=wresnet40_2
    EPOCH=200
    BATCH=128
    LR=0.1
    WD=0.0005
    ALR=0.001
    CUTOUT=16
}

function run_reduced_cifar100 {
    DATASET=reduced_cifar100
    MODEL=wresnet40_2
    EPOCH=200
    BATCH=128
    LR=0.05
    WD=0.005
    CUTOUT=16
    ALR=0.001
}

# cifar100
function run_cifar100 {
    DATASET=reduced_cifar100
    MODEL=wresnet40_2
    EPOCH=200
    BATCH=32
    LR=0.025
    WD=0.0005
    ALR=0.001
}

# cifar10
function run_cifar10 {
    DATASET=cifar10
    MODEL=wresnet40_2
    EPOCH=200
    BATCH=512
    LR=0.1
    WD=0.0005
    ALR=0.001
    CUTOUT=16
}

if [ $1 = "reduced_cifar10" ]; then
    run_reduced_cifar10
elif [ $1 = "cifar10" ]; then
    run_cifar10
elif [ $1 = "reduced_cifar100" ]; then
    run_reduced_cifar100
elif [ $1 = "cifar100" ]; then
    run_cifar100
elif [ $1 = "reduced_svhn" ]; then
    run_reduced_svhn
elif [ $1 = "svhn" ]; then
    run_svhn
fi

AM=vector
if [ $3 = "cnn" ]; then
    AM=cnn
elif [ $3 = "projection" ]; then
    AM=projection
elif [ $3 = "vector" ]; then
    AM=vector
fi

if [ $2 = "latent" ]; then
    echo "### Search in LATENT Space ####"
    SAVE=${DATASET}_${MODEL}_${BATCH}_${EPOCH}_alr${ALR}_af${AF}_cutout_${CUTOUT}_lr${LR}_wd${WD}_${AM}_kops_${KOPS}_npl_${NPL}_latent
    python ada_aug/train_search.py --k_ops ${KOPS} --aug_mode ${AM} --search_latent --report_freq 10 --num_workers 0 --epochs ${EPOCH} --batch_size ${BATCH} --learning_rate ${LR} --dataset ${DATASET} --model_name ${MODEL} --save ${SAVE} --gpu ${GPU} --weight_decay ${WD} --arch_learning_rate ${ALR} --arch_freq ${AF} --cutout --cutout_length ${CUTOUT} --temperature ${TEMP}
elif [ $2 = "input" ]; then
    echo "### Search in INPUT Space ####"
    SAVE=${DATASET}_${MODEL}_${BATCH}_${EPOCH}_alr${ALR}_af${AF}_cutout_${CUTOUT}_lr${LR}_wd${WD}_${AM}_input
    python ada_aug/train_search.py --aug_mode ${AM} --report_freq 10 --num_workers 0 --epochs ${EPOCH} --batch_size ${BATCH} --learning_rate ${LR} --dataset ${DATASET} --model_name ${MODEL} --save ${SAVE} --gpu ${GPU} --weight_decay ${WD} --arch_learning_rate ${ALR} --arch_freq ${AF} --cutout --cutout_length ${CUTOUT} --temperature ${TEMP}
fi
