#!/bin/bash

# ============ 获取命令行参数 ============
Base_model=$1
Model_abbr=$2

# ============ 断点续训配置 ============
# RESUME_FROM: 断点续训起点，格式为 "questioner_v1_step4" 或 "solver_v2_step10" 等
#   - 留空表示从头开始训练
#   - questioner_v1_step4: 从第1轮 questioner 的 step 4 恢复
#   - solver_v1_step10: 从第1轮 solver 的 step 10 恢复
#   - solver_v2_step0: 跳过 v1 的训练，从 solver_v2 开始（使用 solver_v1 最终模型作为起点）
#     [注] step0 表示该阶段尚未保存任何 checkpoint，将使用前一阶段的最终模型作为起点
export RESUME_FROM="${RESUME_FROM:-}"

# ============ log 设置 ============
# 创建 logs 目录（如果不存在）
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
LOG_DIR="${SCRIPT_DIR}/../logs"
mkdir -p "$LOG_DIR"

# 生成带日期时间和Model_abbr的日志文件名
LOG_FILE="${LOG_DIR}/main_${Model_abbr}_$(date '+%Y%m%d_%H%M%S').txt"

# 使用 exec 和 tee 将所有输出同时发送到终端和日志文件
exec > >(tee -a "$LOG_FILE") 2>&1

echo "=========================================="
echo "日志文件: $LOG_FILE"
echo "开始时间: $(date '+%Y-%m-%d %H:%M:%S')"
echo "=========================================="

# ============ 环境变量设置 ============
export STORAGE_PATH="/path/to/R-zero/storage"
export HUGGINGFACENAME="anonymous_user"

# ============ GPU 时分复用配置 ============
# 在时分复用模式下，所有GPU同时用于Trainer和vLLM服务
# vLLM服务启动后立即进入sleep状态，不占用显存
# 当需要计算reward时，vLLM被唤醒；计算完成后重新sleep
export TOTAL_GPU_COUNT=${TOTAL_GPU_COUNT:-8}
# 时分复用模式：vLLM服务数量等于总GPU数量
export VLLM_GPU_COUNT=${VLLM_GPU_COUNT:-8}
export CODE_VLLM_GPU_COUNT=${CODE_VLLM_GPU_COUNT:-8}

# ============ Embedding Type 配置 ============
# EMBEDDING_TYPE: "nl" (natural language) 或 "code" (Python code)
export EMBEDDING_TYPE="${EMBEDDING_TYPE:-code}" 
# KEEP_BATCH_PENALTY_UNCHANGED: 是否保持batch penalty计算方式不变
export KEEP_BATCH_PENALTY_UNCHANGED="${KEEP_BATCH_PENALTY_UNCHANGED:-false}"

export QUESTION_TO_CODE_MODEL="${QUESTION_TO_CODE_MODEL:-Qwen2.5-Coder-7B-Instruct}"
export QUESTION_TO_CODE_MODEL_PATH="${QUESTION_TO_CODE_MODEL_PATH:-/path/to/models/Qwen2.5-Coder-7B-Instruct}"
export NL_EMBEDDING_MODEL="${NL_EMBEDDING_MODEL:-BAAI/bge-large-en-v1.5}"
export CODE_EMBEDDING_MODEL="${CODE_EMBEDDING_MODEL:-jinaai/jina-code-embeddings-1.5b}"

# ============ Memory Penalty 消融实验配置 ============
# PENALTY_ALPHA: batch penalty 权重 (默认 1.0)
export PENALTY_ALPHA="${PENALTY_ALPHA:-1.0}"
# PENALTY_BETA: memory penalty 权重 (默认 1.0, 设为 0.0 可禁用 memory penalty)
export PENALTY_BETA="${PENALTY_BETA:-1.0}"
# MEMORY_PENALTY_THRESHOLD: max_similarity 阈值 (默认 0.5)
export MEMORY_PENALTY_THRESHOLD="${MEMORY_PENALTY_THRESHOLD:-0.5}"
# MEMORY_PENALTY_MEAN_THRESHOLD: mean_similarity 阈值 (默认 0.3, mean通常比max低)
export MEMORY_PENALTY_MEAN_THRESHOLD="${MEMORY_PENALTY_MEAN_THRESHOLD:-0.25}"
# MEMORY_PENALTY_GAMMA: max_penalty 的权重 (默认 1.0 即只用max, 设为 0.5 则 max和mean各占一半)
export MEMORY_PENALTY_GAMMA="${MEMORY_PENALTY_GAMMA:-0.5}"

# ============ Experience Replay 配置 (连续学习) ============
# REPLAY_STRATEGY: 经验重放策略
#   - "none": 不使用经验重放（默认，保持原有行为）
#   - "pre_eval": 前置混合 - 在Evaluate前混合历史数据，用当前Solver重新生成伪标签
#   - "post_eval": 后置混合 - 在训练前混合历史数据，复用历史伪标签
export REPLAY_STRATEGY="${REPLAY_STRATEGY:-post_eval}"
# REPLAY_RATIO: 历史数据占比 (默认 0.3，即历史:新数据 = 3:7)
export REPLAY_RATIO="${REPLAY_RATIO:-0.3}"
# REPLAY_SAMPLING: 采样策略
#   - "uniform": 均匀随机采样
#   - "stratified": 分层采样（越近的轮次权重越高）
#   - "recent_first": 优先采样最近的数据
#   - "score_weighted": 按score加权采样（score越高的样本越容易被采样）
export REPLAY_SAMPLING="${REPLAY_SAMPLING:-uniform}"

# ============ Evaluation 配置 ============
# ENABLE_FINAL_EVAL: 是否在每次solver训练后执行Final Evaluation (默认 true)
export ENABLE_FINAL_EVAL="${ENABLE_FINAL_EVAL:-true}"

echo "=========================================="
echo "1. 环境变量配置:"
echo "  STORAGE_PATH: $STORAGE_PATH"
echo "  HUGGINGFACENAME: $HUGGINGFACENAME"
echo "2. GPU 时分复用配置:"
echo "  TOTAL_GPU_COUNT: $TOTAL_GPU_COUNT"
echo "  VLLM_GPU_COUNT: $VLLM_GPU_COUNT (time-shared with Trainer)"
echo "  CODE_VLLM_GPU_COUNT: $CODE_VLLM_GPU_COUNT (time-shared with Trainer)"
echo "  [注] 时分复用模式: vLLM服务使用sleep/wake_up机制与Trainer共享GPU"
echo "3. Embedding配置:"
echo "  EMBEDDING_TYPE: $EMBEDDING_TYPE"
echo "  KEEP_BATCH_PENALTY_UNCHANGED: $KEEP_BATCH_PENALTY_UNCHANGED"
echo "  QUESTION_TO_CODE_MODEL: $QUESTION_TO_CODE_MODEL"
echo "  QUESTION_TO_CODE_MODEL_PATH: $QUESTION_TO_CODE_MODEL_PATH"
echo "  NL_EMBEDDING_MODEL: $NL_EMBEDDING_MODEL"
echo "  CODE_EMBEDDING_MODEL: $CODE_EMBEDDING_MODEL"
echo "4. Penalty配置:"
echo "  PENALTY_ALPHA (batch): $PENALTY_ALPHA"
echo "  PENALTY_BETA (memory): $PENALTY_BETA"
echo "  MEMORY_PENALTY_THRESHOLD (max): $MEMORY_PENALTY_THRESHOLD"
echo "  MEMORY_PENALTY_MEAN_THRESHOLD: $MEMORY_PENALTY_MEAN_THRESHOLD"
echo "  MEMORY_PENALTY_GAMMA (max weight): $MEMORY_PENALTY_GAMMA"
echo "5. Experience Replay配置:"
echo "  REPLAY_STRATEGY: $REPLAY_STRATEGY"
echo "  REPLAY_RATIO: $REPLAY_RATIO"
echo "  REPLAY_SAMPLING: $REPLAY_SAMPLING"
echo "6. Evaluation配置:"
echo "  ENABLE_FINAL_EVAL: $ENABLE_FINAL_EVAL"
echo "7. 断点续训配置:"
echo "  RESUME_FROM: ${RESUME_FROM:-'(从头开始)'}"
echo "=========================================="


# ============ 自进化循环训练 ============
echo "=========================================="
echo "开始自进化循环训练"
echo "=========================================="
Base_model_path="$(dirname $(dirname $STORAGE_PATH))/models/$(basename $Base_model)"
export MODEL_ABBR=$Model_abbr  # 导出为环境变量，供子进程（reward function等）使用
echo "Model_abbr: $Model_abbr"
echo "MODEL_ABBR (env): $MODEL_ABBR"

# ============ 解析断点续训配置 ============
# 格式: questioner_v1_step4 或 solver_v2_step10
RESUME_TYPE=""      # questioner 或 solver
RESUME_VERSION=0    # 版本号
RESUME_STEP=0       # step 号
if [ -n "$RESUME_FROM" ]; then
    if [[ "$RESUME_FROM" =~ ^(questioner|solver)_v([0-9]+)_step([0-9]+)$ ]]; then
        RESUME_TYPE="${BASH_REMATCH[1]}"
        RESUME_VERSION="${BASH_REMATCH[2]}"
        RESUME_STEP="${BASH_REMATCH[3]}"
        echo "=========================================="
        echo "断点续训模式:"
        echo "  类型: $RESUME_TYPE"
        echo "  版本: v$RESUME_VERSION"
        echo "  Step: $RESUME_STEP"
        echo "=========================================="
    else
        echo "[ERROR] RESUME_FROM 格式错误: $RESUME_FROM"
        echo "  正确格式: questioner_v1_step4 或 solver_v2_step10"
        exit 1
    fi
fi

# 辅助函数：判断是否需要跳过当前训练
should_skip_training() {
    local current_type=$1   # questioner 或 solver
    local current_version=$2
    
    # 如果没有断点续训，不跳过
    [ -z "$RESUME_TYPE" ] && return 1
    
    # 比较版本号
    if [ $current_version -lt $RESUME_VERSION ]; then
        return 0  # 跳过
    elif [ $current_version -eq $RESUME_VERSION ]; then
        # 同一版本，比较类型顺序 (questioner 在 solver 之前)
        if [ "$current_type" = "questioner" ] && [ "$RESUME_TYPE" = "solver" ]; then
            return 0  # questioner 在 solver 之前，跳过
        fi
    fi
    return 1  # 不跳过
}

# 辅助函数：获取恢复 checkpoint 路径
get_resume_checkpoint() {
    local current_type=$1
    local current_version=$2
    
    # 只有当前训练需要恢复时才返回路径
    if [ "$RESUME_TYPE" = "$current_type" ] && [ "$RESUME_VERSION" -eq "$current_version" ]; then
        # step 0 意味着从该阶段开始，但没有已保存的 checkpoint
        # 此时不返回 checkpoint 路径，让训练使用默认的 model_path 开始
        # 例如：solver_v2_step0 会跳过 v1 的训练，但使用 solver_v1 的最终模型作为起点
        if [ "$RESUME_STEP" -eq 0 ]; then
            echo ""
        else
            echo "${STORAGE_PATH}/models/${Model_abbr}_${current_type}_v${current_version}/global_step_${RESUME_STEP}"
        fi
    fi
}

# ============ 第1轮训练 ============
# Questioner v1
if should_skip_training "questioner" 1; then
    echo "[SKIP] questioner_v1 已完成，跳过"
else
    resume_ckpt=$(get_resume_checkpoint "questioner" 1)
    bash scripts/questioner_train_penalty.sh $Base_model_path $Base_model_path ${Model_abbr}_questioner_v1 1 "$resume_ckpt"
fi

# Solver v1
if should_skip_training "solver" 1; then
    echo "[SKIP] solver_v1 已完成，跳过"
else
    resume_ckpt=$(get_resume_checkpoint "solver" 1)
    bash scripts/solver_train.sh $Base_model_path ${STORAGE_PATH}/models/${Model_abbr}_questioner_v1/global_step_5/actor/huggingface ${Model_abbr}_solver_v1 1 "$resume_ckpt"
fi

# ============ 第2-5轮训练 ============
for i in {2..5}; do
    prev=$((i-1))
    
    # Questioner v$i
    if should_skip_training "questioner" $i; then
        echo "[SKIP] questioner_v${i} 已完成，跳过"
    else
        resume_ckpt=$(get_resume_checkpoint "questioner" $i)
        bash scripts/questioner_train_penalty.sh \
            ${STORAGE_PATH}/models/${Model_abbr}_solver_v${prev}/global_step_15/actor/huggingface \
            ${STORAGE_PATH}/models/${Model_abbr}_questioner_v${prev}/global_step_5/actor/huggingface \
            ${Model_abbr}_questioner_v${i} \
            ${i} \
            "$resume_ckpt"
    fi

    # Solver v$i
    if should_skip_training "solver" $i; then
        echo "[SKIP] solver_v${i} 已完成，跳过"
    else
        resume_ckpt=$(get_resume_checkpoint "solver" $i)
        bash scripts/solver_train.sh \
            ${STORAGE_PATH}/models/${Model_abbr}_solver_v${prev}/global_step_15/actor/huggingface \
            ${STORAGE_PATH}/models/${Model_abbr}_questioner_v${i}/global_step_5/actor/huggingface \
            ${Model_abbr}_solver_v${i} \
            ${i} \
            "$resume_ckpt"
    fi
done

# bash evaluation/evaluate.bash $Base_model

echo "=========================================="
echo "自进化循环训练结束"
echo "结束时间: $(date '+%Y-%m-%d %H:%M:%S')"
echo "=========================================="
