#!/bin/bash
weak_model_size="14m"
strong_model_size="410m"
glue_ds="advgluepp"
alpha_max=0.1

weak_model="EleutherAI/pythia-${weak_model_size}"
strong_model="EleutherAI/pythia-${strong_model_size}"
weak_model_size_str=$(echo $weak_model_size | sed 's/\./p/g')
strong_model_size_str=$(echo $strong_model_size | sed 's/\./p/g')
alpha_max_str=$(echo $alpha_max | sed 's/\./p/g')

tasks=("sst2" "qqp" "mnli" "mnli-mm" "qnli" "rte")

for run in 4 9 14
do
    for num_epochs in 0.1 0.3 0.5 0.7 0.9
    do
        num_epochs_str=$(echo $num_epochs | sed 's/\./p/g')
        results_file="results/${glue_ds}/phase-1/alpha-${alpha_max_str}/pythia-${weak_model_size_str}-${strong_model_size_str}-ep${num_epochs_str}-${run}.json"
        completion_file=$(echo $results_file | sed 's/\.json/-completed.txt/')
        touch $completion_file
        
        # Phase 1
        echo "Phase 1"
        for task in "${tasks[@]}"
        do
            # Check if the task has already been completed
            if grep -q "${task}-done" $completion_file; then
                echo "Phase 1 of task ${task} already completed for run ${run}"
                continue
            fi

            # Weak fine-tuning
            python wts.py \
                --model_id $weak_model \
                --results_file $results_file \
                --task $task \
                --ft_mode weak \
                --num_epochs $num_epochs \
                --lambda_coeff 0 \
                --validate_on original \
                --tag "-${run}" \
                --glue_ds $glue_ds

            # Check if the script ran successfully
            if [ $? -ne 0 ]; then
                echo "Phase 1 of task ${task} failed for run ${run}"
                continue
            fi

            # Strong fine-tuning
            python wts.py \
                --model_id $strong_model \
                --results_file $results_file \
                --task $task \
                --ft_mode strong \
                --num_epochs $num_epochs \
                --lambda_coeff 0 \
                --validate_on original \
                --tag "-${run}" \
                --glue_ds $glue_ds

            # Check if the script ran successfully
            if [ $? -ne 0 ]; then
                echo "Phase 1 of task ${task} failed for run ${run}"
                continue
            fi

            # Naive WTS fine-tuning
            python wts.py \
                --model_id $strong_model \
                --results_file $results_file \
                --task $task \
                --ft_mode wts-naive \
                --weak_labels_file "weak_labels/${glue_ds}/${weak_model}-${run}.json" \
                --num_epochs $num_epochs \
                --lambda_coeff 0 \
                --validate_on original \
                --tag "-${run}" \
                --glue_ds $glue_ds

            # Check if the script ran successfully
            if [ $? -ne 0 ]; then
                echo "Phase 1 of task ${task} failed for run ${run}"
                continue
            fi

            # WTS fine-tuning with auxiliary loss
            python wts.py \
                --model_id $strong_model \
                --results_file $results_file \
                --task $task \
                --ft_mode wts-aux-loss \
                --weak_labels_file "weak_labels/${glue_ds}/${weak_model}-${run}.json" \
                --num_epochs $num_epochs \
                --alpha_max $alpha_max \
                --lambda_coeff 0 \
                --validate_on original \
                --tag "-${run}" \
                --glue_ds $glue_ds

            # Check if the script ran successfully
            if [ $? -ne 0 ]; then
                echo "Phase 1 of task ${task} failed for run ${run}"
                continue
            fi

            # Mark task as completed
            echo "${task}-done" >> $completion_file
        done
    done
done
