
# sft android control 

nnodes=2
nproc_per_node=8

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
MAX_PIXELS=2058400 \
NNODES=$nnodes \
NODE_RANK=${RANK} \
MASTER_ADDR=${MASTER_ADDR} \
MASTER_PORT=${MASTER_PORT} \
NPROC_PER_NODE=$nproc_per_node \
swift sft \
    --model ${pt_ckpt} \
    --train_type lora \
    --target_modules all-linear \
    --lora_rank 8 \
    --lora_alpha 16 \
    --learning_rate 4e-4 \
    --num_train_epochs 2.0 \
    --dataset "/swift-format-data/ac/t5_train_high.jsonl" \
    --torch_dtype bfloat16 \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps 8 \
    --packing false \
    --streaming false \
    --lazy_tokenize true \
    --eval_strategy no \
    --eval_steps 500 \
    --save_strategy "epoch" \
    --save_total_limit 2 \
    --logging_steps 5 \
    --deepspeed zero3 \
    --max_length 16384 \
    --warmup_ratio 0.05 \
    --dataloader_num_workers 16 \
    --dataset_num_proc 8 \
    --save_only_model true \
    --output_dir /${run_name} \
    --run_name ${run_name} \
    --report_to wandb \
    --attn_impl flash_attn \
    --freeze_llm false \
    --freeze_vit true \
    --freeze_aligner true \
    --tags ""


mapfile -t checkpoints < <(find /${run_name} -maxdepth 2 -type d -name '*check*')
sorted_checkpoints=($(for ckpt in "${checkpoints[@]}"; do
    echo "$(basename "$ckpt" | cut -d'-' -f2) $ckpt"
done | sort -n | awk '{print $2}'))

last_idx=1
checkpoint=${sorted_checkpoints[$last_idx]}
echo $checkpoint

# merge 
swift export \
    --adapters $checkpoint \
    --merge_lora true \
    --output_dir ${run_name}/full \
    --device_map 'cpu'  
infer_ckpt="${run_name}/full"

