#!/bin/bash



export WANDB_PROJECT=mem_ft
source ./scripts/account/wandb_config.sh

OUTPUT_DIR="./output/ft_corrector"

MODEL_NAME_OR_PATH="/path/to/checkpoints/Llama-3.1-8B-Instruct"
DATASET=squad-train19k_chatgpt_gpt4o-v5.1
DATASET_PATH=./data/ft/${DATASET}.json

NUM_TRAIN_EPOCHS=2

OUTPUT_DIR="${OUTPUT_DIR}/${MODEL_NAME_OR_PATH##*/}_${DATASET}_ep${NUM_TRAIN_EPOCHS}"

##### comment packing!!!
python -m memgpt.trl.sft \
    --model_name_or_path ${MODEL_NAME_OR_PATH} \
    --dataset_name ${DATASET_PATH} \
    --learning_rate 2.0e-4 \
    --packing \
    --num_train_epochs ${NUM_TRAIN_EPOCHS} \
    --per_device_train_batch_size 8 \
    --gradient_accumulation_steps 4 \
    --gradient_checkpointing \
    --save_total_limit 2 \
    --eval_strategy steps \
    --per_device_eval_batch_size 4 \
    --use_peft \
    --lora_r 32 \
    --lora_alpha 16 \
    --dataset_text_field annotated_text \
    --output_dir ${OUTPUT_DIR} \
    --max_seq_length 2048 \
    --logging_steps 1 \
    --save_steps 100 \
    --eval_steps 10 \
    --eval_accumulation_steps 1 \

# steps = (2 epochs * 19k examples) / (8 batch size * 4 accum steps) / packing = 178
# 5h