#!/bin/bash



export WANDB_PROJECT=mem_pretrain
source ./scripts/account/wandb_config.sh

OUTPUT_DIR="output/pretrain_v3"
# MODEL_NAME_OR_PATH="gpt2"
MODEL_NAME_OR_PATH="tiny-llama2-674M"
# MODEL_NAME_OR_PATH="tiny-llama3"
# DATASET_PATH=./data/squad19k
DATASET_PATH=/path/to/version2/data/dwiki5.1M
NUM_TRAIN_EPOCHS=1

OUTPUT_DIR="${OUTPUT_DIR}/${MODEL_NAME_OR_PATH}_${DATASET_PATH##*/}_ep${NUM_TRAIN_EPOCHS}"


# Run the SFT script
python -m memgpt.trl.pretrain \
    --model_name_or_path ${MODEL_NAME_OR_PATH} \
    --dataset_name ${DATASET_PATH} \
    --packing \
    --learning_rate 5.0e-4 \
    --num_train_epochs ${NUM_TRAIN_EPOCHS} \
    --per_device_train_batch_size 32 \
    --gradient_accumulation_steps 2 \
    --logging_steps 20 \
    --gradient_checkpointing \
    --eval_strategy steps \
    --eval_steps 100 \
    --save_steps 5000 \
    --save_total_limit 10 \
    --dataset_text_field annotated_text \
    --output_dir ${OUTPUT_DIR} \
    # --resume_from_checkpoint ./output/pretrain/tiny-llama2-176M_dwiki60k_ep1/checkpoint-2971
## set text_len!!!

