#!/bin/bash
# Get absolute directory
BASE_DIR=$(pwd)
cd ${BASE_DIR}/LLaMA-Factory
BASE="${BASE_DIR}"
TEMPLATE_YAML="${BASE}/demo.yaml"
DATA_BASE=${BASE_DIR}

export FORCE_TORCHRUN=1
export PEFT_DISABLE_HUB_DOWNLOAD=1
export HF_HUB_DISABLE_TELEMETRY=1
export HF_HUB_OFFLINE=1
export TRANSFORMERS_OFFLINE=1
export WANDB_SILENT=true
export TRANSFORMERS_VERBOSITY=error
export TOKENIZERS_PARALLELISM=false
export PYTHONPATH="${BASE_DIR}:${BASE_DIR}/LLaMA-Factory/src:$PYTHONPATH"

SET_BATCH_SIZE=8
EPOCHS=3
LEARNING_RATE=1e-4


MODELS_DIR="${DATA_BASE}/models"
SAVE_DIR="${DATA_BASE}/saves"

# GPU configuration
GPU_COUNT=$(nvidia-smi --list-gpus | wc -l)
if [ $GPU_COUNT -ge 8 ]; then
    export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
elif [ $GPU_COUNT -ge 4 ]; then
    export CUDA_VISIBLE_DEVICES=0,1,2,3
elif [ $GPU_COUNT -ge 2 ]; then
    export CUDA_VISIBLE_DEVICES=0,1
else
    export CUDA_VISIBLE_DEVICES=0
fi

# demo.sh [DATA_METHOD] [TOKEN_METHOD] [DATA_RATIO] [TOKEN_RATIO] [MODEL] [BATCH_SIZE] [EPOCHS] [LR] [PLUG]
DATA_METHOD=${1:-wise}
TOKEN_METHOD=${2:-wise}
DATA_RATIO=${3:-25}
TOKEN_RATIO=${4:-70}
MODEL=${5:-mistral-7b}
BATCH_SIZE=${6:-8}
EPOCH=${7:-3}
LR=${8:-1e-4}
PLUG=${9:-wisely}

DATASET=${DATASET:-"wizard"} #or MathInstruct


if [ "$PLUG" = "wisely" ]; then
    PLUG_SUFFIX="_wisely"
else
    PLUG_SUFFIX=""
fi

NAME_SUFFIX="${DATA_METHOD}_${TOKEN_METHOD}_${LR}${PLUG_SUFFIX}"
OUTPUT_DIR="${SAVE_DIR}/${DATASET}/${MODEL}/data_ratio_${DATA_RATIO}/token_ratio_${TOKEN_RATIO}/${NAME_SUFFIX}"

YAML_TMP="${BASE}/temp_${MODEL}_${DATA_RATIO}_${TOKEN_RATIO}_${BATCH_SIZE}_${EPOCH}_${LR}${PLUG_SUFFIX}.yaml"

        sed -e "s|dataset: .*|dataset: ${DATASET}|" \
                -e "s|model_name_or_path: .*|model_name_or_path: ${MODELS_DIR}/${MODEL}|" \
                -e "s|output_dir: .*|output_dir: ${OUTPUT_DIR}|" \
                -e "s|per_device_train_batch_size: .*|per_device_train_batch_size: ${BATCH_SIZE}|" \
                -e "s|num_train_epochs: .*|num_train_epochs: ${EPOCH}|" \
                -e "s|learning_rate: .*|learning_rate: ${LR}|" \
                -e "s|data_method: .*|data_method: ${DATA_METHOD}|" \
                -e "s|data_ratio: .*|data_ratio: $(awk "BEGIN {print $DATA_RATIO / 100}")|" \
                -e "s|token_method: .*|token_method: ${TOKEN_METHOD}|" \
                -e "s|token_ratio: .*|token_ratio: $(awk "BEGIN {print $TOKEN_RATIO / 100}")|" \
                $TEMPLATE_YAML > $YAML_TMP

# Set model template
        case "$MODEL" in
            mistral-*) TEMPLATE=mistral ;;
            *) TEMPLATE=default ;;
        esac

        if grep -q "^template:" "$YAML_TMP"; then
            sed -i "s|^template: .*|template: ${TEMPLATE}|" "$YAML_TMP"
        else
            echo "template: ${TEMPLATE}" >> "$YAML_TMP"
        fi

        case "$PLUG" in
            wisely)
                {
                    echo "plug: wisely"
                    echo "wise_lambda: 0.5"
                    echo "wise_neighbor_window: 1"
                } >> $YAML_TMP
                ;;
            *)
                {
                    echo "plug: none"
                } >> $YAML_TMP
                ;;
esac

llamafactory-cli train $YAML_TMP

rm -f "$YAML_TMP"
