#!/bin/bash

# GPU自动监控和训练脚本 (使用tmux管理)
# 功能：检测空闲GPU，使用tmux启动训练任务

# 配置参数
TRAIN_SCRIPT="/home/*/workspace/verl-agent/examples/migpo_trainer/run_webshop.sh"
CONDA_ENV="verl-agent-webshop"  # Conda环境名称
CONDA_BASE="$HOME/miniforge3"  # Conda安装路径
CHECK_INTERVAL=100  # 检查间隔（秒）
STARTUP_WAIT=600  # 启动训练后等待时间（秒），默认10分钟
GPU_UTIL_THRESHOLD=10  # GPU利用率阈值（%）
GPU_MEM_THRESHOLD=10   # GPU显存使用阈值（%）
GPUS_PER_TASK=2  # 每个训练任务使用的GPU数量
LOG_FILE="/tmp/auto_train_webshop_tmux_$(date +%Y%m%d_%H%M%S).log"
TMUX_SESSION_PREFIX="webshop_train"  # tmux会话名称前缀

# 日志函数
log() {
    echo "[$(date '+%Y-%m-%d %H:%M:%S')] $1" | tee -a "$LOG_FILE"
}

# 检查tmux是否可用
if ! command -v tmux &> /dev/null; then
    log "错误: tmux 命令不可用，请先安装tmux"
    log "安装命令: sudo apt-get install tmux 或 conda install tmux"
    exit 1
fi

# 检查nvidia-smi是否可用
if ! command -v nvidia-smi &> /dev/null; then
    log "错误: nvidia-smi 命令不可用"
    exit 1
fi

# 检测conda安装路径
detect_conda_base() {
    # 方法1: 使用conda info命令
    if command -v conda &> /dev/null; then
        local conda_base=$(conda info --base 2>/dev/null)
        if [ -n "$conda_base" ] && [ -d "$conda_base" ]; then
            echo "$conda_base"
            return 0
        fi
    fi

    # 方法2: 检查常见的conda安装路径
    local common_paths=(
        "$HOME/miniforge3"
    )

    for path in "${common_paths[@]}"; do
        if [ -d "$path" ] && [ -f "$path/etc/profile.d/conda.sh" ]; then
            echo "$path"
            return 0
        fi
    done

    return 1
}

# 验证conda环境
validate_conda_env() {
    local conda_base=$1
    local env_name=$2

    # 初始化conda
    if [ -f "$conda_base/etc/profile.d/conda.sh" ]; then
        source "$conda_base/etc/profile.d/conda.sh"
    else
        log "错误: 无法找到conda初始化脚本: $conda_base/etc/profile.d/conda.sh"
        return 1
    fi

    # 检查环境是否存在
    if ! conda env list | grep -q "^${env_name}\s"; then
        log "错误: Conda环境 '$env_name' 不存在"
        log "可用的conda环境:"
        conda env list | tee -a "$LOG_FILE"
        return 1
    fi

    log "Conda环境验证成功: $env_name"
    return 0
}

# 自动检测conda路径
log "正在检测conda安装路径..."
DETECTED_CONDA_BASE=$(detect_conda_base)
if [ $? -eq 0 ]; then
    CONDA_BASE="$DETECTED_CONDA_BASE"
    log "检测到conda安装路径: $CONDA_BASE"
else
    log "警告: 无法自动检测conda路径，使用配置的路径: $CONDA_BASE"
    if [ ! -d "$CONDA_BASE" ]; then
        log "错误: Conda路径不存在: $CONDA_BASE"
        exit 1
    fi
fi

# 验证conda环境
log "正在验证conda环境: $CONDA_ENV"
if ! validate_conda_env "$CONDA_BASE" "$CONDA_ENV"; then
    log "错误: Conda环境验证失败"
    exit 1
fi

# 获取空闲GPU列表
get_free_gpus() {
    local free_gpus=()

    # 获取所有GPU的利用率和显存使用情况
    while IFS=',' read -r gpu_id util mem_used mem_total; do
        # 计算显存使用百分比
        mem_percent=$(awk "BEGIN {printf \"%.0f\", ($mem_used/$mem_total)*100}")

        # 判断GPU是否空闲
        if [ "$util" -lt "$GPU_UTIL_THRESHOLD" ] && [ "$mem_percent" -lt "$GPU_MEM_THRESHOLD" ]; then
            free_gpus+=("$gpu_id")
        fi
    done < <(nvidia-smi --query-gpu=index,utilization.gpu,memory.used,memory.total --format=csv,noheader,nounits)

    echo "${free_gpus[@]}"
}

# 选择指定数量的GPU
select_gpus() {
    local available_gpus=($1)
    local num_to_select=$2
    local num_available=${#available_gpus[@]}

    if [ "$num_available" -lt "$num_to_select" ]; then
        return 1
    fi

    # 选择前N个GPU
    local selected_gpus=("${available_gpus[@]:0:$num_to_select}")

    # 将数组转换为逗号分隔的字符串
    local gpu_string=$(IFS=,; echo "${selected_gpus[*]}")
    echo "$gpu_string"
    return 0
}

# 获取当前运行的训练任务数量（通过tmux会话）
get_running_tasks_count() {
    local count=0
    # 统计以指定前缀开头的tmux会话数量
    count=$(tmux list-sessions 2>/dev/null | grep -c "^${TMUX_SESSION_PREFIX}_" || echo 0)
    echo $count
}

# 清理已结束的tmux会话
cleanup_finished_sessions() {
    local cleaned=0

    # 获取所有tmux会话
    local sessions=$(tmux list-sessions -F "#{session_name}" 2>/dev/null | grep "^${TMUX_SESSION_PREFIX}_" || true)

    for session in $sessions; do
        # 检查会话中的窗格是否还在运行
        # 如果训练脚本结束，tmux会话会自动关闭（因为我们没有保持会话）
        # 这里只是记录，tmux会自动清理已结束的会话
        :
    done

    # tmux会自动清理已结束的会话，不需要手动清理
    # 不会产生僵尸进程
}

# 启动训练任务（使用tmux）
start_training() {
    local gpu_ids=$1
    local timestamp=$(date +%Y%m%d_%H%M%S)
    local session_name="${TMUX_SESSION_PREFIX}_${timestamp}_gpu${gpu_ids}"

    log "=========================================="
    log "准备启动新的训练任务"
    log "使用GPU: $gpu_ids"
    log "Conda环境: $CONDA_ENV"
    log "Tmux会话: $session_name"
    log "训练脚本: $TRAIN_SCRIPT"
    log "=========================================="

    # 创建tmux会话并在其中运行训练
    # 使用-d参数创建后台会话
    tmux new-session -d -s "$session_name" bash -c "
        # 初始化conda
        source '$CONDA_BASE/etc/profile.d/conda.sh'

        # 激活conda环境
        conda activate $CONDA_ENV

        # 设置CUDA_VISIBLE_DEVICES
        export CUDA_VISIBLE_DEVICES=$gpu_ids

        # 切换到工作目录
        cd /home/*/workspace/verl-agent

        # 记录开始时间
        echo '=========================================='
        echo '训练任务启动'
        echo '时间: $(date)'
        echo 'GPU: $gpu_ids'
        echo 'Conda环境: $CONDA_ENV'
        echo '=========================================='

        # 运行训练脚本
        bash '$TRAIN_SCRIPT'

        # 训练结束后的信息
        echo '=========================================='
        echo '训练任务结束'
        echo '时间: $(date)'
        echo '=========================================='

        # 等待用户查看（可选，如果不需要可以删除这行）
        # read -p 'Press Enter to close this session...'
    "

    if [ $? -eq 0 ]; then
        log "训练任务已启动"
        log "  Tmux会话: $session_name"
        log "  GPU: $gpu_ids"
        log "  查看训练输出: tmux attach -t $session_name"
        log "  列出所有训练: tmux list-sessions | grep $TMUX_SESSION_PREFIX"
        log "=========================================="
        return 0
    else
        log "错误: 启动训练任务失败"
        return 1
    fi
}

# 主循环
main() {
    log "=========================================="
    log "GPU自动监控训练脚本启动 (使用tmux)"
    log "训练脚本: $TRAIN_SCRIPT"
    log "每个任务使用GPU数: $GPUS_PER_TASK"
    log "检查间隔: ${CHECK_INTERVAL}秒"
    log "启动等待时间: ${STARTUP_WAIT}秒 ($(($STARTUP_WAIT / 60))分钟)"
    log "GPU利用率阈值: ${GPU_UTIL_THRESHOLD}%"
    log "GPU显存阈值: ${GPU_MEM_THRESHOLD}%"
    log "日志文件: $LOG_FILE"
    log "Tmux会话前缀: $TMUX_SESSION_PREFIX"
    log "=========================================="

    while true; do
        # 清理已结束的会话（tmux会自动清理）
        cleanup_finished_sessions

        # 获取当前运行的任务数
        local running_count=$(get_running_tasks_count)
        log "当前运行的训练任务数: $running_count"

        # 获取空闲GPU
        free_gpus=$(get_free_gpus)
        local free_gpus_array=($free_gpus)
        local num_free=${#free_gpus_array[@]}

        log "检测到 $num_free 张空闲GPU: ${free_gpus_array[*]}"

        # 如果有足够的空闲GPU，启动新的训练任务
        if [ "$num_free" -ge "$GPUS_PER_TASK" ]; then
            # 选择GPU
            if selected_gpus=$(select_gpus "$free_gpus" "$GPUS_PER_TASK"); then
                log "发现足够的空闲GPU，准备启动新的训练任务"

                # 启动训练
                if start_training "$selected_gpus"; then
                    log "训练任务启动成功，等待 ${STARTUP_WAIT} 秒让训练完全启动..."
                    log "等待期间不会检测新的GPU，避免重复启动"
                    sleep "$STARTUP_WAIT"
                    log "等待结束，继续监控..."
                else
                    log "训练任务启动失败"
                fi
            else
                log "无法选择足够的GPU"
            fi
        else
            log "空闲GPU不足 $GPUS_PER_TASK 张，等待中..."
        fi

        # 等待下一次检查
        log "等待 ${CHECK_INTERVAL} 秒后进行下一次检查..."
        log "----------------------------------------"
        sleep "$CHECK_INTERVAL"
    done
}

# 信号处理
cleanup() {
    log "收到终止信号，清理并退出..."

    # 列出所有正在运行的训练会话
    local sessions=$(tmux list-sessions -F "#{session_name}" 2>/dev/null | grep "^${TMUX_SESSION_PREFIX}_" || true)

    if [ -n "$sessions" ]; then
        local session_count=$(echo "$sessions" | wc -l)
        log "=========================================="
        log "共有 $session_count 个训练任务仍在运行"
        log "这些tmux会话不会自动终止"
        log ""
        log "查看所有训练任务:"
        log "  tmux list-sessions | grep $TMUX_SESSION_PREFIX"
        log ""
        log "连接到某个训练任务:"
        for session in $sessions; do
            log "  tmux attach -t $session"
        done
        log ""
        log "终止某个训练任务:"
        for session in $sessions; do
            log "  tmux kill-session -t $session"
        done
        log ""
        log "终止所有训练任务:"
        log "  tmux list-sessions | grep $TMUX_SESSION_PREFIX | cut -d: -f1 | xargs -I {} tmux kill-session -t {}"
        log "=========================================="
    else
        log "没有正在运行的训练任务"
    fi

    exit 0
}

trap cleanup SIGINT SIGTERM

# 启动主循环
main
