#!/bin/bash

### still running ###

SRC_PREFIX="reproduce_src"
LOG_PREFIX="eval_results"

BASE_DATASETS=("tiny-imagenet") # "cifar10" "cifar100" "tiny-imagenet"  "pacs"
METHODS=("BATTA") 
# METHODS=("SimATTA_BIN" "SimATTA") 
# METHODS=("Src" "BN_Stats" "TENT" "EATA" "SAR" "SoTTA" "RoTTA" "CoTTA") 


### TTA ############################
# "BN_Stats" "TENT" "EATA" "SAR" "SoTTA" "RoTTA" "CoTTA"

### ADA baselines ################## you must NOT set TGTS="cont" (support only online setting)
# "EADA" "CLUE" "DIANA"

### Src-free ADA baselines ######### you must NOT set TGTS="cont" (support only online setting)
# "ELPT" "MHPL" 

### Active TTA baselines ###########
# "SimATTA_BIN" "SimATTA"

### Ours ###########################
# "BATTA"

SEEDS=(0 1 2)
DISTS=(1)
VALIDATIONS=(
            
            ## cont setting ##
            "--log_name log/"

            ## random setting : you must set TGTS="cont" ##
            # "--log_name log/ --random_setting"    

            ## run baseline with ssl loss ##
            # "--log_name log/ --enable_batta"


           )



MAIN_SCRIPT="main.py"


### continual adaptation ###########################
TGTS="cont"

### online adaptation ##############################

  ### cifar10, cifar100, tiny-imagenet ###
# TGTS="gaussian_noise-5
#       shot_noise-5
#       impulse_noise-5
#       defocus_blur-5
#       glass_blur-5
#       motion_blur-5
#       zoom_blur-5
#       snow-5
#       frost-5
#       fog-5
#       brightness-5
#       contrast-5
#       elastic_transform-5
#       pixelate-5
#       jpeg_compression-5" 

  ### PACS ###
# TGTS="art_painting
#       cartoon
#       sketch"

 
echo BASE_DATASETS: "${BASE_DATASETS[@]}"
echo METHODS: "${METHODS[@]}"
echo SEEDS: "${SEEDS[@]}"
GPUS=(0 1 2 3 4 5 6 7) #available gpus
NUM_GPUS=${#GPUS[@]}

sleep 1 # prevent mistake
mkdir raw_logs # save console outputs here


#### Useful functions
wait_n() {
  #limit the max number of jobs as NUM_MAX_JOB and wait
  background=($(jobs -p))
  local default_num_jobs=8 #num concurrent jobs
  local num_max_jobs=${1:-$default_num_jobs}
  if ((${#background[@]} >= num_max_jobs)); then
    wait -n
  fi
}


test_time_adaptation() {
  ###############################################################
  ###### Run Baselines & Ours; Evaluation: Target domains  ######
  ###############################################################

  i=0

  for DATASET in "${BASE_DATASETS[@]}"; do
    for METHOD in "${METHODS[@]}"; do
      for validation in "${VALIDATIONS[@]}"; do
        update_every_x="64"
        memory_size="64"
        SEED="0"
        lr="0.001" #other baselines
        weight_decay="0"

        for SEED in "${SEEDS[@]}"; do #multiple seeds
          if [ "${DATASET}" = "pacs" ] || [ "${DATASET}" = "tiny-imagenet" ]; then
            MODEL="resnet18_pretrained"

            if [ "${METHODS}" = "BATTA" ] || [ "${METHODS}" = "SimATTA" ] || [ "${METHODS}" = "SimATTA_BIN" ] || [ "${METHODS}" = "ELPT" ] || [ "${METHODS}" = "MHPL" ] || [ "${METHODS}" = "EADA" ] || [ "${METHODS}" = "CLUE" ] || [ "${METHODS}" = "DIANA" ]; then 
              CP="--load_checkpoint_path pretrained_weights/${DATASET}/normal_cp/cp_last.pth.tar"
            elif [ "${METHODS}" = "Src" ] || [ "${METHODS}" = "TENT" ] || [ "${METHODS}" = "EATA" ] || [ "${METHODS}" = "SAR" ] || [ "${METHODS}" = "SoTTA" ] || [ "${METHODS}" = "RoTTA" ] || [ "${METHODS}" = "CoTTA" ]; then
              CP="--load_checkpoint_path pretrained_weights/${DATASET}/normal_cp/cp_last.pth.tar"
              # CP="--load_checkpoint_path pretrained_weights/${DATASET}/enhanced_cp/cp_last.pth.tar"
            fi 
          elif [ "${DATASET}" = "cifar10" ] || [ "${DATASET}" = "cifar100" ]; then
            MODEL="resnet18"
            CP="--load_checkpoint_path log/${DATASET}/Src/tgt_test/reproduce_src_${SEED}/cp/cp_last.pth.tar"

          elif [ "${DATASET}" = "imagenet" ]; then
            MODEL="resnet18_pretrained"
            CP=""

          fi
          

          if [ "${METHOD}" = "Src" ]; then
            EPOCH=0
            #### Train with BN
            for TGT in $TGTS; do
              python ${MAIN_SCRIPT} --gpu_idx ${GPUS[i % ${NUM_GPUS}]} --dataset $DATASET --method ${METHOD} --tgt ${TGT} --model $MODEL --epoch $EPOCH ${CP} --update_every_x ${update_every_x} --seed $SEED \
                --log_prefix ${LOG_PREFIX}_${SEED} \
                ${validation}  \
                2>&1 | tee raw_logs/${DATASET}_${LOG_PREFIX}_${SEED}_job${i}.txt &

              i=$((i + 1))
              wait_n
            done
          elif [ "${METHOD}" = "ELPT" ]; then
            EPOCH=1 # epoch is changed to elpt_interval * (elpt_sample_time + 1)
            lr="0.001"
            if [ "${DATASET}" = "pacs" ]; then
              batch_size=64
              elpt_k=5
              elpt_m=3
              elpt_interval=6
              elpt_sample_time=5
            else
              batch_size=64
              elpt_k=5
              elpt_m=3
              elpt_interval=4
              elpt_sample_time=5
            fi
            for TGT in $TGTS; do
              python ${MAIN_SCRIPT} --gpu_idx ${GPUS[i % ${NUM_GPUS}]} --dataset $DATASET --method ${METHOD} --tgt ${TGT} --model $MODEL ${CP} ${CP} --update_every_x ${update_every_x} --seed $SEED \
                --log_prefix ${LOG_PREFIX}_${SEED} --lr ${lr} --epoch ${EPOCH} --batch_size ${batch_size} --elpt_k ${elpt_k} --elpt_m ${elpt_m} --elpt_interval ${elpt_interval} --elpt_sample_time ${elpt_sample_time} --turn_to_binary --w_final_loss_wrong 0.1\
                ${validation}  \
                2>&1 | tee raw_logs/${DATASET}_${LOG_PREFIX}_${SEED}_job${i}.txt &

              i=$((i + 1))
              wait_n
            done
          elif [ "${METHOD}" = "MHPL" ]; then
            EPOCH=20
            lr="0.001"
            if [ "${DATASET}" = "pacs" ]; then
              batch_size=64
              mhpl_kk=1
              mhpl_alpha=3.0
            else
              batch_size=64
              mhpl_kk=20
              mhpl_alpha=3.0
            fi
            for TGT in $TGTS; do
              python ${MAIN_SCRIPT} --gpu_idx ${GPUS[i % ${NUM_GPUS}]} --dataset $DATASET --method ${METHOD} --tgt ${TGT} --model $MODEL ${CP} ${CP} --update_every_x ${update_every_x} --seed $SEED \
                --log_prefix ${LOG_PREFIX}_${SEED} --lr ${lr} --epoch ${EPOCH} --batch_size ${batch_size} --mhpl_kk ${mhpl_kk} --mhpl_alpha ${mhpl_alpha} --turn_to_binary --w_final_loss_wrong 0.1\
                ${validation}  \
                2>&1 | tee raw_logs/${DATASET}_${LOG_PREFIX}_${SEED}_job${i}.txt &

              i=$((i + 1))
              wait_n
            done
          elif [ "${METHOD}" = "EADA" ]; then
            EPOCH=50
            lr="0.1"
            batch_size=32
            for TGT in $TGTS; do
              python ${MAIN_SCRIPT} --gpu_idx ${GPUS[i % ${NUM_GPUS}]} --dataset $DATASET --method ${METHOD} --tgt ${TGT} --model $MODEL ${CP} ${CP} --update_every_x ${update_every_x} --seed $SEED \
                --log_prefix ${LOG_PREFIX}_${SEED} --lr ${lr} --epoch ${EPOCH} --batch_size ${batch_size} --turn_to_binary --w_final_loss_wrong 0.1\
                ${validation}  \
                2>&1 | tee raw_logs/${DATASET}_${LOG_PREFIX}_${SEED}_job${i}.txt &

              i=$((i + 1))
              wait_n
            done
          elif [ "${METHOD}" = "CLUE" ]; then
            EPOCH=20
            lr="0.0001"
            batch_size=64
            for TGT in $TGTS; do
              python ${MAIN_SCRIPT} --gpu_idx ${GPUS[i % ${NUM_GPUS}]} --dataset $DATASET --method ${METHOD} --tgt ${TGT} --model $MODEL ${CP} ${CP} --update_every_x ${update_every_x} --seed $SEED \
                --log_prefix ${LOG_PREFIX}_${SEED} --lr ${lr} --epoch ${EPOCH} --batch_size ${batch_size} --turn_to_binary --w_final_loss_wrong 0.1\
                ${validation}  \
                2>&1 | tee raw_logs/${DATASET}_${LOG_PREFIX}_${SEED}_job${i}.txt &

              i=$((i + 1))
              wait_n
            done
          elif [ "${METHOD}" = "DIANA" ]; then
            EPOCH=20
            lr="0.0001"
            batch_size=64
            for TGT in $TGTS; do
              python ${MAIN_SCRIPT} --gpu_idx ${GPUS[i % ${NUM_GPUS}]} --dataset $DATASET --method ${METHOD} --tgt ${TGT} --model $MODEL ${CP} ${CP} --update_every_x ${update_every_x} --seed $SEED \
                --log_prefix ${LOG_PREFIX}_${SEED} --lr ${lr} --epoch ${EPOCH} --batch_size ${batch_size} --turn_to_binary --w_final_loss_wrong 0.1\
                ${validation}  \
                2>&1 | tee raw_logs/${DATASET}_${LOG_PREFIX}_${SEED}_job${i}.txt &

              i=$((i + 1))
              wait_n
            done
          elif [ "${METHOD}" = "SoTTA" ]; then

            lr="0.001"
            EPOCH=1
            loss_scaler=0
            bn_momentum=0.2

            if [ "${DATASET}" = "pacs" ]; then
              high_threshold=0.99
            elif [ "${DATASET}" = "tiny-imagenet" ]; then
              high_threshold=0.33
            elif [ "${DATASET}" = "cifar10" ]; then
              high_threshold=0.99
            elif [ "${DATASET}" = "cifar100" ]; then
              high_threshold=0.66
            elif [ "${DATASET}" = "imagenet" ]; then
              high_threshold=0.33
            fi
            #### Train with BN

            for dist in "${DISTS[@]}"; do
              for memory_type in "HUS"; do
                for TGT in $TGTS; do
                  python ${MAIN_SCRIPT} --gpu_idx ${GPUS[i % ${NUM_GPUS}]} --dataset $DATASET --method SoTTA --tgt ${TGT} --model $MODEL --epoch $EPOCH ${CP} --seed $SEED \
                    --remove_cp --online --use_learned_stats --lr ${lr} --weight_decay ${weight_decay} --tgt_train_dist ${dist} --update_every_x ${update_every_x} --memory_size ${memory_size} --memory_type ${memory_type} --bn_momentum ${bn_momentum} \
                    --log_prefix ${LOG_PREFIX}_${SEED}_dist${dist} \
                    --loss_scaler ${loss_scaler} --sam \
                    ${validation} \
                    --high_threshold ${high_threshold} \
                    2>&1 | tee raw_logs/${DATASET}_${LOG_PREFIX}_${SEED}_job${i}.txt &

                  i=$((i + 1))
                  wait_n
                done
              done
            done
          elif [ "${METHOD}" = "RoTTA" ]; then
            EPOCH=1
            loss_scaler=0
            lr="0.001"
            bn_momentum=0.05
            #### Train with BN

            for dist in "${DISTS[@]}"; do

              for memory_type in "CSTU"; do
                for TGT in $TGTS; do
                  python ${MAIN_SCRIPT} --gpu_idx ${GPUS[i % ${NUM_GPUS}]} --dataset $DATASET --method "RoTTA" --tgt ${TGT} --model $MODEL --epoch $EPOCH ${CP} --seed $SEED \
                    --remove_cp --online --use_learned_stats --lr ${lr} --weight_decay ${weight_decay} --tgt_train_dist ${dist} --update_every_x ${update_every_x} --memory_size ${memory_size} --memory_type ${memory_type} --bn_momentum "0.05" \
                    --log_prefix "${LOG_PREFIX}_${SEED}_dist${dist}" \
                    --loss_scaler ${loss_scaler} \
                    ${validation} \
                    2>&1 | tee raw_logs/${DATASET}_${LOG_PREFIX}_${SEED}_job${i}.txt &

                  i=$((i + 1))
                  wait_n
                done
              done
            done
          elif [ "${METHOD}" = "BN_Stats" ]; then
            EPOCH=1
            #### Train with BN
            for dist in "${DISTS[@]}"; do
              for TGT in $TGTS; do

                python ${MAIN_SCRIPT} --gpu_idx ${GPUS[i % ${NUM_GPUS}]} --dataset $DATASET --method ${METHOD} --tgt ${TGT} --model $MODEL --epoch $EPOCH ${CP} --seed $SEED \
                  --remove_cp --online --tgt_train_dist ${dist} --update_every_x ${update_every_x} --memory_size ${memory_size} \
                  --lr ${lr} --weight_decay ${weight_decay} \
                  --log_prefix ${LOG_PREFIX}_${SEED}_dist${dist} \
                  ${validation} \
                  2>&1 | tee raw_logs/${DATASET}_${LOG_PREFIX}_${SEED}_job${i}.txt &
                  
                i=$((i + 1))
                wait_n
              done
            done
          elif [ "${METHOD}" = "TENT" ]; then
            EPOCH=1
            lr=0.001
            #### Train with BN
            for dist in "${DISTS[@]}"; do
              for TGT in $TGTS; do

                python ${MAIN_SCRIPT} --gpu_idx ${GPUS[i % ${NUM_GPUS}]} --dataset $DATASET --method ${METHOD} --tgt ${TGT} --model $MODEL --epoch $EPOCH ${CP} --seed $SEED \
                  --remove_cp --online --tgt_train_dist ${dist} --update_every_x ${update_every_x} --memory_size ${memory_size} \
                  --lr ${lr} --weight_decay ${weight_decay} \
                  --log_prefix ${LOG_PREFIX}_${SEED}_dist${dist} \
                  ${validation} \
                  2>&1 | tee raw_logs/${DATASET}_${LOG_PREFIX}_${SEED}_job${i}.txt &

                i=$((i + 1))
                wait_n
              done
            done
          elif [ "${METHOD}" = "CoTTA" ]; then
            lr=0.001
            EPOCH=1
            aug_threshold=0.1

            for dist in "${DISTS[@]}"; do
              for TGT in $TGTS; do

                python ${MAIN_SCRIPT} --gpu_idx ${GPUS[i % ${NUM_GPUS}]} --dataset $DATASET --method ${METHOD} --tgt ${TGT} --model $MODEL --epoch $EPOCH ${CP} --seed $SEED \
                  --remove_cp --online --tgt_train_dist ${dist} --update_every_x ${update_every_x} --memory_size ${memory_size} \
                  --lr ${lr} --weight_decay ${weight_decay} \
                  --aug_threshold ${aug_threshold} \
                  --log_prefix ${LOG_PREFIX}_${SEED}_dist${dist} \
                  ${validation} \
                  2>&1 | tee raw_logs/${DATASET}_${LOG_PREFIX}_${SEED}_job${i}.txt &

                i=$((i + 1))
                wait_n
              done
            done
          elif [ "${METHOD}" = "SAR" ]; then
            EPOCH=1

            BATCH_SIZE=64
            lr=0.00025 # From SAR paper: args.lr = (0.00025 / 64) * bs * 2 if bs < 32 else 0.00025

            #### Train with BN
            for dist in "${DISTS[@]}"; do
              for TGT in $TGTS; do
                python ${MAIN_SCRIPT} --gpu_idx ${GPUS[i % ${NUM_GPUS}]} --dataset $DATASET --method ${METHOD} --tgt ${TGT} --model $MODEL ${CP} --epoch $EPOCH --seed $SEED \
                  --remove_cp --online --tgt_train_dist ${dist} --update_every_x ${update_every_x} --memory_size ${memory_size} \
                  --lr ${lr} --weight_decay ${weight_decay} \
                  --log_prefix "${LOG_PREFIX}_${SEED}_dist${dist}" \
                  ${validation} \
                  2>&1 | tee raw_logs/${DATASET}_${LOG_PREFIX}_${SEED}_job${i}.txt &

                i=$((i + 1))
                wait_n
              done
            done
          elif [ "${METHOD}" = "EATA" ] || [ "${METHOD}" = "ETA" ]; then
            EPOCH=1

            if [ "${DATASET}" = "pacs" ] ; then
              lr=0.001
              e_margin=0.7784 # 0.4*ln(7)
              d_margin=0.5
              fisher_alpha=2000
            elif [ "${DATASET}" = "tiny-imagenet" ] ; then
              lr=0.001
              e_margin=2.1193 # 0.4*ln(5)
              d_margin=0.5
              fisher_alpha=2000
            elif [ "${DATASET}" = "cifar10" ]; then
              lr=0.005
              e_margin=0.92103 # 0.4*ln(10)
              d_margin=0.4
              fisher_alpha=1
            elif [ "${DATASET}" = "cifar100" ]; then
              lr=0.005
              e_margin=1.84207 # 0.4*ln(100)
              d_margin=0.4
              fisher_alpha=1
            elif [ "${DATASET}" = "imagenet" ]; then
              lr=0.00025
              e_margin=2.76310 # 0.4*ln(1000)
              d_margin=0.05
              fisher_alpha=2000
            fi

            #### Train with BN
            for dist in "${DISTS[@]}"; do
              for TGT in $TGTS; do
                python ${MAIN_SCRIPT} --gpu_idx ${GPUS[i % ${NUM_GPUS}]} --dataset $DATASET --method ${METHOD} --tgt ${TGT} --model $MODEL ${CP} --epoch $EPOCH --seed $SEED \
                  --remove_cp --online --tgt_train_dist ${dist} --update_every_x ${update_every_x} --memory_size ${memory_size} \
                  --lr ${lr} --weight_decay ${weight_decay} \
                  --log_prefix "${LOG_PREFIX}_${SEED}_dist${dist}" \
                  --e_margin ${e_margin} --d_margin ${d_margin} --fisher_alpha ${fisher_alpha} \
                  ${validation} \
                  2>&1 | tee raw_logs/${DATASET}_${LOG_PREFIX}_${SEED}_job${i}.txt &

                i=$((i + 1))
                wait_n
              done
            done
          elif [ "${METHOD}" = "SimATTA" ]; then
            EPOCH=10
            update_every_x="64"
            memory_size="64"

            lr=0.001
            if [ "${DATASET}" = "tiny-imagenet" ]; then
              lr=0.0001
            fi

            #### Train with BN
            for dist in "${DISTS[@]}"; do
              for TGT in $TGTS; do
                python ${MAIN_SCRIPT} --gpu_idx ${GPUS[i % ${NUM_GPUS}]} --dataset $DATASET --method ${METHOD} --tgt ${TGT} --model $MODEL ${CP} --epoch $EPOCH --seed $SEED \
                  --remove_cp --use_learned_stats --online --tgt_train_dist ${dist} --update_every_x ${update_every_x} --memory_size ${memory_size} --memory_type "FIFO"\
                  --lr ${lr} --weight_decay ${weight_decay} --early_stop\
                  --log_prefix "${LOG_PREFIX}_${SEED}_dist${dist}"\
                  ${validation} \
                  2>&1 | tee raw_logs/${DATASET}_${LOG_PREFIX}_${SEED}_job${i}.txt &

                i=$((i + 1))
                wait_n
              done
            done
          elif [ "${METHOD}" = "SimATTA_BIN" ]; then
            EPOCH=10
            update_every_x="64"
            memory_size="64"

            lr=0.001
            if [ "${DATASET}" = "tiny-imagenet" ]; then
              lr=0.0001
            fi

            #### Train with BN
            for dist in "${DISTS[@]}"; do
              for TGT in $TGTS; do  # TODO: add
                python ${MAIN_SCRIPT} --gpu_idx ${GPUS[i % ${NUM_GPUS}]} --dataset $DATASET --method ${METHOD} --tgt ${TGT} --model $MODEL ${CP} --epoch $EPOCH --seed $SEED \
                  --remove_cp --use_learned_stats --online --tgt_train_dist ${dist} --update_every_x ${update_every_x} --memory_size ${memory_size} --memory_type "FIFO"\
                  --lr ${lr} --weight_decay ${weight_decay} --early_stop\
                  --log_prefix "${LOG_PREFIX}_${SEED}_dist${dist}"\
                  ${validation} \
                  2>&1 | tee raw_logs/${DATASET}_${LOG_PREFIX}_${SEED}_job${i}.txt &

                i=$((i + 1))
                wait_n
              done
            done
          elif [ "${METHOD}" = "BATTA" ] ; then
            if [ "${DATASET}" = "pacs" ]; then
              lr=0.001
              dropout_rate=0.4
              EPOCH=3
            elif [ "${DATASET}" = "tiny-imagenet" ] ; then
              lr=0.00005
              dropout_rate=0.1
              EPOCH=5
            elif [ "${DATASET}" = "cifar10" ] ; then
              lr=0.0001
              dropout_rate=0.3
              EPOCH=3
            elif [ "${DATASET}" = "cifar100" ] ; then
              lr=0.0001
              dropout_rate=0.3
              EPOCH=3
            fi
            memory_type="ActivePriorityFIFO"
            #### Train with BN
            for dist in "${DISTS[@]}"; do
              for TGT in $TGTS; do
                
                python ${MAIN_SCRIPT} --gpu_idx ${GPUS[i % ${NUM_GPUS}]} --dataset $DATASET --method ${METHOD} --tgt ${TGT} --model $MODEL ${CP} --epoch $EPOCH --seed $SEED \
                  --remove_cp --online --tgt_train_dist ${dist} --update_every_x ${update_every_x} --memory_size ${memory_size} --memory_type ${memory_type}\
                  --lr ${lr}\
                  --log_prefix ${LOG_PREFIX}_${SEED}_dist${dist} --use_learned_stats --bn_momentum 0.3 --active_binary --dropout_rate ${dropout_rate} --w_final_loss_correct 10.0\
                  ${validation} \
                  2>&1 | tee raw_logs/${DATASET}_${LOG_PREFIX}_${SEED}_job${i}.txt &

                i=$((i + 1))
                wait_n
              done
            done
          fi

        done
      done
    done
  done

  wait
}

test_time_adaptation
