#!/usr/bin/env bash


# args: checkpoint_folder step

TASK=$1             # GLUE task name
PREFIX=$2           # Path prefix
DIR=$3              # Path to data dir
CKP=$4              # Model checkpoint
LR=$5               # Learning rate
SEED=$6             # Random seed
TNF_LAMBDA=$8       # TNF lambda
TNF_GAMMA=$9        # TNF gamma

if [ "$TNF_LAMBDA" = "" ]
then
    TNF_LAMBDA=0.5
    echo "defaultly set tnf lambda to 0.5"
fi

if [ "$TNF_GAMMA" = "" ]
then
    TNF_GAMMA=0.1
    echo "defaultly set tnf gamma to 0.1"
fi

UPDATE_TNF=""
if [ "$7" = "update-tnf" ]
then
    UPDATE_TNF="--update-tnf-emb default --ctx windowavg --ctx-window-size 8"
    echo "update tnf: "$UPDATE_TNF" "
else
    echo "donot update tnf: "$UPDATE_TNF" "
fi

TNF_BP=""
if [ "${10}" = "yes" ]
then
    TNF_BP="--glue-tnf-bp True"
else
    echo "donot bp tnf: "${10}" "
fi

REL_POS=""

if [[ $DIR  == *"rel-pos"* ]]; then
  echo "use --rel-pos"
  REL_POS=--rel-pos
fi


ARCH=${11}
UPDATE_FREQ=2
N_EPOCH=10
WEIGHT_DECAY=0.1
SENT_PER_GPU=16
MAX_TOKENS=2200
if [ "$ARCH" = "tnf_large" ]
then
    echo "run tnf large"
    SENT_PER_GPU=8
    UPDATE_FREQ=4
    MAX_TOKENS=1100
else
    echo "run tnf base"
fi
WARMUP_RATIO=0.06
# valid 20 times per epoch
VALID_FREQ=20

BERT_MODEL_PATH=$PREFIX/$DIR/$CKP

if [ ! -e $BERT_MODEL_PATH ]; then
    echo "Checkpoint $BERT_MODEL_PATH doesn't exist"
    exit 0
fi

ROOT=${12}
GLUE_DIR=glue-data
DATA_DIR=$PREFIX/$GLUE_DIR/$TASK/data-bin
TNF_DATA_DIR=$PREFIX/$GLUE_DIR/$TASK/tnf-data-bin

OPTION=""
METRIC=accuracy
N_CLASSES=2

if [ "$TASK" = "MNLI" ]
then
N_EPOCH=10
N_CLASSES=3
OPTION="--valid-subset valid,valid1"
fi

if [ "$TASK" = "QQP" ]
then
N_EPOCH=10
fi

if [ "$TASK" = "CoLA" ]
then
METRIC=mcc
fi

if [ "$TASK" = "STS-B" ]
then
METRIC=pearson_spearman
N_CLASSES=1
OPTION="--regression-target"
fi

echo $DATA_DIR
UB_SUFFIX=""
BS_SUFFIX="-bs${TNF_LAMBDA}"
LAMBDA_SUFFIX="-bl${TNF_GAMMA}"
TNF_BP_SUFFIX="-bp"
if [ "$7" = "update-tnf" ]
then
    UB_SUFFIX="-update-tnf"
fi
if [ "$TNF_LAMBDA" = "0" ]
then
    BS_SUFFIX=""
fi
if [ "$TNF_GAMMA" = "0" ]
then
    LAMBDA_SUFFIX=""
fi
if [ "$TNF_BP" = "" ]
then
    TNF_BP_SUFFIX=""
fi

if [ "$SEED" = "ALL" ]
then
    for I_SEED in 1 2 3 4 5
    do
        OUTPUT_PATH=$PREFIX/$ROOT/${DIR}${UB_SUFFIX}${BS_SUFFIX}${LAMBDA_SUFFIX}${TNF_BP_SUFFIX}/${CKP}/${TASK}/$LR-$I_SEED
        mkdir -p $OUTPUT_PATH
        echo $OUTPUT_PATH
        if [ -e $OUTPUT_PATH/train_log.txt ]; then
            if grep -q 'done training' $OUTPUT_PATH/train_log.txt && grep -q 'loaded checkpoint' $OUTPUT_PATH/train_log.txt; then
                echo "Training log existed"
                continue
            fi
        fi
        python train.py $DATA_DIR --fp16 --fp16-init-scale 4 --threshold-loss-scale 1 --fp16-scale-window 128 \
        --tnf-data $TNF_DATA_DIR $TNF_BP \
        --tnf-lambda $TNF_LAMBDA --tnf-gamma $TNF_GAMMA --update-tnf-lambda 0 $UPDATE_TNF \
        --tnf-emb-zero-init 1 --fix-dict-shift True \
        --restore-file $BERT_MODEL_PATH \
        --max-positions 512 \
        --max-sentences $SENT_PER_GPU --update-freq $UPDATE_FREQ \
        --max-tokens $MAX_TOKENS \
        --task tnf_sentence_prediction \
        --reset-optimizer --reset-dataloader --reset-meters \
        --required-batch-size-multiple 1 \
        --init-token 0 --separator-token 2 \
        --arch $ARCH \
        --criterion sentence_prediction $OPTION \
        --num-classes $N_CLASSES \
        --dropout 0.1 --attention-dropout 0.1 \
        --weight-decay $WEIGHT_DECAY --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-06 \
        --clip-norm 0.0 --validate-interval-updates $VALID_FREQ \
        --lr-scheduler polynomial_decay --lr $LR --warmup-ratio $WARMUP_RATIO \
        --max-epoch $N_EPOCH --seed $I_SEED --save-dir $OUTPUT_PATH --no-progress-bar --log-interval 100 --no-epoch-checkpoints --no-last-checkpoints --no-best-checkpoints \
        --find-unused-parameters --skip-invalid-size-inputs-valid-test --truncate-sequence --embedding-normalize \
        --tensorboard-logdir . \
        --best-checkpoint-metric $METRIC --maximize-best-checkpoint-metric $REL_POS | tee $OUTPUT_PATH/train_log.txt
    done
else
    OUTPUT_PATH=$PREFIX/$ROOT/${DIR}${UB_SUFFIX}${BS_SUFFIX}${LAMBDA_SUFFIX}${TNF_BP_SUFFIX}/${CKP}/${TASK}/$LR-$SEED
    mkdir -p $OUTPUT_PATH
    echo $OUTPUT_PATH
    if [ -e $OUTPUT_PATH/train_log.txt ]; then
        if grep -q 'done training' $OUTPUT_PATH/train_log.txt && grep -q 'loaded checkpoint' $OUTPUT_PATH/train_log.txt; then
            echo "Training log existed"
            exit 0
        fi
    fi
    python train.py $DATA_DIR --fp16 --fp16-init-scale 4 --threshold-loss-scale 1 --fp16-scale-window 128 \
    --tnf-data $TNF_DATA_DIR $TNF_BP \
    --tnf-lambda $TNF_LAMBDA --tnf-gamma $TNF_GAMMA --update-tnf-lambda 0 $UPDATE_TNF \
    --tnf-emb-zero-init 1 --fix-dict-shift True \
    --restore-file $BERT_MODEL_PATH \
    --max-positions 512 \
    --max-sentences $SENT_PER_GPU --update-freq $UPDATE_FREQ \
    --max-tokens $MAX_TOKENS \
    --task tnf_sentence_prediction \
    --reset-optimizer --reset-dataloader --reset-meters \
    --required-batch-size-multiple 1 \
    --init-token 0 --separator-token 2 \
    --arch $ARCH \
    --criterion sentence_prediction $OPTION \
    --num-classes $N_CLASSES \
    --dropout 0.1 --attention-dropout 0.1 \
    --weight-decay $WEIGHT_DECAY --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-06 \
    --clip-norm 0.0 --validate-interval-updates $VALID_FREQ \
    --lr-scheduler polynomial_decay --lr $LR --warmup-ratio $WARMUP_RATIO \
    --max-epoch $N_EPOCH --seed $SEED --save-dir $OUTPUT_PATH --no-progress-bar --log-interval 100 --no-epoch-checkpoints --no-last-checkpoints --no-best-checkpoints \
    --find-unused-parameters --skip-invalid-size-inputs-valid-test --truncate-sequence --embedding-normalize \
    --tensorboard-logdir . \
    --best-checkpoint-metric $METRIC --maximize-best-checkpoint-metric $REL_POS | tee $OUTPUT_PATH/train_log.txt
fi