#!/usr/bin/env bash
# Run output_score_with_entropy_confidence.py for every SAE, pairing each dl_local_dir with the matching features file.
# NAME = folder basename under each BASE (e.g., batch_topk_50, gated_1070, jumprelu_327, topk_80)
# The script will search several candidate directories and filename patterns to locate the correct features.json.

set -Eeuo pipefail
IFS=$'\n\t'

# -------- Config --------
DEVICE="${DEVICE:-cuda:1}"
MODEL_TYPE="${MODEL_TYPE:-gemma2_9b}"

# If FEATURES_DIR is not set by env, default to repo-relative path.
FEATURES_DIR="${FEATURES_DIR:-/home/dslabra5/sae4steer/saes-are-good-for-steering/data/features}"

CONF_TOPK="${CONF_TOPK:-3}"          # top-k for confidence selection
AMP_FACTOR="${AMP_FACTOR:-10}"       # amplification factor
NEUTRAL_SENT="${NEUTRAL_SENT:-From my experience,}"

# Script/Repo roots (absolute)
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
REPO_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)"

# SAE base directories (absolute)
BASES=(
  "/home/dslabra5/sae4steer/SAEBench/sae_bench/custom_saes/downloaded_saes/trained_saes___google_gemma-2-9b_batch_top_k_jump_relu_standard_new/resid_post_layer_20"
  "/home/dslabra5/sae4steer/SAEBench/sae_bench/custom_saes/downloaded_saes/trained_saes___google_gemma-2-9b_gated_top_k/resid_post_layer_20"
)

# Logs
LOG_DIR="$REPO_ROOT/logs/output_score_ec"
mkdir -p "$LOG_DIR"

# Safety: globs that don't match expand to nothing
shopt -s nullglob

# -------- Helpers --------
find_features_file() {
  local name="$1"

  # Candidate directories that commonly appear in repos
  local cand_dirs=(
    "$FEATURES_DIR/gemma2-9b-l20"
    "$FEATURES_DIR/gemma2_9b_l20"
    "$FEATURES_DIR/gemma-2-9b-l20"
    "$FEATURES_DIR"                    # fallback: directly under features
  )

  # Candidate filename patterns (ordered by preference)
  local patterns=(
    "gemma2-9b_${name}_features.json"
    "gemma2_9b_${name}_features.json"
    "gemma_9b_${name}_features.json"
    "${name}_features.json"
    "gemma2-9b-l20_${name}_features.json"   # in case layer encoded in filename
  )

  local d p f
  for d in "${cand_dirs[@]}"; do
    [[ -d "$d" ]] || continue
    for p in "${patterns[@]}"; do
      f="$d/$p"
      if [[ -f "$f" ]]; then
        echo "$f"
        return 0
      fi
    done
  done

  # Not found
  return 1
}

# -------- Main --------
# Report which features root we are using
if [[ ! -d "$FEATURES_DIR" ]]; then
  echo "[error] FEATURES_DIR not found: $FEATURES_DIR"
  echo "        Adjust FEATURES_DIR (env var) or fix the default path in this script."
  exit 1
fi

for BASE in "${BASES[@]}"; do
  if [[ ! -d "$BASE" ]]; then
    echo "[warn] Base not found: $BASE"
    continue
  fi

  for SAE_DIR in "$BASE"/*; do
    [[ -d "$SAE_DIR" ]] || continue

    NAME="$(basename "$SAE_DIR")"  # e.g., batch_topk_160 / gated_1070 / jumprelu_327 / topk_80

    FEATURES_FILE=""
    if FEATURES_FILE="$(find_features_file "$NAME")"; then
      :
    else
      echo "[skip] Missing features for $NAME"
      echo "       Searched under:"
      echo "         - $FEATURES_DIR/gemma2-9b-l20"
      echo "         - $FEATURES_DIR/gemma2_9b_l20"
      echo "         - $FEATURES_DIR/gemma-2-9b-l20"
      echo "         - $FEATURES_DIR"
      echo "       Tried filename patterns:"
      echo "         - gemma2-9b_${NAME}_features.json"
      echo "         - gemma2_9b_${NAME}_features.json"
      echo "         - gemma_9b_${NAME}_features.json"
      echo "         - ${NAME}_features.json"
      echo "         - gemma2-9b-l20_${NAME}_features.json"
      continue
    fi

    echo "===================="
    echo "Scoring SAE+EC: $NAME"
    echo "  dl_local_dir  : $SAE_DIR"
    echo "  features_file : $FEATURES_FILE"
    echo "  device/model  : $DEVICE / $MODEL_TYPE"
    echo "  conf_topk/amp : $CONF_TOPK / $AMP_FACTOR"
    echo "  neutral_sent  : $NEUTRAL_SENT"
    echo "===================="

    # Run and tee output to a per-SAE log
    set +e
    python "$SCRIPT_DIR/output_score_with_entropy_confidence.py" \
      --device "$DEVICE" \
      --model_type "$MODEL_TYPE" \
      --features_file "$FEATURES_FILE" \
      --dl_local_dir "$SAE_DIR" \
      --confidence_top_k "$CONF_TOPK" \
      --amp_factor "$AMP_FACTOR" \
      --neutral_sentence "$NEUTRAL_SENT" \
      |& tee "$LOG_DIR/${NAME}.log"
    STATUS=${PIPESTATUS[0]}
    set -e

    if [[ $STATUS -ne 0 ]]; then
      echo "[error] Failed on $NAME (exit $STATUS)"
    else
      echo "[done] $NAME"
    fi
    echo
  done
done
