#!/bin/bash

# Main settings with default values
TASK=${TASK:-"SST-2"}           # see all the options in the "cases" below
SEED=${SEED:-13}                # random seed and also data seed, by default the data split seeds are {13, 21, 42, 87, 100}
K=${K:-16}                      # choose from {16, 64, 512} by default
MODEL=${MODEL:-"facebook/opt-125m"}  # pick a RoBERTa or BERT model
TYPE=${TYPE:-"prompt"}          # fine-tuning setting, choose from "finetune" and "prompt"
TRAINER=${TRAINER:-"standard"}  # choose from "standard", "kernel" and "linearhead"
TAG=${TAG:-}                    # set a tag to distinguish and aggregate runs in the log
NUM_GPU=${NUM_GPU:-1}           # by default use 1 GPU, set to 0 for CPU-only training
OPT=${OPT:-"adam"}
FORMULA=${FORMULA:-"signgd"}
STEPS=${STEPS:-32}

MODELNAME=$(echo $MODEL | tr "/" "-")

TASK_EXTRA=""
case $TASK in
    SST-2)
        TEMPLATE="*bos**sent_0*_It_was"
        MAPPING="{'0':'terrible','1':'great'}"
        SFC_PROMPT="_It_was"
        ;;
    QQP)
        TEMPLATE="*bos**sent_0*.*+sentu_1*"
        MAPPING="{'0':'No','1':'Yes'}"
        SFC_PROMPT="."
        ;;
    QNLI)
        TEMPLATE="*bos**sent-_0*.*+sentu-_1*?"
        MAPPING="{'not_entailment':'No','entailment':'Yes'}"
        SFC_PROMPT="?"
        ;;
    MNLI)
        TEMPLATE="*bos**sent-_0*.*+sentu-_1*?"
        MAPPING="{'contradiction':'No','entailment':'Yes','neutral':'Maybe'}"
        TASK_EXTRA="--max_seq_len 256 --first_sent_limit 240"
        SFC_PROMPT="?"
        ;;
    SNLI)
        TEMPLATE="*bos**sent-_0*.*+sentu-_1*?"
        MAPPING="{'contradiction':'No','entailment':'Yes','neutral':'Maybe'}"
        TASK_EXTRA="--max_seq_len 256 --num_sample 4"
        SFC_PROMPT="?"
        ;;
    trec)
        TEMPLATE="*bos**+sent_0*:"
        MAPPING="{0:'Description',1:'Entity',2:'Expression',3:'Human',4:'Location',5:'Number'}"
        TASK_EXTRA="--first_sent_limit 110"
        SFC_PROMPT=":"
        ;;
    mr)
        TEMPLATE="*bos**sent_0*_It_was"
        MAPPING="{0:'terrible',1:'great'}"
        TASK_EXTRA="--first_sent_limit 110 --other_sent_limit 50"
        SFC_PROMPT="_It_was"
        ;;
    cr)
        TEMPLATE="*bos**sent_0*_It_was"
        MAPPING="{0:'terrible',1:'great'}"
        TASK_EXTRA="--first_sent_limit 110 --other_sent_limit 50"
        SFC_PROMPT="_It_was"
        ;;
    mpqa)
        TEMPLATE="*bos**sent_0*_It_was"
        MAPPING="{0:'terrible',1:'great'}"
        TASK_EXTRA="--first_sent_limit 110"
        SFC_PROMPT="_It_was"
        ;;
    CoLA)
        TEMPLATE="*bos**sent_0*_This_is"
        MAPPING="{'0':'incorrect','1':'correct'}"
        SFC_PROMPT="_This_is"
        ;;
    subj)
        TEMPLATE="*bos**sent_0*_This_is"
        MAPPING="{0:'subjective',1:'objective'}"
        TASK_EXTRA="--first_sent_limit 110 --other_sent_limit 50"
        SFC_PROMPT="_This_is"
        ;;
    MRPC)
        TEMPLATE="*bos**sent-_0*.*+sentu-_1*?"
        MAPPING="{'0':'No','1':'Yes'}"
        SFC_PROMPT="?"
        ;;
    RTE)
        TEMPLATE="*bos**sent-_0*.*+sentu-_1*?"
        MAPPING="{'not_entailment':'No','entailment':'Yes'}"
        TASK_EXTRA="--max_seq_len 256 --first_sent_limit 240"
        SFC_PROMPT="?"
        ;;
esac

if [ ! -z "$LOAD_KERNELS_TAG" ]; then
    # Load pre-computed kernels from an existing directory
    LOAD_KERNELS="--load_kernels result/$TASK-$MODEL-$TYPE-$TRAINER-$LOAD_KERNELS_TAG/$K-$SEED"
fi

ALL_ARGS_TOGETHER="
    --model_name_or_path $MODEL --few_shot_type $TYPE
    --task_name $TASK --template $TEMPLATE --mapping $MAPPING
    --sfc_prompt $SFC_PROMPT
    --data_dir data/k-shot-1k-test/$TASK/$K-$SEED
    --overwrite_output_dir --output_dir result/$TASK-$MODELNAME-$TYPE-$TRAINER-$TAG$GRID_TAG/$K-$SEED
    --num_k $K
    --tag $TAG
    --max_seq_length 128
    --seed $SEED
    --do_eval --do_predict 
    --trainer $TRAINER
    --optimizer $OPT --kernel_formula $FORMULA
    --max_steps $STEPS
    $TASK_EXTRA
    $LOAD_KERNELS
    $@
"
mkdir result/$TASK-$MODELNAME-$TYPE-$TRAINER-$TAG$GRID_TAG/

if [[ $NUM_GPU > 1 ]]; then
    # Randomly set a port number
    # If you encounter "address already used" error, just run again or manually set an available port id.
    PORT_ID=$(expr $RANDOM + 1000)

    # Allow multiple threads
    export OMP_NUM_THREADS=8

    python -m torch.distributed.launch --nproc_per_node $NUM_GPU --master_port $PORT_ID run.py \
        $ALL_ARGS_TOGETHER
else
    python run.py \
        $ALL_ARGS_TOGETHER
fi

rm -rf result/$TASK-$MODELNAME-$TYPE-$TRAINER-$TAG$GRID_TAG/$K-$SEED
