export OPENAI_AZURE_ENDPOINT= # azure openai endpoint
export OPENAI_API_KEY= # azure openai key
export OPENAI_API_VERSION= # azure openai version
export PROCESSOR=gpt-4o-mini 

# vllm
VLLM_GPUS=${VLLM_GPUS:-"0,1,2,3"}
LOCAL_PORT=${LOCAL_PORT:-8000}
GPU_MEMORY_UTILIZATION=${GPU_MEMORY_UTILIZATION:-0.9}
VLLM_HEALTH_CHECK_URL="http://localhost:${LOCAL_PORT}/health"
VLLM_READINESS_TIMEOUT=180
SLEEP_AFTER_RUN=5

# This function is called automatically on script exit, interrupt (Ctrl+C), or termination signal
cleanup() {
    echo "Cleaning up vLLM process (PID: $vllm_pid)..."
    if [ -n "$vllm_pid" ]; then # Check if PID variable is not empty
        # Check if the process with this PID is still running
        if ps -p $vllm_pid > /dev/null; then
            echo "Attempting graceful shutdown (SIGTERM) for PID $vllm_pid..."
            # Send SIGTERM first to allow for graceful shutdown
            kill -TERM "$vllm_pid"
            # Wait a few seconds for it to shut down
            sleep 5
            # Check again if the process is still running
            if ps -p $vllm_pid > /dev/null; then
                echo "Process $vllm_pid did not shut down gracefully. Sending SIGKILL..."
                # Send SIGKILL as a last resort
                kill -KILL "$vllm_pid"
            else
                echo "Process $vllm_pid shut down gracefully."
            fi
        else
            echo "vLLM process $vllm_pid is not running (already stopped?)."
        fi
        vllm_pid="" # Clear PID after attempting cleanup
    else
        echo "No vLLM process PID recorded for cleanup."
    fi
}

trap cleanup EXIT INT TERM

for alpha in 0.001
do
    for epoch in 0 1 2 3
    do
        if [[ "${epoch}" -eq 0 ]]; then
            data_path="./data/GSM8K/proposal/iter0"
        else
            data_path="./data/GSM8K/proposal/ed_dpo/alpha=${alpha}/epoch=${epoch}"
        fi

        # VPO training
        set -x

        # if use multi-gpu, multiple eval_steps with gpu nums
        read -r -d '' training_commands <<EOF
        proposal_dpo_train \
        --save_path ckpt/proposal/GSM8K/ed_dpo/alpha=${alpha}/epoch=${epoch} \
        --save_steps -1 \
        --logging_steps 1 \
        --eval_steps -1 \
        --train_batch_size 256 \
        --micro_train_batch_size 4 \
        --pretrain meta-llama/Meta-Llama-3-8B-Instruct \
        --bf16 \
        --max_epochs 1 \
        --max_len 8192 \
        --zero_stage 3 \
        --learning_rate 5e-7 \
        --beta 0.1 \
        --dataset json@${data_path} \
        --apply_chat_template \
        --chosen_key chosen \
        --rejected_key rejected \
        --gradient_checkpointing \
        --label_smoothing 0.1 \
        --ckpt_path ckpt/proposal/GSM8K/ed_dpo/alpha=${alpha}/epoch=${epoch}/checkpoints \
        --max_ckpt_num 5 \
        --add_value_incentive \
        --alpha_value_incentive ${alpha} \
        --lora_rank 16 \
        --lora_alpha 32 \
        --lora_dropout 0.05 \
        --adam_offload \
        --flash_attn
EOF


        if [[ ${1} != "slurm" ]]; then
            deepspeed --master_port=29300 --include localhost:0,1,2,3 --module $training_commands
        fi

        echo "Alpha=${alpha}, Epoch=${epoch}: VPO training finish"

        # Combine Lora
        python combine_lora.py --output_dir ckpt/proposal/GSM8K/ed_dpo/alpha=${alpha}/epoch=${epoch}/checkpoints \
            --lora_dir ckpt/proposal/GSM8K/ed_dpo/alpha=${alpha}/epoch=${epoch} \
            --base_model meta-llama/Meta-Llama-3-8B-Instruct

        echo "Alpha=${alpha}, Epoch=${epoch}: Ckpt Save finish"

        # epoch=3 does not continue

        if [[ "${epoch}" -eq 3 ]]; then
            echo "Skip After Process, since epoch=3"
        else
            # Start VLLM
            MODEL="ckpt/proposal/GSM8K/ed_dpo/alpha=${alpha}/epoch=${epoch}/checkpoints/merged_model"
            vllm_pid=""
            echo "Starting vLLM server on GPUs $VLLM_GPUS with model $MODEL..."

            CUDA_VISIBLE_DEVICES=$VLLM_GPUS python -m vllm.entrypoints.openai.api_server \
                --model "${MODEL}" \
                --trust-remote-code \
                --dtype=bfloat16 \
                --port="${LOCAL_PORT}" \
                --gpu-memory-utilization="${GPU_MEMORY_UTILIZATION}" \
                --seed 63 \
                --tensor-parallel-size 4 &

            # Capture the PID of the background process
            vllm_pid=$!

            if [ $? -ne 0 ]; then
                echo "Error: Failed to initiate the vLLM background process."
                vllm_pid="" # Ensure cleanup doesn't run with a potentially invalid PID
                exit 1
            fi

            echo "vLLM process started with PID: $vllm_pid"
            echo "Waiting for vLLM server to become ready at $VLLM_HEALTH_CHECK_URL (timeout: ${VLLM_READINESS_TIMEOUT}s)..."

            # --- Wait for vLLM Readiness ---
            start_time=$(date +%s)
            while true; do
                # Check if the vLLM process is still running
                if ! ps -p $vllm_pid > /dev/null; then
                    echo "Error: vLLM process (PID $vllm_pid) died unexpectedly during startup."
                    exit 1 # Trap will handle cleanup (though the process is already dead)
                fi

                # Use curl to check the health endpoint
                # -s: silent, -f: fail silently (exit code > 0 on HTTP errors), -o /dev/null: discard output
                if curl -sf -o /dev/null "$VLLM_HEALTH_CHECK_URL"; then
                    echo "vLLM server is ready."
                    break
                fi

                # Check for timeout
                current_time=$(date +%s)
                elapsed_time=$((current_time - start_time))
                if [ "$elapsed_time" -ge "$VLLM_READINESS_TIMEOUT" ]; then
                    echo "Error: vLLM server did not become ready within $VLLM_READINESS_TIMEOUT seconds."
                    # The trap will handle cleanup
                    exit 1
                fi

                # Wait before retrying
                sleep 5
            done

            # --- Run the Interaction Program ---
            echo "vLLM is ready. Starting run.py to interact with the server..."


            # Start Data Collection
            CUDA_VISIBLE_DEVICES=0 accelerate launch --mixed_precision fp16 \
                --main_process_port=29600 sample_and_adapter_train.py \
                --config configs_llama/gsm8k_collect_data.yaml \
                --port ${LOCAL_PORT} \
                --debug train/initial_critic/split=whole \
                --seed=63 \
                --do_train \
                --threadpool \
                --save_data_path ckpt/data/GSM8K/ed_dpo/alpha=${alpha}/epoch=$((epoch+1)) \
                --whitebox $MODEL

            run_py_exit_code=$?

            if [ $run_py_exit_code -ne 0 ]; then
                echo "Error: sample_and_adapter_train.py exited with error code $run_py_exit_code. Aborting this run."
                # No need for 'exit 1' here if you want the trap to handle cleanup first,
                # but letting it proceed to cleanup and then the loop might be better.
                # You might want to check run_py_exit_code AFTER the explicit cleanup.
                # For now, let's add an exit after cleanup if sample_and_adapter_train.py failed.
            else
                echo "sample_and_adapter_train.py finished successfully."
            fi

            # --- Post-Run Sleep ---
            echo "Sleeping for $SLEEP_AFTER_RUN seconds after sample_and_adapter_train.py finished..."
            sleep $SLEEP_AFTER_RUN

            # --- EXPLICIT CLEANUP FOR VLLM OF THIS EPOCH ---
            # Call the cleanup function to kill the vLLM BEFORE the loop continues
            echo "Alpha=${alpha}, Epoch=${epoch}: Calling explicit cleanup for vLLM..."
            cleanup

            # Check sample_and_adapter_train.py exit code AFTER cleanup
            if [ $run_py_exit_code -ne 0 ]; then
                echo "Aborting script due to sample_and_adapter_train.py failure in epoch ${epoch}."
                exit $run_py_exit_code # Exit the whole script if data collection failed
            fi
            echo "Alpha=${alpha}, Epoch=${epoch}: VLLM cleanup successful."


            # --- Merge data for the next epoch ---
            echo "Alpha=${alpha}, Epoch=${epoch}: Merging data for next epoch..."
            # merge data
            if [[ "${epoch}" -eq 0 ]]; then
                previous_path="./data/GSM8K/proposal/iter0"
            else
                previous_path="./data/GSM8K/proposal/ed_dpo/alpha=${alpha}/epoch=${epoch}"
            fi
            data_path="ckpt/data/GSM8K/ed_dpo/alpha=${alpha}/epoch=$((epoch+1))"

            python merge_dpo.py --data_path ${data_path} --previous_path ${previous_path}
        fi
    done
done
