#!/bin/bash

source /data/home/the/anaconda3/bin/activate
unset http_proxy https_proxy HTTP_PROXY HTTPS_PROXY all_proxy ALL_PROXY;

# 开启错误即退出的选项
set -e

# 可选：定义一个错误处理函数
handle_error() {
    local exit_code=$?
    local line_no=$1
    echo "Error in script at line $line_no: command exited with code $exit_code."
    echo "Terminating script."
    # 在这里添加清理命令
    ACCELERATE_GPUS_TO_CLEAN_ON_ERROR=(0 1 2 3 4 5 6 7)
    echo "Attempting to clean up VLLM processes..."
    ./kill_vllm.sh "${ACCELERATE_GPUS_TO_CLEAN_ON_ERROR[@]}"
    
    exit $exit_code
}

# 可选：设置 trap 来调用错误处理函数
trap 'handle_error $LINENO' ERR

NUM_EPOCHS=200
NUM_ITERATIONS=1
DATASET="DeepScaler"
DATA_PATH="/data/home/the/codes/open-r1/data/DeepScaleR-Preview-Dataset"
BASE_MODEL_NAME="/data/home/the/models/DeepSeek-R1-Distill-Qwen-1.5B"
REF_MODEL_NAME="/data/home/the/models/DeepSeek-R1-Distill-Qwen-1.5B"
UPDATE_REF=1
ALGORITHM="bnpo"
APPLY_PROCESS_REWARD=1
SKIP_EMPTY_PROCESS_REWARD=1
PROCESS_MODE=6
PIECE_SECTIONS=0
REWARD_EXP=1.0
REWARD_COE=1.0
APPLY_ENTROPY_LOSS=0
THETA=0.0
DELTA=0.0
TGT_ENT=0.0
MIN_THETA=0.0
MAX_THETA=0.0
MAX_ATTEMPT_TIMES=1
MINI_INTERVAL=10
INTERVAL=$((${MINI_INTERVAL} * ${NUM_ITERATIONS}))
COLLECT_MAX_STEP=1
ROLLOUT_N=8
TRAIN_BATCH_SIZE=5
TEMPERATURE=1.0
COLLECT_BATCH_SIZE=$((${ROLLOUT_N} * ${TRAIN_BATCH_SIZE} * ${INTERVAL} / ${COLLECT_MAX_STEP}))
BETA=0.0
LR_SCHEDULER="constant"
EPSILON_HIGH=0.28
MODEL_NAME="RD-1.5B-${ALGORITHM}-${DATASET}-${INTERVAL}-${NUM_EPOCHS}-${MAX_ATTEMPT_TIMES}-${NUM_ITERATIONS}-${ROLLOUT_N}-${TRAIN_BATCH_SIZE}-${APPLY_PROCESS_REWARD}-${PROCESS_MODE}-${SKIP_EMPTY_PROCESS_REWARD}-${REWARD_EXP}-${REWARD_COE}-${PIECE_SECTIONS}-${UPDATE_REF}-${BETA}-${APPLY_ENTROPY_LOSS}-${THETA}-${DELTA}-${TGT_ENT}-${TEMPERATURE}-${LR_SCHEDULER}-${EPSILON_HIGH}"
SAVE_DIR="/data/home/the/codes/thought_evolve_offline/data/${MODEL_NAME}"
COLLECT_OUTPUT_DIR="${SAVE_DIR}/collect_solutions"
OUTPUT_PATH="${SAVE_DIR}/saved_checkpoints"
MODEL_DIR="${SAVE_DIR}/saved_checkpoints"
WANDB_GROUP="0.8-0.0-0.0-0.7"

for ((i=26; i<NUM_EPOCHS; i++)); do
    echo "=========== Starting Iteration $((i+1)) / ${NUM_EPOCHS} ==========="
    START_STEP=$((i * ${INTERVAL}))
    echo "=========== Starting STEP ${START_STEP} to $((${START_STEP} + ${INTERVAL})) ==========="
    COLLECT_PATH="${COLLECT_OUTPUT_DIR}/train-${START_STEP}-$((${START_STEP} + ${INTERVAL}))-steps-data.jsonl"
    WRITE_PATH="${SAVE_DIR}/vllm_handled_data/train-${START_STEP}-$((${START_STEP} + ${INTERVAL}))-steps-data.jsonl"
    CATHE_PATH="${SAVE_DIR}/train-${START_STEP}-$((${START_STEP} + ${INTERVAL}))-steps-data.jsonl"

    # step1: collect solutions
    echo "---------------------------------------------------"
    echo "step1: collect solutions"
    echo "---------------------------------------------------"
    conda activate openr1
    ./vllm_serve.sh "${MODEL_DIR}"
    sleep 60
    ./collect_solutions.sh "${START_STEP}" "${INTERVAL}" "${COLLECT_MAX_STEP}" "${MAX_ATTEMPT_TIMES}" "${COLLECT_BATCH_SIZE}" "${DATA_PATH}" "${COLLECT_PATH}" "${COLLECT_OUTPUT_DIR}" "${MODEL_NAME}" "${TEMPERATURE}"


    # step2: post handle data
    if [ $APPLY_PROCESS_REWARD == 1 ]; then
        echo "---------------------------------------------------"
        echo "step2: post handle data"
        echo "---------------------------------------------------"
        # step2.1: clear all GPU
        ACCELERATE_GPUS_TO_CLEAN=(0 1 2 3 4 5 6 7)
        ./kill_vllm.sh "${ACCELERATE_GPUS_TO_CLEAN[@]}"
        sleep 15

        # step2.2: start 4 groups vllm server
        ./vllm_serve_group.sh
        sleep 45
        # # step2.3: start handle data
        PYTHON_SCRIPT_PATH="./src/process_eval/process_evaluation.py"
        
        ./call_vllm.sh "${COLLECT_PATH}" "${WRITE_PATH}" "${PYTHON_SCRIPT_PATH}"
    else
        WRITE_PATH=$COLLECT_PATH
    fi

    # step3: train
    echo "---------------------------------------------------"
    echo "step3: train"
    echo "---------------------------------------------------"
    # step3.1: clear all GPU
    ACCELERATE_GPUS_TO_CLEAN=(0 1 2 3 4 5 6 7)
    ./kill_vllm.sh "${ACCELERATE_GPUS_TO_CLEAN[@]}"
    sleep 15

    if [ $UPDATE_REF == 0 ]; then
        REF_MODEL_NAME=$BASE_MODEL_NAME
    fi

    # step3.2: do grpo train
    WANDB_NAME="${MODEL_NAME}-$((${START_STEP} + ${INTERVAL}))"
    ./main_ppo.sh "${WRITE_PATH}" "${CATHE_PATH}" "${OUTPUT_PATH}" "${APPLY_PROCESS_REWARD}" "${REWARD_EXP}" "${REWARD_COE}" "${WANDB_NAME}" "${UPDATE_REF}" "${REF_MODEL_NAME}" "${BETA}" "${NUM_EPOCHS}" "${APPLY_ENTROPY_LOSS}" "${THETA}" "${DELTA}" "${TGT_ENT}" "${ROLLOUT_N}" "${TRAIN_BATCH_SIZE}" "${TEMPERATURE}" "${WANDB_GROUP}" "${SKIP_EMPTY_PROCESS_REWARD}" "${LR_SCHEDULER}" "${EPSILON_HIGH}" "${MIN_THETA}" "${MAX_THETA}"

    # step3.3 clear all GPU
    ACCELERATE_GPUS_TO_CLEAN=(0 1 2 3 4 5 6 7)
    ./kill_vllm.sh "${ACCELERATE_GPUS_TO_CLEAN[@]}"
    sleep 15
    echo "=========== Finished Iteration $((i+1)) / ${NUM_EPOCHS} ==========="
done

echo "All iterations completed successfully."
