#!/bin/bash
RANK=${1}
DATASET=${2}
ARCH=${3}

declare -A vgg_lr
vgg_lr=(
    ["mnist"]="0.06786" ["fmnist"]="0.009597" ["svhn"]="0.094015"
    ["cifar10"]="0.048982" ["cifar100"]="0.180451" ["stl10"]="0.01842"
)

declare -A resnet_lr
resnet_lr=(
    ["mnist"]="0.013296" ["fmnist"]="0.13025" ["svhn"]="0.009597"
    ["cifar10"]="0.06786" ["cifar100"]="0.13025" ["stl10"]="0.25"
)

declare -A densenet_lr
densenet_lr=(
    ["mnist"]="0.094015" ["fmnist"]="0.009597" ["svhn"]="0.06786"
    ["cifar10"]="0.094015" ["cifar100"]="0.13025" ["stl10"]="0.06786"
)

if [ ${ARCH} = "vgg" ]
then
    lr=${vgg_lr[${DATASET}]}
elif [ ${ARCH} = "resnet" ]
then
    lr=${resnet_lr[${DATASET}]}
else
    lr=${densenet_lr[${DATASET}]}
fi

CFG="--cfg configs/${DATASET}/ce_${DATASET}_${ARCH}.yaml"
GPU_OPT="ddp False dp False mixed_precision True rank ${RANK}"
ARGS="
    save_only_result False
    output_dir ./output/znc_visual_tr
    name Best.${DATASET}.${ARCH}
    classifier.type OrthLinear
    classifier.bias False
    train.optimizer.base_lr ${lr}
"

echo python main/znc_visual_tr.py ${CFG} ${GPU_OPT} ${ARGS}
python main/znc_visual_tr.py ${CFG} ${GPU_OPT} ${ARGS}
