set -x

export HYDRA_FULL_ERROR=1
# export CUDA_VISIBLE_DEVICES=0,1,2,3

PROJECT_NAME=fortune

## task config
# TASK=formula
TASK=text



# model config
# MODEL=qwen
MODEL=llama
# MODEL=qwen14b

# training config
NODE_NUM=1
GPU_NUM=4
MICRO_BATCH_SIZE_PER_GPU=1


EXP_NAME=sft_${TASK}_${MODEL}_all
OUTPUT_DIR=training_outputs/sft/${PROJECT_NAME}/${EXP_NAME}


set +x
if [ $MODEL = 'qwen' ]; then
    INIT_MODEL=Qwen/Qwen2.5-Coder-7B-Instruct
elif [ $MODEL = 'llama' ]; then
    INIT_MODEL=hf_models/meta-llama/Llama-3.1-8B-Instruct
elif [ $MODEL = 'qwen14b' ]; then
    INIT_MODEL=Qwen/Qwen2.5-Coder-14B-Instruct
    MODEL=qwen
fi


TRAIN_DATASET_LIST="wikitq tabfact finqa hitab multihiertt"
train_files=""

for DATASET in $TRAIN_DATASET_LIST; do
    tmp_train_files=data/processed_data/${TASK}/${MODEL}/${DATASET}/sft.parquet
    if [ -z "$train_files" ]; then
        train_files="'$tmp_train_files'"
    else
        train_files="$train_files, '$tmp_train_files'"
    fi
done

set -x

train_files="[$train_files]"



torchrun --standalone --nnodes=$NODE_NUM --nproc_per_node=$GPU_NUM \
     -m verl.trainer.fsdp_sft_trainer \
    data.train_files="$train_files" \
    data.val_files=null \
    data.train_batch_size=64 \
    data.micro_batch_size_per_gpu=$MICRO_BATCH_SIZE_PER_GPU \
    optim.lr=2e-5 \
    data.prompt_key=prompt \
    data.response_key=extra_info \
    +data.prompt_dict_keys=null \
    +data.response_dict_keys=['sft_response'] \
    data.max_length=11000 \
    data.truncation=error \
    model.partial_pretrain=$INIT_MODEL \
    trainer.default_local_dir=$OUTPUT_DIR \
    trainer.project_name=$PROJECT_NAME \
    trainer.experiment_name=$EXP_NAME \
    trainer.logger=['console, wandb'] \
    trainer.total_epochs=6 \
    trainer.default_hdfs_dir=null \
    ulysses_sequence_parallel_size=2 \
    use_remove_padding=true
