#!/bin/bash



export WANDB_PROJECT=mem_pretrain
source ./scripts/account/wandb_config.sh

OUTPUT_DIR="output/pretrain_plain_v3"
MODEL_NAME_OR_PATH="tiny-llama2-382M"
DATASET_PATH=/path/to/version2/data/dwiki3.8M
NUM_TRAIN_EPOCHS=1

OUTPUT_DIR="${OUTPUT_DIR}/${MODEL_NAME_OR_PATH}_plain_${DATASET_PATH##*/}_ep${NUM_TRAIN_EPOCHS}"


# Run the SFT script
python ./experiment/train/pretrain_plain.py \
    --model_name_or_path ${MODEL_NAME_OR_PATH} \
    --dataset_name ${DATASET_PATH} \
    --packing \
    --learning_rate 5.0e-4 \
    --num_train_epochs ${NUM_TRAIN_EPOCHS} \
    --per_device_train_batch_size 32 \
    --gradient_accumulation_steps 2 \
    --logging_steps 20 \
    --gradient_checkpointing \
    --eval_strategy steps \
    --eval_steps 100 \
    --save_steps 5000 \
    --save_total_limit 10 \
    --dataset_text_field annotated_text \
    --output_dir ${OUTPUT_DIR} \

    # --resume_from_checkpoint ./output/pretrain_plain/gpt2_plain_dwiki60k_ep4/checkpoint-7740
    # --evaluate_only \

