export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5
export DISABLE_VERSION_CHECK=1
export PYTHONPATH=$(pwd)/train/stage_sft
export PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:128 

export GLOO_SOCKET_IFNAME=ens42f0
export NCCL_SOCKET_IFNAME=ens42f0
export NCCL_DEBUG=INFO
export NCCL_IB_DISABLE=0
export NCCL_IB_CUDA_SUPPORT=1
export NCCL_IB_GID_INDEX=3
export OMP_NUM_THREADS=4
export NCCL_IB_HCA=mlx5_0,mlx5_1

export WANDB_MODE=offline
export ACCELERATE_CPU_AFFINITY=1

# IMAGE PATH
export IMAGE_DIR=/path/to/imgs/files/

# MODEL PATH
export PRETRAIN_MODEL_PATH=/your/path/to/model

# OUTPUT MODEL PATH
export OUTPUT_PATH=/your/path/to/output/model

# dataset
export DATASET=MEDICAL-NO-THINKING-SFT

if [ ! -d "$OUTPUT_PATH" ]; then
  mkdir "$OUTPUT_PATH"
fi

# python train/stage_sft/train.py \
torchrun --nproc_per_node=6 --nnodes=1 --master_port=29514 \
  train/stage_sft/train.py \
  --deepspeed scripts/train/zero3.json \
  --stage sft \
  --do_train \
  --model_name_or_path $PRETRAIN_MODEL_PATH \
  --dataset $DATASET \
  --image_dir $IMAGE_DIR \
  --template qwen2_vl \
  --finetuning_type full \
  --output_dir $OUTPUT_PATH \
  --overwrite_cache \
  --overwrite_output_dir \
  --warmup_steps 100 \
  --weight_decay 0.1 \
  --per_device_train_batch_size 1 \
  --gradient_accumulation_steps 2 \
  --gradient_checkpointing \
  --ddp_timeout 90000 \
  --learning_rate 1e-5 \
  --lr_scheduler_type cosine \
  --logging_steps 5 \
  --save_steps 600 \
  --plot_loss \
  --num_train_epochs 1 \
  --bf16 \
  --save_only_model
  2>&1 | tee ${OUTPUT_PATH}/train.log