CUDA_ID=$1
export CUDA_VISIBLE_DEVICES=$CUDA_ID
export CUDA_LAUNCH_BLOCKING=1

# Arguments
TASK=$2
MODEL=$3
SEED=$4

# Save and evaluation parameters
SAVE=1

TASK_EXTRA=""
case $TASK in
    CoLA)
        if [ $MODEL == ctrl ]; then
            # Hyper-parameters
            BSZ=2
            LR=5e-3
            EPOCHS=20
            TRAIN_MODE=prefix-infix

            # Meta weighting parameters
            WEIGHT_LR=5e-3 # $9
            META_LR=1e-2 #${10}

            # Evaluation setting
            EVAL=1
        fi
        ;;
    SST-2)
        if [ $MODEL == ctrl ]; then
            # Hyper-parameters
            BSZ=2
            LR=5e-3
            EPOCHS=20
            TRAIN_MODE=prefix-infix

            # Meta weighting parameters
            WEIGHT_LR=5e-3 # $9
            META_LR=1e-2 #${10}

            # Evaluation setting
            EVAL=1
        fi
        ;;
    MRPC)
        if [ $MODEL == ctrl ]; then
            # Hyper-parameters
            BSZ=2
            LR=5e-3
            EPOCHS=20
            TRAIN_MODE=prefix-infix

            # Meta weighting parameters
            WEIGHT_LR=5e-3 # $9
            META_LR=1e-2 #${10}

            # Evaluation setting
            EVAL=1
        fi
        ;;
    QQP)
        if [ $MODEL == ctrl ]; then
            # Hyper-parameters
            BSZ=2
            LR=5e-3
            EPOCHS=20
            TRAIN_MODE=prefix-infix

            # Meta weighting parameters
            WEIGHT_LR=5e-3 # $9
            META_LR=1e-2 #${10}

            # Evaluation setting
            EVAL=1
        fi
        ;;
    MNLI)
        if [ $MODEL == ctrl ]; then
            # Hyper-parameters
            BSZ=2
            LR=5e-3
            EPOCHS=20
            TRAIN_MODE=prefix-infix

            # Meta weighting parameters
            WEIGHT_LR=5e-3 # $9
            META_LR=1e-2 #${10}

            # Evaluation setting
            EVAL=1
        fi
        ;;
    QNLI)
        if [ $MODEL == ctrl ]; then
            # Hyper-parameters
            BSZ=2
            LR=5e-3
            EPOCHS=20
            TRAIN_MODE=prefix-infix

            # Meta weighting parameters
            WEIGHT_LR=5e-3 # $9
            META_LR=1e-2 #${10}

            # Evaluation setting
            EVAL=1
        fi
        ;;
    RTE)
        ;;
esac

if (( $EVAL > 0 ))
then
    EVAL_EXTRA="--evaluation_strategy steps --eval_steps $EVAL"
else
    EVAL_EXTRA=""
fi

if (( $SAVE > 0 ))
then
    SAVE_EXTRA=""
else
    SAVE_EXTRA="--no_save"
fi

echo "Training ${MODEL}, with seed ${SEED} for ${TASK}"
python train_gen.py \
    --task_name $TASK \
    --data_dir data/k-shot/$TASK/16-$SEED \
    --model_name_or_path $MODEL \
    --max_seq_length 128 \
    --do_train \
    --do_eval \
    --logging_steps 5 \
    --weight_net_lr $WEIGHT_LR \
    --meta_lr $META_LR \
    --learning_rate $LR \
    --per_device_train_batch_size $BSZ \
    --train_mode $TRAIN_MODE \
    --freeze_control_code \
    --num_train_epochs $EPOCHS \
    --output_dir ./train_gen_all_label_disc_final_${SEED}/${TASK}/$MODEL-$TRAIN_MODE-$LR-$BSZ-$EPOCHS-meta-weight \
    --overwrite_output_dir \
    $SAVE_EXTRA $EVAL_EXTRA
