#!/bin/bash

export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
## Distributed training configuration
MASTER_ADDR="127.0.0.2"                     # [Required] Master node IP for multi-GPU training
MASTER_PORT=$(shuf -i 20000-29999 -n 1)
NNODES=${WORLD_SIZE:-1}
#NPROC_PER_NODE=$(nvidia-smi --list-gpus | wc -l)
#NNODES=1
NPROC_PER_NODE=1
# DeepSpeed configuration
deepspeed="zero3.json"

# Model configuration
llm="models/Qwen2.5-VL-7B-Instruct" # Using HuggingFace model ID
# Training hyperparameters
lr=2e-5
batch_size=8
grad_accum_steps=4

# Training entry point
entry_file="qwen-vl-finetune/qwenvl/train/train_lora.py"

# Output configuration
run_name="Qwen2.5-vl-7b-lora-vision-origianl"
output_dir="finetune_checkpoints/${run_name}"
#data_flatten
#per_device_train_batch_size ${batch_size}
#per_device_eval_batch_size ${batch_size}
GPUS=${GPUS:-1}
BATCH_SIZE=${BATCH_SIZE:-16}
PER_DEVICE_BATCH_SIZE=${PER_DEVICE_BATCH_SIZE:-4}
GRADIENT_ACC=$((BATCH_SIZE / PER_DEVICE_BATCH_SIZE / GPUS))

# Dataset configuration (replace with public dataset names)
datasets=extend1

# Training arguments
args="
    --lora_model False \
    --deepspeed ${deepspeed} \
    --model_name_or_path ${llm} \
    --dataset_use ${datasets} \
    --data_flatten True \
    --tune_mm_vision False \
    --tune_mm_mlp False \
    --tune_mm_llm True \
    --bf16 \
    --output_dir ${output_dir} \
    --num_train_epochs 0.5 \
    --per_device_train_batch_size ${batch_size} \
    --per_device_eval_batch_size ${batch_size} \
    --gradient_accumulation_steps ${grad_accum_steps} \
    --max_pixels 50176 \
    --min_pixels 784 \
    --eval_strategy "no" \
    --save_strategy "steps" \
    --save_steps 100 \
    --save_total_limit 5 \
    --learning_rate ${lr} \
    --weight_decay 0 \
    --warmup_ratio 0.03 \
    --max_grad_norm 1 \
    --lr_scheduler_type "cosine" \
    --logging_steps 1 \
    --model_max_length 8192 \
    --gradient_checkpointing True \
    --dataloader_num_workers 4 \
    --run_name ${run_name} \
    --report_to wandb"

LOG_FILE="./logs/lora.txt"
# Launch training
CUDA_VISIBLE_DEVICES=0 torchrun --nproc_per_node=${NPROC_PER_NODE} \
         --master_addr=${MASTER_ADDR} \
         --master_port=${MASTER_PORT} \
         --nnodes=1 \
         --node_rank=0 \
         ${entry_file} ${args} 2>&1 | tee $LOG_FILE