#!/bin/bash

# GPU自动监控和训练脚本
# 功能：检测空闲GPU，当有足够空闲GPU时自动启动训练

# 配置参数
TRAIN_SCRIPT="/home/*/workspace/verl-agent/examples/migpo_trainer/run_webshop.sh"
CONDA_ENV="verl-agent-webshop"  # Conda环境名称
CONDA_BASE="$HOME/miniforge3"  # Conda安装路径，根据实际情况修改
CHECK_INTERVAL=10  # 检查间隔（秒）
GPU_UTIL_THRESHOLD=10  # GPU利用率阈值（%）
GPU_MEM_THRESHOLD=10   # GPU显存使用阈值（%）
GPUS_PER_TASK=2  # 每个训练任务使用的GPU数量
LOG_FILE="/tmp/auto_train_webshop_$(date +%Y%m%d_%H%M%S).log"
PID_DIR="/tmp/auto_train_webshop_pids"  # 存储多个训练进程PID的目录

# 创建PID目录
mkdir -p "$PID_DIR"

# 日志函数
log() {
    echo "[$(date '+%Y-%m-%d %H:%M:%S')] $1" | tee -a "$LOG_FILE"
}

# 检测conda安装路径
detect_conda_base() {
    # 方法1: 使用conda info命令
    if command -v conda &> /dev/null; then
        local conda_base=$(conda info --base 2>/dev/null)
        if [ -n "$conda_base" ] && [ -d "$conda_base" ]; then
            echo "$conda_base"
            return 0
        fi
    fi

    # 方法2: 检查常见的conda安装路径
    local common_paths=(
        "$HOME/miniforge3"
        "$HOME/miniconda3"
        "$HOME/anaconda3"
        "/opt/conda"
        "/usr/local/anaconda3"
    )

    for path in "${common_paths[@]}"; do
        if [ -d "$path" ] && [ -f "$path/etc/profile.d/conda.sh" ]; then
            echo "$path"
            return 0
        fi
    done

    return 1
}

# 验证conda环境
validate_conda_env() {
    local conda_base=$1
    local env_name=$2

    # 初始化conda
    if [ -f "$conda_base/etc/profile.d/conda.sh" ]; then
        source "$conda_base/etc/profile.d/conda.sh"
    else
        log "错误: 无法找到conda初始化脚本: $conda_base/etc/profile.d/conda.sh"
        return 1
    fi

    # 检查环境是否存在
    if ! conda env list | grep -q "^${env_name}\s"; then
        log "错误: Conda环境 '$env_name' 不存在"
        log "可用的conda环境:"
        conda env list | tee -a "$LOG_FILE"
        return 1
    fi

    log "Conda环境验证成功: $env_name"
    return 0
}

# 检查nvidia-smi是否可用
if ! command -v nvidia-smi &> /dev/null; then
    log "错误: nvidia-smi 命令不可用"
    exit 1
fi

# 自动检测conda路径
log "正在检测conda安装路径..."
DETECTED_CONDA_BASE=$(detect_conda_base)
if [ $? -eq 0 ]; then
    CONDA_BASE="$DETECTED_CONDA_BASE"
    log "检测到conda安装路径: $CONDA_BASE"
else
    log "警告: 无法自动检测conda路径，使用配置的路径: $CONDA_BASE"
    if [ ! -d "$CONDA_BASE" ]; then
        log "错误: Conda路径不存在: $CONDA_BASE"
        exit 1
    fi
fi

# 验证conda环境
log "正在验证conda环境: $CONDA_ENV"
if ! validate_conda_env "$CONDA_BASE" "$CONDA_ENV"; then
    log "错误: Conda环境验证失败"
    exit 1
fi

# 获取空闲GPU列表（根据实际利用率和显存使用情况判断）
get_free_gpus() {
    local free_gpus=()

    # 获取所有GPU的利用率和显存使用情况
    while IFS=',' read -r gpu_id util mem_used mem_total; do
        # 计算显存使用百分比
        mem_percent=$(awk "BEGIN {printf \"%.0f\", ($mem_used/$mem_total)*100}")

        # 判断GPU是否空闲（根据实际使用情况，不考虑PID文件）
        if [ "$util" -lt "$GPU_UTIL_THRESHOLD" ] && [ "$mem_percent" -lt "$GPU_MEM_THRESHOLD" ]; then
            free_gpus+=("$gpu_id")
        fi
    done < <(nvidia-smi --query-gpu=index,utilization.gpu,memory.used,memory.total --format=csv,noheader,nounits)

    echo "${free_gpus[@]}"
}

# 清理已结束的训练进程和僵尸进程
cleanup_finished_tasks() {
    local cleaned=0

    # 首先清理所有僵尸进程
    # 使用wait -n来清理任何已结束的子进程，避免僵尸进程
    while wait -n 2>/dev/null; do
        : # 继续清理直到没有更多已结束的子进程
    done

    for pid_file in "$PID_DIR"/*.pid; do
        [ -e "$pid_file" ] || continue

        local pid=$(grep "^PID=" "$pid_file" | cut -d'=' -f2)
        if [ -n "$pid" ]; then
            # 检查进程是否还在运行
            if ! ps -p "$pid" > /dev/null 2>&1; then
                local gpus=$(grep "^GPUS=" "$pid_file" | cut -d'=' -f2)
                log "训练进程 $pid (GPU: $gpus) 已结束，清理PID文件"
                rm -f "$pid_file"
                cleaned=$((cleaned + 1))
            fi
        fi
    done

    if [ $cleaned -gt 0 ]; then
        log "清理了 $cleaned 个已结束的训练进程"
    fi
}

# 选择指定数量的GPU
select_gpus() {
    local available_gpus=($1)
    local num_to_select=$2
    local num_available=${#available_gpus[@]}

    if [ "$num_available" -lt "$num_to_select" ]; then
        return 1
    fi

    # 选择前N个GPU
    local selected_gpus=("${available_gpus[@]:0:$num_to_select}")

    # 将数组转换为逗号分隔的字符串
    local gpu_string=$(IFS=,; echo "${selected_gpus[*]}")
    echo "$gpu_string"
    return 0
}

# 获取当前运行的训练任务数量
get_running_tasks_count() {
    local count=0
    for pid_file in "$PID_DIR"/*.pid; do
        [ -e "$pid_file" ] || continue
        local pid=$(grep "^PID=" "$pid_file" | cut -d'=' -f2)
        if [ -n "$pid" ] && ps -p "$pid" > /dev/null 2>&1; then
            count=$((count + 1))
        fi
    done
    echo $count
}

# 启动训练
start_training() {
    local gpu_ids=$1
    local timestamp=$(date +%Y%m%d_%H%M%S)

    log "=========================================="
    log "准备启动新的训练任务"
    log "使用GPU: $gpu_ids"
    log "Conda环境: $CONDA_ENV"
    log "训练脚本: $TRAIN_SCRIPT"
    log "=========================================="

    # 设置CUDA_VISIBLE_DEVICES并启动训练
    cd /home/*/workspace/verl-agent

    # 使用nohup在后台运行训练，并将输出重定向到日志文件
    TRAIN_LOG="/tmp/webshop_training_${timestamp}_gpu${gpu_ids}.log"

    # 创建一个临时脚本来激活conda环境并运行训练
    TEMP_SCRIPT="/tmp/run_training_${timestamp}.sh"
    cat > "$TEMP_SCRIPT" << EOF
#!/bin/bash
# 初始化conda
source "$CONDA_BASE/etc/profile.d/conda.sh"

# 激活conda环境
conda activate $CONDA_ENV

# 设置CUDA_VISIBLE_DEVICES
export CUDA_VISIBLE_DEVICES=$gpu_ids

# 切换到工作目录
cd /home/*/workspace/verl-agent

# 运行训练脚本
bash "$TRAIN_SCRIPT"
EOF

    chmod +x "$TEMP_SCRIPT"

    # 使用setsid在新会话中运行，避免产生僵尸进程
    # setsid会让进程脱离当前终端和进程组，不会成为监控脚本的子进程
    setsid bash "$TEMP_SCRIPT" > "$TRAIN_LOG" 2>&1 &

    local train_pid=$!

    # 立即disown，确保进程完全独立
    disown $train_pid 2>/dev/null || true

    # 创建PID文件，记录进程ID和使用的GPU
    local pid_file="$PID_DIR/train_${train_pid}.pid"
    cat > "$pid_file" << EOF
PID=$train_pid
GPUS=$gpu_ids
START_TIME=$timestamp
LOG_FILE=$TRAIN_LOG
SCRIPT=$TEMP_SCRIPT
EOF

    log "训练任务已启动"
    log "  PID: $train_pid"
    log "  GPU: $gpu_ids"
    log "  日志: $TRAIN_LOG"
    log "=========================================="
}

# 主循环
main() {
    log "=========================================="
    log "GPU自动监控训练脚本启动"
    log "训练脚本: $TRAIN_SCRIPT"
    log "每个任务使用GPU数: $GPUS_PER_TASK"
    log "检查间隔: ${CHECK_INTERVAL}秒"
    log "GPU利用率阈值: ${GPU_UTIL_THRESHOLD}%"
    log "GPU显存阈值: ${GPU_MEM_THRESHOLD}%"
    log "日志文件: $LOG_FILE"
    log "PID目录: $PID_DIR"
    log "=========================================="

    while true; do
        # 清理已结束的训练进程
        cleanup_finished_tasks

        # 获取当前运行的任务数
        local running_count=$(get_running_tasks_count)
        log "当前运行的训练任务数: $running_count"

        # 获取空闲GPU
        free_gpus=$(get_free_gpus)
        local free_gpus_array=($free_gpus)
        local num_free=${#free_gpus_array[@]}

        log "检测到 $num_free 张空闲GPU: ${free_gpus_array[*]}"

        # 如果有足够的空闲GPU，启动新的训练任务
        if [ "$num_free" -ge "$GPUS_PER_TASK" ]; then
            # 选择GPU
            if selected_gpus=$(select_gpus "$free_gpus" "$GPUS_PER_TASK"); then
                log "发现足够的空闲GPU，准备启动新的训练任务"
                # 启动训练
                start_training "$selected_gpus"
            else
                log "无法选择足够的GPU"
            fi
        else
            log "空闲GPU不足 $GPUS_PER_TASK 张，等待中..."
        fi

        # 等待下一次检查
        log "等待 ${CHECK_INTERVAL} 秒后进行下一次检查..."
        log "----------------------------------------"
        sleep "$CHECK_INTERVAL"
    done
}

# 信号处理
cleanup() {
    log "收到终止信号，清理并退出..."

    # 检查是否有正在运行的训练进程
    local running_pids=()
    for pid_file in "$PID_DIR"/*.pid; do
        [ -e "$pid_file" ] || continue

        local pid=$(grep "^PID=" "$pid_file" | cut -d'=' -f2)
        local gpus=$(grep "^GPUS=" "$pid_file" | cut -d'=' -f2)

        if [ -n "$pid" ] && ps -p "$pid" > /dev/null 2>&1; then
            running_pids+=("$pid")
            log "警告: 训练进程 $pid (GPU: $gpus) 仍在运行"
        fi
    done

    if [ ${#running_pids[@]} -gt 0 ]; then
        log "=========================================="
        log "共有 ${#running_pids[@]} 个训练进程仍在运行"
        log "这些进程不会自动终止"
        log "如需终止所有训练进程，请执行:"
        log "  kill ${running_pids[*]}"
        log "或者逐个终止"
        log "=========================================="
    else
        log "没有正在运行的训练进程"
    fi

    exit 0
}

trap cleanup SIGINT SIGTERM

# 启动主循环
main
