#!/bin/bash

echo "----------Start llava15_train----------"

task_name=llava15_7b_DPO
exp_name=${1:-"logps"}
ckpt=${2:-"liuhaotian/llava-v1.5-7b"}
raw_data_path=${3:-""}
data_dir=${raw_data_path}-with-logps

echo "exp_name: "$exp_name
echo "ckpt: "$ckpt
echo "raw_data_path: "$raw_data_path

MASTER_PORT_START=10000
MASTER_PORT_END=65535
MASTER_PORT="$(
	comm -23 \
		<(seq "${MASTER_PORT_START}" "${MASTER_PORT_END}" | sort) \
		<(ss -Htan | awk '{ print $4 }' | awk -F ':' '{ print $NF }' | sort -u) |
		shuf | head -n 1
)"

gpu_vis=0,1,2,3

PYTHONPATH=./:$PYTHONPATH \
HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 \
deepspeed --include localhost:$gpu_vis  --master_port $MASTER_PORT \
    --module muffin.train.train_llava15 \
    --deepspeed script/zero2.json \
    --model_name_or_path $ckpt \
    --raw_data_path $raw_data_path \
    --data_dir $data_dir \
    --image_folder not_used \
    --vision_tower openai/clip-vit-large-patch14-336 \
    --mm_use_im_start_end False \
    --mm_use_im_patch_token False \
    --fully_tune False \
    --train_lora True \
    --image_aspect_ratio pad \
    --bf16 True \
    --mm_projector_type mlp2x_gelu \
    --mm_vision_select_layer -2 \
    --output_dir result/symmpo/$task_name-$exp_name \
    --num_train_epochs 10 \
    --per_device_train_batch_size 4 \
    --per_device_eval_batch_size 4 \
    --gradient_accumulation_steps 16 \
    --evaluation_strategy "no" \
    --save_strategy "steps" \
    --data_source_names '' \
    --data_source_weights 1 \
    --max_steps 1339 \
    --learning_rate 5e-7 \
    --weight_decay 0.01 \
    --warmup_ratio 0.05 \
    --lr_scheduler_type "cosine" \
    --logging_steps 2 \
    --logging_dir checkpoint/$task_name-$exp_name/log \
    --tf32 True \
    --model_max_length 2048 \
    --gradient_checkpointing True \
    --lazy_preprocess True \
    --task DPO \
    --report_to tensorboard \
    --run_name $exp_name \
    --dataloader_num_workers 16 \
    --dpo_use_average False \
    --dpo_token_weighted False \
    --dpo_token_weight 1.0 \
    --dpo_beta_1 0.1 \
    --dpo_beta_2 0.1 \
    --save_steps 335 \
    --save_total_limit 4 \
