#!/bin/bash
weak_model_size="14m"
strong_model_size="410m"
num_epochs=1
glue_ds="advgluepp"
lambda_coeff=0.3
alpha_max=0.1

weak_model="EleutherAI/pythia-${weak_model_size}"
strong_model="EleutherAI/pythia-${strong_model_size}"
weak_model_size_str=$(echo $weak_model_size | sed 's/\./p/g')
strong_model_size_str=$(echo $strong_model_size | sed 's/\./p/g')
alpha_max_str=$(echo $alpha_max | sed 's/\./p/g')
lambda_coeff_str=$(echo $lambda_coeff | sed 's/\./p/g')

tasks=("sst2" "qqp" "mnli" "mnli-mm" "qnli" "rte")

for run in 1
do
    results_file="results/${glue_ds}/phase-1/alpha-${alpha_max_str}/pythia-${weak_model_size_str}-${strong_model_size_str}-ep${num_epochs}-${run}.json"
    
    # Phase 1
    echo "Phase 1"
    for task in "${tasks[@]}"
    do
        # Weak fine-tuning
        python wts.py \
            --model_id $weak_model \
            --results_file $results_file \
            --task $task \
            --ft_mode weak \
            --num_epochs $num_epochs \
            --lambda_coeff 0 \
            --validate_on original \
            --tag "-${run}" \
            --glue_ds $glue_ds

        # Strong fine-tuning
        python wts.py \
            --model_id $strong_model \
            --results_file $results_file \
            --task $task \
            --ft_mode strong \
            --num_epochs $num_epochs \
            --lambda_coeff 0 \
            --validate_on original \
            --tag "-${run}" \
            --glue_ds $glue_ds

        # Naive WTS fine-tuning
        python wts.py \
            --model_id $strong_model \
            --results_file $results_file \
            --task $task \
            --ft_mode wts-naive \
            --weak_labels_file "weak_labels/${glue_ds}/${weak_model}-${run}.json" \
            --num_epochs $num_epochs \
            --lambda_coeff 0 \
            --validate_on original \
            --tag "-${run}" \
            --glue_ds $glue_ds

        # WTS fine-tuning with auxiliary loss
        python wts.py \
            --model_id $strong_model \
            --results_file $results_file \
            --task $task \
            --ft_mode wts-aux-loss \
            --weak_labels_file "weak_labels/${glue_ds}/${weak_model}-${run}.json" \
            --num_epochs $num_epochs \
            --alpha_max $alpha_max \
            --lambda_coeff 0 \
            --validate_on original \
            --tag "-${run}" \
            --glue_ds $glue_ds
    done
    
    # Phase 2
    echo "Phase 2"
    results_file="results/${glue_ds}/phase-2/lambda-${lambda_coeff_str}/alpha-${alpha_max_str}/pythia-${weak_model_size_str}-${strong_model_size_str}-ep${num_epochs}-${run}.json"
    for task in "${tasks[@]}"
    do
        # Weak fine-tuning
        python wts.py \
            --model_id $weak_model \
            --results_file $results_file \
            --task $task \
            --ft_mode weak \
            --num_epochs $num_epochs \
            --lambda_coeff $lambda_coeff \
            --validate_on adversarial \
            --tag "-${run}" \
            --glue_ds $glue_ds

        # Strong fine-tuning
        python wts.py \
            --model_id $strong_model \
            --results_file $results_file \
            --task $task \
            --ft_mode strong \
            --num_epochs $num_epochs \
            --lambda_coeff $lambda_coeff \
            --validate_on adversarial \
            --tag "-${run}" \
            --glue_ds $glue_ds

        # Naive WTS fine-tuning
        python wts.py \
            --model_id $strong_model \
            --results_file $results_file \
            --task $task \
            --ft_mode wts-naive \
            --weak_labels_file "weak_labels/${glue_ds}/${weak_model}-${run}.json" \
            --num_epochs $num_epochs \
            --lambda_coeff 0 \
            --validate_on original \
            --tag "-${run}" \
            --glue_ds $glue_ds

        # WTS fine-tuning with auxiliary loss
        python wts.py \
            --model_id $strong_model \
            --results_file $results_file \
            --task $task \
            --ft_mode wts-aux-loss \
            --weak_labels_file "weak_labels/${glue_ds}/${weak_model}-${run}.json" \
            --num_epochs $num_epochs \
            --alpha_max $alpha_max \
            --lambda_coeff 0 \
            --validate_on original \
            --tag "-${run}" \
            --glue_ds $glue_ds
    done

    # Phase 3
    echo "Phase 3"
    results_file_p3="results/${glue_ds}/phase-3/lambda-${lambda_coeff_str}/alpha-${alpha_max_str}/pythia-${weak_model_size_str}-${strong_model_size_str}-ep${num_epochs}-${run}.json"

    # Filter weak and strong performances from phase 2 to phase 3
    python filter.py $results_file $results_file_p3

    for task in "${tasks[@]}"
    do
        # Naive WTS fine-tuning
        python wts.py \
            --model_id $strong_model \
            --results_file $results_file_p3 \
            --task $task \
            --ft_mode wts-naive \
            --weak_labels_file "weak_labels/${glue_ds}/${weak_model}-${run}.json" \
            --num_epochs $num_epochs \
            --lambda_coeff $lambda_coeff \
            --validate_on adversarial \
            --tag "-${run}" \
            --glue_ds $glue_ds

        # WTS fine-tuning with auxiliary loss
        python wts.py \
            --model_id $strong_model \
            --results_file $results_file_p3 \
            --task $task \
            --ft_mode wts-aux-loss \
            --weak_labels_file "weak_labels/${glue_ds}/${weak_model}-${run}.json" \
            --num_epochs $num_epochs \
            --alpha_max $alpha_max \
            --lambda_coeff $lambda_coeff \
            --validate_on adversarial \
            --tag "-${run}" \
            --glue_ds $glue_ds
    done
done
