solver_model_path=$1
questioner_model_path=$2
experiment_name=$3
iteration=$4
resume_checkpoint_path=${5:-""}  # 可选：断点续训的 checkpoint 路径 (应以 global_step_* 结尾)

echo $STORAGE_PATH
echo "Iteration: $iteration"

echo "start train solver $experiment_name $solver_model_path $questioner_model_path" 

export VLLM_DISABLE_COMPILE_CACHE=1

# ============ 打印出配置 ============
echo "=========================================="
echo "[Solver Training] Some Configurations"
echo "  Experiment Name: $experiment_name"
echo "  Questioner Model Path: $questioner_model_path"
echo "  Solver Model Path: $solver_model_path"
echo "  Iteration: $iteration"
echo "  Learning Rate: ${SOLVER_LR:-1e-6}"
if [ -n "$resume_checkpoint_path" ]; then
    echo "  Resume Checkpoint: $resume_checkpoint_path"
fi
echo "=========================================="
echo "  Embedding Type: $EMBEDDING_TYPE"
echo "=========================================="
echo "  Replay Strategy: $REPLAY_STRATEGY"
echo "  Replay Ratio: $REPLAY_RATIO"
echo "  Replay Sampling: $REPLAY_SAMPLING"
echo "=========================================="
echo "  Enable Final Eval: $ENABLE_FINAL_EVAL"
echo "=========================================="

# 创建时间日志文件
TIME_LOG_FILE="logs/time_log_${experiment_name}_$(date +%Y%m%d_%H%M%S).txt"
mkdir -p logs

# 记录脚本开始时间
SCRIPT_START=$(date +%s)
echo "========================================" | tee -a $TIME_LOG_FILE
echo "Solver Training Time Log" | tee -a $TIME_LOG_FILE
echo "Experiment: $experiment_name" | tee -a $TIME_LOG_FILE
echo "Start Time: $(date '+%Y-%m-%d %H:%M:%S')" | tee -a $TIME_LOG_FILE
echo "Embedding Type: $EMBEDDING_TYPE" | tee -a $TIME_LOG_FILE
echo "Replay Strategy: $REPLAY_STRATEGY" | tee -a $TIME_LOG_FILE
echo "Replay Ratio: $REPLAY_RATIO" | tee -a $TIME_LOG_FILE
echo "========================================" | tee -a $TIME_LOG_FILE
echo "" | tee -a $TIME_LOG_FILE

# ============ 断点续训模式检测 ============
# 如果提供了 resume_checkpoint_path（非空），说明是断点续训，跳过数据准备阶段
# 避免重复生成问题、重复更新 Memory Bank 等问题
#
# 注意：当 RESUME_FROM=solver_v2_step0 时，main.sh 会传入空字符串
# 这意味着虽然跳过了之前的训练阶段，但 solver_v2 本身还没开始过
# 此时需要执行完整的数据准备流程（Question Generate, Evaluate, Memory Bank 等）
SKIP_DATA_PREPARATION=false
if [ -n "$resume_checkpoint_path" ]; then
    echo "==========================================" | tee -a $TIME_LOG_FILE
    echo "[Resume Mode] Detected checkpoint: $resume_checkpoint_path" | tee -a $TIME_LOG_FILE
    echo "[Resume Mode] Skipping data preparation (Question Generate, Evaluate, Memory Bank, Upload)" | tee -a $TIME_LOG_FILE
    echo "==========================================" | tee -a $TIME_LOG_FILE
    SKIP_DATA_PREPARATION=true
fi

if [ "$SKIP_DATA_PREPARATION" = false ]; then
    # ============ 模块1: Question Generate ============
    echo 'start generate question'
    MODULE1_START=$(date +%s)
    bash question_generate/question_generate.bash $questioner_model_path 1000 $experiment_name
    MODULE1_END=$(date +%s)
    MODULE1_DURATION=$((MODULE1_END - MODULE1_START))
    echo "[Question Generate] Duration: ${MODULE1_DURATION}s ($(date -u -d @${MODULE1_DURATION} +%H:%M:%S))" | tee -a $TIME_LOG_FILE

    # ============ 模块1.5: 前置混合 (Experience Replay - pre_eval策略) ============
    # 在Evaluate之前混合历史数据，让当前Solver重新生成伪标签
    if [ "$REPLAY_STRATEGY" = "pre_eval" ] && [ $iteration -gt 1 ]; then
        echo 'start pre-eval experience replay (mixing with memory bank before evaluation)'
        MODULE1_5_START=$(date +%s)
        python memory_bank/sample_for_replay.py \
            --experiment_name ${experiment_name} \
            --iteration ${iteration} \
            --replay_ratio ${REPLAY_RATIO} \
            --sampling_strategy ${REPLAY_SAMPLING} \
            --model_abbr ${MODEL_ABBR} \
            --embedding_type ${EMBEDDING_TYPE}
        MODULE1_5_END=$(date +%s)
        MODULE1_5_DURATION=$((MODULE1_5_END - MODULE1_5_START))
        echo "[Pre-Eval Replay] Duration: ${MODULE1_5_DURATION}s ($(date -u -d @${MODULE1_5_DURATION} +%H:%M:%S))" | tee -a $TIME_LOG_FILE
    else
        echo "[Pre-Eval Replay] Skipped (REPLAY_STRATEGY=$REPLAY_STRATEGY, iteration=$iteration)"
    fi

    # ============ 模块2: Evaluate Generated Question ============
    echo 'start evaluate generated question'
    MODULE2_START=$(date +%s)
    bash question_evaluate/evaluate.sh $solver_model_path $experiment_name
    MODULE2_END=$(date +%s)
    MODULE2_DURATION=$((MODULE2_END - MODULE2_START))
    echo "[Question Evaluate] Duration: ${MODULE2_DURATION}s ($(date -u -d @${MODULE2_DURATION} +%H:%M:%S))" | tee -a $TIME_LOG_FILE

    # ============ 模块3: Update Memory Bank (需要在Upload之前运行，因为upload会删除结果文件) ============
    echo 'start updating memory bank'
    MODULE3_START=$(date +%s)

    # 根据EMBEDDING_TYPE传递不同的参数
    python memory_bank/update_memory.py \
        --experiment_name ${experiment_name} \
        --iteration ${iteration} \
        --max_score 0.8 \
        --min_score 0.3 \
        --max_iterations -1 \
        --embedding_type ${EMBEDDING_TYPE} \
        --question_to_code_model ${QUESTION_TO_CODE_MODEL_PATH} \
        --nl_embedding_model ${NL_EMBEDDING_MODEL} \
        --code_embedding_model ${CODE_EMBEDDING_MODEL} \
        --model_abbr ${MODEL_ABBR}

    MODULE3_END=$(date +%s)
    MODULE3_DURATION=$((MODULE3_END - MODULE3_START))
    echo "[Memory Bank Update] Duration: ${MODULE3_DURATION}s ($(date -u -d @${MODULE3_DURATION} +%H:%M:%S))" | tee -a $TIME_LOG_FILE

    # ============ 模块3.5: Upload ============
    echo 'start upload'
    MODULE3_5_START=$(date +%s)
    python question_evaluate/upload.py --repo_name ${experiment_name} --max_score 0.8 --min_score 0.3 --experiment_name ${experiment_name}
    UPLOAD_EXIT_CODE=$?
    MODULE3_5_END=$(date +%s)
    MODULE3_5_DURATION=$((MODULE3_5_END - MODULE3_5_START))
    echo "[Upload] Duration: ${MODULE3_5_DURATION}s ($(date -u -d @${MODULE3_5_DURATION} +%H:%M:%S))" | tee -a $TIME_LOG_FILE

    if [ $UPLOAD_EXIT_CODE -ne 0 ]; then
        echo "[ERROR] upload.py failed with exit code $UPLOAD_EXIT_CODE, stopping training" | tee -a $TIME_LOG_FILE
        exit 1
    fi

    # ============ 模块3.6: 后置混合 (Experience Replay - post_eval策略) ============
    # 在训练之前混合历史数据，复用历史伪标签
    if [ "$REPLAY_STRATEGY" = "post_eval" ] && [ $iteration -gt 1 ]; then
        echo 'start post-eval experience replay (mixing with memory bank before training)'
        MODULE3_6_START=$(date +%s)
        python memory_bank/mix_training_data.py \
            --experiment_name ${experiment_name} \
            --iteration ${iteration} \
            --replay_ratio ${REPLAY_RATIO} \
            --sampling_strategy ${REPLAY_SAMPLING} \
            --model_abbr ${MODEL_ABBR} \
            --embedding_type ${EMBEDDING_TYPE}
        MIX_EXIT_CODE=$?
        MODULE3_6_END=$(date +%s)
        MODULE3_6_DURATION=$((MODULE3_6_END - MODULE3_6_START))
        echo "[Post-Eval Replay] Duration: ${MODULE3_6_DURATION}s ($(date -u -d @${MODULE3_6_DURATION} +%H:%M:%S))" | tee -a $TIME_LOG_FILE
        
        if [ $MIX_EXIT_CODE -ne 0 ]; then
            echo "[ERROR] mix_training_data.py failed with exit code $MIX_EXIT_CODE, stopping training" | tee -a $TIME_LOG_FILE
            exit 1
        fi
        
        # 使用混合后的数据集 (格式: repo_name:config_name@split)
        TRAIN_DATA_CONFIG="${experiment_name}:${experiment_name}_mixed@train"
    else
        echo "[Post-Eval Replay] Skipped (REPLAY_STRATEGY=$REPLAY_STRATEGY, iteration=$iteration)"
        # 使用原始数据集 (格式: repo_name:config_name@split)
        TRAIN_DATA_CONFIG="${experiment_name}:${experiment_name}@train"
    fi
else
    # ============ 断点续训模式：跳过数据准备，直接使用已有数据集 ============
    echo "[Resume Mode] Skipping modules 1-3.6 (data preparation)" | tee -a $TIME_LOG_FILE
    
    # 确定训练数据配置（复用之前已上传的数据集）
    if [ "$REPLAY_STRATEGY" = "post_eval" ] && [ $iteration -gt 1 ]; then
        TRAIN_DATA_CONFIG="${experiment_name}:${experiment_name}_mixed@train"
    else
        TRAIN_DATA_CONFIG="${experiment_name}:${experiment_name}@train"
    fi
    echo "[Resume Mode] Using existing dataset: ${HUGGINGFACENAME}/${TRAIN_DATA_CONFIG}" | tee -a $TIME_LOG_FILE
fi

# ============ 模块4: Train ============
echo 'start training solver'
echo "  Training data: ${HUGGINGFACENAME}/${TRAIN_DATA_CONFIG}"
MODULE4_START=$(date +%s)

# 构建断点续训参数
RESUME_ARG=""
if [ -n "$resume_checkpoint_path" ]; then
    echo "Resuming from checkpoint: $resume_checkpoint_path"
    RESUME_ARG="trainer.load_checkpoint_path=$resume_checkpoint_path"
fi

python3 -m verl.trainer.main \
    config=examples/config.yaml \
    data.max_response_length=4096 \
    worker.actor.model.model_path=$solver_model_path \
    trainer.experiment_name=${experiment_name} \
    trainer.save_checkpoint_path=${STORAGE_PATH}/models/${experiment_name}/ \
    data.train_files=${HUGGINGFACENAME}/${TRAIN_DATA_CONFIG} \
    trainer.total_epochs=100 \
    data.format_prompt=./examples/format_prompt/solver.jinja \
    trainer.val_freq=4 \
    worker.actor.micro_batch_size_per_device_for_update=1 \
    worker.actor.micro_batch_size_per_device_for_experience=1 \
    worker.actor.optim.lr=${SOLVER_LR:-1e-6} \
    trainer.max_steps=15 \
    trainer.save_freq=1 \
    trainer.save_limit=1 \
    trainer.n_gpus_per_node=$TOTAL_GPU_COUNT \
    $RESUME_ARG
MODULE4_END=$(date +%s)
MODULE4_DURATION=$((MODULE4_END - MODULE4_START))
echo "[Training] Duration: ${MODULE4_DURATION}s ($(date -u -d @${MODULE4_DURATION} +%H:%M:%S))" | tee -a $TIME_LOG_FILE

# ============ 模块5: Model Merge + Cleanup ============
echo "start merging solver model and cleanup"
MODULE5_START=$(date +%s)

# 合并模型并清理分片文件
python scripts/model_merger.py --local_dir ${STORAGE_PATH}/models/${experiment_name}/global_step_15/actor
# python scripts/model_merger.py --local_dir ${STORAGE_PATH}/models/${experiment_name}/global_step_15/actor --cleanup
# python scripts/model_merger.py --local_dir ${STORAGE_PATH}/models/${experiment_name}/global_step_10/actor --cleanup &
# python scripts/model_merger.py --local_dir ${STORAGE_PATH}/models/${experiment_name}/global_step_20/actor --cleanup &

MODULE5_END=$(date +%s)
MODULE5_DURATION=$((MODULE5_END - MODULE5_START))
echo "[Model Merge + Cleanup] Duration: ${MODULE5_DURATION}s ($(date -u -d @${MODULE5_DURATION} +%H:%M:%S))" | tee -a $TIME_LOG_FILE
sleep 10

# ============ 模块6: Final Evaluation ============
if [ "$ENABLE_FINAL_EVAL" = "true" ]; then
    echo "start final evaluation"
    MODULE6_START=$(date +%s)
    bash evaluation/evaluate.bash ${STORAGE_PATH}/models/${experiment_name}/global_step_15/actor/huggingface false
    MODULE6_END=$(date +%s)
    MODULE6_DURATION=$((MODULE6_END - MODULE6_START))
    echo "[Final Evaluation] Duration: ${MODULE6_DURATION}s ($(date -u -d @${MODULE6_DURATION} +%H:%M:%S))" | tee -a $TIME_LOG_FILE
else
    echo "[Final Evaluation] Skipped (ENABLE_FINAL_EVAL=$ENABLE_FINAL_EVAL)" | tee -a $TIME_LOG_FILE
fi

# ============ 记录脚本总耗时 ============
SCRIPT_END=$(date +%s)
TOTAL_DURATION=$((SCRIPT_END - SCRIPT_START))
echo "" | tee -a $TIME_LOG_FILE
echo "========================================" | tee -a $TIME_LOG_FILE
echo "Time Summary" | tee -a $TIME_LOG_FILE
echo "========================================" | tee -a $TIME_LOG_FILE
echo "End Time: $(date '+%Y-%m-%d %H:%M:%S')" | tee -a $TIME_LOG_FILE
echo "Total Duration: ${TOTAL_DURATION}s ($(date -u -d @${TOTAL_DURATION} +%H:%M:%S))" | tee -a $TIME_LOG_FILE
echo "" | tee -a $TIME_LOG_FILE
echo "Time log saved to: $TIME_LOG_FILE"
