#!/bin/bash



export WANDB_PROJECT=mem_tune
source ./scripts/account/wandb_config.sh

OUTPUT_DIR="./output/tune_annotator"

MODEL_NAME_OR_PATH="/path/to/checkpoints/Llama-3.1-8B-Instruct"
# MODEL_NAME_OR_PATH=./output/ft_annotator_v8/Llama-3.1-8B-Instruct_squad-train19k_chatgpt_gpt4o-v5.1_ep1
DATASET=squad-train1k_dwiki-train1k_chatgpt_gpt4o-v7.1
# DATASET=squad-train1k_chatgpt_gpt4o-v7.1
DATASET_PATH=./data/ft/${DATASET}.json

NUM_TRAIN_EPOCHS=10

OUTPUT_DIR="${OUTPUT_DIR}/${MODEL_NAME_OR_PATH##*/}_${DATASET}_ep${NUM_TRAIN_EPOCHS}"

python -m memgpt.trl.sft \
    --model_name_or_path ${MODEL_NAME_OR_PATH} \
    --dataset_name ${DATASET_PATH} \
    --learning_rate 2.0e-4 \
    --num_train_epochs ${NUM_TRAIN_EPOCHS} \
    --per_device_train_batch_size 8 \
    --gradient_accumulation_steps 4 \
    --gradient_checkpointing \
    --save_total_limit 2 \
    --eval_strategy steps \
    --per_device_eval_batch_size 4 \
    --use_peft \
    --lora_r 32 \
    --lora_alpha 16 \
    --dataset_text_field formatted_text \
    --output_dir ${OUTPUT_DIR} \
    --max_seq_length 2048 \
    --logging_steps 1 \
    --save_steps 30 \
    --eval_steps 10 \
    --eval_accumulation_steps 1 \
    # --logging_steps 1 \
    # --save_steps 30 \
    # --eval_steps 50 \
    # --packing \

# steps = (4 epochs * 19k examples) / (16 batch size * 4 accum steps) = 1188
# steps = (4 epochs * 1k examples) / (16 batch size * 4 accum steps) = 62.5   
# steps = (2 epochs * 19k examples) / (8 batch size * 4 accum steps) = 1188