
export WANDB_API_KEY=xxxxxxx
torchrun \
    --nnodes=1 \
    --nproc_per_node=8 \
    train.py \
    --dataset_name xxx \
    --deepspeed configs/zero2.json \
    --model_name_or_path xxxx \
    --per_device_train_batch_size 1 \
    --output_dir xxx \
    --torch_dtype bfloat16 \
    --gradient_checkpointing \
    --learning_rate 2e-5 \
    --save_steps 100 \
    --save_total_limit 12 \
    --report_to wandb \
    --run_name xxxx \
    --logging_steps 1 \
    --num_train_epochs 2 











