#!/usr/bin/env bash

CONFIG=$1
CHECKPOINT=$2
shift 2

# —— 从环境变量读取多机配置，或使用默认值
NNODES=${ARNOLD_WORKER_NUM:-1}
NODE_RANK=${ARNOLD_ID:-0}
MASTER_ADDR=${METIS_WORKER_0_HOST:-"127.0.0.1"}
MASTER_PORT=${METIS_WORKER_0_PORT:-29500}

# —— 解析 DEVICES：如果下一个参数是数字，就当作设备数并 shift；否则自动探测
if [ $# -ge 1 ] && [[ "$1" =~ ^[0-9]+$ ]]; then
  DEVICES=$1
  shift
else
  # 1) 优先检测 NVIDIA GPU
  if command -v nvidia-smi &> /dev/null; then
    DEVICES=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l | tr -d ' ')
  # 2) 再检测华为 Ascend NPU：npu-smi
  elif command -v npu-smi &> /dev/null; then
    DEVICES=$(npu-smi info --list | grep -c "Device Id")
  # 3) 或通过环境变量 ASCEND_VISIBLE_DEVICES
  elif [ -n "${ASCEND_VISIBLE_DEVICES-}" ]; then
    DEVICES=$(echo "$ASCEND_VISIBLE_DEVICES" | tr ',' '\n' | wc -l | tr -d ' ')
  else
    # 都不行时回退到 1
    DEVICES=1
  fi
fi

# —— 计算全局总设备数（可选，仅供日志或动态调整用）
TOTAL_DEVICES=$(( DEVICES * NNODES ))

# 剩余所有参数都传给 test.py
TEST_ARGS=("$@")

echo "Launching inference with:"
echo "  CONFIG        = $CONFIG"
echo "  CHECKPOINT    = $CHECKPOINT"
echo "  NNODES        = $NNODES"
echo "  NODE_RANK     = $NODE_RANK"
echo "  MASTER_ADDR   = $MASTER_ADDR"
echo "  MASTER_PORT   = $MASTER_PORT"
echo "  DEVICES/node  = $DEVICES"
echo "  TOTAL_DEVICES = $TOTAL_DEVICES"
echo "  EXTRA ARGS    = ${TEST_ARGS[*]}"

PYTHONPATH="$(dirname "$0")/..":$PYTHONPATH \
torchrun \
  --nnodes="$NNODES" \
  --nproc_per_node="$DEVICES" \
  --node_rank="$NODE_RANK" \
  --master_addr="$MASTER_ADDR" \
  --master_port="$MASTER_PORT" \
  "$(dirname "$0")/test.py" \
  "$CONFIG" \
  "$CHECKPOINT" \
  --launcher pytorch \
  "${TEST_ARGS[@]}"
