#!/bin/bash

set -e
trap cleanup_and_exit SIGINT SIGTERM EXIT

JUDGE_PID=""
TRAINING_PID=""

cleanup_and_exit() {
    echo "Cleaning up processes..."
    
    if [ ! -z "$TRAINING_PID" ]; then
        echo "Stopping training process..."
        kill -TERM $TRAINING_PID 2>/dev/null || true
        wait $TRAINING_PID 2>/dev/null || true
    fi
    
    if [ ! -z "$JUDGE_PID" ]; then
        echo "Stopping judge server..."
        kill -TERM $JUDGE_PID 2>/dev/null || true
        wait $JUDGE_PID 2>/dev/null || true
    fi
    
    pkill -f "vllm.entrypoints.openai.api_server" 2>/dev/null || true
    lsof -ti:8000 | xargs kill -9 2>/dev/null || true
    
    # Kill GPU processes
    echo "Killing GPU processes..."
    nvidia-smi --query-compute-apps=pid --format=csv,noheader,nounits | xargs -r kill -9 2>/dev/null || true
    
    echo "Cleanup complete."
}

CONFIG_FILE="${1:-config/rl_config.yaml}"

echo "Starting RL training with Accelerate..."
echo "Using config file: $CONFIG_FILE"

# Prepare models
echo "Preparing merged models..."
python dist_supp/prepare_models.py --config $CONFIG_FILE

# Start judge server
echo "Starting judge server on GPU 0..."
export CUDA_VISIBLE_DEVICES=0
LOG_LEVEL=ERROR nohup python dist_supp/start_judge_server.py --config $CONFIG_FILE > judge_server.log 2>&1 &
JUDGE_PID=$!

sleep 60

# Start training with accelerate
echo "Starting training with accelerate..."
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6

accelerate launch --num_processes=6 --gpu_ids=1,2,3,4,5,6 rl_train_vllm_accelerate.py --config $CONFIG_FILE &
TRAINING_PID=$!

wait $TRAINING_PID
TRAINING_EXIT_CODE=$?

if [ $TRAINING_EXIT_CODE -eq 0 ]; then
    echo "Training completed successfully!"
else
    echo "Training failed with exit code: $TRAINING_EXIT_CODE"
fi

echo "Script completed."