folder=""
datadir=""

setting=$2
alpha=$3
beta=$4
norm_type=$5
niter_warmup=$6
if [ $setting -eq 1 ]; then
    # get alpha and beta as float from 3rd and 4th argument
    sym_flag="--gate_sym"
    g_balance_flag=""
    jobname=graphglobalepoch-m-$setting-alpha$alpha-beta$beta-sym-$norm_type-nwarmup$niter_warmup
#else if
elif [ $setting -eq 2 ]; then
    # get alpha and beta as float from 3rd and 4th argument
    sym_flag=""
    g_balance_flag=""
    jobname=graphglobalepoch-m-$setting-alpha$alpha-beta$beta-asym-$norm_type-nwarmup$niter_warmup
elif [ $setting -eq 3 ]; then
    # get alpha and beta as float from 3rd and 4th argument
    sym_flag="--gate_sym"
    g_balance_flag="--g_blance"
    jobname=graphglobalepoch-m-$setting-alpha$alpha-beta$beta-sym-$norm_type-gblance-nwarmup$niter_warmup
#else if
elif [ $setting -eq 4 ]; then
    # get alpha and beta as float from 3rd and 4th argument
    sym_flag=""
    g_balance_flag="--g_blance"
    jobname=graphglobalepoch-m-$setting-alpha$alpha-beta$beta-asym-$norm_type-gblance-nwarmup$niter_warmup
else
    echo "Invalid setting"
    exit 1
fi
wandb_key=$(cat wandb_key.secret)


mkdir -p $folder/ckpts/

n_gpu=4
gpu_indices=0,1,2,3


args="
--data $datadir \
--base_arch transformer \
--architecture sgsgsgsgsgsg \
--gate_name graph_global_per_epoch \
--gate_alpha $alpha \
--gate_beta $beta \
--gate_norm_type $norm_type \
--gate_softmax_temp 1 \
$sym_flag $g_balance_flag \
--niter_gate_adj_warmup $niter_warmup \
--nlayers 6 \
--hid-sz 352 \
--inner-hid-sz 352 \
--nheads 8 \
--block-sz 512 \
--attn-span 1024 \
--dropout 0.1 \
--load_balance 0.01 \
--optim adam \
--lr 0.0007 \
--lr-warmup 4000 \
--niter 80 \
--batch-sz 48 \
--batch-split 2 \
--nbatches 1000 \
--distributed \
--checkpoint $folder/ckpts/$jobname \
--wandb-entity "your-account" \
--project-name graphSMoE \
--job-name $jobname \
--wandb-flag
"


echo Start training time
date
echo "Training ..."
SETTING=$setting WANDB_API_KEY=$wandb_key CUDA_VISIBLE_DEVICES=$gpu_indices python -m torch.distributed.launch --master_port $1 --nproc_per_node=$n_gpu --use_env train.py $args --job-name $jobname

SETTING=$setting WANDB_API_KEY=$wandb_key CUDA_VISIBLE_DEVICES=$gpu_indices python -m torch.distributed.launch --master_port $1 --nproc_per_node=$n_gpu --use_env train.py $args --job-name eval-$jobname  --resume --full-eval-mode

echo Finish time
date
