#!/bin/bash
# train.sh

# === Model and paths ===
model_type="qwen25vl"
llm_model="model/Qwen2.5-VL-3B-Instruct"  # model_path
output_dir="output/DAD-SFT"
data_path="dataset/sftdatabbox.json"
teacher_map_path="dataset/teacher_attn_map"
neg_map_path="dataset/neg_attn_map"
deepspeed_config="scripts/zero3.json"

# === Environment ===
export DEBUG_MODE="true"
export RUN_NAME="distill"
export LOG_PATH="./debug_distill"
export proxy=""
export http_proxy=""
export https_proxy=""

# === Training Command ===
python -m torch.distributed.run --nproc_per_node=8 train.py \
  --deepspeed ${deepspeed_config} \
  --data_path ${data_path} \
  --teacher_map_path ${teacher_map_path} \
  --neg_map_path ${neg_map_path} \
  --model_type ${model_type} \
  --model_name_or_path ${llm_model} \
  --group_by_modality_length True \
  --resize_ratio 4 \
  --bf16 True \
  --output_dir ${output_dir} \
  --num_train_epochs 2 \
  --per_device_train_batch_size 2 \
  --per_device_eval_batch_size 4 \
  --gradient_accumulation_steps 1 \
  --eval_strategy no \
  --save_strategy steps \
  --save_steps 2000 \
  --learning_rate 5e-6 \
  --weight_decay 0. \
  --warmup_ratio 0.03 \
  --lr_scheduler_type cosine \
  --logging_steps 10 \
  --tf32 True \
  --model_max_length 24576 \
  --gradient_checkpointing True \
  --dataloader_num_workers 8 \
  --max_pixels 5720064 \
  --unfreeze_all_parameters True
