#!/bin/bash

CKPT_ROOT=your_save_path

while true; do
    LATEST_CKPT=$(python find_latest_ckpt.py $CKPT_ROOT)
    if [ -n "$LATEST_CKPT" ]; then
        echo "✅ 找到最新 checkpoint: $LATEST_CKPT"
        accelerate launch \
            --num_processes 8 \
            --multi_gpu \
            --main_process_port=$RANDOM \
            main.py \
            +experiments=[large_scale_train] \
            debug=true \
            loader.batch_size=1 \
            data.data_dir_train=your_dataset_path \
            data.data_dir_val=your_dataset_path \
            model.vqgan_config=your_vqgan_config_path \
            model.vqgan_ckpt=your_vqgan_ckpt_path \
            model.llama_ckpt=your_llama_ckpt_path \
            model.liquid_ckpt=your_liquid_ckpt_path \
            trainer.load_from_state_dict="$LATEST_CKPT"
    else
        echo "⚠️ 没有找到 checkpoint，开始新的训练..."
        accelerate launch \
            --num_processes 8 \
            --multi_gpu \
            --main_process_port=$RANDOM \
            main.py \
            +experiments=[large_scale_train] \
            debug=true \
            loader.batch_size=1 \
            data.data_dir_train=your_dataset_path \
            data.data_dir_val=your_dataset_path \
            model.vqgan_config=your_vqgan_config_path \
            model.vqgan_ckpt=your_vqgan_ckpt_path \
            model.llama_ckpt=your_llama_ckpt_path \
            model.liquid_ckpt==your_liquid_ckpt_path
    fi

    echo "⏳ 训练进程结束，等待 10 秒后重启..."
    sleep 10
done
