#!/bin/bash

# # pt 
base_name="pretrain"

run_name="${base_name}_pt"

nnodes=2
nproc_per_node=8

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
MAX_PIXELS=1058400 \
NNODES=$nnodes \
NODE_RANK=${RANK} \
MASTER_ADDR=${MASTER_ADDR} \
MASTER_PORT=${MASTER_PORT} \
NPROC_PER_NODE=$nproc_per_node \
swift sft \
    --model Qwen2.5-VL-3B-Instruct \
    --use_adapter_learning True \
    --dataset "/swift-format-data/merge/trajdataset.jsonl" \
    --train_type lora \
    --learning_rate 2e-4 \
    --num_train_epochs 1.0 \
    --save_total_limit 1 \
    --torch_dtype bfloat16 \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps 8 \
    --packing false \
    --streaming false \
    --lazy_tokenize true \
    --eval_strategy no \
    --eval_steps 500 \
    --save_strategy "epoch" \
    --logging_steps 5 \
    --deepspeed zero3 \
    --max_length 16384 \
    --warmup_ratio 0.05 \
    --dataloader_num_workers 16 \
    --dataset_num_proc 8 \
    --save_only_model true \
    --output_dir output/${run_name} \
    --run_name ${run_name} \
    --report_to wandb \
    --attn_impl flash_attn \
    --freeze_llm false \
    --freeze_vit true \
    --freeze_aligner true \
    --tags "" 


# find ckpt
mapfile -t checkpoint < <(find ./output/${run_name} -maxdepth 2 -type d -name '*check*')
echo $checkpoint

# merge 
swift export \
    --adapters $checkpoint \
    --merge_lora true \
    --output_dir output/${run_name}/full \
    --device_map 'cpu'  
pt_ckpt="output/${run_name}/full"

