#!/bin/bash

MASTER_ADDR=""
PORT=""
RANK=""

while [[ $# -gt 0 ]]; do
    case $1 in
        --master_addr) MASTER_ADDR=$2; shift 2 ;;
        --port) PORT=$2; shift 2 ;;
        --rank) RANK=$2; shift 2 ;;
        *) shift ;;
    esac
done

echo "[Node$RANK] Waiting for remote vLLM at $MASTER_ADDR ..."

ports=(5000 5001 5002 5003 5004 5005 5006 5007)

for port in "${ports[@]}"; do
    while true; do
        status=$(curl -s -o /dev/null -w "%{http_code}" http://$MASTER_ADDR:$port/health)
        if [[ "$status" == "200" ]]; then
            echo "[Node$RANK] vLLM port $port READY."
            break
        fi
        sleep 2
    done
done


export RANK0_ADDR=$MASTER_ADDR

export RUN_ID=$(date +%s%N)
export STORAGE_PATH="test"
export VLLM_DISABLE_COMPILE_CACHE=1





export Model_abbr="Qwen3-8B-Base"
export Ch_model='test'
questioner_model_dir="${Model_abbr}_questioner_v3_random"



temp_results_dir="$STORAGE_PATH/temp_results"
model_dir="$STORAGE_PATH/models"


mkdir -p "$temp_results_dir"
mkdir -p "$model_dir/$questioner_model_dir"


export train_data_set='train.jsonl'
export val_data_set='val.json'
echo "Start training questioner: $Ch_model -> $questioner_model_dir"


CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 -m verl.trainer.main \
    config=examples/config.yaml \
    data.max_response_length=4096 \
    worker.actor.model.model_path=$Ch_model \
    trainer.experiment_name=$questioner_model_dir \
    trainer.save_checkpoint_path=${STORAGE_PATH}/models/$questioner_model_dir \
    worker.reward.reward_function=./examples/reward_function/caller_penalty.py:compute_score \
    data.train_files=$train_data_set \
    data.val_files=$val_data_set  \
    trainer.total_epochs=1 \
    trainer.val_freq=-1 \
    trainer.n_gpus_per_node=8 \
    data.format_prompt=./examples/format_prompt/questioner_w_memory.jinja \
    data.rollout_batch_size=32 \
    worker.actor.global_batch_size=32\
    worker.rollout.n=8 \
    worker.actor.micro_batch_size_per_device_for_update=8 \
    worker.actor.rollout_n=8 \
    data.val_batch_size=4 \
    data.answer_key=problem \
    worker.actor.use_sparse_kl=false \
    worker.rollout.temperature=0.7 \
    trainer.save_freq=5

sleep 5

# 合并模型
echo "merging model"
python scripts/model_merger.py --local_dir ${STORAGE_PATH}/models/$save_path/global_step_30/actor
python scripts/model_merger.py --local_dir ${STORAGE_PATH}/models/$save_path/global_step_60/actor

pkill python










