#!/bin/bash

export WANDB_DISABLED=True
export WANDB_ENTITY=
export WANDB_PROJECT=
export WANDB_RUN_GROUP=""

export MODEL_NAME=llama
export MODEL_PRECISION=bfloat16
export COMPUTE_PRECISION=amp_bf16
export NUM_EPOCHS=1
export LR=6e-6 # 3e-5 for sql, 6e-6 for gsm8k and 4e-5 viggo (single epoch)
export WEIGHT_DECAY=0
export WARMUP=20
export BS=32
export PER_DEVICE_BS=1
export SEED=42
export DATASET=gsm8k
export SUBSET_PATH=
export WEIGHTS_PATH=

export BASE_SAVE_PATH="./checkpoints"

# this exports the arguments to environment variables, e.g., "bash script.sh LR=1e-5" overrides the LR variable above
for ARGUMENT in "$@"
do
   KEY=$(echo $ARGUMENT | cut -f1 -d=)
   KEY_LENGTH=${#KEY}
   VALUE="${ARGUMENT:$KEY_LENGTH+1}"
   export "$KEY"="$VALUE"
done

if [ "$MODEL_NAME" = "qwen" ]; then
  export MODEL="Qwen/Qwen2.5-7B-Instruct"
elif [ "$MODEL_NAME" = "qwen14b" ]; then
  export MODEL="Qwen/Qwen2.5-14B-Instruct"
elif [ "$MODEL_NAME" = "qwen32b" ]; then
  export MODEL="Qwen/Qwen2.5-32B-Instruct"
elif [ "$MODEL_NAME" = "llama" ]; then
  export MODEL="meta-llama/Meta-Llama-3-8B-Instruct"
else
  echo "Unknown model name: $MODEL_NAME"
fi

export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0}" # if not set, default to 0
export MAX_DURATION=${NUM_EPOCHS}ep
# export MAX_DURATION=1ba
export RUN_NAME="sdpa_${MODEL_NAME}f-lraft-${DATASET}-seed${SEED}-wd${WEIGHT_DECAY}-lr${LR}-subset_$(basename ${SUBSET_PATH})-weights_$(basename ${WEIGHTS_PATH})_$RANDOM"
export CONFIG="./configs/fft_${DATASET}_subset.yaml"

export RUN_SAVE_PATH=${BASE_SAVE_PATH}/${RUN_NAME}

# export NCCL_NTHREADS=64 # for faster fsdp communication on RTX

composer main.py \
    ${CONFIG} \
    model.pretrained_model_name_or_path=${MODEL} \
    tokenizer.name=${MODEL} \
    model.master_weights_dtype=${MODEL_PRECISION} \
    precision=${COMPUTE_PRECISION} \
    max_duration=${MAX_DURATION} \
    run_name=${RUN_NAME} \
    optimizer.lr=${LR} \
    optimizer.weight_decay=${WEIGHT_DECAY} \
    global_train_batch_size=${BS} \
    device_train_microbatch_size=${PER_DEVICE_BS} \
    device_eval_batch_size=${PER_DEVICE_BS} \
    scheduler.t_warmup=${WARMUP}ba \
    seed=${SEED} \
    algorithms.data_seeder.seed=${SEED} \
    callbacks.hq_hf_checkpointer.precision=${MODEL_PRECISION} \
    callbacks.hq_hf_checkpointer.save_folder=${RUN_SAVE_PATH} \
    train_loader.dataset.hf_kwargs.data_files=${SUBSET_PATH} \
    algorithms.loss_weighter.weights_path=${WEIGHTS_PATH}

# move the checkpoint (saved by llm-foundry) to the correct directory
export LAST_SAVE_DIR_NAME=$(ls -t ${RUN_SAVE_PATH}/huggingface | head -n 1)
mv ${RUN_SAVE_PATH}/huggingface/${LAST_SAVE_DIR_NAME}/* ${RUN_SAVE_PATH}
rm -rf ${RUN_SAVE_PATH}/huggingface

echo "find the model at ${RUN_SAVE_PATH}"
# python eval.py --dataset=${DATASET} --model_path=${RUN_SAVE_PATH} --precision=${MODEL_PRECISION}
#echo "bash eval_only.sh DATASET=${DATASET} MODEL=${RUN_SAVE_PATH} MODEL_PRECISION=${MODEL_PRECISION}" >> qeval.sh
CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:0:1} bash eval_only.sh DATASET=${DATASET} MODEL=${RUN_SAVE_PATH} MODEL_PRECISION=${MODEL_PRECISION}
