#!/usr/bin/env bash

CONFIG=$1
shift

# 从环境变量读取多机配置，或使用默认值
NNODES=${ARNOLD_WORKER_NUM:-1}
NODE_RANK=${ARNOLD_ID:-0}
MASTER_ADDR=${METIS_WORKER_0_HOST:-"127.0.0.1"}
MASTER_PORT=${METIS_WORKER_0_PORT:-29500}

# —— 解析 GPUS/NPU：如果下一个参数是数字，就当作设备数并 shift；否则自动探测
if [ $# -ge 1 ] && [[ "$1" =~ ^[0-9]+$ ]]; then
  DEVICES=$1
  shift
  DETECTED_TYPE="user-specified"
  MODEL_LIST=""
else
  if command -v nvidia-smi &> /dev/null; then
    # 获取 GPU 型号列表
    MODEL_LIST=$(nvidia-smi --query-gpu=name --format=csv,noheader)
    # 统计卡数
    DEVICES=$(echo "$MODEL_LIST" | wc -l | tr -d ' ')
    DETECTED_TYPE="NVIDIA GPU"
  elif command -v npu-smi &> /dev/null; then
    MODEL_LIST=$(npu-smi info --list | grep "Device Id" || true)
    DEVICES=$(echo "$MODEL_LIST" | wc -l | tr -d ' ')
    DETECTED_TYPE="Ascend NPU"
  elif [ -n "${ASCEND_VISIBLE_DEVICES-}" ]; then
    MODEL_LIST=$(echo "$ASCEND_VISIBLE_DEVICES" | tr ',' '\n')
    DEVICES=$(echo "$MODEL_LIST" | wc -l | tr -d ' ')
    DETECTED_TYPE="ASCEND_VISIBLE_DEVICES"
  # 2) 再检测华为 Ascend NPU：npu-smi
  else
    # 都不行时回退到 1
    DEVICES=1
    DETECTED_TYPE="fallback"
    MODEL_LIST="unknown"
  fi
fi

# 打印检测到的设备信息
echo "================ Device Info ================"
echo "  Device type   : ${DETECTED_TYPE}"
echo "  Device count  : ${DEVICES}"
if [ -n "${MODEL_LIST}" ]; then
  echo "  Model list    :"
  echo "${MODEL_LIST}" | sed 's/^/    - /'
fi
echo "============================================="

# 剩余所有参数都传给 train.py
TRAIN_ARGS=("$@")

# 启动分布式训练
PYTHONPATH="$(dirname "$0")/..":$PYTHONPATH \
torchrun \
  --nnodes="$NNODES" \
  --nproc_per_node="$DEVICES" \
  --node_rank="$NODE_RANK" \
  --master_addr="$MASTER_ADDR" \
  --master_port="$MASTER_PORT" \
  "$(dirname "$0")/train.py" \
  "$CONFIG" \
  --launcher pytorch \
  "${TRAIN_ARGS[@]}"
