#!/usr/bin/env bash
set -euo pipefail

# 用法：
#   bash run_sft_with_mempeak.sh
#   bash run_sft_with_mempeak.sh 0            # 只监控 GPU 0
#   bash run_sft_with_mempeak.sh all 0.2      # 监控所有 GPU，采样间隔 0.2s
#
# 参数：
#   $1: GPU 号（默认 all），或具体数字（如 0 / 1）
#   $2: 采样间隔秒数（默认 0.2）
GPU_SEL="${1:-all}"
INTERVAL="${2:-0.1}"

if ! command -v nvidia-smi >/dev/null 2>&1; then
  echo "[ERROR] 未找到 nvidia-smi。这个脚本用于 NVIDIA GPU 环境。" >&2
  exit 127
fi

# 日志目录与文件
LOG_DIR="${LOG_DIR:-./logs}"
mkdir -p "$LOG_DIR"
TS="$(date +%Y%m%d_%H%M%S)"
LOG_FILE="${LOG_FILE:-$LOG_DIR/gpu_mem_${TS}.csv}"

echo "timestamp,gpu_index,mem_used_mib,peak_gpu_mib,peak_overall_mib" > "$LOG_FILE"

# 启动训练脚本
echo "[INFO] Launch: bash run_standard_sft.sh"
bash scripts/run_weight_sft.sh &
SFT_PID=$!

# 信号处理：Ctrl+C / kill 时，结束子进程
cleanup() {
  echo "[INFO] Caught signal, stopping sft.sh (pid=$SFT_PID) ..." >&2
  kill -TERM "$SFT_PID" 2>/dev/null || true
}
trap cleanup INT TERM

# 峰值统计
declare -A PEAK_BY_GPU
PEAK_OVERALL=0

init_gpu_list() {
  mapfile -t ALL_GPUS < <(nvidia-smi --query-gpu=index --format=csv,noheader,nounits | tr -d ' ')
  if [[ "${#ALL_GPUS[@]}" -eq 0 ]]; then
    echo "[ERROR] nvidia-smi 未返回 GPU 列表。" >&2
    exit 2
  fi
  for g in "${ALL_GPUS[@]}"; do
    PEAK_BY_GPU["$g"]=0
  done
}

query_all() {
  # 输出多行： "idx, used"
  nvidia-smi --query-gpu=index,memory.used --format=csv,noheader,nounits \
    | sed 's/ //g'
}

query_one() {
  local g="$1"
  nvidia-smi -i "$g" --query-gpu=index,memory.used --format=csv,noheader,nounits \
    | sed 's/ //g'
}

init_gpu_list

echo "[INFO] Monitoring GPU memory... (GPU_SEL=$GPU_SEL, interval=${INTERVAL}s)"
echo "[INFO] Log: $LOG_FILE"

while kill -0 "$SFT_PID" >/dev/null 2>&1; do
  NOW="$(date +%F' '%T)"

  if [[ "$GPU_SEL" == "all" ]]; then
    while IFS=',' read -r IDX USED; do
      [[ -z "${IDX:-}" || -z "${USED:-}" ]] && continue
      # 更新该 GPU 峰值
      if (( USED > PEAK_BY_GPU["$IDX"] )); then
        PEAK_BY_GPU["$IDX"]="$USED"
      fi
      # 更新整体峰值
      if (( USED > PEAK_OVERALL )); then
        PEAK_OVERALL="$USED"
      fi
      echo "${NOW},${IDX},${USED},${PEAK_BY_GPU[$IDX]},${PEAK_OVERALL}" >> "$LOG_FILE"
    done < <(query_all)
  else
    # 只监控单个 GPU
    while IFS=',' read -r IDX USED; do
      [[ -z "${IDX:-}" || -z "${USED:-}" ]] && continue
      if (( USED > PEAK_BY_GPU["$IDX"] )); then
        PEAK_BY_GPU["$IDX"]="$USED"
      fi
      if (( USED > PEAK_OVERALL )); then
        PEAK_OVERALL="$USED"
      fi
      echo "${NOW},${IDX},${USED},${PEAK_BY_GPU[$IDX]},${PEAK_OVERALL}" >> "$LOG_FILE"
    done < <(query_one "$GPU_SEL")
  fi

  sleep "$INTERVAL"
done

# 等待 sft.sh 结束并拿到退出码
wait "$SFT_PID" || SFT_RC=$? || true
SFT_RC="${SFT_RC:-0}"

echo ""
echo "[RESULT] sft.sh exit code: $SFT_RC"
echo "[RESULT] Peak overall GPU memory.used: ${PEAK_OVERALL} MiB"

if [[ "$GPU_SEL" == "all" ]]; then
  echo "[RESULT] Peak per GPU:"
  for g in "${ALL_GPUS[@]}"; do
    echo "  GPU $g: ${PEAK_BY_GPU[$g]} MiB"
  done
else
  echo "[RESULT] Peak GPU $GPU_SEL: ${PEAK_BY_GPU[$GPU_SEL]} MiB"
fi

echo "[RESULT] Log saved to: $LOG_FILE"
exit "$SFT_RC"
