#!/bin/bash

# ============ 读取配置 ============
solver_model_path=$1
questioner_model_path=$2
save_path=$3
iteration=$4
resume_checkpoint_path=${5:-""}  # 可选：断点续训的 checkpoint 路径 (应以 global_step_* 结尾)
echo "save_path: $save_path"

# 生成唯一 RUN_ID
RUN_ID=$(date +%s%N)
export RUN_ID
echo "RUN_ID=$RUN_ID"

echo "=========================================="
echo "[Questioner Training] Some Configurations"
echo "  Experiment Name: $save_path"
echo "  Questioner Model Path: $questioner_model_path"
echo "  Solver Model Path: $solver_model_path"
echo "  Iteration: $iteration"
echo "  Learning Rate: ${QUESTIONER_LR:-1e-6}"
if [ -n "$resume_checkpoint_path" ]; then
    echo "  Resume Checkpoint: $resume_checkpoint_path"
fi
echo "=========================================="
echo "  Embedding Type: $EMBEDDING_TYPE"
echo "  Keep Batch Penalty Unchanged: $KEEP_BATCH_PENALTY_UNCHANGED"
echo "=========================================="

# ============ GPU 分配（时分复用模式） ============
# 在时分复用模式下，所有GPU都用于Trainer和vLLM服务的时分复用
# vLLM服务启动后立即进入sleep状态，释放GPU显存
# 当需要计算reward时，vLLM服务被唤醒，Trainer暂停
# 这样可以充分利用所有GPU资源

TRAINER_GPU_COUNT=$TOTAL_GPU_COUNT
TRAINER_GPUS=$(seq -s, 0 $((TRAINER_GPU_COUNT - 1)))

echo "=========================================="
echo "[时分复用模式] GPU Allocation"
echo "  Total GPUs: $TOTAL_GPU_COUNT"
echo "  Trainer GPUs: $TRAINER_GPUS (count: $TRAINER_GPU_COUNT)"
echo "  vLLM Services: All GPUs (time-shared with sleep mode)"
echo "=========================================="

# ============ 启动 vLLM 服务（sleep模式） ============
# vLLM服务会在所有GPU上启动，但立即进入sleep状态
# 这样不会占用GPU显存，直到被唤醒
echo "[vLLM] Starting Solver vLLM services (will sleep immediately)..."
bash vllm_service_init/start.sh $solver_model_path $RUN_ID 0.8
echo "Solver vLLM services started with RUN_ID=$RUN_ID (in sleep mode)"

# 如果是code模式，启动代码生成vLLM服务
# 注意：必须等待 Solver vLLM 服务完全进入 sleep 状态后再启动 Code vLLM
# 否则会导致 vLLM memory profiling 时检测到显存变化而报错
if [ "$EMBEDDING_TYPE" = "code" ]; then
    echo "[vLLM] Verifying Solver vLLM services are fully sleeping before starting Code vLLM..."
    
    # 主动检查所有 Solver vLLM 服务是否已进入 sleep 状态
    MAX_WAIT=180  # 最长等待 3 分钟
    WAIT_INTERVAL=10
    WAITED=0
    N_SERVICES=${TOTAL_GPU_COUNT:-8}
    BASE_PORT=5000
    
    while [ $WAITED -lt $MAX_WAIT ]; do
        ALL_SLEEPING=true
        SLEEPING_COUNT=0
        
        for i in $(seq 0 $((N_SERVICES - 1))); do
            port=$((BASE_PORT + i))
            response=$(curl -s "http://127.0.0.1:$port/is_sleeping" 2>/dev/null)
            # 使用更宽松的匹配，兼容紧凑JSON格式 {"is_sleeping":true} 和 {"is_sleeping": true}
            if echo "$response" | grep -qE '"is_sleeping":\s*true'; then
                SLEEPING_COUNT=$((SLEEPING_COUNT + 1))
            else
                ALL_SLEEPING=false
            fi
        done
        
        echo "[vLLM] Solver vLLM status: $SLEEPING_COUNT/$N_SERVICES services sleeping (waited ${WAITED}s)"
        
        if [ "$ALL_SLEEPING" = true ]; then
            echo "[vLLM] All Solver vLLM services confirmed sleeping!"
            break
        fi
        
        sleep $WAIT_INTERVAL
        WAITED=$((WAITED + WAIT_INTERVAL))
    done
    
    if [ $WAITED -ge $MAX_WAIT ]; then
        echo "[vLLM] WARNING: Timeout waiting for Solver vLLM services to sleep, continuing anyway..."
    fi
    
    # 额外等待 30 秒，确保 GPU 显存完全稳定
    echo "[vLLM] Waiting 30s for GPU memory to fully stabilize..."
    sleep 30
    
    echo "[vLLM] Starting Code vLLM services (will sleep immediately)..."
    bash vllm_service_init/start_code.sh $QUESTION_TO_CODE_MODEL_PATH $RUN_ID 0.8
    echo "Code vLLM services started with RUN_ID=$RUN_ID (in sleep mode)"
fi

# ============ 开始训练 Questioner ============
echo "Start training questioner: $questioner_model_path -> $save_path"

# 构建断点续训参数
RESUME_ARG=""
if [ -n "$resume_checkpoint_path" ]; then
    echo "Resuming from checkpoint: $resume_checkpoint_path"
    RESUME_ARG="trainer.load_checkpoint_path=$resume_checkpoint_path"
fi

# 时分复用模式下，Trainer使用所有GPU
# 外部vLLM服务使用sleep模式，wake_up时Trainer暂停，因此可以使用较高的gpu_memory_utilization
CUDA_VISIBLE_DEVICES=$TRAINER_GPUS python3 -m verl.trainer.main \
    config=examples/config.yaml \
    data.max_response_length=4096 \
    worker.actor.model.model_path=$questioner_model_path \
    trainer.experiment_name=$save_path \
    trainer.save_checkpoint_path=${STORAGE_PATH}/models/$save_path \
    trainer.total_epochs=1000 \
    worker.reward.reward_function=./examples/reward_function/caller_penalty.py:compute_score \
    trainer.val_freq=-1 \
    trainer.val_before_train=false \
    trainer.n_gpus_per_node=$TRAINER_GPU_COUNT \
    data.format_prompt=./examples/format_prompt/questioner.jinja \
    worker.rollout.n=4 \
    worker.rollout.gpu_memory_utilization=0.8 \
    worker.actor.global_batch_size=128 \
    worker.actor.optim.lr=${QUESTIONER_LR:-1e-6} \
    trainer.max_steps=5 \
    trainer.save_freq=1 \
    trainer.save_limit=1 \
    $RESUME_ARG

sleep 10

# 合并模型并清理分片文件
echo "merging model and cleanup shards"
python scripts/model_merger.py --local_dir ${STORAGE_PATH}/models/$save_path/global_step_5/actor
# python scripts/model_merger.py --local_dir ${STORAGE_PATH}/models/$save_path/global_step_5/actor --cleanup
# python scripts/model_merger.py --local_dir ${STORAGE_PATH}/models/$save_path/global_step_4/actor --cleanup &
# python scripts/model_merger.py --local_dir ${STORAGE_PATH}/models/$save_path/global_step_6/actor --cleanup &

sleep 20

pkill python
pkill -f vllm

echo "questioner training finished"
