export SEED=1

export BASE_DATA_PATH=./data/training_data
export BASE_CKPT_PATH=./checkpoints
export MODEL_NAME=llama2-7b

for ARGUMENT in "$@"
do
   KEY=$(echo $ARGUMENT | cut -f1 -d=)
   KEY_LENGTH=${#KEY}
   VALUE="${ARGUMENT:$KEY_LENGTH+1}"
   export "$KEY"="$VALUE"
done

export RUN_NAME=${MODEL_NAME}_warmup_10k_${RANDOM}
export TRAIN_FILE=${BASE_DATA_PATH}/warmup_10k.jsonl
export OUTPUT_PATH=${BASE_CKPT_PATH}/${RUN_NAME}

if [ ${MODEL_NAME} = "llama2-7b" ]; then
    MODEL_NAME=meta-llama/Llama-2-7b-hf
elif [ ${MODEL_NAME} = "llama3-1b" ]; then
    MODEL_NAME=meta-llama/Llama-3.2-1B
elif [ ${MODEL_NAME} = "llama3-3b" ]; then
    MODEL_NAME=meta-llama/Llama-3.2-3B
elif [ ${MODEL_NAME} = "qwen2_5-1_5b" ]; then
    MODEL_NAME=Qwen/Qwen2.5-1.5B
elif [ ${MODEL_NAME} = "qwen2_5-3b" ]; then
    MODEL_NAME=Qwen/Qwen2.5-3B
else
    echo "Invalid model name: ${MODEL_NAME}"
    exit 1
fi

python -m minimal_multitask.instruction_tune \
    --model_name ${MODEL_NAME} \
    --output_dir ${OUTPUT_PATH} \
    --per_device_train_batch_size 1 \
    --gradient_accumulation_steps 128 \
    --num_train_epochs 2 \
    --learning_rate 2e-5 \
    --seed ${SEED} \
    --warmup_ratio 0.03 \
    --lr_scheduler_type linear \
    --weight_decay 0. \
    --evaluation_strategy no \
    --save_strategy no \
    --logging_steps 1 \
    --is_llama=True \
    --use_hf_auth_token True \
    --train_dataset "$TRAIN_FILE"

bash run_eval.sh ${OUTPUT_PATH} --all-evals