#!/usr/bin/env bash
set -euo pipefail

# Args from launcher: <slot> <gpu_id> <num_gpus>
SLOT="${1:-0}"        # 0..(NUM_GPUS-1)
GPU_ID="${2:-0}"      # physical GPU id
NUM_GPUS="${3:-1}"

: "${OPENAI_API_KEY:?Please export OPENAI_API_KEY}"   # required
export CUDA_VISIBLE_DEVICES="$GPU_ID"

# ---------------- Resolve repo root & script paths (robust) ----------------
THIS_DIR="$(cd "$(dirname "$0")" && pwd)"

# If PROJECT_ROOT is pre-set and looks valid, keep it; otherwise infer.
if [[ -n "${PROJECT_ROOT:-}" && -d "${PROJECT_ROOT}/src" ]]; then
  REPO_ROOT="$PROJECT_ROOT"
else
  # Try: parent of this script
  if [[ -d "$THIS_DIR/../src" ]]; then
    REPO_ROOT="$(cd "$THIS_DIR/.." && pwd)"
  else
    # Try: git toplevel
    if command -v git >/dev/null 2>&1; then
      TL="$(git -C "$THIS_DIR" rev-parse --show-toplevel 2>/dev/null || true)"
    else
      TL=""
    fi
    if [[ -n "$TL" && -d "$TL/src" ]]; then
      REPO_ROOT="$TL"
    else
      # Fallback: search for eval_prompt_steering.py and derive root as the dir above "src"
      FOUND=$(find "$THIS_DIR" -maxdepth 4 -type f -name "eval_prompt_steering.py" 2>/dev/null | head -n1 || true)
      if [[ -z "$FOUND" ]]; then
        FOUND=$(find "$(cd "$THIS_DIR/.." && pwd)" -maxdepth 5 -type f -name "eval_prompt_steering.py" 2>/dev/null | head -n1 || true)
      fi
      if [[ -n "$FOUND" ]]; then
        # strip .../src/axbench_steering/eval_prompt_steering.py -> .../
        REPO_ROOT="${FOUND%/src/axbench_steering/eval_prompt_steering.py}"
      else
        echo "[fatal] Could not infer repository root. Set PROJECT_ROOT env var or place this script inside the repo."
        exit 1
      fi
    fi
  fi
fi

# Locate Python entry script
CANDIDATES=(
  "$REPO_ROOT/src/axbench_steering/eval_prompt_steering.py"
  "$REPO_ROOT/axbench_steering/eval_prompt_steering.py"
  "$REPO_ROOT/src/eval_prompt_steering.py"
)
PY_SCRIPT=""
for p in "${CANDIDATES[@]}"; do
  if [[ -f "$p" ]]; then PY_SCRIPT="$p"; break; fi
done
if [[ -z "$PY_SCRIPT" ]]; then
  FOUND=$(find "$REPO_ROOT" -maxdepth 4 -type f -name "eval_prompt_steering.py" 2>/dev/null | head -n1 || true)
  if [[ -n "$FOUND" && -f "$FOUND" ]]; then PY_SCRIPT="$FOUND"; fi
fi
if [[ -z "$PY_SCRIPT" ]]; then
  echo "[fatal] eval_prompt_steering.py not found under $REPO_ROOT"
  echo "Tip: ensure it exists, e.g. $REPO_ROOT/src/axbench_steering/eval_prompt_steering.py"
  exit 1
fi

# Ensure Python can import from repo/src
export PYTHONPATH="$REPO_ROOT/src:${PYTHONPATH:-}"

# ---------------- Data paths ----------------
# Concept directory (required)
if [[ -d "$REPO_ROOT/concept" ]]; then
  CONCEPT_DIR="$REPO_ROOT/concept"
else
  # Try sibling "concept" next to src
  if [[ -d "$REPO_ROOT/src/../concept" ]]; then
    CONCEPT_DIR="$REPO_ROOT/src/../concept"
  else
    echo "[fatal] Concept directory not found (expected $REPO_ROOT/concept). Create it and put *_concept_descriptions.json inside."
    exit 1
  fi
fi

# Instructions file (try a few common locations, otherwise search)
IF_CANDIDATES=(
  "$REPO_ROOT/axbench/axbench/data/alpaca_eval.json"
  "$REPO_ROOT/../axbench/axbench/data/alpaca_eval.json"
  "$REPO_ROOT/data/alpaca_eval.json"
)
INSTRUCTIONS_FILE=""
for p in "${IF_CANDIDATES[@]}"; do
  if [[ -f "$p" ]]; then INSTRUCTIONS_FILE="$p"; break; fi
done
if [[ -z "$INSTRUCTIONS_FILE" ]]; then
  FOUND=$(find "$REPO_ROOT" -maxdepth 5 -type f -name "alpaca_eval.json" 2>/dev/null | head -n1 || true)
  if [[ -n "$FOUND" && -f "$FOUND" ]]; then INSTRUCTIONS_FILE="$FOUND"; fi
fi
if [[ -z "$INSTRUCTIONS_FILE" ]]; then
  echo "[fatal] Instructions file alpaca_eval.json not found under $REPO_ROOT"
  echo "Tip: set INSTRUCTIONS_FILE env var or place the file under axbench/axbench/data/"
  exit 1
fi

# ---------------- Models ----------------
BASE_MODEL="${BASE_MODEL:-google/gemma-2-2b-it}"
PROMPT_GEN_BACKEND="${PROMPT_GEN_BACKEND:-openai_async}"
PROMPT_GEN_MODEL="${PROMPT_GEN_MODEL:-gpt-4o-mini}"
JUDGE_BACKEND="${JUDGE_BACKEND:-openai_async}"
JUDGE_MODEL="${JUDGE_MODEL:-gpt-4o-mini}"

# ---------------- Eval knobs ----------------
DEV_K="${DEV_K:-5}"
MAX_NEW_TOKENS="${MAX_NEW_TOKENS:-128}"
TEMPERATURE="${TEMPERATURE:-0.7}"
TOP_P="${TOP_P:-0.95}"
SEED="${SEED:-42}"

# ---------------- Collect concept files ----------------
mapfile -t ALL_CONCEPTS < <(ls -1 "$CONCEPT_DIR"/*_concept_descriptions.json 2>/dev/null || true)
TOTAL=${#ALL_CONCEPTS[@]}
if [[ $TOTAL -eq 0 ]]; then
  echo "[fatal] No concept files under $CONCEPT_DIR"
  exit 1
fi

echo "[info] Worker slot=$SLOT gpu_id=$GPU_ID num_gpus=$NUM_GPUS ; total concept files=$TOTAL"
echo "[info] REPO_ROOT=$REPO_ROOT"
echo "[info] PY_SCRIPT=$PY_SCRIPT"
echo "[info] INSTRUCTIONS_FILE=$INSTRUCTIONS_FILE"
echo "[info] CONCEPT_DIR=$CONCEPT_DIR"

for ((idx=0; idx<TOTAL; idx++)); do
  # Round-robin split across GPUs
  if (( idx % NUM_GPUS != SLOT )); then
    continue
  fi

  CONCEPTS_FILE="${ALL_CONCEPTS[$idx]}"
  NAME="$(basename "$CONCEPTS_FILE" .json)"
  OUT_DIR="$REPO_ROOT/cache/results_prompt_eval_openai/${BASE_MODEL//\//_}/prompt_baseline/$NAME"

  echo "[gpu $GPU_ID] Processing ($((idx+1))/$TOTAL) -> $NAME"

  # Skip if already finished
  if [[ -f "$OUT_DIR/eval.json" ]]; then
    echo "[gpu $GPU_ID] skip: $OUT_DIR/eval.json exists"
    continue
  fi

  # Atomic claim to avoid races across workers
  mkdir -p "$OUT_DIR"
  if ! mkdir "$OUT_DIR/.lock" 2>/dev/null; then
    echo "[gpu $GPU_ID] skip: locked by another worker -> $OUT_DIR"
    continue
  fi
  cleanup_lock() { rmdir "$OUT_DIR/.lock" 2>/dev/null || true; }
  trap cleanup_lock EXIT

  # Run prompt-steering eval for this concept file
  python -u "$PY_SCRIPT" \
    --base_model "$BASE_MODEL" \
    --prompt_gen_backend "$PROMPT_GEN_BACKEND" \
    --prompt_gen_model "$PROMPT_GEN_MODEL" \
    --judge_backend "$JUDGE_BACKEND" \
    --judge_model "$JUDGE_MODEL" \
    --instructions_file "$INSTRUCTIONS_FILE" \
    --concepts_file "$CONCEPTS_FILE" \
    --dev_k "$DEV_K" \
    --max_new_tokens "$MAX_NEW_TOKENS" \
    --temperature "$TEMPERATURE" \
    --top_p "$TOP_P" \
    --seed "$SEED" \
    --sae_index $((idx+1)) \
    --total_saes "$TOTAL" \
    --debug --sample_print_k 1 --print_chars 300

  # Mark done & release lock
  date +'%F %T' > "$OUT_DIR/.done"
  rmdir "$OUT_DIR/.lock" 2>/dev/null || true
  trap - EXIT
done
