#!/usr/bin/env bash
CONFIG=$1
shift

# —— 多机配置（不变） ——
NNODES=${ARNOLD_WORKER_NUM:-1}
NODE_RANK=${ARNOLD_ID:-0}
MASTER_ADDR=${METIS_WORKER_0_HOST:-"127.0.0.1"}
MASTER_PORT=${METIS_WORKER_0_PORT:-29500}

# —— 设备探测 ——
npu-smi info

if [ $# -ge 1 ] && [[ "$1" =~ ^[0-9]+$ ]]; then
  DEVICES=$1; shift
  DETECTED_TYPE="user-specified"
  MODEL_LIST=""
else
  if command -v npu-smi &>/dev/null; then
    MODEL_LIST=$(
      npu-smi info \
      | grep -E '^\|\s*[0-9]+[[:space:]]+[A-Za-z]' \
      | sed -E 's/^\|\s*([0-9]+).*/\1/'
    )
    DEVICES=$(echo "$MODEL_LIST" | wc -l | tr -d ' ')
    DETECTED_TYPE="Ascend NPU"
  elif [ -n "${ASCEND_VISIBLE_DEVICES-}" ]; then
    MODEL_LIST=$(echo "$ASCEND_VISIBLE_DEVICES" | tr ',' '\n')
    DEVICES=$(echo "$MODEL_LIST" | wc -l | tr -d ' ')
    DETECTED_TYPE="ASCEND_VISIBLE_DEVICES"
  else
    DEVICES=1
    DETECTED_TYPE="fallback"
    MODEL_LIST="0"
  fi
fi

# —— 把真正的 ID 列表导出给子进程 ——  
# 这样 torchrun 分配 local_rank=0,1,… 时，都会按顺序映射到 MODEL_LIST 中的真实 NPU ID。
if [ -n "$MODEL_LIST" ]; then
  # 逗号拼成一行： e.g. "0,3"
  VISIBLE=$(echo "$MODEL_LIST" | paste -sd, -)
  export ASCEND_VISIBLE_DEVICES="$VISIBLE"
fi

# —— 输出检测到的信息 ——
echo "================ Device Info ================"
echo "  Device type   : ${DETECTED_TYPE}"
echo "  Device count  : ${DEVICES}"
if [ -n "${MODEL_LIST}" ]; then
  echo "  Model list    :"
  echo "${MODEL_LIST}" | sed 's/^/    - /'
  echo "  -> exported ASCEND_VISIBLE_DEVICES=${ASCEND_VISIBLE_DEVICES}"
fi
echo "============================================="

# —— 启动分布式训练 ——
TRAIN_ARGS=("$@")
PYTHONPATH="$(dirname "$0")/..":$PYTHONPATH \
torchrun \
  --nnodes="$NNODES" \
  --nproc_per_node="$DEVICES" \
  --node_rank="$NODE_RANK" \
  --master_addr="$MASTER_ADDR" \
  --master_port="$MASTER_PORT" \
  "$(dirname "$0")/train.py" \
  "$CONFIG" \
  --launcher pytorch \
  "${TRAIN_ARGS[@]}"
