from __future__ import annotations  


import os, sys
CURRENT_FILE = os.path.abspath(__file__)
CURRENT_DIR = os.path.dirname(CURRENT_FILE)
SRC_DIR = os.path.normpath(os.path.join(CURRENT_DIR, ".."))
if SRC_DIR not in sys.path:
    sys.path.insert(0, SRC_DIR)

import argparse
import json
import os
import random
import re
from typing import Dict, List, Tuple, Any

import glob
import torch
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer

from utils import get_features_by_layers, get_sae
from sae_utils import AmlifySAEHook
from judge import (
    AsyncOpenAIJudge, OpenAIJudgeConfig,
    LocalHFJudge, LocalJudgeConfig,
    harmonic_mean_0_2,
)

# --------------------
# I/O helpers
# --------------------
def read_json(path: str):
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)

def write_json(obj, path: str):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    with open(path, "w", encoding="utf-8") as f:
        json.dump(obj, f, ensure_ascii=False, indent=2)

def parse_factors(arg: str) -> List[float]:
    return [float(x) for x in arg.split(",") if x.strip()]

def split_dev_holdout(pool: List[str], k: int, seed: int) -> Tuple[List[str], List[str]]:
    rnd = random.Random(seed)
    xs = pool[:]
    rnd.shuffle(xs)
    dev = xs[:k]
    hold = xs[k: k + k]
    return dev, hold

def load_instructions_file(path: str, min_needed: int) -> List[str]:
    """
    Load a pool of instructions and normalize to List[str].
    Supports:
      - ["string1", "string2", ...]
      - [{"instruction": "...", "input": "..."}, ...]
      - [{"prompt": "..."}, {"question": "..."}, ...]
      - [{"messages": [{"role": "...", "content": "..."}, ...]}, ...]
    We concatenate instruction + optional input with a newline.
    """
    with open(path, "r", encoding="utf-8") as f:
        raw = json.load(f)

    items: List[str] = []

    def norm_one(e: Any) -> str | None:
        if isinstance(e, str):
            s = e.strip()
            return s if s else None

        if isinstance(e, dict):
            inst = e.get("instruction") or e.get("prompt") or e.get("question") or e.get("query") or e.get("task") or e.get("title") or e.get("text")
            inp  = e.get("input") or e.get("context") or e.get("extra")
            msgs = e.get("messages")

            if isinstance(inst, str) and inst.strip():
                s = inst.strip()
                if isinstance(inp, str) and inp.strip():
                    s = s + "\n" + inp.strip()
                return s

            if isinstance(msgs, list):
                parts = []
                for m in msgs:
                    if isinstance(m, dict) and isinstance(m.get("content"), str):
                        parts.append(m["content"].strip())
                s = "\n".join(p for p in parts if p)
                return s if s.strip() else None

        return None

    for e in raw:
        s = norm_one(e)
        if s:
            items.append(s)

    # Deduplicate while preserving order
    items = [s for s in dict.fromkeys(items) if s]

    if len(items) < min_needed:
        raise ValueError(
            f"Instructions file '{path}' yielded only {len(items)} usable strings; "
            f"need at least {min_needed}. Ensure it has 'instruction'/'input' fields "
            f"or provide a plain list[str]."
        )
    return items

# =======================================================================
# ==== Diagnostics utilities ============================================
# =======================================================================
def truncate_text(s: str, max_chars: int) -> str:
    if s is None:
        return ""
    s = s.replace("\n", "\\n")
    return (s[:max_chars] + " ...<truncated>") if len(s) > max_chars else s

def safe_float(x: Any) -> float:
    try:
        return float(x)
    except Exception:
        return 0.0

def infer_sae_num_features(sae) -> int | None:
    """
    Best-effort inference of SAE feature count for sanity checks.
    """
    for attr in ("num_features", "n_features"):
        if hasattr(sae, attr):
            try:
                return int(getattr(sae, attr))
            except Exception:
                pass
    # Try common modules/weights
    for cand_attr in ("encoder", "decoder", "ae"):
        mod = getattr(sae, cand_attr, None)
        if mod is None:
            continue
        for wname in ("weight",):
            w = getattr(mod, wname, None)
            if isinstance(w, torch.Tensor):
                # Heuristic: out_features ~ number of codes/features
                try:
                    return int(w.shape[0])
                except Exception:
                    pass
    return None

class DiagnosticHook:
    """
    A thin wrapper that counts calls while delegating to AmlifySAEHook.
    """
    def __init__(self, inner_hook, layer: int, feature: int, amp_factor: float):
        self.inner = inner_hook
        self.layer = layer
        self.feature = feature
        self.amp_factor = amp_factor
        self.calls = 0

    def __call__(self, module, inputs, output):
        self.calls += 1
        return self.inner(module, inputs, output)

def score_with_guard(judge, concept_desc: str, inst: str, resp: str, debug: bool = False) -> Tuple[float, float, float, str | None]:
    """
    Call judge.score and catch exceptions; return (c,i,f,error_message) as floats.
    """
    try:
        c, i, f = judge.score(concept_desc, inst, resp)
        return safe_float(c), safe_float(i), safe_float(f), None
    except Exception as e:
        if debug:
            print(f"[WARN] judge.score failed: {type(e).__name__}: {e}")
        return 0.0, 0.0, 0.0, f"{type(e).__name__}: {e}"

def diagnose_zero_case(
    layer: int,
    feature: int,
    concept_desc: str,
    dev_details_per_factor: List[Dict[str, Any]],
    hold_items: List[Dict[str, Any]],
    hook_call_samples: List[int],
    judge_errors: int,
    sae_feature_upper_bound: int | None
):
    """
    Print human-readable diagnosis when scores are all zeros or suspiciously low.
    """
    all_dev_overalls = [round(d["avg_overall"], 6) for d in dev_details_per_factor]
    all_hold_overalls = [round(x.get("overall", 0.0), 6) for x in hold_items]
    dev_all_zero = all(v == 0.0 for v in all_dev_overalls)
    hold_all_zero = all(v == 0.0 for v in all_hold_overalls)

    total_items = sum(len(d["items"]) for d in dev_details_per_factor) + len(hold_items)
    total_c = sum(sum(item["concept"] for item in d["items"]) for d in dev_details_per_factor) + sum(item["concept"] for item in hold_items)
    total_i = sum(sum(item["instruct"] for item in d["items"]) for d in dev_details_per_factor) + sum(item["instruct"] for item in hold_items)
    total_f = sum(sum(item["fluency"] for item in d["items"]) for d in dev_details_per_factor) + sum(item["fluency"] for item in hold_items)

    print("\n===== DIAGNOSIS START =====")
    print(f"[Diag] Layer={layer} Feature={feature}")
    print(f"[Diag] Concept desc empty? {'YES' if not concept_desc.strip() else 'NO'}")
    if sae_feature_upper_bound is not None:
        out_of_range = feature < 0 or feature >= sae_feature_upper_bound
        print(f"[Diag] Feature index out of SAE range? {'YES' if out_of_range else 'NO'} (feature={feature}, SAE_features≈{sae_feature_upper_bound})")
    else:
        print(f"[Diag] SAE feature count unknown (could not infer).")

    print(f"[Diag] Dev avg_overall per factor: {all_dev_overalls}")
    print(f"[Diag] Hold overall items: {all_hold_overalls}")
    print(f"[Diag] Sum of component scores across all items: concept={total_c}, instruct={total_i}, fluency={total_f} (N={total_items})")
    print(f"[Diag] Hook call counts (samples): {hook_call_samples if hook_call_samples else '[no samples recorded]'}")
    print(f"[Diag] judge.score errors count: {judge_errors}")

    if judge_errors > 0 and dev_all_zero and hold_all_zero:
        print("[Conclusion] Likely JUDGE ISSUE (API/Model errors) — all overall scores are zero and judge raised exceptions.")
    elif not concept_desc.strip():
        print("[Conclusion] DATA ISSUE — empty concept description; scoring cannot reflect the intended concept.")
    elif sae_feature_upper_bound is not None and (feature < 0 or feature >= sae_feature_upper_bound):
        print("[Conclusion] FEATURE ISSUE — feature index out of SAE range; the steering target is invalid.")
    elif hook_call_samples and max(hook_call_samples) == 0:
        print("[Conclusion] HOOK NOT TRIGGERED — forward hook was not called; check layer index and model block wiring.")
    elif total_c == 0 and total_i == 0 and total_f == 0:
        print("[Conclusion] AMBIGUOUS — all component subscores are zero. Possible causes: (1) judge thresholds too strict; (2) generated text off-topic; (3) concept description too vague/misaligned.")
    elif dev_all_zero and hold_all_zero:
        print("[Conclusion] VERY LOW EFFECT — steering produced no measurable improvement across all factors.")
    else:
        print("[Conclusion] MIXED — some signals present but still very low; consider increasing factor range, inspecting outputs, or revising concept/instructions.")
    print("===== DIAGNOSIS END =====\n")

# =======================================================================
# ==== Generation with SAE steering hook =================================
# =======================================================================
@torch.no_grad()
def generate_with_hook(
    model, tokenizer, sae, layer: int, feature: int, instruction: str,
    amp_factor: float, device: str, max_new_tokens: int = 128,
    temperature: float = 0.7, top_p: float = 0.95,
    # ==== [ADDED] ====
    debug: bool = False, print_chars: int = 300
) -> Dict[str, Any]:
    """
    Generate with SAE steering; returns {'text': str, 'diag': {...}}.
    The 'diag' contains hook call count and metadata for debugging.
    """
    assert isinstance(instruction, str), f"instruction must be str, got {type(instruction)}"
    model = model.to(device)
    sae = sae.to(device)

    # Validate layer index exists
    # ==== [ADDED] ====
    try:
        block = model.model.layers[layer]
    except Exception as e:
        raise ValueError(f"Invalid layer index {layer}: {e}")

    # Wrap AmlifySAEHook with DiagnosticHook
    # ==== [ADDED] ====
    inner = AmlifySAEHook(layer, sae, [feature], amp_factor, device)
    hook = DiagnosticHook(inner, layer=layer, feature=feature, amp_factor=amp_factor)
    handle = block.register_forward_hook(hook, always_call=True)

    inputs = tokenizer(instruction, return_tensors="pt").to(device)
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token = tokenizer.eos_token

    out_ids = model.generate(
        **inputs,
        do_sample=True if temperature > 0 else False,
        temperature=temperature if temperature > 0 else None,
        top_p=top_p,
        max_new_tokens=max_new_tokens,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
    )

    handle.remove()
    # safety cleanup
    for h in list(block._forward_hooks.values()):
        try:
            block._forward_hooks.pop(h.id, None)
        except Exception:
            pass

    text = tokenizer.decode(out_ids[0], skip_special_tokens=True)
    diag = {
        "layer": layer, "feature": feature, "amp_factor": amp_factor,
        "hook_calls": hook.calls,
        # ==== [ADDED] keep short preview for logs ====
        "preview": truncate_text(text, print_chars) if debug else None
    }
    return {"text": text, "diag": diag}

# =======================================================================
# ==== Baseline generation (no hook) =====================================
# =======================================================================
@torch.no_grad()
def generate_baseline(
    model, tokenizer, instruction: str, device: str,
    max_new_tokens: int = 128, temperature: float = 0.7, top_p: float = 0.95
) -> str:
    model = model.to(device)
    inputs = tokenizer(instruction, return_tensors="pt").to(device)
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token = tokenizer.eos_token
    out_ids = model.generate(
        **inputs,
        do_sample=True if temperature > 0 else False,
        temperature=temperature if temperature > 0 else None,
        top_p=top_p,
        max_new_tokens=max_new_tokens,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
    )
    return tokenizer.decode(out_ids[0], skip_special_tokens=True)

# =======================================================================
# ==== Helper: auto save path under cache/results_sae_eval_openai ========
# =======================================================================
def _detect_repo_root_from_features(features_file_path: str) -> str:
    """
    Find '/.../saes-are-good-for-steering' root from the features file path.
    """
    norm = os.path.normpath(features_file_path)
    parts = norm.split(os.sep)
    if "saes-are-good-for-steering" in parts:
        idx = parts.index("saes-are-good-for-steering")
        return os.sep.join(parts[: idx + 1])
    # fallback
    return os.getcwd()

def _method_tag_from_trainer_class(trainer_class_name: str) -> str:
    name = (trainer_class_name or "").lower().strip()  # ← add .strip()

    if "gated" in name:
        return "gated"
    if "batchtopk" in name or "batch_topk" in name:
        return "batch_topk"
    if "topk" in name and "batch" not in name:
        return "topk"
    if "jumprelu" in name or ("jump" in name and "relu" in name):
        return "jump_relu"
    if "standard" in name and "april" in name and "update" in name:
        return "standard_april_update"

    name = re.sub(r"[^a-z0-9_]+", "_", name).strip("_")
    return name[:32] if name else "unknown"

def _trainer_id_from_path(path: str) -> str:
    patterns = [
        r"(?:^|/)(?:trainer|gated|batch[_-]?topk|top[_-]?k|jump[_-]?relu|jumprelu|standard[_-]?april[_-]?update)[_/ -]?(\d+)(?:/|$)",
        r"[_-](\d+)(?:/|$)",
    ]
    for pat in patterns:
        ms = list(re.finditer(pat, path, flags=re.IGNORECASE))
        if ms:
            return ms[-1].group(1)
    return "unk"

def _build_experiment_name(trainer_class: str, trainer_id: str) -> str:
    method = _method_tag_from_trainer_class(trainer_class)
    tid = (trainer_id or "").lower().strip()
    m = re.search(r"(\d+)$", tid)
    tid_num = m.group(1) if m else tid
    return f"{method}_{tid_num}"

def _build_eval_out_path(repo_root: str, model_type: str, layer: int, trainer_class: str, trainer_id: str) -> str:
    """
    /.../saes-are-good-for-steering/cache/results_sae_eval_openai/<model_dir>/layer<layer>/<exp_name>/eval.json
    """
    base = os.path.join(repo_root, "cache", "results_sae_eval_openai")
    layer_dir = f"layer{int(layer)}"
    exp = _build_experiment_name(trainer_class, trainer_id)
    model_dir = model_type.split("/")[-1]  # NEW: avoid extra directory level like "Qwen/"
    out_dir = os.path.join(base, model_dir, layer_dir, exp)
    os.makedirs(out_dir, exist_ok=True)
    return os.path.join(out_dir, "eval.json")


# --------------------
# Local snapshot resolver for offline loading
# --------------------
def _resolve_local_snapshot_if_available(repo_or_path: str, cache_root: str, offline: bool) -> str:
    """
    If `repo_or_path` is a repo id like 'Qwen/Qwen2.5-3B(-Instruct)' and offline=True,
    try to locate the latest local snapshot under <cache_root>/hub/models--{repo}/snapshots/<hash>.
    If found, return that snapshot path; else return `repo_or_path` unchanged.
    """
    if os.path.isdir(repo_or_path):
        return repo_or_path
    if not offline:
        return repo_or_path

    base = os.path.join(cache_root, "hub", f"models--{repo_or_path.replace('/', '--')}")
    snap_root = os.path.join(base, "snapshots")
    if os.path.isdir(snap_root):
        candidates = [p for p in glob.glob(os.path.join(snap_root, "*")) if os.path.isdir(p)]
        if candidates:
            candidates.sort(key=os.path.getmtime, reverse=True)
            chosen = candidates[0]
            print(f"[INFO] Using local snapshot for {repo_or_path}: {chosen}")
            return chosen
        else:
            print(f"[WARN] No snapshots found under: {snap_root}")
    else:
        print(f"[WARN] Snapshot root not found: {snap_root}")
    return repo_or_path

# --------------------
# Main
# --------------------
def main():
    p = argparse.ArgumentParser()
    # Tested model
    p.add_argument("--model_type", type=str, default="gemma2_9b",
                   choices=[
                       "gemma2_2b","gemma2_2b_it", "gemma2_9b", "gemma2_9b_it",
                       "Qwen/Qwen2.5-3B", "Qwen/Qwen2.5-3B-Instruct"
                   ])
    # SAE
    p.add_argument("--dl_local_dir", type=str, required=True,
                   help="One DL-SAE folder (contains config.json + ae.pt). The script auto-swaps the layer id in path.")
    p.add_argument("--features_file", type=str, required=True,
                   help="JSON: {layer: [feature_ids...]}")
    # Instructions and concept descriptions
    p.add_argument("--instructions_file", type=str, required=True,
                   help="Instruction pool (Alpaca-Eval style). Accepts list[str] or list[dict] with 'instruction'/chat keys.")
    p.add_argument("--concepts_file", type=str, required=True,
                   help="{'<layer>_<feature>': 'concept description'} used by Concept score.")
    # Judges
    p.add_argument("--judge_backend", type=str, default="openai_async",
                   choices=["openai_async", "hf_local"])
    p.add_argument("--judge_model", type=str, default="gpt-4o-mini",
                   help="When openai_async: OpenAI model; when hf_local: HF model name.")
    # Eval settings
    p.add_argument("--layers", type=int, nargs="*", default=None,
                   help="If set, only evaluate these layers; else use all keys from features_file.")
    p.add_argument("--steering_factors", type=str,
                   default="0.2,0.4,0.8,1.5,2.0,3.0")
    p.add_argument("--dev_k", type=int, default=5, help="Dev split size (holdout has the same size).")
    p.add_argument("--max_new_tokens", type=int, default=128)
    p.add_argument("--temperature", type=float, default=0.7)
    p.add_argument("--seed", type=int, default=123)
    # Output (kept but NOT used; we auto-save to cache/results_sae_eval_openai)
    p.add_argument("--save_dir", type=str, required=True,
                   help="(Ignored) Output root; the script now auto-saves under cache/results_sae_eval_openai.")

    # Debug/diagnostics flags
    p.add_argument("--debug", action="store_true",
                   help="Print detailed diagnostics and baseline vs steered outputs (truncated).")
    p.add_argument("--sample_print_k", type=int, default=1,
                   help="Per factor, how many instructions to print with baseline vs steered comparison.")
    p.add_argument("--print_chars", type=int, default=300,
                   help="Max characters to print for each generated sample.")
    p.add_argument("--top_p", type=float, default=0.95,
                   help="Top-p for sampling; exposed for completeness.")

    args = p.parse_args()

    # Model mapping
    if args.model_type == "gemma2_2b":
        model_name = "google/gemma-2-2b"
    elif args.model_type == "gemma2_2b_it":
        model_name = "google/gemma-2-2b-it"
    elif args.model_type == "gemma2_9b":
        model_name = "google/gemma-2-9b"
    elif args.model_type == "gemma2_9b_it":
        model_name = "google/gemma-2-9b-it"
    elif args.model_type == "Qwen/Qwen2.5-3B":
        model_name = "Qwen/Qwen2.5-3B"
    elif args.model_type == "Qwen/Qwen2.5-3B-Instruct":
        model_name = "Qwen/Qwen2.5-3B-Instruct"
    else:
        raise ValueError(f"Unsupported model_type {args.model_type}")

    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    random.seed(args.seed)

    # Features / layers
    features_by_layers: Dict[int, List[int]] = get_features_by_layers(args.features_file)
    if args.layers:
        allow = set(args.layers)
        features_by_layers = {k: v for k, v in features_by_layers.items() if k in allow}

    if not features_by_layers:
        print("[ERROR] No concepts to evaluate: --layers has no overlap with features_file keys.")
        return

    # Instructions (normalize to List[str])
    instructions_pool: List[str] = load_instructions_file(
        args.instructions_file,
        min_needed=2 * args.dev_k
    )
    print(f"[INFO] Loaded {len(instructions_pool)} normalized instructions from {args.instructions_file}")

    # Concept descriptions
    concepts_map: Dict[str, str] = read_json(args.concepts_file)

    # ===== Model/Tokenizer loading (with offline local snapshot support) =====
    print(f"[INFO] Loading base model and tokenizer: {model_name} on {device}")

    # Resolve cache root & offline flag
    cache_dir = (
        os.environ.get("HF_HOME")
        or os.environ.get("HF_CACHE_DIR")
        or os.path.expanduser("~/.cache/huggingface")
    )
    offline = (os.environ.get("TRANSFORMERS_OFFLINE", "0") == "1") or (os.environ.get("HF_LOCAL_ONLY", "0") == "1")

    # For Qwen models we typically need trust_remote_code=True (local code in snapshot)
    trust_remote = ("qwen" in model_name.lower())

    # If offline and a local snapshot exists, replace model_name with snapshot path
    model_name_resolved = _resolve_local_snapshot_if_available(model_name, cache_dir, offline)

    tokenizer = AutoTokenizer.from_pretrained(
        model_name_resolved,
        trust_remote_code=trust_remote,
        cache_dir=cache_dir,
        local_files_only=offline,
    )
    model = AutoModelForCausalLM.from_pretrained(
        model_name_resolved,
        trust_remote_code=trust_remote,
        cache_dir=cache_dir,
        local_files_only=offline,
    ).to(device)

    if tokenizer.pad_token_id is None:
        tokenizer.pad_token = tokenizer.eos_token

    # Judge
    if args.judge_backend == "openai_async":
        judge = AsyncOpenAIJudge(OpenAIJudgeConfig(model=args.judge_model))
    else:
        judge = LocalHFJudge(LocalJudgeConfig(model_name=args.judge_model, device=device))
    print(f"[INFO] Judge backend: {args.judge_backend} | model: {args.judge_model}")

    # SAE cache
    saes = {}

    factors = parse_factors(args.steering_factors)
    print(f"[INFO] Will search over steering factors: {factors}")

    # Repo root (for auto save path)
    repo_root = _detect_repo_root_from_features(args.features_file)

    # Per-layer evaluation
    for layer, feat_list in features_by_layers.items():
        sae = get_sae(
            model_type=args.model_type,
            layer=layer,
            saes=saes,
            backend="dl_local",
            dl_local_dir=args.dl_local_dir,
            device=device,
        )

        # Build save path for THIS layer using SAE meta
        trainer_class = getattr(sae, "trainer_class_name", "") or ""
        trainer_id = _trainer_id_from_path(args.dl_local_dir)
        out_path = _build_eval_out_path(
            repo_root=repo_root,
            model_type=args.model_type,
            layer=layer,
            trainer_class=trainer_class,
            trainer_id=trainer_id,
        )

        # SAE feature count sanity
        sae_feat_count = infer_sae_num_features(sae)
        if sae_feat_count is not None:
            print(f"[INFO] SAE inferred feature count (approx): {sae_feat_count}")
        else:
            print(f"[INFO] SAE feature count could not be inferred (proceeding).")

        # === [RESUME SUPPORT] load existing layer_result if any ===
        try:
            layer_result: Dict[str, Any] = read_json(out_path) if os.path.exists(out_path) else {}
            if layer_result:
                print(f"[RESUME] Loaded {len(layer_result)} items from {out_path}")
        except Exception as e:
            print(f"[RESUME] Could not read existing {out_path}: {type(e).__name__}: {e}")
            layer_result = {}

        print(f"\n[Eval] layer={layer} | features={len(feat_list)} | save -> {out_path}")

        # running aggregates for diagnosis
        judge_error_total = 0

        for feat in tqdm([int(x) for x in feat_list], desc=f"Layer {layer}"):
            lf_key = f"{layer}_{feat}"

            # === [RESUME SUPPORT] skip finished feature ===
            if lf_key in layer_result:
                print(f"[SKIP] already evaluated {lf_key}")
                continue

            concept_desc = concepts_map.get(lf_key, "").strip()
            if not concept_desc:
                raise ValueError(f"Missing concept description for '{lf_key}' in {args.concepts_file}")

            # SAE feature range sanity check
            if sae_feat_count is not None and (feat < 0 or feat >= sae_feat_count):
                print(f"[WARN] Feature {feat} seems out of SAE feature range (≈{sae_feat_count}). Steering may have no effect.")

            # For each concept, sample dev/holdout
            rnd_seed = (args.seed * 1000003 + feat) & 0x7fffffff
            dev_insts, hold_insts = split_dev_holdout(instructions_pool, k=args.dev_k, seed=rnd_seed)

            # Dev: search best factor
            best_factor: float | None = None
            best_dev_avg = -1.0
            dev_details_per_factor = []

            # capture some hook call counts for diagnosis
            diag_hook_calls_samples: List[int] = []

            for fac in factors:
                dev_items = []
                print(f"\n>>> [DEV] L{layer} F{feat} | factor={fac} | {len(dev_insts)} instructions")
                for idx, inst in enumerate(dev_insts, 1):
                    # Generate with hook
                    resp_obj = generate_with_hook(
                        model=model, tokenizer=tokenizer, sae=sae,
                        layer=layer, feature=feat, instruction=inst,
                        amp_factor=fac, device=device,
                        max_new_tokens=args.max_new_tokens,
                        temperature=args.temperature, top_p=args.top_p,
                        debug=args.debug, print_chars=args.print_chars
                    )
                    resp = resp_obj["text"]
                    hook_calls = resp_obj["diag"]["hook_calls"]
                    diag_hook_calls_samples.append(hook_calls)

                    # Score
                    c, i, f, err = score_with_guard(judge, concept_desc, inst, resp, debug=args.debug)
                    if err: judge_error_total += 1
                    overall = harmonic_mean_0_2(c, i, f)
                    dev_items.append({"concept": float(c), "instruct": float(i), "fluency": float(f), "overall": float(overall)})

                    # optional baseline vs steered print
                    if args.debug and idx <= args.sample_print_k:
                        base = generate_baseline(
                            model=model, tokenizer=tokenizer, instruction=inst,
                            device=device, max_new_tokens=args.max_new_tokens,
                            temperature=args.temperature, top_p=args.top_p
                        )
                        print(f"[Sample DEV #{idx}] factor={fac} hook_calls={hook_calls}")
                        print(f"  Instruction: {truncate_text(inst, args.print_chars)}")
                        print(f"  Baseline  : {truncate_text(base, args.print_chars)}")
                        print(f"  Steered   : {truncate_text(resp, args.print_chars)}")
                        print(f"  Scores (c,i,f)->overall: ({c:.3f},{i:.3f},{f:.3f})->{overall:.3f}")

                avg_overall = sum(x["overall"] for x in dev_items) / len(dev_items) if dev_items else 0.0
                dev_details_per_factor.append({"factor": fac, "items": dev_items, "avg_overall": avg_overall})
                if avg_overall > best_dev_avg:
                    best_dev_avg = avg_overall
                    best_factor = fac

                print(f"<<< [DEV SUMMARY] L{layer} F{feat} factor={fac} avg_overall={avg_overall:.3f} (best_so_far={best_dev_avg:.3f} @ {best_factor})")

            # Holdout with best factor
            hold_items = []
            print(f"\n>>> [HOLDOUT] L{layer} F{feat} | best_factor={best_factor} | {len(hold_insts)} instructions")
            for idx, inst in enumerate(hold_insts, 1):
                resp_obj = generate_with_hook(
                    model=model, tokenizer=tokenizer, sae=sae,
                    layer=layer, feature=feat, instruction=inst,
                    amp_factor=best_factor, device=device,
                    max_new_tokens=args.max_new_tokens,
                    temperature=args.temperature, top_p=args.top_p,
                    debug=args.debug, print_chars=args.print_chars
                )
                resp = resp_obj["text"]
                hook_calls = resp_obj["diag"]["hook_calls"]
                c, i, f, err = score_with_guard(judge, concept_desc, inst, resp, debug=args.debug)
                if err: judge_error_total += 1
                overall = harmonic_mean_0_2(c, i, f)
                hold_items.append({"concept": float(c), "instruct": float(i), "fluency": float(f), "overall": float(overall)})

                if args.debug and idx <= args.sample_print_k:
                    base = generate_baseline(
                        model=model, tokenizer=tokenizer, instruction=inst,
                        device=device, max_new_tokens=args.max_new_tokens,
                        temperature=args.temperature, top_p=args.top_p
                    )
                    print(f"[Sample HOLD #{idx}] best_factor={best_factor} hook_calls={hook_calls}")
                    print(f"  Instruction: {truncate_text(inst, args.print_chars)}")
                    print(f"  Baseline  : {truncate_text(base, args.print_chars)}")
                    print(f"  Steered   : {truncate_text(resp, args.print_chars)}")
                    print(f"  Scores (c,i,f)->overall: ({c:.3f},{i:.3f},{f:.3f})->{overall:.3f}")

            hold_avg = {
                "concept": sum(x["concept"] for x in hold_items) / len(hold_items) if hold_items else 0.0,
                "instruct": sum(x["instruct"] for x in hold_items) / len(hold_items) if hold_items else 0.0,
                "fluency": sum(x["fluency"] for x in hold_items) / len(hold_items) if hold_items else 0.0,
                "overall": sum(x["overall"] for x in hold_items) / len(hold_items) if hold_items else 0.0,
            }

            # ====== Keep minimal fields for every feature =====
            layer_result[lf_key] = {
                "layer": layer,
                "feature": feat,
                "best_factor": float(best_factor),
                "holdout": {
                    "mean": hold_avg
                }
            }

            # === [CHECKPOINT] incremental save per feature ===
            write_json(layer_result, out_path)
            print(f"[CHECKPOINT] wrote {lf_key} -> {out_path}")

            print(f"[L{layer} F{feat}] best_factor={best_factor:.4g}  dev_overall={best_dev_avg:.3f}  hold_overall={hold_avg['overall']:.3f}")
            torch.cuda.empty_cache()

            # Automatic diagnosis if scores are zero-ish
            if abs(best_dev_avg) < 1e-9 and abs(hold_avg["overall"]) < 1e-9:
                diagnose_zero_case(
                    layer=layer, feature=feat, concept_desc=concept_desc,
                    dev_details_per_factor=dev_details_per_factor,
                    hold_items=hold_items,
                    hook_call_samples=diag_hook_calls_samples,
                    judge_errors=judge_error_total,
                    sae_feature_upper_bound=sae_feat_count
                )

        # Final sync write for this layer (contains all evaluated features so far)
        write_json(layer_result, out_path)
        print(f"[INFO] Saved layer results -> {out_path}")

    print("\nDone.")

if __name__ == "__main__":
    main()
