#!/bin/bash
# Single-GPU worker with resume:
# - RESULT_DIR may already exist (from previous runs). We now use RESULT_DIR/.lock as the atomic mutex.
# - If RESULT_DIR/eval.json exists, we parse it to compute resume_start_idx, then create a temporary
#   truncated features file (remaining only) and continue evaluation.
# - After eval, we merge the new eval.json with any previous one to produce a unified eval.json.
# - Markers: .inprogress / .done / .failed are maintained for observability.

set -euo pipefail
shopt -s nullglob

MODEL_TYPE="gemma2_9b_it"
LAYERS=20
INSTRUCTIONS_FILE="/home/dslabra5/sae4steer/axbench/axbench/data/alpaca_eval.json"

PROJECT_ROOT="/home/dslabra5/sae4steer/saes-are-good-for-steering"
FEATURES_DIR="$PROJECT_ROOT/data/features/gemma2-9b-l20"
CONCEPTS_DIR="$PROJECT_ROOT/concept/gemma2-9b/l20"
WORKDIR="$(cd "$(dirname "$0")" && pwd)"
cd "$WORKDIR"

# Where the python script saves its own artifacts (we also keep per-SAE logs here)
SAVE_ROOT="$WORKDIR/runs_sae_steering_score/gemma2_9b_l20"
mkdir -p "$SAVE_ROOT"

# Results root used both for eval outputs and for resume detection
RESULTS_ROOT="$PROJECT_ROOT/cache/results_sae_eval_openai/${MODEL_TYPE}/layer${LAYERS}"
mkdir -p "$RESULTS_ROOT"

# SAE bases
SAEBENCH_ROOT="/home/dslabra5/sae4steer/SAEBench"
BASES=(
  "$SAEBENCH_ROOT/sae_bench/custom_saes/downloaded_saes/trained_saes___google_gemma-2-9b_gated_top_k/resid_post_layer_20"
)

# Eval knobs
STEERING_FACTORS="0.2,0.4,0.8,1.5,2.0,3.0"
DEV_K=5
MAX_NEW_TOKENS=128
JUDGE_BACKEND="openai_async"
JUDGE_MODEL="gpt-4o-mini"

echo "[info] Worker started on CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-ALL} ; RESULTS_ROOT=$RESULTS_ROOT"

for BASE in "${BASES[@]}"; do
  if [[ ! -d "$BASE" ]]; then
    echo "[warn] Base not found: $BASE"
    continue
  fi

  for SAE_DIR in "$BASE"/*; do
    [[ -d "$SAE_DIR" ]] || continue
    NAME="$(basename "$SAE_DIR")"   # e.g., batch_topk_80 / standard_april_update_156 / ...

    FEATURES_FILE="$FEATURES_DIR/gemma_9b_${NAME}_features.json"
    CONCEPTS_FILE="$CONCEPTS_DIR/${NAME}_concept_descriptions.json"
    if [[ ! -f "$FEATURES_FILE" ]]; then
      echo "[skip] Missing features for $NAME -> $FEATURES_FILE"
      continue
    fi
    if [[ ! -f "$CONCEPTS_FILE" ]]; then
      echo "[skip] Missing concepts for $NAME -> $CONCEPTS_FILE"
      continue
    fi

    RESULT_DIR="$RESULTS_ROOT/$NAME"
    mkdir -p "$RESULT_DIR"

    # If already fully done, skip directly
    if [[ -f "$RESULT_DIR/.done" ]]; then
      echo "[skip] Done: $RESULT_DIR -> $NAME"
      continue
    fi

    # Atomic lock (prevents two workers from running the same SAE)
    LOCK_DIR="$RESULT_DIR/.lock"
    if ! mkdir "$LOCK_DIR" 2>/dev/null; then
      echo "[skip] Locked by another worker -> $NAME"
      continue
    fi

    # From here on, we must ensure lock release
    release_lock() { rmdir "$LOCK_DIR" 2>/dev/null || true; }
    trap 'release_lock; exit 130' INT TERM

    # Compute resume_start_idx from existing eval.json (if any), by mapping feature ids to positions
    PREV_EVAL_JSON="$RESULT_DIR/eval.json"
    RESUME_START_IDX=0
    TOTAL_FEATURES=0
    TMP_FEATURES_FILE=""

    # Helper: robustly compute resume index via inline Python
    read -r RESUME_START_IDX TOTAL_FEATURES < <(
      python - "$FEATURES_FILE" "$PREV_EVAL_JSON" <<'PY'
import json, sys, os, re
features_path, prev_eval_path = sys.argv[1], sys.argv[2]

# Load features list robustly
with open(features_path, 'r', encoding='utf-8') as f:
    obj = json.load(f)

def extract_list(o):
    if isinstance(o, list):
        return o
    if isinstance(o, dict):
        for k in ('features','feature_ids','feature_indices','ids'):
            if k in o and isinstance(o[k], list):
                return o[k]
        # fallback: first list-like value
        for v in o.values():
            if isinstance(v, list):
                return v
    raise SystemExit("ERROR: cannot locate features list in features_file")

feat_list = extract_list(obj)
# normalize to ints
try:
    feats = [int(str(x)) for x in feat_list]
except Exception:
    # if items are dicts with 'id'
    feats = []
    for x in feat_list:
        if isinstance(x, dict) and 'id' in x:
            feats.append(int(str(x['id'])))
        else:
            raise
id2pos = {fid: i for i, fid in enumerate(feats)}

resume_idx = 0
if os.path.exists(prev_eval_path):
    try:
        with open(prev_eval_path, 'r', encoding='utf-8') as f:
            prev = json.load(f)
        done_ids = set()
        for k in prev.keys():
            # keys like "12_12348" -> take the suffix as feature id
            if isinstance(k, str):
                m = re.match(r'^\d+_(\d+)$', k)
                if m:
                    done_ids.add(int(m.group(1)))
        # compute max completed position that also exists in features
        max_pos = -1
        for fid in done_ids:
            pos = id2pos.get(fid, None)
            if pos is not None and pos > max_pos:
                max_pos = pos
        resume_idx = max_pos + 1
    except Exception:
        # if eval.json malformed, just start from 0
        resume_idx = 0

print(resume_idx, len(feats))
PY
    )

    # Nothing left to do?
    if (( RESUME_START_IDX >= TOTAL_FEATURES )); then
      echo "[skip] Already complete: $NAME (resume_idx=$RESUME_START_IDX >= total=$TOTAL_FEATURES)"
      # ensure .done marker exists
      date +'%F %T' > "$RESULT_DIR/.done"
      release_lock
      trap - INT TERM
      continue
    fi

    # Build a truncated features file if resuming from non-zero
    FEATURES_ARG="$FEATURES_FILE"
    if (( RESUME_START_IDX > 0 )); then
      TMP_FEATURES_FILE="$(mktemp "$SAVE_ROOT/${NAME}_features_resume_XXXX.json")"
      python - "$FEATURES_FILE" "$TMP_FEATURES_FILE" "$RESUME_START_IDX" <<'PY'
import json, sys
src, dst, idx = sys.argv[1], sys.argv[2], int(sys.argv[3])
with open(src, 'r', encoding='utf-8') as f:
    obj = json.load(f)

def slice_in_place(o, start):
    if isinstance(o, list):
        return o[start:]
    if isinstance(o, dict):
        for k in ('features','feature_ids','feature_indices','ids'):
            if k in o and isinstance(o[k], list):
                o[k] = o[k][start:]
                return o
        # fallback: slice first list-like value
        for k, v in o.items():
            if isinstance(v, list):
                o[k] = v[start:]
                return o
    raise SystemExit("ERROR: cannot slice features structure")

new_obj = slice_in_place(obj, idx)
with open(dst, 'w', encoding='utf-8') as f:
    json.dump(new_obj, f, ensure_ascii=False, indent=2)
PY
      FEATURES_ARG="$TMP_FEATURES_FILE"
      echo "[resume] $NAME -> resume_start_idx=$RESUME_START_IDX / total=$TOTAL_FEATURES"
    fi

    # Mark in progress
    date +'%F %T' > "$RESULT_DIR/.inprogress"

    SAVE_DIR="$SAVE_ROOT/$NAME"
    mkdir -p "$SAVE_DIR"
    LOGFILE="$SAVE_DIR/run_$(date +%Y%m%d_%H%M%S).log"

    echo "===================="
    echo "Running SAE: $NAME"
    echo "  dl_local_dir : $SAE_DIR"
    echo "  features_file: $FEATURES_ARG   # (may be truncated for resume)"
    echo "  concepts_file: $CONCEPTS_FILE"
    echo "  save_dir     : $SAVE_DIR"
    echo "  result_dir   : $RESULT_DIR"
    echo "===================="

    set +e
    python -u eval_sae_steering.py \
      --model_type "$MODEL_TYPE" \
      --dl_local_dir "$SAE_DIR" \
      --features_file "$FEATURES_ARG" \
      --instructions_file "$INSTRUCTIONS_FILE" \
      --concepts_file "$CONCEPTS_FILE" \
      --judge_backend "$JUDGE_BACKEND" \
      --judge_model "$JUDGE_MODEL" \
      --layers "$LAYERS" \
      --steering_factors "$STEERING_FACTORS" \
      --dev_k "$DEV_K" \
      --max_new_tokens "$MAX_NEW_TOKENS" \
      --save_dir "$SAVE_DIR" \
      --debug --sample_print_k 1 --print_chars 300 \
      |& tee -a "$LOGFILE"
    STATUS=${PIPESTATUS[0]}
    set -e

    # Merge previous eval.json (if existed) with the new one
    if [[ -f "$PREV_EVAL_JSON" && -s "$PREV_EVAL_JSON" ]]; then
      python - "$PREV_EVAL_JSON" "$RESULT_DIR/eval.json" <<'PY'
import json, sys, os, tempfile, shutil
prev_path, cur_path = sys.argv[1], sys.argv[2]
def load(path):
    if os.path.exists(path) and os.path.getsize(path) > 0:
        with open(path, 'r', encoding='utf-8') as f:
            return json.load(f)
    return {}
prev = load(prev_path)
cur  = load(cur_path)
# merge: new results overwrite old keys if duplicated
merged = {}
merged.update(prev)
merged.update(cur)
tmp = cur_path + ".tmp"
with open(tmp, 'w', encoding='utf-8') as f:
    json.dump(merged, f, ensure_ascii=False, indent=2)
shutil.move(tmp, cur_path)
PY
    fi

    # Cleanup/markers
    rm -f "$RESULT_DIR/.inprogress" || true
    if [[ -n "${TMP_FEATURES_FILE:-}" && -f "$TMP_FEATURES_FILE" ]]; then
      rm -f "$TMP_FEATURES_FILE" || true
    fi

    if [[ $STATUS -ne 0 ]]; then
      echo "[error] Failed on $NAME (exit $STATUS)"
      echo "$STATUS" > "$RESULT_DIR/.failed"
    else
      echo "[done] $NAME"
      date +'%F %T' > "$RESULT_DIR/.done"
    fi

    # Release lock and clear trap
    release_lock
    trap - INT TERM
    echo
  done
done

echo "[info] Worker finished on CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-ALL}"
