#!/bin/bash
# Run output_score.py for every SAE, pairing each dl_local_dir with the matching features file:
#   NAME = folder basename under each BASE (e.g., batch_topk_50, gated_78, jumprelu_330, topk_80)
#   --dl_local_dir  = <BASE>/<NAME>
#   --features_file = /home/dslabra5/sae4steer/saes-are-good-for-steering/data/features/gemma_2b_<NAME>_features.json
#
# Change DEVICE below if needed (or export DEVICE=cuda:0 before running).

DEVICE="${DEVICE:-cuda:1}"
MODEL_TYPE="gemma2_2b"
FEATURES_DIR="/home/dslabra5/sae4steer/saes-are-good-for-steering/data/features"

# SAE base directories (from your tree)
BASES=(
  "/home/dslabra5/sae4steer/SAEBench/sae_bench/custom_saes/downloaded_saes/trained_saes__google_gemma-2-2b_gated_top_k/resid_post_layer_12"
  "/home/dslabra5/sae4steer/SAEBench/sae_bench/custom_saes/downloaded_saes/trained_saes__google_gemma-2-2b_jump_relu_batch_top_k_standard_new/resid_post_layer_12"
  "/home/dslabra5/sae4steer/SAEBench/sae_bench/custom_saes/downloaded_saes/trained_saes_2__google_gemma-2-2b_batch_top_k/resid_post_layer_12"
  "/home/dslabra5/sae4steer/SAEBench/sae_bench/custom_saes/downloaded_saes/trained_saes_2__google_gemma-2-2b_top_k_jump_relu/resid_post_layer_12"
)

# Optional: store logs
LOG_DIR="./logs/output_score"
mkdir -p "$LOG_DIR"

# Safety: globs that don't match expand to nothing
shopt -s nullglob

for BASE in "${BASES[@]}"; do
  if [[ ! -d "$BASE" ]]; then
    echo "[warn] Base not found: $BASE"
    continue
  fi

  for SAE_DIR in "$BASE"/*; do
    [[ -d "$SAE_DIR" ]] || continue

    NAME="$(basename "$SAE_DIR")"  # e.g., batch_topk_50 / gated_78 / jumprelu_330 / topk_80
    FEATURES_FILE="$FEATURES_DIR/gemma_2b_${NAME}_features.json"

    if [[ ! -f "$FEATURES_FILE" ]]; then
      echo "[skip] Missing features for $NAME -> $FEATURES_FILE"
      continue
    fi

    echo "===================="
    echo "Scoring SAE: $NAME"
    echo "  dl_local_dir  : $SAE_DIR"
    echo "  features_file : $FEATURES_FILE"
    echo "  device/model  : $DEVICE / $MODEL_TYPE"
    echo "===================="

    # Run and tee output to a per-SAE log
    python output_score.py \
      --device "$DEVICE" \
      --model_type "$MODEL_TYPE" \
      --features_file "$FEATURES_FILE" \
      --dl_local_dir "$SAE_DIR" \
      |& tee "$LOG_DIR/${NAME}.log"

    STATUS=${PIPESTATUS[0]}
    if [[ $STATUS -ne 0 ]]; then
      echo "[error] Failed on $NAME (exit $STATUS)"
    else
      echo "[done] $NAME"
    fi
    echo
  done
done
