#!/usr/bin/env bash
#
# Bootstraps a two-node (total eight GPU) MPI job for the qwen25vl_3b_car
# recipe.  Rank 0 becomes the Ray head; all other ranks join as workers and
# keep running until training finishes.  The original single-node script is
# reused so that hyper-parameter overrides stay identical.

set -Eeuo pipefail

SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
PROJECT_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
BASE_SCRIPT="${SCRIPT_DIR}/qwen25vl_3b_car.sh"

if [[ ! -x "${BASE_SCRIPT}" ]]; then
  echo "Base script not found or not executable: ${BASE_SCRIPT}" >&2
  exit 1
fi

STATE_DIR="${RAY_MPI_STATE_DIR:-${PROJECT_ROOT}/.ray_mpi_state}"
mkdir -p "${STATE_DIR}"
HEAD_FILE="${STATE_DIR}/head_two_node.txt"
EXIT_FILE="${STATE_DIR}/exit_two_node.txt"

RANK="${OMPI_COMM_WORLD_RANK:-${PMI_RANK:-0}}"
WORLD_SIZE="${OMPI_COMM_WORLD_SIZE:-${PMI_SIZE:-1}}"

detect_gpu_count() {
  if [[ -n "${NUM_LOCAL_GPUS:-}" ]]; then
    echo "${NUM_LOCAL_GPUS}"
    return
  fi
  if command -v nvidia-smi >/dev/null 2>&1; then
    local count
    count="$(nvidia-smi -L 2>/dev/null | wc -l | tr -d ' ')"
    if [[ "${count}" =~ ^[0-9]+$ && "${count}" -gt 0 ]]; then
      echo "${count}"
      return
    fi
  fi
  if [[ -n "${CUDA_VISIBLE_DEVICES:-}" ]]; then
    IFS=',' read -r -a arr <<< "${CUDA_VISIBLE_DEVICES}"
    echo "${#arr[@]}"
    return
  fi
  echo 4
}

wait_for_file() {
  local path="$1"
  local timeout="${2:-300}"
  local waited=0
  while [[ ! -f "${path}" ]]; do
    if [[ "${timeout}" -gt 0 && "${waited}" -ge "${timeout}" ]]; then
      echo "Timeout waiting for ${path}" >&2
      return 1
    fi
    sleep 2
    waited=$((waited + 2))
  done
}

wait_for_head() {
  local host="$1"
  local port="$2"
  python - "$host" "$port" <<'PY'
import socket
import sys
import time

host = sys.argv[1]
port = int(sys.argv[2])

deadline = time.time() + 300
while time.time() < deadline:
    with socket.socket() as sock:
        sock.settimeout(2)
        try:
            sock.connect((host, port))
        except OSError:
            time.sleep(2)
        else:
            sys.exit(0)
sys.exit(1)
PY
}

start_head() {
  local head_ip="$1"
  local ray_port="$2"
  local dash_port="$3"
  local num_gpus="$4"

  ray stop --force >/dev/null 2>&1 || true

  ray start --head \
    --node-ip-address="${head_ip}" \
    --port="${ray_port}" \
    --dashboard-host="0.0.0.0" \
    --dashboard-port="${dash_port}" \
    --num-gpus="${num_gpus}" \
    --disable-usage-stats >/dev/null
}

start_worker() {
  local head_addr="$1"
  local node_ip="$2"
  local num_gpus="$3"

  ray stop --force >/dev/null 2>&1 || true
  ray start --address="${head_addr}" \
    --node-ip-address="${node_ip}" \
    --num-gpus="${num_gpus}" \
    --disable-usage-stats >/dev/null
}

HEAD_PORT="${RAY_PORT:-6379}"
DASHBOARD_PORT="${RAY_DASHBOARD_PORT:-8265}"
LOCAL_IP="${MPI_LOCAL_IP:-$(hostname -I | awk '{print $1}')}"
NUM_GPUS="$(detect_gpu_count)"

if [[ -z "${LOCAL_IP}" ]]; then
  echo "Unable to determine local IP; set MPI_LOCAL_IP." >&2
  exit 1
fi

if [[ "${RANK}" == "0" ]]; then
  rm -f "${HEAD_FILE}" "${EXIT_FILE}"
  HEAD_IP="${RAY_HEAD_IP:-${LOCAL_IP}}"

  echo "[MPI-2node] Rank0 starting Ray head at ${HEAD_IP}:${HEAD_PORT} (GPUs=${NUM_GPUS})" >&2
  start_head "${HEAD_IP}" "${HEAD_PORT}" "${DASHBOARD_PORT}" "${NUM_GPUS}"
  echo "${HEAD_IP}:${HEAD_PORT}" > "${HEAD_FILE}"
  sync

  export RAY_ADDRESS="auto"

  set +e
  bash "${BASE_SCRIPT}" \
    ray_kwargs.ray_init.address=auto \
    ray_kwargs.ray_init._node_ip_address="${HEAD_IP}" \
    "$@"
  status=$?
  set -e

  echo "${status}" > "${EXIT_FILE}"
  sync
  sleep 5

  ray stop --force >/dev/null 2>&1 || true
  exit "${status}"
else
  wait_for_file "${HEAD_FILE}" 300
  head_target=$(cat "${HEAD_FILE}")
  head_host="${head_target%%:*}"
  head_port="${head_target##*:}"

  if ! wait_for_head "${head_host}" "${head_port}"; then
    echo "Failed to reach Ray head at ${head_target}" >&2
    exit 1
  fi

  echo "[MPI-2node] Rank ${RANK} joining Ray at ${head_target} (GPUs=${NUM_GPUS}, IP=${LOCAL_IP})" >&2
  start_worker "${head_target}" "${LOCAL_IP}" "${NUM_GPUS}"

  wait_for_file "${EXIT_FILE}" 0
  status=$(cat "${EXIT_FILE}")

  ray stop --force >/dev/null 2>&1 || true
  exit "${status:-0}"
fi
