set -x

export NCCL_IB_HCA=mlx5_bond_1,mlx5_bond_2,mlx5_bond_3,mlx5_bond_4
export NCCL_IB_TC=136
export NCCL_IB_SL=5
export NCCL_IB_GID_INDEX=3
export NCCL_SOCKET_IFNAME=en,eth,em,bond
export NCCL_DEBUG=INFO

project_name='Test'
exp_name='DLC_2nodes_ray'
export WANDB_API_KEY=50720d0546e5277bc2ac6e8ec92a78aa30ac8112

# 获取环境变量
MASTER_ADDR=${MASTER_ADDR}
MASTER_PORT=${MASTER_PORT}
WORLD_SIZE=${WORLD_SIZE}
RANK=${RANK}
GPUS_PER_NODE=8 # 根据实际集群环境修改
CPUS_PER_TASK=64 # 根据实际集群环境修改
total_num_gpu=$((GPUS_PER_NODE * WORLD_SIZE))

# Ray head node port
RAY_PORT=6379
ip_head=${MASTER_ADDR}:${RAY_PORT}
RAY_CLUSTER_ADDRESS=${ip_head}
RAY_NODE_WAIT_TIMEOUT=300

source /nas/shared/sys2/liyizhi/conda_init.sh
conda activate verl_public
cd /nas/shared/sys2/liyizhi/verl_public

# Set Model Path for training
# This path will be used by the main_ppo script
MODEL_PATH="/nas/shared/sys2/liyizhi/models/Qwen2.5-Coder-7B" # Adjust if you want to use the 7B model or another



if [ "$RANK" -eq 0 ]; then
    # Head Node (RANK 0)
    echo "Node ${RANK}: Starting Ray HEAD at ${MASTER_ADDR}"
    # Use `docker exec` to run ray start inside the container 
    ray start --head --node-ip-address="${MASTER_ADDR}" --port=${RAY_PORT} \
    --dashboard-port=8266 \
    --num-cpus "${CPUS_PER_TASK}" --num-gpus "${GPUS_PER_NODE}" \
    --dashboard-host=0.0.0.0 \
    --block & # Run in background, `--block` seems necessary for some Ray versions to wait for init? Testing needed. If causes issues, remove --block.
    # Allow head node time to start
    echo "Node ${RANK}: Waiting for Ray head to initialize..."
    sleep 20 # Increased sleep

    # Optional: Add check for Ray head status
    # docker exec "${CONTAINER_NAME}" ray status

else
    # Worker Nodes (RANK > 0)
    echo "Node ${RANK}: Waiting for Ray head to be ready at ${ip_head}..."
    # Simple sleep to wait for the head node. Robustness could be improved.
    sleep 30 # Workers wait longer

    # TODO: check ray status
    # ray status --address="$ip_head"

    echo "Node ${RANK}: Starting Ray WORKER, connecting to ${ip_head}"
    ray start --address "$ip_head" \
    --num-cpus "${CPUS_PER_TASK}" --num-gpus "${GPUS_PER_NODE}" \
    --block & # Run in background
    echo "Node ${RANK}: Ray worker initiated connection."
    sleep 10 # Allow worker time to connect
fi


if [ "$RANK" -eq 0 ]; then
    echo "Node ${RANK}: Waiting for ${WORLD_SIZE} nodes to join the Ray cluster..."
    start_time=$(date +%s)
    while true; do
        # Use python inside container to check node count
        current_nodes=$(python -c 'import ray; ray.init(address="auto", ignore_reinit_error=True); print(len(ray.nodes())); ray.shutdown()' 2>/dev/null || echo 0)

        if [[ "$current_nodes" -ge "$WORLD_SIZE" ]]; then
            echo "Node ${RANK}: Detected ${current_nodes} nodes. All nodes joined!"
            break
        fi

        current_time=$(date +%s)
        elapsed_time=$((current_time - start_time))
        if [[ "$elapsed_time" -ge "$RAY_NODE_WAIT_TIMEOUT" ]]; then
            echo "Node ${RANK}: Timeout waiting for all Ray nodes to join. Found ${current_nodes} nodes."
            python -c 'import ray; ray.init(address="auto", ignore_reinit_error=True); print(ray.nodes()); ray.shutdown()' || true
            exit 1
        fi

        echo "Node ${RANK}: Currently ${current_nodes}/${WORLD_SIZE} nodes connected. Waiting..."
        sleep 5
    done

    # check_cluster_status() {
    #     echo `ray status --address $RAY_CLUSTER_ADDRESS`
    #     ray status --address $RAY_CLUSTER_ADDRESS | grep -q "0.0/${total_num_gpu}.0 GPU"
    # }

    # while ! check_cluster_status; do
    #     echo "Waiting for sufficient workers..."
    #     sleep 10  # Adjust the sleep interval as needed
    # done

    echo "Node ${RANK}: Testing Ray initialization..."
    python3 -c '
import ray
try:
    ray.init(address="auto", ignore_reinit_error=True)
    print("\n=== Ray Cluster Status ===")
    nodes = ray.nodes()
    print(f"Number of nodes: {len(nodes)}")
    for node in nodes:
        status = "Alive" if node["Alive"] else "Dead"
        print(f"  Node: {node.get(\"NodeManagerHostname\", \"N/A\")}, Status: {status}, Resources: {node.get(\"Resources\", {})}")
    ray.shutdown()
    print("Ray initialization successful!")
except Exception as e:
    import traceback
    print(f"Ray initialization failed: {str(e)}")
    traceback.print_exc()
' || echo "Node ${RANK}: Ray status check command failed."
    echo "Node ${RANK}: === Ray test completed ==="

    # --- Data Preprocessing & Model Download (Run only on RANK 0) ---
    echo "Node ${RANK}: Starting data preprocessing..."
    python3 "examples/data_preprocess/gsm8k.py" "--local_dir" "../data/gsm8k"
    python3 "examples/data_preprocess/math_dataset.py" "--local_dir" "../data/math"

    echo "Node ${RANK}: Downloading and testing base model..."
    # Define the base model path for downloading/testing


    echo "Node ${RANK}: == Data and model loading Done =="
    echo "Node ${RANK}: Preparing for training..."

    # --- Training Launch (Run only on RANK 0) ---
    echo "Node ${RANK}: Starting VERL PPO training..."

    # Note: We are NOT using torchrun here. We run the command on Rank 0,
    # and the verl.trainer.main_ppo script uses the Ray cluster we set up.
    # Ensure PYTHONUNBUFFERED is set for real-time logs
            # --no-wait \
    ray job submit --address="http://127.0.0.1:8266" \
        --runtime-env=verl/trainer/runtime_env.yaml \
        -- \
        python3 -m verl.trainer.main_ppo \
            algorithm.adv_estimator=grpo \
            data.train_files=/nas/shared/sys2/liyizhi/verl_public/data/gsm8k/train.parquet \
            data.val_files=/nas/shared/sys2/liyizhi/verl_public/data/gsm8k/test.parquet \
            data.train_batch_size=1024 \
            data.max_prompt_length=512 \
            data.max_response_length=1024 \
            data.filter_overlong_prompts=True \
            data.truncation='error' \
            actor_rollout_ref.model.path=$MODEL_PATH \
            actor_rollout_ref.actor.optim.lr=1e-6 \
            actor_rollout_ref.model.use_remove_padding=True \
            actor_rollout_ref.actor.ppo_mini_batch_size=256 \
            actor_rollout_ref.actor.use_dynamic_bsz=True \
            actor_rollout_ref.actor.ppo_max_token_len_per_gpu=24000 \
            actor_rollout_ref.actor.use_kl_loss=True \
            actor_rollout_ref.actor.kl_loss_coef=0.001 \
            actor_rollout_ref.actor.kl_loss_type=low_var_kl \
            actor_rollout_ref.actor.entropy_coeff=0 \
            actor_rollout_ref.model.enable_gradient_checkpointing=True \
            actor_rollout_ref.actor.fsdp_config.param_offload=False \
            actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
            actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
            actor_rollout_ref.rollout.name=vllm \
            actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
            actor_rollout_ref.rollout.n=5 \
            actor_rollout_ref.ref.fsdp_config.param_offload=True \
            actor_rollout_ref.rollout.enforce_eager=False \
            actor_rollout_ref.rollout.free_cache_engine=False \
            algorithm.use_kl_in_reward=False \
            trainer.critic_warmup=0 \
            trainer.logger=['console','wandb'] \
            trainer.project_name="${project_name}" \
            trainer.experiment_name="${exp_name}" \
            trainer.val_before_train=True \
            trainer.n_gpus_per_node=${GPUS_PER_NODE} \
            trainer.nnodes=${WORLD_SIZE} \
            trainer.save_freq=-1 \
            trainer.test_freq=5 \
            trainer.total_epochs=15 | tee debug.log

    echo "Node ${RANK}: Training finished."

else
    # --- Worker Nodes Wait ---
    echo "Node ${RANK}: Worker node waiting for training to complete..."
    # Workers just need to keep their Ray process alive inside the container.
    # The `tail -f /dev/null` in the `docker run` command keeps the container running.
    # The script on the worker node can technically exit here, but we'll wait indefinitely
    # to prevent the overall job orchestrator (if any) from thinking this node finished prematurely.
    # This assumes the orchestrator waits for the script on *all* nodes to exit.
    # wait # Wait for background processes like 'ray start &' if they weren't fully daemonized
    # If 'wait' returns immediately, use sleep:
    sleep infinity
    echo "Node ${RANK}: Worker finished waiting."

fi

echo "Node ${RANK}: Script execution finished."
exit 0
