# Main training shell scripts

TODAY=$(date "+%Y%m%d")
TYPES="pc"
EPOCHS=300

DATAPATH='./data/'
CIFAR10=${DATAPATH}'cifar10'
CIFAR100=${DATAPATH}'cifar100'

LOGPATH='./logs/'
RESPATH=${LOGPATH}'resnet/'
VGGPATH=${LOGPATH}'vgg/'
DENSEPATH=${LOGPATH}'densenet/'
MOBILEPATH=${LOGPATH}'mobilenetv2/'


PROGRAM_NAME=`/usr/bin/basename "$0"`
echo shell arg 0: $0
echo USING BASENAME: ${PROGRAM_NAME}
arg_data=default
arg_arch=default

function print_usage(){
/bin/cat << EOF
Usage:
    ${PROGRAM_NAME} [-d arg_data] [-a arg_arch] [-p arg_peer] [-e arg_exp] [-s seed] [-c arg_cuda]
Option:
    -d, dataset
    -a, model
    -p, num_peers
    -e, exposure
    -s, seed
    -c, cuda
EOF
}
if [ $# -eq 0 ];
then
    print_usage
    exit 1
fi

while getopts "d:a:p:e:s:c:h" opt
do
    case $opt in
        d) arg_data=$OPTARG; echo "ARG DATA: $arg_data";;
        a) arg_arch=$OPTARG; echo "ARG ARCH: $arg_arch";;
        p) arg_peer=$OPTARG; echo "ARG PEER: $arg_peer";;
        e) arg_exp=$OPTARG; echo "ARG EXP: $arg_exp";;
        s) arg_seed=$OPTARG; echo "ARG SEED: $arg_seed";;
        c) arg_cuda=$OPTARG; echo "ARG EXP: $arg_cuda";;
        h) print_usage;;
    esac
done

## CIFAR-10
if [ "$arg_data" = "cifar10" ];
then
    if [ "$arg_arch" = "My_ResNet" ]
    then
    python3 main.py \
                --dataset=$arg_data \
                --arch=$arg_arch \
                --num_branches=$arg_peer \
                --depth=32 \
                --epochs=${EPOCHS} \
                --batch-size=128 \
                --test-batch-size=128 \
                --milestones 150 225 \
                --T 3 \
                --seed=$arg_seed \
                --consistency_rampup=80 \
                --data=${CIFAR10} \
                --save=${RESPATH}"${TODAY}_r32_$arg_peer-b_$arg_exp-e_$arg_seed-seed_${TYPES}_$arg_data" \
                --margin=$arg_exp \
                --ngpu=$arg_cuda
    fi
    if [ "$arg_arch" = "My_VGG" ]
    then
    python3 main.py \
                --dataset=$arg_data \
                --arch=$arg_arch \
                --num_branches=$arg_peer \
                --depth=16 \
                --epochs=${EPOCHS} \
                --batch-size=128 \
                --test-batch-size=128 \
                --milestones 150 225 \
                --T 3 \
                --seed=$arg_seed \
                --consistency_rampup=80 \
                --data=${CIFAR10} \
                --save=${VGGPATH}"${TODAY}_vgg16_$arg_peer-b_$arg_exp-e_$arg_seed-seed_${TYPES}_$arg_data" \
                --margin=$arg_exp \
                --ngpu=$arg_cuda
    fi
    if [ "$arg_arch" = "My_DenseNet" ]
    then
        python3 main.py \
                --dataset=$arg_data \
                --arch=$arg_arch \
                --num_branches=$arg_peer \
                --epochs=${EPOCHS} \
                --batch-size=128 \
                --test-batch-size=128 \
                --milestones 150 225 \
		--T 3 \
                --seed=$arg_seed \
                --consistency_rampup=80 \
                --data=${CIFAR10} \
                --save=${DENSEPATH}"${TODAY}_d4012_$arg_peer-b_$arg_exp-e_$arg_seed-seed_${TYPES}_$arg_data" \
                --margin=$arg_exp \
                --ngpu=$arg_cuda
    fi
    if [ "$arg_arch" = "My_MobileNetV2" ]
    then
        python3 main.py \
                --dataset=$arg_data \
                --arch=$arg_arch \
                --num_branches=$arg_peer \
                --epochs=${EPOCHS} \
                --batch-size=128 \
                --test-batch-size=128 \
                --milestones 150 225 \
		--T 3 \
                --seed=$arg_seed \
                --consistency_rampup=80 \
                --data=${CIFAR10} \
                --save=${MOBILEPATH}"${TODAY}_mbnv2_$arg_peer-b_$arg_exp-e_$arg_seed-seed_${TYPES}_$arg_data" \
                --margin=$arg_exp \
                --ngpu=$arg_cuda
    fi
    if [ "$arg_arch" = "OurNet" ];
    then
        python3 main.py \
                --dataset=$arg_data \
                --arch=$arg_arch \
                --dml_arch='VGG' \
                --depth=16 \
                --num_branches=$arg_peer \
                --epochs=${EPOCHS} \
                --batch-size=128 \
                --test-batch-size=256 \
                --milestones 150 225 \
                --T 3 \
                --seed=$arg_seed \
                --consistency_rampup=80 \
                --data=${CIFAR10} \
                --save=${VGGPATH}"${TODAY}_NETvgg16_$arg_peer-b_$arg_exp-e_$arg_seed-seed_${TYPES}_$arg_data" \
		--margin=$arg_exp \
                --ngpu=$arg_cuda
    fi
fi

# CIFAR-100
if [ "$arg_data" = "cifar100" ];
then
    if [ "$arg_arch" = "My_ResNet" ]
    then
    python3 main.py \
                --dataset=$arg_data \
                --arch=$arg_arch \
                --num_branches=$arg_peer \
                --depth=32 \
                --epochs=${EPOCHS} \
                --batch-size=128 \
                --test-batch-size=128 \
                --milestones 150 225 \
                --T 3 \
                --seed=$arg_seed \
                --consistency_rampup=80 \
                --data=${CIFAR100} \
                --save=${RESPATH}"${TODAY}_r32_$arg_peer-b_$arg_exp-e_$arg_seed-seed_${TYPES}_$arg_data" \
                --margin=$arg_exp \
                --ngpu=$arg_cuda
    fi
    if [ "$arg_arch" = "My_VGG" ]
    then
    python3 main.py \
                --dataset=$arg_data \
                --arch=$arg_arch \
                --num_branches=$arg_peer \
                --depth=16 \
                --epochs=${EPOCHS} \
                --batch-size=128 \
                --test-batch-size=128 \
                --milestones 150 225 \
                --T 3 \
                --seed=$arg_seed \
                --consistency_rampup=80 \
                --data=${CIFAR100} \
                --save=${VGGPATH}"${TODAY}_vgg16_$arg_peer-b_$arg_exp-e_$arg_seed-seed_${TYPES}_$arg_data" \
                --margin=$arg_exp \
                --ngpu=$arg_cuda
    fi
    if [ "$arg_arch" = "My_DenseNet" ]
    then
        python3 main.py \
                --dataset=$arg_data \
                --arch=$arg_arch \
                --num_branches=$arg_peer \
                --epochs=${EPOCHS} \
                --batch-size=128 \
                --test-batch-size=128 \
                --milestones 150 225 \
		--T 3 \
                --seed=$arg_seed \
                --consistency_rampup=80 \
                --data=${CIFAR100} \
                --save=${DENSEPATH}"${TODAY}_d4012_$arg_peer-b_$arg_exp-e_$arg_seed-seed_${TYPES}_$arg_data" \
                --margin=$arg_exp \
                --ngpu=$arg_cuda
    fi
    if [ "$arg_arch" = "My_MobileNetV2" ]
    then
        python3 main.py \
                --dataset=$arg_data \
                --arch=$arg_arch \
                --num_branches=$arg_peer \
                --epochs=${EPOCHS} \
                --batch-size=128 \
                --test-batch-size=128 \
                --milestones 150 225 \
		--T 3 \
                --seed=$arg_seed \
                --consistency_rampup=80 \
                --data=${CIFAR100} \
                --save=${MOBILEPATH}"${TODAY}_mbnv2_$arg_peer-b_$arg_exp-e_$arg_seed-seed_${TYPES}_$arg_data" \
                --margin=$arg_exp \
                --ngpu=$arg_cuda
    fi
    if [ "$arg_arch" = "OurNet" ];
    then
        python3 main.py \
                --dataset=$arg_data \
                --arch=$arg_arch \
                --dml_arch='MobileNetV2' \
                --num_branches=$arg_peer \
                --epochs=${EPOCHS} \
                --batch-size=128 \
                --test-batch-size=256 \
                --milestones 150 225 \
                --T 3 \
                --seed=$arg_seed \
                --consistency_rampup=80 \
                --data=${CIFAR100} \
                --save=${MOBILEPATH}"${TODAY}_NETmbnv2_$arg_peer-b_$arg_exp-e_$arg_seed-seed_${TYPES}_$arg_data" \
		--margin=$arg_exp \
                --ngpu=$arg_cuda
    fi
fi
