#!/bin/bash

# Load conda environment
source /data/user/miniconda3/etc/profile.d/conda.sh
conda activate rllm2
cd /data/user/rllm

set -a
. /data/user/rllm/.env
set +a

set -x
# Print GPU info
srun -l bash -c 'echo "Node: $(hostname -s)"; nvidia-smi -L'

# --- vLLM / torch env
unset ROCR_VISIBLE_DEVICES ROCM_VISIBLE_DEVICES HIP_VISIBLE_DEVICES
export VLLM_ATTENTION_BACKEND=FLASH_ATTN
export PYTORCH_CUDA_ALLOC_CONF="expandable_segments:False"
export VLLM_USE_V1=1
export VLLM_ALLOW_LONG_MAX_MODEL_LEN=1
export VLLM_ENGINE_ITERATION_TIMEOUT_S=1000000000
export CUDA_DEVICE_ORDER=PCI_BUS_ID
export HYDRA_FULL_ERROR=1
export RAY_DISABLE_DASHBOARD=1

# clean any previous instances and stray shared dirs
ray stop -f || true
pkill -9 -f "ray::" || true
rm -rf "/tmp/$USER"/ray_* 2>/dev/null || true
export RAY_TMPDIR="/data/user/ray"
export TMPDIR="/data/user/tmp"
mkdir -p "$RAY_TMPDIR" "$TMPDIR"
chmod 700 "$RAY_TMPDIR" "$TMPDIR"
export RAY_object_store_allow_fallback_to_memory=1

RLLM_DIR="$(pwd -P)"

# -----------------------------
# Checkpoint root (required)
# Set via environment variable for sbatch compatibility
# -----------------------------
CHECKPOINT_ROOT="${CHECKPOINT_ROOT:-}"

if [[ -z "$CHECKPOINT_ROOT" ]]; then
    echo "ERROR: CHECKPOINT_ROOT environment variable is required."
    echo ""
    echo "Usage:"
    echo "  CHECKPOINT_ROOT=/path/to/checkpoints sbatch $0"
    echo "  CHECKPOINT_ROOT=/path/to/checkpoints STEPS=5,10,15 sbatch $0"
    echo ""
    echo "Or run directly:"
    echo "  CHECKPOINT_ROOT=/path/to/checkpoints ./$0"
    exit 1
fi

if [[ ! -d "$CHECKPOINT_ROOT" ]]; then
    echo "ERROR: Checkpoint root directory does not exist: $CHECKPOINT_ROOT"
    exit 1
fi

# -----------------------------
# Configuration (from env or defaults)
# -----------------------------
HOST="0.0.0.0"
PORT=30000
TP=1
DP=4
CUDA_DEVICES=0,1,2,3
N_PARALLEL=64
FIXER_ATTEMPTS_VAL=1
VAL_DATASETS="${VAL_DATASETS:-bugbench:test bugbench_human:test bugbench_qwen7b_sampled:test bugbench_gpt-oss-20b_sampled:test bugbench_adversarial:test}"
OUTPUT_DIR="${OUTPUT_DIR:-logs/eval/bugs/}"
STEPS="${STEPS:-}"

# vLLM environment
unset ROCR_VISIBLE_DEVICES ROCM_VISIBLE_DEVICES HIP_VISIBLE_DEVICES 2>/dev/null || true
export VLLM_ATTENTION_BACKEND="${VLLM_ATTENTION_BACKEND:-FLASH_ATTN}"
export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:False}"
export VLLM_USE_V1="${VLLM_USE_V1:-1}"
export VLLM_ALLOW_LONG_MAX_MODEL_LEN="${VLLM_ALLOW_LONG_MAX_MODEL_LEN:-1}"
export VLLM_ENGINE_ITERATION_TIMEOUT_S="${VLLM_ENGINE_ITERATION_TIMEOUT_S:-1000000000}"
export CUDA_DEVICE_ORDER="${CUDA_DEVICE_ORDER:-PCI_BUS_ID}"

BASE_URL="http://127.0.0.1:${PORT}/v1"

# -----------------------------
# Find checkpoints to evaluate
# -----------------------------
CHECKPOINTS=()
if [[ -n "$STEPS" ]]; then
    # Specific steps requested
    IFS=',' read -ra STEP_ARRAY <<< "$STEPS"
    for step in "${STEP_ARRAY[@]}"; do
        ckpt_path="${CHECKPOINT_ROOT}/global_step_${step}"
        if [[ -d "$ckpt_path" ]]; then
            CHECKPOINTS+=("$ckpt_path")
        else
            echo "WARNING: Checkpoint not found: $ckpt_path"
        fi
    done
else
    # Find all global_step_* directories, sorted numerically
    while IFS= read -r -d '' dir; do
        CHECKPOINTS+=("$dir")
    done < <(find "$CHECKPOINT_ROOT" -maxdepth 1 -type d -name "global_step_*" -print0 | sort -z -V)
fi

if [[ ${#CHECKPOINTS[@]} -eq 0 ]]; then
    echo "ERROR: No checkpoints found in: $CHECKPOINT_ROOT"
    echo "Looking for directories matching: global_step_*"
    exit 1
fi

echo "=============================================="
echo "Checkpoint Evaluation Pipeline"
echo "=============================================="
echo "Checkpoint root: $CHECKPOINT_ROOT"
echo "Found ${#CHECKPOINTS[@]} checkpoint(s):"
for ckpt in "${CHECKPOINTS[@]}"; do
    echo "  - $(basename "$ckpt")"
done
echo ""
echo "Configuration:"
echo "  Host: $HOST:$PORT"
echo "  Tensor Parallel: $TP"
echo "  CUDA Devices: $CUDA_DEVICES"
echo "  Parallel tasks: $N_PARALLEL"
echo "  Fixer attempts: $FIXER_ATTEMPTS_VAL"
echo "  Val datasets: $VAL_DATASETS"
echo "  Output dir: $OUTPUT_DIR"
echo "=============================================="
echo ""

mkdir -p "$OUTPUT_DIR"

# -----------------------------
# Helper: Kill vLLM server
# -----------------------------
cleanup_vllm() {
    echo "[cleanup] Stopping vLLM server..."
    if [[ -n "${VLLM_PID:-}" ]]; then
        kill "$VLLM_PID" 2>/dev/null || true
        wait "$VLLM_PID" 2>/dev/null || true
    fi
    # Also kill any stray vllm processes on this port
    pkill -9 -f "vllm.*--port.*${PORT}" 2>/dev/null || true
    sleep 2
}

# Cleanup on exit
trap cleanup_vllm EXIT

# -----------------------------
# Helper: Wait for vLLM to be ready
# -----------------------------
wait_for_vllm() {
    local model_name="$1"
    local max_attempts="${2:-60}"
    
    echo "[vllm] Waiting for server to be ready..."
    for i in $(seq 1 "$max_attempts"); do
        if curl -fsS "${BASE_URL}/models" 2>/dev/null | grep -q "data"; then
            echo "[vllm] Server is ready!"
            return 0
        fi
        
        # Check if server process is still alive
        if ! kill -0 "$VLLM_PID" 2>/dev/null; then
            echo "[vllm] ERROR: Server process died"
            return 1
        fi
        
        echo "[vllm] Waiting... (${i}/${max_attempts})"
        sleep 5
    done
    
    echo "[vllm] ERROR: Server did not become ready in time"
    return 1
}

# -----------------------------
# Main loop: Evaluate each checkpoint
# -----------------------------
TOTAL=${#CHECKPOINTS[@]}
CURRENT=0

for CKPT_PATH in "${CHECKPOINTS[@]}"; do
    CURRENT=$((CURRENT + 1))
    CKPT_NAME=$(basename "$CKPT_PATH")
    CKPT_OUTPUT_DIR="${OUTPUT_DIR}/${CKPT_NAME}"
    
    echo ""
    echo "=============================================="
    echo "[${CURRENT}/${TOTAL}] Evaluating: $CKPT_NAME"
    echo "=============================================="
    
    # Check if already evaluated
    if [[ -f "${CKPT_OUTPUT_DIR}/results.json" ]]; then
        echo "[skip] Already evaluated (found results.json). Skipping..."
        continue
    fi
    
    mkdir -p "$CKPT_OUTPUT_DIR"
    VLLM_LOG="${CKPT_OUTPUT_DIR}/vllm_server.log"
    
    # Cleanup any existing server
    cleanup_vllm
    
    # Start vLLM server
    echo "[vllm] Starting server for: $CKPT_PATH"
    echo "[vllm] Log: $VLLM_LOG"
    
    CUDA_VISIBLE_DEVICES="$CUDA_DEVICES" \
        vllm serve "$CKPT_PATH" \
            --host "$HOST" \
            --port "$PORT" \
            --served-model-name "$CKPT_NAME" \
            --tensor-parallel-size "$TP" \
            --data-parallel-size "$DP" \
            >> "$VLLM_LOG" 2>&1 &
    VLLM_PID=$!
    
    echo "[vllm] Launched with PID: $VLLM_PID"
    
    # Wait for server to be ready
    if ! wait_for_vllm "$CKPT_NAME" 60; then
        echo "[error] Failed to start vLLM for $CKPT_NAME"
        echo "[error] Last 50 lines of log:"
        tail -50 "$VLLM_LOG" 2>/dev/null || true
        cleanup_vllm
        continue
    fi
    
    # Run evaluation
    echo ""
    echo "[eval] Running evaluation..."
    EVAL_LOG="${CKPT_OUTPUT_DIR}/eval.log"
    
    python -m examples.bugs_refactor.run_generator_fixer_flow \
        --val_datasets ${VAL_DATASETS} \
        --model "$CKPT_NAME" \
        --base_url "${BASE_URL}" \
        --n_parallel ${N_PARALLEL} \
        --n_tasks 10 \
        --eval_pregenerated_only \
        --evaluate_codegen \
        --include_failed_test_output \
        --fixer_attempts_val ${FIXER_ATTEMPTS_VAL} \
        --save_results \
        --output_dir "${CKPT_OUTPUT_DIR}" \
        2>&1 | tee "$EVAL_LOG"
    
    EVAL_STATUS=$?
    
    if [[ $EVAL_STATUS -eq 0 ]]; then
        echo "[eval] ✅ Evaluation complete for $CKPT_NAME"
    else
        echo "[eval] ❌ Evaluation failed for $CKPT_NAME (exit code: $EVAL_STATUS)"
    fi
    
    # Cleanup server before next iteration
    cleanup_vllm
    
done

echo ""
echo "=============================================="
echo "All evaluations complete!"
echo "Results saved to: $OUTPUT_DIR"
echo "=============================================="
