#!/usr/bin/env bash
#
# Helper launcher for running the qwen25vl_3b_car recipe under an MPI job on a
# single node with 8 GPUs.  Only MPI rank 0 performs work: it bootstraps a Ray
# head locally, proxies the training call to the original script (forcing it to
# attach to the external Ray cluster), then tears everything down.  Other MPI
# ranks simply block until training completes so the MPI job can exit cleanly.

set -Eeuo pipefail

SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
PROJECT_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
BASE_SCRIPT="${SCRIPT_DIR}/qwen25vl_3b_car.sh"

if [[ ! -x "${BASE_SCRIPT}" ]]; then
  echo "Base script not found or not executable: ${BASE_SCRIPT}" >&2
  exit 1
fi

# Shared state directory (override via RAY_MPI_STATE_DIR if necessary)
STATE_DIR="${RAY_MPI_STATE_DIR:-${PROJECT_ROOT}/.ray_mpi_state}"
mkdir -p "${STATE_DIR}"
HEAD_FILE="${STATE_DIR}/head_single_node.txt"
EXIT_FILE="${STATE_DIR}/exit_single_node.txt"

RANK="${OMPI_COMM_WORLD_RANK:-${PMI_RANK:-0}}"

detect_gpu_count() {
  if [[ -n "${NUM_LOCAL_GPUS:-}" ]]; then
    echo "${NUM_LOCAL_GPUS}"
    return
  fi
  if command -v nvidia-smi >/dev/null 2>&1; then
    local count
    count="$(nvidia-smi -L 2>/dev/null | wc -l | tr -d ' ')"
    if [[ "${count}" =~ ^[0-9]+$ && "${count}" -gt 0 ]]; then
      echo "${count}"
      return
    fi
  fi
  if [[ -n "${CUDA_VISIBLE_DEVICES:-}" ]]; then
    IFS=',' read -r -a arr <<< "${CUDA_VISIBLE_DEVICES}"
    echo "${#arr[@]}"
    return
  fi
  echo 8
}

start_head() {
  local head_ip="$1"
  local ray_port="$2"
  local dash_port="$3"
  local num_gpus="$4"

  ray stop --force >/dev/null 2>&1 || true

  ray start --head \
    --node-ip-address="${head_ip}" \
    --port="${ray_port}" \
    --dashboard-host="0.0.0.0" \
    --dashboard-port="${dash_port}" \
    --num-gpus="${num_gpus}" \
    --disable-usage-stats >/dev/null
}

HEAD_PORT="${RAY_PORT:-6379}"
DASHBOARD_PORT="${RAY_DASHBOARD_PORT:-8265}"
HEAD_IP="${RAY_HEAD_IP:-$(hostname -I | awk '{print $1}')}"
NUM_GPUS="$(detect_gpu_count)"

if [[ "${RANK}" == "0" ]]; then
  rm -f "${HEAD_FILE}" "${EXIT_FILE}"

  if [[ -z "${HEAD_IP}" ]]; then
    echo "Unable to determine head IP; set RAY_HEAD_IP explicitly." >&2
    exit 1
  fi

  echo "[MPI-1node] Starting Ray head at ${HEAD_IP}:${HEAD_PORT} (GPUs=${NUM_GPUS})" >&2
  start_head "${HEAD_IP}" "${HEAD_PORT}" "${DASHBOARD_PORT}" "${NUM_GPUS}"
  echo "${HEAD_IP}:${HEAD_PORT}" > "${HEAD_FILE}"
  sync

  export RAY_ADDRESS="auto"

  set +e
  bash "${BASE_SCRIPT}" \
    ray_kwargs.ray_init.address=auto \
    ray_kwargs.ray_init._node_ip_address="${HEAD_IP}" \
    "$@"
  status=$?
  set -e

  echo "${status}" > "${EXIT_FILE}"
  sync

  ray stop --force >/dev/null 2>&1 || true
  exit "${status}"
else
  echo "MPI rank ${RANK} waiting for training to finish..." >&2
  while [[ ! -f "${EXIT_FILE}" ]]; do
    sleep 10
  done
  status=$(cat "${EXIT_FILE}")
  exit "${status:-0}"
fi
