#!/bin/bash
TAG="exp_ID-textPT"     #exp_ID

# Function to run the job
run_job() {
    OPTIM_SEED=${optim_seed} \
    VIS_ENCODER=${vis_encoder} \
    DATASET_NAME=${dataset_name} \
    SPLIT_SEED=${split_seed} \
    MODEL=${SETTING} \
    DATASET_DIR=${dataset_dir} \
    OUTPUT_DIR=${DIR} \
    LOSS_TYPE=${loss_type} \
    EPOCHS=${epoch_num} \
    LR=${lr} \
    DECAY=${decay} \
    BATCH_SIZE=${batch_size} \
    CANDIDATE_METHOD=${method} \
    TEMPERATURE=${TEMPERATURE} \
    CONF_QUANTILE=${CONF_QUANTILE} \
    CONF_THRESHOLD="quantile" \
    REGULAR_THRESHOLD=${REGULAR_THRESHOLD} \
    STEP_QUANTILE=${STEP_QUANTILE} \
    Device_ID=${device_id} \
    Round=${round} \
    Client_Num=${client_num} \
    Partition=${partition} \
    Local_Epoch=${local_epoch} \
    Beta=${beta} \
    Selectlevel=${selectlevel} \
    USE_SOFT_PARTIAL=${USE_SOFT_PARTIAL} \
    Prototype=${prototype} \
    Num_prompt=${num_prompt} \
    LR_attention=${lr_attention} \
    FedProx=${fedprox} \
    Num_repesudo_round=${num_repesudo_round} \
    Scheduler=${scheduler} \
    Ctx_init=${ctx_init} \
    Joining_rate=${joining_rate} \
    accelerate launch --config_file methods_config/accelerate_localtest_config.yml run_main_${learning_paradigm}.py \
                      --model_config ${SETTING}_config_PLL.yml --learning_paradigm ${learning_paradigm}
}

EPOCHS=(10)
DECAY=(0.05)
BATCH_SIZE=(64)
dataset_dirs=('dataset')    #  add the path here containing datasets
vis_encoders=('ViT-B/32') #  'ViT-B/32' or 'ViT-L/14' 'ResNet50' 'ViT-B/16' 'RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px'
split_seeds=(500)         # This indicate the split for TRZSL, i.e., 500, 0, or 200. For other learning setting this is 500 as default.
dataset_names=('DTD' 'RESICS45' 'CUB' 'UCF101') # 'CUB' 'DTD' 'EuroSAT' 'RESICS45' 'Flowers102' 'cifar10' 'cifar100' 'FGVCAircraft'  UCF101
SETTINGS=('our_dual_local_prompt')   # 'textual_fpl'  'grip_textual' fedavg_grip_textual  our_grip_textual global_text opt_textual our_grip_vision our_dual_local_prompt
device_ids=('3') 
prototypes=('0') 
num_prompts=('1') 
num_repesudo_rounds=('5') 
optim_seeds=(2)     # 1 2 3 are the kkseeds we used
round=(20) 
ctx_init=('0') 
selectlevels=('class')  
partition=("iid")  #"iid" "noniid" "noniid-labeldir"
local_epoch=("10")  
betas=("0.1" )     
loss_types=('CE')            # Choose among different loss func: 'cc' (default) 'rc_rc' (RC) 'lw_lw' (LW) 'rc_cav' (CAV)
methods=('CPL')    
lrs=('0.1')   
lr_attentions=('0.5') 
TEMPERATUREs=(1.0)
USE_SOFT_PARTIALs=(False)
LAMBDAs=(1.0) 
fedprox=(0) 
learning_paradigms=('ul') # Choose among: 'ul' 'ssl' 'trzsl'
scheduler=('cosine') 
joining_rates=('1') 


for dataset_dir in "${dataset_dirs[@]}"; do
for dataset_name in "${dataset_names[@]}"; do

if [ "$dataset_name" == "CUB" ]; then
STEP_QUANTILE=20
else
STEP_QUANTILE=10
fi

client_num=(10)

if [ "$dataset_name" == "Flowers102" ]; then
client_num=(5)
fi

if [ "$dataset_name" == "UCF101" ]; then
client_num=(5)
fi

if [ "$dataset_name" == "CUB" ]; then
client_num=(5)
fi

if [ "$dataset_name" == "FGVCAircraft" ]; then
client_num=(5)
fi

for learning_paradigm in "${learning_paradigms[@]}"; do

for vis_encoder in "${vis_encoders[@]}"; do
for optim_seed in "${optim_seeds[@]}"; do
for split_seed in "${split_seeds[@]}"; do

for epoch_num in "${EPOCHS[@]}"; do
for decay in "${DECAY[@]}"; do
for batch_size in "${BATCH_SIZE[@]}"; do
for SETTING in "${SETTINGS[@]}"; do
for loss_type in "${loss_types[@]}"; do

for USE_SOFT_PARTIAL in "${USE_SOFT_PARTIALs[@]}"; do
for method in "${methods[@]}"; do
for device_id in "${device_ids[@]}"; do
for beta in "${betas[@]}"; do
# NOTE: CONF_QUANTILE is used to represent the hyperparameter (alpha*100) in the paper
# NOTE: REGULAR_THRESHOLD is used to represent the hyperparameter beta in the paper

    if [ "$method" == "CPL" ]; then

        if [ "$dataset_name" == "EuroSAT" ]; then
            CONF_QUANTILEs=(40) 
            if [ "$learning_paradigm" == "trzsl" ]; then
                REGULAR_THRESHOLDs=("auto*3.0")  
            else
                # REGULAR_THRESHOLDs=("0.80")
                REGULAR_THRESHOLDs=("0.50")
            fi

        elif [ "$dataset_name" == "Flowers102" ]; then 
            CONF_QUANTILEs=(40)
            if [ "$learning_paradigm" == "trzsl" ]; then
                REGULAR_THRESHOLDs=("auto*1.0")  
            else
                # REGULAR_THRESHOLDs=("0.99")
                REGULAR_THRESHOLDs=("0.5")
            fi

        elif [ "$dataset_name" == "FGVCAircraft" ]; then 
            CONF_QUANTILEs=(90)
            if [ "$learning_paradigm" == "trzsl" ]; then
                REGULAR_THRESHOLDs=("auto*1.50")  
            else
                REGULAR_THRESHOLDs=("0.97")
            fi
            
        elif [ "$dataset_name" == "CUB" ]; then 
            CONF_QUANTILEs=(50) #for noniid = 0.5 
            if [ "$learning_paradigm" == "trzsl" ]; then
                REGULAR_THRESHOLDs=("auto*2.25")  
            else
                REGULAR_THRESHOLDs=("0.99")
            fi

        elif [ "$dataset_name" == "DTD" ]; then 
            CONF_QUANTILEs=(50)
            if [ "$learning_paradigm" == "trzsl" ]; then
                REGULAR_THRESHOLDs=("auto*2.0")  
            else
                REGULAR_THRESHOLDs=("0.90")
                # REGULAR_THRESHOLDs=("0.5")
            fi

        elif [ "$dataset_name" == "RESICS45" ]; then 
            CONF_QUANTILEs=(50)
            if [ "$learning_paradigm" == "trzsl" ]; then
                REGULAR_THRESHOLDs=("auto*2.0")
            else
                # REGULAR_THRESHOLDs=("0.3")
                REGULAR_THRESHOLDs=("0.97")
            fi

        elif [ "$dataset_name" == "cifar10" ]; then 
            CONF_QUANTILEs=(40)
            if [ "$learning_paradigm" == "trzsl" ]; then
                REGULAR_THRESHOLDs=("auto*2.0")
            else
                # REGULAR_THRESHOLDs=("0.3")
                REGULAR_THRESHOLDs=("0.97")
            fi

        elif [ "$dataset_name" == "cifar100" ]; then 
            CONF_QUANTILEs=(40)
            if [ "$learning_paradigm" == "trzsl" ]; then
                REGULAR_THRESHOLDs=("auto*2.0")
            else
                # REGULAR_THRESHOLDs=("0.3")
                REGULAR_THRESHOLDs=("0.97")
            fi

        elif [ "$dataset_name" == "UCF101" ]; then 
            CONF_QUANTILEs=(50)
            if [ "$learning_paradigm" == "trzsl" ]; then
                REGULAR_THRESHOLDs=("auto*2.0")
            else
                # REGULAR_THRESHOLDs=("0.3")
                REGULAR_THRESHOLDs=("0.97")
            fi


        elif [ "$dataset_name" == "food101" ]; then 
            CONF_QUANTILEs=(40)
            if [ "$learning_paradigm" == "trzsl" ]; then
                REGULAR_THRESHOLDs=("auto*2.0")
            else
                # REGULAR_THRESHOLDs=("0.3")
                REGULAR_THRESHOLDs=("0.97")
            fi

        else 
            echo "Invalid Dataset name"
            exit 1
        fi
    else
        echo "Invalid method name"
        exit 1
    fi

for REGULAR_THRESHOLD in "${REGULAR_THRESHOLDs[@]}"; do
for TEMPERATURE in "${TEMPERATUREs[@]}"; do
for CONF_QUANTILE in "${CONF_QUANTILEs[@]}"; do
for selectlevel in "${selectlevels[@]}"; do
for prototype in "${prototypes[@]}"; do
for num_prompt in "${num_prompts[@]}"; do 
for num_repesudo_round in "${num_repesudo_rounds[@]}"; do
for lr_attention in "${lr_attentions[@]}"; do
for lr in "${lrs[@]}"; do
for joining_rate in "${joining_rates[@]}"; do

    LOG_FILE="script_results/log_${TAG}_${dataset_name}.txt"
    total_iterations=$((${#EPOCHS[@]} * \
                        ${#DECAY[@]} * \
                        ${#BATCH_SIZE[@]} * \
                        ${#dataset_dirs[@]} * \
                        ${#vis_encoders[@]} * \
                        ${#split_seeds[@]} * \
                        ${#dataset_names[@]} * \
                        ${#SETTINGS[@]} * \
                        ${#optim_seeds[@]} * \
                        ${#loss_types[@]} * \
                        ${#methods[@]} * \
                        ${#REGULAR_THRESHOLDs[@]} * \
                        ${#TEMPERATUREs[@]} * \
                        ${#USE_SOFT_PARTIALs[@]} * \
                        ${#device_ids[@]} * \
                        ${#selectlevels[@]} * \
                        ${#num_prompts[@]} * \
                        ${#lrs[@]} * \
                        ${#lr_attentions[@]} * \
                        ${#num_repesudo_rounds[@]} * \
                        ${#betas[@]} * \
                        ${#learning_paradigms[@]}))
  
    echo "The loop will iterate $total_iterations times."

    common_id="dataset-${dataset_name}_setting-${SETTING}_lpardigm-${learning_paradigm}_encoder-${vis_encoder}_split-${split_seed}_seed-${optim_seed}_epoch-${epoch_num}_lr-${lr}_decay-${decay}_bs-${batch_size}_loss-${loss_type}_method-${method}_T-${TEMPERATURE}_regularThr-${REGULAR_THRESHOLD}_confQ-${CONF_QUANTILE}"
    DIR=./output/${dataset_name}/${SETTING}/${vis_encoder}_SplitSeed${split_seed}-${TAG}/SEED${optim_seed}/${common_id}
    
    # if [ -d "$DIR" ]; then
    #     echo -e "------------\n Results are available in ${DIR}. Skip this job"
    # else
    #     echo "======>>> Run this job and save the output to ${DIR}"
    run_job

    # ACCURACY=$(grep 'Testset accuracy:' ${DIR}/log.txt | awk -F': ' '{print $2}')
    # RECORD="id: ${common_id} ----> test * accuracy: ${ACCURACY}"
    # echo "${RECORD}" | tee -a ${LOG_FILE}
    # echo "${RECORD}" >> ${DIR}/log.txt
    # fi

done
done
done
done
done

done
done
done
done
done
done
done
done

done
done
done
done
done
done
done
done
done
done
done
done