#!/bin/bash

# Load conda environment
source /data/user/miniconda3/etc/profile.d/conda.sh
conda activate rllm2
cd /data/user/rllm

# Load env vars (HF token, etc.)
set -a
. /data/user/rllm/.env
set +a

set -x

# Print GPU info
srun -l bash -c 'echo "Node: $(hostname -s)"; nvidia-smi -L'
set -euo pipefail

# Single-script runner that:
#   1) launches base vLLM server (background)
#   2) waits until base is ready
#   3) launches fixer vLLM server (background)
#   4) waits until fixer is ready
#   5) runs examples/bugs/eval_bigcodebench_bugfixer.py
#   6) cleans up both servers on exit
#
# You can configure everything via environment variables (defaults below).

RLLM_DIR="${RLLM_DIR:-/data/user/rllm}"
cd "$RLLM_DIR"

# Optional env file (HF token, etc.)
if [[ -f "$RLLM_DIR/.env" ]]; then
  set -a
  . "$RLLM_DIR/.env"
  set +a
fi

# ----------------------
# Base vLLM config
# ----------------------
BASE_MODEL_PATH="Qwen/Qwen2.5-Coder-7B-Instruct"
BASE_SERVED_MODEL_NAME="${BASE_SERVED_MODEL_NAME:-$BASE_MODEL_PATH}"
BASE_HOST="${BASE_HOST:-127.0.0.1}"
BASE_PORT="${BASE_PORT:-30000}"
BASE_TP="${BASE_TP:-1}"
BASE_CUDA_VISIBLE_DEVICES="${BASE_CUDA_VISIBLE_DEVICES:-0}"

# ----------------------
# Fix vLLM config
# ----------------------
FIX_MODEL_PATH="$HOME/rllm/checkpoints/rllm-agent/solver-flow/global_step_20"
FIX_SERVED_MODEL_NAME="${FIX_SERVED_MODEL_NAME:-$FIX_MODEL_PATH}"
FIX_HOST="${FIX_HOST:-127.0.0.1}"
FIX_PORT="${FIX_PORT:-30001}"
FIX_TP="${FIX_TP:-1}"
FIX_CUDA_VISIBLE_DEVICES="${FIX_CUDA_VISIBLE_DEVICES:-1}"

# ----------------------
# Eval config
# ----------------------
HF_DATASET="${HF_DATASET:-anonymous/bigcodebench}"
HF_SPLIT="${HF_SPLIT:-v0.1.0_hf}"
START="${START:-0}"
MAX_EXAMPLES="${MAX_EXAMPLES:-200}"
OUTPUT_JSONL="${OUTPUT_JSONL:-}"

BASE_TEMPERATURE="${BASE_TEMPERATURE:-0.2}"
BASE_TOP_P="${BASE_TOP_P:-0.95}"
BASE_MAX_TOKENS="${BASE_MAX_TOKENS:-2048}"

FIX_TEMPERATURE="${FIX_TEMPERATURE:-0.6}"
FIX_TOP_P="${FIX_TOP_P:-0.95}"
FIX_MAX_TOKENS="${FIX_MAX_TOKENS:-2048}"

# ----------------------
# Logging
# ----------------------
LOG_DIR="${LOG_DIR:-$RLLM_DIR/runs/eval_bigcodebench_bugfixer}"
mkdir -p "$LOG_DIR"
BASE_LOG="${BASE_LOG:-$LOG_DIR/base_vllm_${BASE_PORT}.log}"
FIX_LOG="${FIX_LOG:-$LOG_DIR/fix_vllm_${FIX_PORT}.log}"
EVAL_LOG="${EVAL_LOG:-$LOG_DIR/eval.log}"

# ----------------------
# vLLM / torch env
# ----------------------
unset ROCR_VISIBLE_DEVICES ROCM_VISIBLE_DEVICES HIP_VISIBLE_DEVICES
export VLLM_ATTENTION_BACKEND="${VLLM_ATTENTION_BACKEND:-FLASH_ATTN}"
export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:False}"
export VLLM_USE_V1="${VLLM_USE_V1:-1}"
export VLLM_ALLOW_LONG_MAX_MODEL_LEN="${VLLM_ALLOW_LONG_MAX_MODEL_LEN:-1}"
export VLLM_ENGINE_ITERATION_TIMEOUT_S="${VLLM_ENGINE_ITERATION_TIMEOUT_S:-1000000000}"
export CUDA_DEVICE_ORDER="${CUDA_DEVICE_ORDER:-PCI_BUS_ID}"

BASE_OPENAI_URL="http://${BASE_HOST}:${BASE_PORT}/v1"
FIX_OPENAI_URL="http://${FIX_HOST}:${FIX_PORT}/v1"

wait_for_server() {
  local url="$1"
  local name="$2"
  local timeout_s="${3:-600}"
  local start_ts
  start_ts="$(date +%s)"
  echo "[wait] Waiting for $name at ${url}/models (timeout=${timeout_s}s)"
  while true; do
    if curl -fsS "${url}/models" >/dev/null 2>&1; then
      echo "[wait] $name is up."
      return 0
    fi
    local now_ts
    now_ts="$(date +%s)"
    if (( now_ts - start_ts > timeout_s )); then
      echo "[wait] ERROR: timed out waiting for $name at ${url}/models" >&2
      return 1
    fi
    sleep 2
  done
}

cleanup() {
  set +e
  echo "[cleanup] stopping servers..."
  if [[ -n "${FIX_PID:-}" ]]; then kill "$FIX_PID" 2>/dev/null || true; fi
  if [[ -n "${BASE_PID:-}" ]]; then kill "$BASE_PID" 2>/dev/null || true; fi
  sleep 2
  if [[ -n "${FIX_PID:-}" ]]; then kill -9 "$FIX_PID" 2>/dev/null || true; fi
  if [[ -n "${BASE_PID:-}" ]]; then kill -9 "$BASE_PID" 2>/dev/null || true; fi
}
trap cleanup EXIT INT TERM

echo "[run] starting base vLLM..."
(
  echo "[base] MODEL_PATH=$BASE_MODEL_PATH"
  echo "[base] SERVED_MODEL_NAME=$BASE_SERVED_MODEL_NAME"
  echo "[base] CUDA_VISIBLE_DEVICES=$BASE_CUDA_VISIBLE_DEVICES TP=$BASE_TP HOST=$BASE_HOST PORT=$BASE_PORT"
  CUDA_VISIBLE_DEVICES="$BASE_CUDA_VISIBLE_DEVICES" \
    vllm serve "$BASE_MODEL_PATH" \
      --host "$BASE_HOST" \
      --port "$BASE_PORT" \
      --served-model-name "$BASE_SERVED_MODEL_NAME" \
      --tensor-parallel-size "$BASE_TP"
) >"$BASE_LOG" 2>&1 &
BASE_PID=$!
echo "[run] base pid=$BASE_PID log=$BASE_LOG"
wait_for_server "$BASE_OPENAI_URL" "base"

echo "[run] starting fix vLLM..."
(
  echo "[fix] MODEL_PATH=$FIX_MODEL_PATH"
  echo "[fix] SERVED_MODEL_NAME=$FIX_SERVED_MODEL_NAME"
  echo "[fix] CUDA_VISIBLE_DEVICES=$FIX_CUDA_VISIBLE_DEVICES TP=$FIX_TP HOST=$FIX_HOST PORT=$FIX_PORT"
  CUDA_VISIBLE_DEVICES="$FIX_CUDA_VISIBLE_DEVICES" \
    vllm serve "$FIX_MODEL_PATH" \
      --host "$FIX_HOST" \
      --port "$FIX_PORT" \
      --served-model-name "$FIX_SERVED_MODEL_NAME" \
      --tensor-parallel-size "$FIX_TP"
) >"$FIX_LOG" 2>&1 &
FIX_PID=$!
echo "[run] fix pid=$FIX_PID log=$FIX_LOG"
wait_for_server "$FIX_OPENAI_URL" "fix"

echo "[run] running eval..."
set -x
EVAL_CMD=(python "$RLLM_DIR/examples/bugs/eval_bigcodebench_bugfixer.py"
  --hf_dataset "$HF_DATASET"
  --hf_split "$HF_SPLIT"
  --start "$START"
  --max_examples "$MAX_EXAMPLES"
  --base_model "$BASE_SERVED_MODEL_NAME"
  --base_model_url "$BASE_OPENAI_URL"
  --base_temperature "$BASE_TEMPERATURE"
  --base_top_p "$BASE_TOP_P"
  --base_max_tokens "$BASE_MAX_TOKENS"
  --fix_model "$FIX_SERVED_MODEL_NAME"
  --fix_model_url "$FIX_OPENAI_URL"
  --fix_temperature "$FIX_TEMPERATURE"
  --fix_top_p "$FIX_TOP_P"
  --fix_max_tokens "$FIX_MAX_TOKENS"
)
if [[ -n "$OUTPUT_JSONL" ]]; then
  EVAL_CMD+=(--output_jsonl "$OUTPUT_JSONL")
fi
"${EVAL_CMD[@]}" 2>&1 | tee "$EVAL_LOG"
set +x

echo "[run] done"
echo "  base log: $BASE_LOG"
echo "  fix  log: $FIX_LOG"
echo "  eval log: $EVAL_LOG"


