#!/bin/bash
weak_model_size="14m"
strong_model_size="410m"
num_epochs=1
glue_ds="advgluepp"

weak_model="EleutherAI/pythia-${weak_model_size}"
strong_model="EleutherAI/pythia-${strong_model_size}"
weak_model_size_str=$(echo $weak_model_size | sed 's/\./p/g')
strong_model_size_str=$(echo $strong_model_size | sed 's/\./p/g')

tasks=("sst2" "qqp" "mnli" "mnli-mm" "qnli" "rte")

for run in 5 10 15
do
    # for lambda_coeff in 0.0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0
    # for alpha_max in 0.0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0
    # for num_epochs in 1 2 3 4 5 6 7 8 9 10
    for warm_up in 0.0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0
    do
        # lambda_coeff_str=$(echo $lambda_coeff | sed 's/\./p/g')
        # alpha_max_str=$(echo $alpha_max | sed 's/\./p/g')
        warm_up_str=$(echo $warm_up | sed 's/\./p/g')
        # results_file="results/${glue_ds}/hyperparam_wts/lambda-${lambda_coeff_str}/pythia-${weak_model_size_str}-${strong_model_size_str}-ep${num_epochs}-${run}.json"
        # results_file="results/${glue_ds}/hyperparam_wts/alpha-${alpha_max_str}/pythia-${weak_model_size_str}-${strong_model_size_str}-ep${num_epochs}-${run}.json"
        # results_file="results/${glue_ds}/hyperparam_wts/epochs-${num_epochs}/pythia-${weak_model_size_str}-${strong_model_size_str}-ep${num_epochs}-${run}.json"
        results_file="results/${glue_ds}/hyperparam_wts/warm_up-${warm_up_str}/pythia-${weak_model_size_str}-${strong_model_size_str}-ep${num_epochs}-${run}.json"
        completion_file=$(echo $results_file | sed 's/\.json/-completed.txt/')
        mkdir -p $(dirname $completion_file)
        touch $completion_file
        
        for task in "${tasks[@]}"
        do
            # Check if the task has already been completed
            if grep -q "${task}-done" $completion_file; then
                continue
            fi

            # Weak fine-tuning
            python wts.py \
                --model_id $weak_model \
                --results_file $results_file \
                --task $task \
                --ft_mode weak \
                --num_epochs $num_epochs \
                --lambda_coeff 0.3 \
                --validate_on adversarial \
                --tag "-${run}" \
                --glue_ds $glue_ds

            # Check if the script ran successfully
            if [ $? -ne 0 ]; then
                continue
            fi

            # # Naive WTS fine-tuning
            # python wts.py \
            #     --model_id $strong_model \
            #     --results_file $results_file \
            #     --task $task \
            #     --ft_mode wts-naive \
            #     --weak_labels_file "weak_labels/${glue_ds}/${weak_model}-${run}.json" \
            #     --num_epochs $num_epochs \
            #     --lambda_coeff 0 \
            #     --validate_on original \
            #     --tag "-${run}" \
            #     --glue_ds $glue_ds

            # # Check if the script ran successfully
            # if [ $? -ne 0 ]; then
            #     continue
            # fi

            # Naive WTS fine-tuning (Phase 3)
            # python wts.py \
            #     --model_id $strong_model \
            #     --results_file $results_file \
            #     --task $task \
            #     --ft_mode wts-naive \
            #     --weak_labels_file "weak_labels/${glue_ds}/${weak_model}-${run}.json" \
            #     --num_epochs $num_epochs \
            #     --lambda_coeff $lambda_coeff \
            #     --validate_on adversarial \
            #     --tag "-${run}" \
            #     --glue_ds $glue_ds

            # WTS fine-tuning with auxiliary loss (Phase 3)
            python wts.py \
                --model_id $strong_model \
                --results_file $results_file \
                --task $task \
                --ft_mode wts-aux-loss \
                --weak_labels_file "weak_labels/${glue_ds}/${weak_model}-${run}.json" \
                --num_epochs $num_epochs \
                --alpha_max 0.1 \
                --lambda_coeff 0.3 \
                --validate_on adversarial \
                --warm_up $warm_up \
                --tag "-${run}" \
                --glue_ds $glue_ds

            # Check if the script ran successfully
            if [ $? -ne 0 ]; then
                continue
            fi

            # Strong fine-tuning
            # python wts.py \
            #     --model_id $strong_model \
            #     --results_file $results_file \
            #     --task $task \
            #     --ft_mode strong \
            #     --num_epochs $num_epochs \
            #     --lambda_coeff $lambda_coeff \
            #     --validate_on adversarial \
            #     --tag "-${run}" \
            #     --glue_ds $glue_ds

            # Mark task as completed
            echo "${task}-done" >> $completion_file
        done
    done
done
