#!/bin/bash
weak_model_size="14m"
strong_model_size="410m"
num_epochs=1
glue_ds="advgluepp"

lambda_coeff=0.3
lambda_coeff_str=$(echo $lambda_coeff | sed 's/\./p/g')

weak_model="EleutherAI/pythia-${weak_model_size}"
strong_model="EleutherAI/pythia-${strong_model_size}"
weak_model_size_str=$(echo $weak_model_size | sed 's/\./p/g')
strong_model_size_str=$(echo $strong_model_size | sed 's/\./p/g')

tasks=("sst2" "qqp" "mnli" "mnli-mm" "qnli" "rte")

for run in 1
do
    for alpha_max in 0.0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0
    do
        # Phase 2
        echo "Phase 2"
        alpha_max_str=$(echo $alpha_max | sed 's/\./p/g')
        results_file="results/${glue_ds}/hyperparam_wts/phase-2/alpha-${alpha_max_str}/pythia-${weak_model_size_str}-${strong_model_size_str}-ep${num_epochs}-${run}.json"
        completion_file=$(echo $results_file | sed 's/\.json/-completed.txt/')
        mkdir -p $(dirname $completion_file)
        touch $completion_file
        
        for task in "${tasks[@]}"
        do
            # Check if the task has already been completed
            if grep -q "${task}-done" $completion_file; then
                echo "Phase 2 of task ${task} already completed for run ${run}"
                continue
            fi

            # Weak fine-tuning
            python wts.py \
                --model_id $weak_model \
                --results_file $results_file \
                --task $task \
                --ft_mode weak \
                --num_epochs $num_epochs \
                --lambda_coeff $lambda_coeff \
                --validate_on adversarial \
                --tag "-${run}" \
                --glue_ds $glue_ds

            # Check if the script ran successfully
            if [ $? -ne 0 ]; then
                echo "Phase 2 of task ${task} failed for run ${run}"
                continue
            fi

            # Strong fine-tuning
            python wts.py \
                --model_id $strong_model \
                --results_file $results_file \
                --task $task \
                --ft_mode strong \
                --num_epochs $num_epochs \
                --lambda_coeff $lambda_coeff \
                --validate_on adversarial \
                --tag "-${run}" \
                --glue_ds $glue_ds

            # Check if the script ran successfully
            if [ $? -ne 0 ]; then
                echo "Phase 2 of task ${task} failed for run ${run}"
                continue
            fi

            # Naive WTS fine-tuning
            python wts.py \
                --model_id $strong_model \
                --results_file $results_file \
                --task $task \
                --ft_mode wts-naive \
                --weak_labels_file "weak_labels/${glue_ds}/${weak_model}-${run}.json" \
                --num_epochs $num_epochs \
                --lambda_coeff 0 \
                --validate_on original \
                --tag "-${run}" \
                --glue_ds $glue_ds

            # Check if the script ran successfully
            if [ $? -ne 0 ]; then
                echo "Phase 2 of task ${task} failed for run ${run}"
                continue
            fi

            # WTS fine-tuning with auxiliary loss
            python wts.py \
                --model_id $strong_model \
                --results_file $results_file \
                --task $task \
                --ft_mode wts-aux-loss \
                --weak_labels_file "weak_labels/${glue_ds}/${weak_model}-${run}.json" \
                --num_epochs $num_epochs \
                --alpha_max $alpha_max \
                --lambda_coeff 0 \
                --validate_on original \
                --tag "-${run}" \
                --glue_ds $glue_ds

            # Check if the script ran successfully
            if [ $? -ne 0 ]; then
                echo "Phase 2 of task ${task} failed for run ${run}"
                continue
            fi

            # Mark task as completed
            echo "${task}-done" >> $completion_file
        done

        # Phase 3
        echo "Phase 3"
        results_file_p3="results/${glue_ds}/hyperparam_wts/phase-3/alpha-${alpha_max_str}/pythia-${weak_model_size_str}-${strong_model_size_str}-ep${num_epochs}-${run}.json"
        completion_file=$(echo $results_file_p3 | sed 's/\.json/-completed.txt/')
        touch $completion_file

        # Filter weak and strong performances from phase 2 to phase 3
        python filter.py $results_file $results_file_p3

        for task in "${tasks[@]}"
        do
            # Check if the task has already been completed
            if grep -q "${task}-done" $completion_file; then
                echo "Phase 3 of task ${task} already completed for run ${run}"
                continue
            fi

            # Naive WTS fine-tuning
            python wts.py \
                --model_id $strong_model \
                --results_file $results_file_p3 \
                --task $task \
                --ft_mode wts-naive \
                --weak_labels_file "weak_labels/${glue_ds}/${weak_model}-${run}.json" \
                --num_epochs $num_epochs \
                --lambda_coeff $lambda_coeff \
                --validate_on adversarial \
                --tag "-${run}" \
                --glue_ds $glue_ds

            # Check if the script ran successfully
            if [ $? -ne 0 ]; then
                echo "Phase 3 of task ${task} failed for run ${run}"
                continue
            fi

            # WTS fine-tuning with auxiliary loss
            python wts.py \
                --model_id $strong_model \
                --results_file $results_file_p3 \
                --task $task \
                --ft_mode wts-aux-loss \
                --weak_labels_file "weak_labels/${glue_ds}/${weak_model}-${run}.json" \
                --num_epochs $num_epochs \
                --alpha_max $alpha_max \
                --lambda_coeff $lambda_coeff \
                --validate_on adversarial \
                --tag "-${run}" \
                --glue_ds $glue_ds

            # Check if the script ran successfully
            if [ $? -ne 0 ]; then
                echo "Phase 3 of task ${task} failed for run ${run}"
                continue
            fi

            # Mark task as completed
            echo "${task}-done" >> $completion_file
        done
    done

    for alpha_max in 0.0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0
    do
        # Phase 1
        echo "Phase 1"
        alpha_max_str=$(echo $alpha_max | sed 's/\./p/g')
        results_file="results/${glue_ds}/hyperparam_wts/phase-1/alpha-${alpha_max_str}/pythia-${weak_model_size_str}-${strong_model_size_str}-ep${num_epochs}-${run}.json"
        completion_file=$(echo $results_file | sed 's/\.json/-completed.txt/')
        touch $completion_file
        
        for task in "${tasks[@]}"
        do
            # Check if the task has already been completed
            if grep -q "${task}-done" $completion_file; then
                echo "Phase 1 of task ${task} already completed for run ${run}"
                continue
            fi

            # Weak fine-tuning
            python wts.py \
                --model_id $weak_model \
                --results_file $results_file \
                --task $task \
                --ft_mode weak \
                --num_epochs $num_epochs \
                --lambda_coeff 0 \
                --validate_on original \
                --tag "-${run}" \
                --glue_ds $glue_ds

            # Check if the script ran successfully
            if [ $? -ne 0 ]; then
                echo "Phase 1 of task ${task} failed for run ${run}"
                continue
            fi

            # Strong fine-tuning
            python wts.py \
                --model_id $strong_model \
                --results_file $results_file \
                --task $task \
                --ft_mode strong \
                --num_epochs $num_epochs \
                --lambda_coeff 0 \
                --validate_on original \
                --tag "-${run}" \
                --glue_ds $glue_ds

            # Check if the script ran successfully
            if [ $? -ne 0 ]; then
                echo "Phase 1 of task ${task} failed for run ${run}"
                continue
            fi

            # Naive WTS fine-tuning
            python wts.py \
                --model_id $strong_model \
                --results_file $results_file \
                --task $task \
                --ft_mode wts-naive \
                --weak_labels_file "weak_labels/${glue_ds}/${weak_model}-${run}.json" \
                --num_epochs $num_epochs \
                --lambda_coeff 0 \
                --validate_on original \
                --tag "-${run}" \
                --glue_ds $glue_ds

            # Check if the script ran successfully
            if [ $? -ne 0 ]; then
                echo "Phase 1 of task ${task} failed for run ${run}"
                continue
            fi

            # WTS fine-tuning with auxiliary loss
            python wts.py \
                --model_id $strong_model \
                --results_file $results_file \
                --task $task \
                --ft_mode wts-aux-loss \
                --weak_labels_file "weak_labels/${glue_ds}/${weak_model}-${run}.json" \
                --num_epochs $num_epochs \
                --alpha_max $alpha_max \
                --lambda_coeff 0 \
                --validate_on original \
                --tag "-${run}" \
                --glue_ds $glue_ds

            # Check if the script ran successfully
            if [ $? -ne 0 ]; then
                echo "Phase 1 of task ${task} failed for run ${run}"
                continue
            fi

            # Mark task as completed
            echo "${task}-done" >> $completion_file
        done
    done
done
