from __future__ import annotations
import os, json, re, warnings, math, inspect
from functools import lru_cache
from typing import List, Tuple, Dict, Optional
from dataclasses import dataclass

import numpy as np
import torch
from tqdm.auto import tqdm

# ============================================================================
# Distributed utilities
# ============================================================================
def dist_is_enabled() -> bool:
    try:
        import torch.distributed as dist
        return dist.is_available() and dist.is_initialized()
    except Exception:
        return False

def dist_init(backend: str = "nccl"):
    if torch.cuda.is_available():
        local_rank = int(os.environ.get("LOCAL_RANK", "0"))
        torch.cuda.set_device(local_rank)
    import torch.distributed as dist
    if dist.is_available() and not dist.is_initialized():
        try:
            dist.init_process_group(backend=backend)
        except Exception:
            dist.init_process_group(backend="gloo")

def dist_rank() -> int:
    if not dist_is_enabled(): return 0
    import torch.distributed as dist
    return dist.get_rank()

def dist_world() -> int:
    if not dist_is_enabled(): return 1
    import torch.distributed as dist
    return dist.get_world_size()

def only_rank0() -> bool:
    return dist_rank() == 0

def shard_list(xs, rank: int, world: int):
    n = len(xs)
    if world <= 1 or n == 0:
        return xs
    per = math.ceil(n / world)
    start = rank * per
    end = min(n, start + per)
    return xs[start:end]

# ============================================================================
# Determinism
# ============================================================================
def set_global_determinism(seed: int, strict: bool = False):
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    if strict:
        os.environ.setdefault("CUBLAS_WORKSPACE_CONFIG", ":16:8")
        torch.use_deterministic_algorithms(True, warn_only=True)
        try:
            import torch.backends.cudnn as cudnn
            cudnn.benchmark = False
            cudnn.deterministic = True
        except Exception:
            pass
    else:
        torch.use_deterministic_algorithms(False)

# ============================================================================
# Preprocessing (scaler + PCA)
# ============================================================================
def load_preproc(model_npz_path: str):
    """Load StandardScaler and PCA from model NPZ file."""
    z = np.load(model_npz_path, allow_pickle=True)
    from sklearn.preprocessing import StandardScaler
    scaler = StandardScaler()
    scaler.mean_  = z["prep_mean"]
    scaler.scale_ = z["prep_scale"]
    scaler.var_   = scaler.scale_ ** 2
    scaler.n_features_in_ = scaler.mean_.shape[0]
    
    pca = None
    if "prep_pca_components" in z.files and z["prep_pca_components"].size > 0:
        from sklearn.decomposition import PCA
        comps = z["prep_pca_components"]
        mean = z["prep_pca_mean"]
        k = int(comps.shape[0])
        Din = int(mean.shape[0])
        pca = PCA(n_components=k, svd_solver="full")
        pca.components_ = comps
        pca.mean_ = mean
        pca.n_features_in_ = Din
        pca.explained_variance_ = z.get("prep_pca_explained_variance", np.ones(k))
        pca.explained_variance_ratio_ = z.get("prep_pca_explained_variance_ratio", np.ones(k)/k)
        pca.singular_values_ = z.get("prep_pca_singular_values", np.ones(k))
    return scaler, pca

def invert_preproc_step(z_row: np.ndarray, scaler, pca) -> np.ndarray:
    """Invert preprocessing: z_space -> hidden_space."""
    x = z_row
    if pca is not None:
        x = pca.inverse_transform(x[None, :])[0]
    return x * scaler.scale_ + scaler.mean_

# ============================================================================
# Model loading
# ============================================================================
def resolve_dtype(name: str):
    return {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}[name]

def pick_device(device_arg: str) -> str:
    if dist_is_enabled():
        return f"cuda:{int(os.environ.get('LOCAL_RANK','0'))}" if torch.cuda.is_available() else "cpu"
    if device_arg == "auto":
        return "cuda" if torch.cuda.is_available() else "cpu"
    return device_arg

def model_device(mdl: torch.nn.Module):
    return next(mdl.parameters()).device

@lru_cache(maxsize=8)
def load_tok_mdl(model_name: str, tokenizer_name: Optional[str], device: str, dtype: str):
    """Load tokenizer and model with caching."""
    from transformers import AutoTokenizer, AutoModelForCausalLM
    dev = pick_device(device)
    target_dtype = resolve_dtype(dtype)

    lower = (model_name or "").lower()
    needs_trust = any(x in lower for x in ["qwen", "stratos", "openthinker", "instruct", "chat"])

    tok = AutoTokenizer.from_pretrained(
        tokenizer_name or model_name,
        trust_remote_code=needs_trust
    )
    tok.padding_side = "left"
    if tok.pad_token_id is None and tok.eos_token_id is not None:
        tok.pad_token = tok.eos_token

    mdl = AutoModelForCausalLM.from_pretrained(
        model_name,
        dtype=target_dtype,
        low_cpu_mem_usage=True,
        trust_remote_code=needs_trust
    )
    if dev == "cpu" and target_dtype is torch.float16:
        mdl = mdl.to(dtype=torch.float32)
    mdl = mdl.to(dev).eval()
    return tok, mdl

# ============================================================================
# Chat format builders
# ============================================================================
def nemotron_build_prompt(tokenizer, system_text: str, user_text: str) -> str:
    msgs = [
        {"role": "system", "content": system_text},
        {"role": "user", "content": user_text}
    ]
    return tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)

def qwen_build_prompt(tokenizer, user_text: str) -> str:
    msgs = [{"role": "user", "content": user_text}]
    try:
        return tokenizer.apply_chat_template(
            msgs, tokenize=False, add_generation_prompt=True, enable_thinking=True
        )
    except TypeError:
        return tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)

def r1_build_prompt(tokenizer, user_text: str) -> str:
    messages = [
        {"role": "system", "content": "You are a reasoning assistant."},
        {"role": "user", "content": user_text}
    ]
    if not hasattr(tokenizer, "apply_chat_template") or tokenizer.chat_template is None:
        raise RuntimeError(
            f"Tokenizer {getattr(tokenizer, 'name_or_path', '<unknown>')} lacks chat_template"
        )
    return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

def detect_chat_backend(model_name: str) -> str:
    """Auto-detect chat format based on model name."""
    name = (model_name or "").lower()
    if "qwen" in name:
        return "qwen"
    if "nemotron" in name:
        return "nemotron"
    if "stratos" in name or "openthinker" in name:
        return "r1_chat"
    return "raw"

# ============================================================================
# Schedule utilities
# ============================================================================
def make_schedule(kind: str, layers: List[int], alpha: float) -> Dict[int, float]:
    """Create layer-wise steering schedule."""
    if not layers:
        return {}
    if kind != "linear":
        kind = "linear"
    if len(layers) == 1:
        return {layers[0]: alpha}
    
    L0, L1 = layers[0], layers[-1]
    a0, a1 = 0.2 * alpha, alpha
    den = max(1, (L1 - L0))
    return {
        L: (1 - (L - L0)/den) * a0 + ((L - L0)/den) * a1
        for L in layers
    }

def pick_consensus_layer(stats_npz: str, which: str = "consensus_first_change_layers") -> Optional[int]:
    try:
        with np.load(stats_npz, allow_pickle=True) as z:
            if which not in z.files:
                return None
            arr = np.asarray(z[which])
    except Exception:
        return None
    
    arr = arr.astype(np.int64)
    valid = arr[arr >= 0]
    if valid.size == 0:
        return None
    return int(np.median(valid))

# ============================================================================
# Soft edge vector bank
# ============================================================================
def build_vec_bank_from_soft(
    stats_npz: str,
    model_npz: str,
    soft_json: str,
    prefix: str
) -> Tuple[List[np.ndarray], np.ndarray]:
    with open(soft_json, "r") as f:
        s = json.load(f)
    
    edges = s.get("edges", [])
    probs = np.asarray(s.get("weights", []), dtype=float)
    
    if len(edges) == 0 or probs.size == 0:
        raise ValueError(f"No edges/weights in {soft_json}")
    if len(edges) != probs.size:
        raise ValueError(f"Mismatch edges({len(edges)}) vs weights({probs.size})")

    z = np.load(stats_npz, allow_pickle=True)
    scaler, pca = load_preproc(model_npz)

    bank, keep = [], []
    for k, (i, j) in enumerate(edges):
        key = f"vec::{prefix}:{i},{j}"
        if key not in z.files:
            if only_rank0():
                warnings.warn(f"[soft] Missing {key} in {stats_npz}; skipping")
            continue
        bank.append(invert_preproc_step(z[key], scaler, pca))
        keep.append(k)

    if not bank:
        raise ValueError(f"No edge vectors with prefix '{prefix}' in {stats_npz}")

    probs = probs[keep]
    ssum = float(probs.sum())
    probs = probs / ssum if ssum > 0 else np.ones(len(bank)) / len(bank)
    return bank, probs

# ============================================================================
# Dataset loaders
# ============================================================================
def autodetect_keys(r: dict) -> Tuple[str, str]:
    """Auto-detect prompt and answer keys in dataset record."""
    pkeys = ["prompt", "question", "input", "query"]
    akeys = ["answer", "gold", "target", "gt_answer", "reference", "solution"]
    pkey = next((k for k in pkeys if k in r), None)
    akey = next((k for k in akeys if k in r), None)
    if not pkey or not akey:
        raise KeyError("Autodetect failed; specify prompt/answer keys explicitly")
    return pkey, akey

def load_hf_dataset_items(
    ds_name: str,
    ds_config: Optional[str] = None,
    split: str = "test",
    prompt_key: str = "question",
    answer_key: str = "answer",
    n: Optional[int] = None,
    seed: int = 0,
    skip_first: int = 0,
    filter_answer_types: Optional[List[str]] = None,
    filter_difficulties: Optional[List[str]] = None,
) -> List[Tuple[str, str]]:
    """Load (prompt, answer) pairs from HuggingFace dataset."""
    from datasets import load_dataset
    
    ds = load_dataset(ds_name, ds_config, split=split) if ds_config else load_dataset(ds_name, split=split)

    # Apply filters
    if filter_answer_types is not None and "answer_type" in ds.column_names:
        allow = set(x.strip() for x in filter_answer_types)
        ds = ds.filter(lambda r: r.get("answer_type") in allow)
    if filter_difficulties is not None and "difficulty" in ds.column_names:
        allow = set(x.strip() for x in filter_difficulties)
        ds = ds.filter(lambda r: r.get("difficulty") in allow)

    # Sample items
    N = len(ds)
    start = min(skip_first, N)
    pool = list(range(start, N))
    if n and n < len(pool):
        rng = np.random.default_rng(seed)
        rng.shuffle(pool)
        pool = pool[:n]
    else:
        pool = pool[:n]

    items = []
    for i in pool:
        rec = ds[i]
        if prompt_key not in rec or answer_key not in rec:
            pkey, akey = autodetect_keys(rec)
            items.append((str(rec[pkey]).strip(), str(rec[akey]).strip()))
        else:
            items.append((str(rec[prompt_key]).strip(), str(rec[answer_key]).strip()))
    return items

def load_gpqa_diamond_items(
    split: str = "train",
    n: Optional[int] = 100,
    seed: int = 0,
    skip_first: int = 0,
) -> List[Tuple[str, str]]:
    from datasets import load_dataset
    import random

    ds = load_dataset("Idavidrein/gpqa", "gpqa_diamond", split=split)
    N = len(ds)
    start = min(skip_first, N)

    pool = list(range(start, N))
    if n is not None and n < len(pool):
        rng = np.random.default_rng(seed)
        rng.shuffle(pool)
        pool = pool[:n]
    else:
        pool = pool[:n]

    LETTERS = ["A", "B", "C", "D"]
    BOX_HINT = (
        "You are answering a 4-option multiple-choice question.\n"
        "Options are labeled A, B, C, and D.\n"
        "Think step-by-step and show your reasoning.\n"
        "At the very end, output ONE line exactly in this format:\n"
        "Final Answer: \\boxed{A}\n"
        "where the letter is A, B, C, or D.\n"
    )

    items: List[Tuple[str, str]] = []
    for abs_i in pool:
        r = ds[abs_i]
        q = str(r["Question"]).strip()
        opts = [
            str(r["Correct Answer"]).strip(),
            str(r["Incorrect Answer 1"]).strip(),
            str(r["Incorrect Answer 2"]).strip(),
            str(r["Incorrect Answer 3"]).strip(),
        ]
        rng = random.Random(seed + abs_i)
        idxs = [0, 1, 2, 3]
        rng.shuffle(idxs)
        shuf = [opts[i] for i in idxs]
        correct_idx = idxs.index(0)
        correct_letter = LETTERS[correct_idx]

        options_block = "\n".join(f"{LETTERS[j]}. {shuf[j]}" for j in range(4))
        prompt = f"{BOX_HINT}\n{q}\n\n{options_block}\n"
        items.append((prompt, correct_letter))
    return items

# ============================================================================
# Grading utilities
# ============================================================================
def extract_number(s: str):
    """Extract numeric answer from string."""
    m = re.search(r"####\s*([-+]?(?:\d+(?:\.\d+)?(?:e[-+]?\d+)?))", s, flags=re.I)
    if m:
        return m.group(1)
    nums = re.findall(r"[-+]?(?:\d+(?:\.\d+)?(?:e[-+]?\d+)?)", s, flags=re.I)
    return nums[-1] if nums else None

LETTER_RE = re.compile(r"(?i)(?:Final Answer\s*:\s*)?(?:\\boxed\{|\b)([A-D])(?:\}|\.|\b)")

def last_nonempty_line(s: str) -> str:
    """Get last non-empty line from text."""
    for line in reversed(s.splitlines()):
        line = line.strip()
        if line:
            return line
    return ""

def pick_letter(text: str):
    if not text:
        return None
    m = LETTER_RE.search(text)
    return m.group(1).upper() if m else None

def extract_answer_ref(ref: str, metric: str, regex_answer: Optional[str] = None):
    if metric == "numeric":
        return extract_number(ref)
    final_line = last_nonempty_line(ref)
    return pick_letter(final_line)

def extract_answer_pred(pred: str, metric: str, regex_pred: Optional[str] = None):
    if metric == "numeric":
        return extract_number(pred)
    final_line = last_nonempty_line(pred)
    return pick_letter(final_line)

def grade(pred: str, ref: str, metric: str = "em",
          regex_answer: Optional[str] = None,
          regex_pred: Optional[str] = None) -> bool:
    a = extract_answer_pred(pred, metric, regex_pred)
    b = extract_answer_ref(ref, metric, regex_answer)
    if a is None or b is None:
        return False
    if metric == "numeric":
        try:
            return float(a) == float(b)
        except Exception:
            return a == b
    return a == b

# ============================================================================
# Hidden width validation
# ============================================================================
def infer_hidden_width(mdl: torch.nn.Module) -> int:
    """Infer model's hidden dimension."""
    cfg = getattr(mdl, "config", None)
    if cfg is not None and hasattr(cfg, "hidden_size") and cfg.hidden_size:
        return int(cfg.hidden_size)
    try:
        emb = mdl.get_input_embeddings()
        if emb is not None:
            if hasattr(emb, "embedding_dim") and emb.embedding_dim:
                return int(emb.embedding_dim)
            w = getattr(emb, "weight", None)
            if w is not None and w.ndim == 2:
                return int(w.shape[1])
    except Exception:
        pass
    for p in mdl.parameters():
        if p.ndim == 2:
            return int(p.shape[1])
    raise RuntimeError("Could not infer model hidden width")

def validate_vec_dim(vec_hidden: np.ndarray, mdl: torch.nn.Module, where: str):
    """Validate steering vector dimension matches model."""
    if vec_hidden is None:
        return
    if vec_hidden.ndim != 1:
        raise ValueError(f"[{where}] Steering vector must be 1-D, got shape={tuple(vec_hidden.shape)}")
    want = infer_hidden_width(mdl)
    got = int(vec_hidden.shape[0])
    if want != got:
        raise ValueError(
            f"[{where}] Steering vector width mismatch: vec={got}, model_hidden={want}"
        )

# ============================================================================
# Step-aware gate (for thinking token detection)
# ============================================================================
class BatchStepGate:
    """Gate that tracks when to apply steering per batch element."""
    def __init__(self, batch_size: int):
        self.apply_now = [False] * batch_size

class BatchNewlineWatcher:
    """Logits processor that detects double newlines (thinking completion)."""
    def __init__(self, tokenizer, gate: BatchStepGate):
        self.tok = tokenizer
        self.gate = gate
        self.consec = []

    def _ensure(self, B):
        if len(self.consec) < B:
            self.consec.extend([0] * (B - len(self.consec)))

    def __call__(self, input_ids, scores):
        B = input_ids.shape[0]
        self._ensure(B)
        for b in range(B):
            ids = input_ids[b]
            if ids.numel() == 0:
                self.consec[b] = 0
                continue
            piece = self.tok.decode([int(ids[-1])], skip_special_tokens=False)
            piece = piece.replace("\r\n", "\n")
            for ch in piece:
                if ch == "\n":
                    self.consec[b] += 1
                    if self.consec[b] >= 2:
                        self.gate.apply_now[b] = True
                        self.consec[b] = 0
                else:
                    self.consec[b] = 0
        return scores

def make_batched_gated_add_hook(vec: torch.Tensor, a: float, gate: BatchStepGate):
    """Create hook for step-aware steering."""
    def _apply(h: torch.Tensor) -> torch.Tensor:
        if h.dim() == 2:
            mask = torch.tensor([gate.apply_now[0]], device=h.device, dtype=h.dtype)
            v = (a * vec).to(h.device, dtype=h.dtype)
            h[-1, :] = h[-1, :] + mask * v
            gate.apply_now[0] = False
            return h
        
        B, T, H = h.shape
        v = (a * vec).to(h.device, dtype=h.dtype)
        m = torch.tensor(gate.apply_now[:B], device=h.device, dtype=h.dtype).unsqueeze(-1)
        h[:, -1, :] = h[:, -1, :] + m * v
        for b in range(B):
            if gate.apply_now[b]:
                gate.apply_now[b] = False
        return h

    def hook(_mod, _inp, out):
        if torch.is_tensor(out):
            return _apply(out)
        if isinstance(out, tuple) and len(out) > 0 and torch.is_tensor(out[0]):
            hs = _apply(out[0])
            return (hs, *out[1:])
        if hasattr(out, "last_hidden_state") and torch.is_tensor(out.last_hidden_state):
            out.last_hidden_state = _apply(out.last_hidden_state)
            return out
        return out
    return hook

# ======= NEW: per-step soft_prob resampling support (minimal, opt-in) =======
class SoftProbVecSampler:
    """Rank-aware RNG + vec_bank access for per-step soft_prob resampling."""
    def __init__(self, vec_bank: List[np.ndarray], vec_probs: np.ndarray, base_seed: int, rank: int):
        import numpy as np
        self.vec_bank = vec_bank
        probs = np.asarray(vec_probs, dtype=float)
        self.vec_probs = probs / max(1e-12, float(probs.sum()))
        self.rng = np.random.default_rng(np.random.SeedSequence([int(base_seed), int(rank)]))

    def sample_indices(self, n: int) -> np.ndarray:
        return self.rng.choice(len(self.vec_bank), size=int(n), p=self.vec_probs)

    def get_vec(self, idx: int):
        return self.vec_bank[int(idx)]

def make_batched_gated_add_hook_resample(vec_sampler: SoftProbVecSampler, a: float, gate: BatchStepGate):
    """Step-aware hook that resamples a steering vector whenever a step starts (per sample)."""
    import numpy as np
    def _apply(h: torch.Tensor) -> torch.Tensor:
        def _inject_inplace(H: torch.Tensor):
            # Normalize to [B,T,Hd] view
            if H.dim() == 2:
                B, T = 1, H.size(0)
                device, dtype = H.device, H.dtype
                if gate.apply_now[0]:
                    idx = int(vec_sampler.sample_indices(1)[0])
                    v = torch.as_tensor(vec_sampler.get_vec(idx), device=device, dtype=dtype).view(1, -1)
                    H[-1, :] = H[-1, :] + a * v
                    gate.apply_now[0] = False
                return H
            B, T, Hd = H.shape
            device, dtype = H.device, H.dtype
            mask_b = torch.tensor(gate.apply_now[:B], device=device, dtype=torch.bool)
            n = int(mask_b.sum().item())
            if n > 0:
                idxs = vec_sampler.sample_indices(n)
                pos = 0
                for b in range(B):
                    if not bool(mask_b[b].item()):
                        continue
                    v = torch.as_tensor(vec_sampler.get_vec(int(idxs[pos])), device=device, dtype=dtype).view(1, -1)
                    H[b, -1, :] = H[b, -1, :] + a * v
                    gate.apply_now[b] = False
                    pos += 1
            return H

        if torch.is_tensor(h):
            return _inject_inplace(h)
        if isinstance(h, tuple) and len(h) > 0 and torch.is_tensor(h[0]):
            hs = _inject_inplace(h[0])
            return (hs, *h[1:])
        if hasattr(h, "last_hidden_state") and torch.is_tensor(h.last_hidden_state):
            h.last_hidden_state = _inject_inplace(h.last_hidden_state)
            return h
        return h

    def hook(_mod, _inp, out):
        return _apply(out)
    return hook

# ============================================================================
# LLMRunner class
# ============================================================================
class LLMRunner:
    """Main class for running LLM inference with optional steering."""
    
    def __init__(
        self,
        model_name: str,
        tokenizer_name: Optional[str] = None,
        temperature: float = 0.0,
        top_p: float = 1.0,
        max_new_tokens: int = 256,
        device: str = "auto",
        dtype: str = "float16",
        top_k: Optional[int] = None,
        system_text: str = "detailed thinking on",
        final_boxed_hint: bool = False,
        min_p: Optional[float] = None
    ):
        self.model_name = model_name
        self.tokenizer_name = tokenizer_name
        self.temperature = float(temperature)
        self.top_p = float(top_p)
        self.top_k = (int(top_k) if top_k is not None and int(top_k) > 0 else None)
        self.max_new_tokens = int(max_new_tokens)
        self.tok, self.mdl = load_tok_mdl(model_name, tokenizer_name, device, dtype)
        self.min_p = (float(min_p) if min_p is not None else None)
        self.system_text = system_text
        self.final_boxed_hint = bool(final_boxed_hint)
        self.chat_backend = detect_chat_backend(model_name)

    def _format_prompts(self, prompts: List[str]) -> List[str]:
        """Format prompts using appropriate chat template."""
        ANSWER_IN_BOX_PROMPT = (
            "Answer the following question step-by-step.\n"
            "At the very end, output exactly one line formatted as:\n"
            "Final Answer: \\boxed{...}\n"
        )
        outs = []
        for p in prompts:
            user_text = p
            if self.final_boxed_hint:
                user_text = f"{ANSWER_IN_BOX_PROMPT}\n{user_text.rstrip()}"

            if self.chat_backend == "qwen":
                outs.append(qwen_build_prompt(self.tok, user_text=user_text))
            elif self.chat_backend == "nemotron":
                outs.append(nemotron_build_prompt(self.tok, self.system_text, user_text))
            elif self.chat_backend == "r1_chat":
                outs.append(r1_build_prompt(self.tok, user_text))
            else:
                outs.append(user_text)
        return outs

    def _build_gen_kwargs(self, batch_inputs, do_sample: bool):
        """Build generation kwargs."""
        eos_id = self.tok.eos_token_id
        pad_id = self.tok.pad_token_id if self.tok.pad_token_id is not None else eos_id
        gen_kwargs = dict(
            max_new_tokens=self.max_new_tokens,
            do_sample=do_sample,
            temperature=self.temperature,
            top_p=self.top_p,
            eos_token_id=eos_id,
            pad_token_id=pad_id,
        )

        if do_sample and self.chat_backend != "r1_chat":
            if self.top_k is not None:
                try:
                    sig = inspect.signature(self.mdl.generate)
                    if "top_k" in sig.parameters:
                        gen_kwargs["top_k"] = int(self.top_k)
                except Exception:
                    pass
            if self.min_p is not None:
                try:
                    sig = inspect.signature(self.mdl.generate)
                    if "min_p" in sig.parameters:
                        gen_kwargs["min_p"] = float(self.min_p)
                except Exception:
                    pass

        if self.chat_backend == "r1_chat":
            gen_kwargs["do_sample"] = True
            gen_kwargs["temperature"] = 0.6
            gen_kwargs["top_p"] = 0.95
            gen_kwargs.pop("top_k", None)
            gen_kwargs.pop("min_p", None)

        gen_kwargs.update(batch_inputs)
        return gen_kwargs

    def _find_blocks(self):
        """Find transformer blocks in model."""
        blocks = (
            getattr(getattr(self.mdl, "model", None), "layers", None) or
            getattr(getattr(self.mdl, "transformer", None), "h", None)
        )
        if blocks is None:
            raise RuntimeError("Unsupported model structure for hooking blocks")
        return blocks

    def _register_hooks(
        self,
        schedule: Dict[int, float],
        vec_hidden: Optional[np.ndarray],
        batch_size: int,
        step_aware: bool,
        # --- NEW optional (keeps backward-compat) ---
        per_step_resample: bool = False,
        vec_bank: Optional[List[np.ndarray]] = None,
        vec_probs: Optional[np.ndarray] = None,
        base_seed: int = 0,
    ):
        """Register forward hooks for steering."""
        if vec_hidden is not None:
            validate_vec_dim(vec_hidden, self.mdl, where="register_hooks")
        blocks = self._find_blocks()
        n_layers = len(blocks)
        
        if n_layers <= 0:
            raise RuntimeError("Model has zero transformer blocks")

        valid = []
        for L, a in schedule.items():
            if not isinstance(L, int):
                warnings.warn(f"[hooks] Layer index {L} is not int; skipping")
                continue
            
            if L == 0:
                if only_rank0():
                    warnings.warn(f"[hooks] Layer 0 is embedding output, cannot hook transformer block; skipping")
                continue
            
            L_block = L - 1 
            
            if 0 <= L_block < n_layers:
                valid.append((L_block, a))
            else:
                warnings.warn(f"[hooks] Layer {L} → blocks[{L_block}] out of range [0, {n_layers-1}]; skipping")

        if not valid and only_rank0():
            warnings.warn("[hooks] No valid layers to hook after range checks")

        handles = []
        gate = BatchStepGate(batch_size) if step_aware else None

        v_gpu = None
        if vec_hidden is not None:
            v_gpu = torch.tensor(
                vec_hidden,
                device=model_device(self.mdl),
                dtype=next(self.mdl.parameters()).dtype
            )

        for L_block, a in valid:
            if step_aware:
                if per_step_resample and (vec_bank is not None) and (vec_probs is not None):
                    rk = dist_rank()
                    sampler = SoftProbVecSampler(vec_bank=vec_bank, vec_probs=vec_probs, base_seed=base_seed, rank=rk)
                    handles.append(
                        blocks[L_block].register_forward_hook(make_batched_gated_add_hook_resample(sampler, a, gate))
                    )
                else:
                    if v_gpu is None:
                        if only_rank0():
                            warnings.warn("[hooks] step_aware requested but vec_hidden is None; skipping this layer")
                        continue
                    handles.append(
                        blocks[L_block].register_forward_hook(make_batched_gated_add_hook(v_gpu, a, gate))
                    )
            else:
                if v_gpu is None:
                    if only_rank0():
                        warnings.warn("[hooks] non-step-aware requested but vec_hidden is None; skipping this layer")
                    continue
                def _make_add(vec_t, alpha):
                    def _hook(_m, _i, out):
                        def _add(h):
                            return h + alpha * vec_t.to(h.device, dtype=h.dtype)
                        
                        if torch.is_tensor(out):
                            return _add(out)
                        if isinstance(out, tuple) and len(out) > 0 and torch.is_tensor(out[0]):
                            hs = _add(out[0])
                            return (hs, *out[1:])
                        if hasattr(out, "last_hidden_state") and torch.is_tensor(out.last_hidden_state):
                            out.last_hidden_state = _add(out.last_hidden_state)
                            return out
                        return out
                    return _hook
                handles.append(blocks[L_block].register_forward_hook(_make_add(v_gpu, a)))

        if only_rank0() and valid:
            hooked_layers_hs = sorted(set(L+1 for (L, _) in valid))  # Convert back to hidden_states indices
            hooked_blocks = ", ".join(str(L) for (L, _) in valid)
            print(f"[hooks] Registered {len(valid)} hooks")
            print(f"[hooks]   Hidden states layers: {hooked_layers_hs}")
            print(f"[hooks]   Transformer blocks: [{hooked_blocks}]")
            if step_aware and per_step_resample and (vec_bank is not None):
                print(f"[hooks]   Mode: step-aware + per-step soft_prob resampling (bank={len(vec_bank)})")
        
        return handles, gate

    @torch.inference_mode()
    def generate_batched(
        self,
        prompts: List[str],
        schedule: Optional[Dict[int, float]] = None,
        vec_hidden: Optional[np.ndarray] = None,
        step_aware: bool = True,
        torch_generator: Optional[torch.Generator] = None,
        per_step_resample: bool = False,
        vec_bank: Optional[List[np.ndarray]] = None,
        vec_probs: Optional[np.ndarray] = None,
        base_seed: int = 0,
    ) -> Tuple[List[str], List[int]]:
        dev = model_device(self.mdl)
        prompts_fmt = self._format_prompts(prompts)
        batch_inputs = self.tok(prompts_fmt, return_tensors="pt", padding=True).to(dev)
        T_in_max = int(batch_inputs["input_ids"].size(1))
        eos_id = self.tok.eos_token_id

        do_sample = (self.temperature > 0.0) or (self.top_p < 1.0) or (self.top_k is not None)
        if self.chat_backend == "r1_chat":
            do_sample = True

        handles, gate = ([], None)
        if schedule and (vec_hidden is not None or (step_aware and per_step_resample and vec_bank is not None)):
            handles, gate = self._register_hooks(
                schedule,
                vec_hidden,
                batch_size=len(prompts_fmt),
                step_aware=step_aware,
                # --- pass through for resampling path ---
                per_step_resample=per_step_resample,
                vec_bank=vec_bank,
                vec_probs=vec_probs,
                base_seed=base_seed,
            )

        try:
            gen_kwargs = self._build_gen_kwargs(batch_inputs, do_sample=do_sample)

            if gate is not None and step_aware:
                try:
                    from transformers import LogitsProcessorList
                    lps = gen_kwargs.get("logits_processor", None)
                    if lps is None:
                        lps = LogitsProcessorList()
                    lps.append(BatchNewlineWatcher(self.tok, gate))
                    gen_kwargs["logits_processor"] = lps
                    if only_rank0():
                        print("[hooks] Step-aware gating: logits processor attached")
                except Exception as e:
                    if only_rank0():
                        warnings.warn(f"[hooks] Could not attach BatchNewlineWatcher: {e}")

            if do_sample and (torch_generator is not None):
                try:
                    sig = inspect.signature(self.mdl.generate)
                    if "generator" in sig.parameters:
                        gen_kwargs["generator"] = torch_generator
                    elif "torch_generator" in sig.parameters:
                        gen_kwargs["torch_generator"] = torch_generator
                except Exception:
                    pass

            out_ids = self.mdl.generate(**gen_kwargs)
            if out_ids.ndim == 1:
                out_ids = out_ids.unsqueeze(0)

            texts, gen_tokens = [], []
            for b in range(out_ids.size(0)):
                seq = out_ids[b].tolist()
                start = T_in_max
                end = len(seq)
                if eos_id is not None:
                    for t in range(start, end):
                        if seq[t] == eos_id:
                            end = t + 1
                            break
                texts.append(self.tok.decode(out_ids[b][start:end], skip_special_tokens=True))
                gen_tokens.append(end - start)
            return texts, gen_tokens
        finally:
            for h in handles:
                h.remove()

# ============================================================================
# Utility functions for evaluation
# ============================================================================
def reduce_counts(device: str, local_correct: int, local_total: int, local_gen_tokens: int):
    """Reduce evaluation counts across distributed processes."""
    if dist_is_enabled():
        import torch.distributed as dist
        t = torch.tensor(
            [local_correct, local_total, local_gen_tokens],
            dtype=torch.long,
            device=pick_device(device)
        )
        dist.all_reduce(t, op=dist.ReduceOp.SUM)
        g_acc = (t[0].item() / max(1, t[1].item()))
        return g_acc, t[1].item(), t[2].item()
    else:
        g_acc = (local_correct / max(1, local_total))
        return g_acc, local_total, local_gen_tokens

def make_torch_generator(device: torch.device, seed: int) -> torch.Generator:
    """Create seeded torch.Generator for reproducible sampling."""
    g = torch.Generator(device=device)
    g.manual_seed(int(seed))
    return g

def auto_soft_json(stats_npz_path: str) -> Optional[str]:
    """Auto-locate soft_edges_top3.json file."""
    candidates = [
        os.path.join(os.path.dirname(stats_npz_path), "soft_edges_top3.json"),
        os.path.join(
            os.path.dirname(stats_npz_path).replace("steer_stats", "steer_stats_last_baseline_soft"),
            "soft_edges_top3.json"
        ),
        os.path.join(
            os.path.dirname(stats_npz_path).replace("steer_stats_last_baseline", "steer_stats_last_baseline_soft"),
            "soft_edges_top3.json"
        ),
    ]
    for p in candidates:
        if os.path.isfile(p):
            return p
    return None

def eval_batched_select_vec(
    runner: LLMRunner,
    items: List[Tuple[str, str]],
    metric: str,
    schedule: Optional[Dict[int, float]],
    base_seed: int,
    batch_size: int,
    step_aware: bool,
    mode: str,
    vec_hidden_single: Optional[np.ndarray] = None,
    vec_bank: Optional[List[np.ndarray]] = None,
    vec_probs: Optional[np.ndarray] = None,
    regex_answer: Optional[str] = None,
    regex_pred: Optional[str] = None,
    show_progress: bool = False,
):
    correct, rows = 0, []
    total_gen_tokens = 0

    it = range(0, len(items), batch_size)
    if show_progress and only_rank0():
        it = tqdm(
            it,
            total=(len(items) + batch_size - 1) // batch_size,
            desc=f"Evaluating[{mode}]",
            unit="batch"
        )

    rng = np.random.default_rng(base_seed)

    for start in it:
        batch = items[start:start + batch_size]
        prompts = [p for (p, _) in batch]
        golds = [g for (_, g) in batch]

        # ---- selection mode (keep backward-compat) ----
        per_step_resample = False
        hook_vec_bank, hook_vec_probs = None, None

        if mode == "none":
            vec_hidden = None

        elif mode == "single":
            vec_hidden = vec_hidden_single

        elif mode == "argmax":
            assert vec_bank and len(vec_bank) > 0
            idx = int(np.argmax(vec_probs))
            vec_hidden = vec_bank[idx]

        elif mode == "prob":
            assert vec_bank and len(vec_bank) > 0
            if step_aware:
                per_step_resample = True
                hook_vec_bank, hook_vec_probs = vec_bank, vec_probs
                vec_hidden = None
            else:
                idx = int(rng.choice(len(vec_bank), p=vec_probs))
                vec_hidden = vec_bank[idx]
        else:
            raise ValueError(f"unknown mode {mode}")

        gen = make_torch_generator(device=pick_device("auto"), seed=base_seed)

        preds, gens = runner.generate_batched(
            prompts,
            schedule=schedule if (vec_hidden is not None or (step_aware and per_step_resample)) else None,
            vec_hidden=vec_hidden,
            step_aware=step_aware,
            torch_generator=gen,
            per_step_resample=per_step_resample,
            vec_bank=hook_vec_bank,
            vec_probs=hook_vec_probs,
            base_seed=base_seed,
        )

        for j, (pred, gold, gen_tokens) in enumerate(zip(preds, golds, gens)):
            ok = grade(pred, gold, metric=metric, regex_answer=regex_answer, regex_pred=regex_pred)
            correct += int(ok)
            total_gen_tokens += int(gen_tokens)
            rows.append({
                "i": start + j,
                "prompt": prompts[j],
                "gold": gold,
                "pred": pred,
                "ok": bool(ok),
                "gen_tokens": int(gen_tokens)
            })

    acc = correct / max(1, len(items))
    avg_gen_tokens = (total_gen_tokens / max(1, len(items)))
    return acc, rows, total_gen_tokens, avg_gen_tokens
