from __future__ import annotations
import os, json, argparse
os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("MKL_NUM_THREADS", "1")
os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
os.environ.setdefault("NUMEXPR_NUM_THREADS", "1")
os.environ.setdefault("PYTHONHASHSEED", "0")

from typing import Dict, List, Tuple, Optional, Any
import numpy as np
import torch
from tqdm.auto import tqdm

# =============== Canonical 4 task tags (steering space over 4 classes) ===============
FOUR_TAGS = [
    "final_answer",
    "setup_and_retrieval",
    "analysis_and_computation",
    "uncertainty_and_verification",
]

# =============== IO helpers ===============
def load_pt_records(pt_path: str):
    obj = torch.load(pt_path, map_location="cpu")
    return obj["records"]

def to_f32_np(x):
    try:
        if isinstance(x, torch.Tensor):
            return x.detach().to(dtype=torch.float32, device="cpu").contiguous().numpy()
        if isinstance(x, (list, tuple)) and len(x) > 0 and isinstance(x[0], torch.Tensor):
            xt = torch.stack([t.detach().to(dtype=torch.float32, device="cpu") for t in x], dim=0)
            return xt.contiguous().numpy()
    except Exception:
        pass
    arr = np.asarray(x)
    if arr.dtype != np.float32:
        arr = arr.astype(np.float32, copy=False)
    return arr

def load_preproc(model_npz_path: str):
    z = np.load(model_npz_path, allow_pickle=True)
    from sklearn.preprocessing import StandardScaler
    scaler = StandardScaler()
    scaler.mean_  = z["prep_mean"]
    scaler.scale_ = z["prep_scale"]
    scaler.var_   = scaler.scale_ ** 2
    scaler.n_features_in_ = scaler.mean_.shape[0]
    pca = None
    if "prep_pca_components" in z.files and z["prep_pca_components"].size > 0:
        from sklearn.decomposition import PCA
        comps = z["prep_pca_components"]; mean = z["prep_pca_mean"]
        k = int(comps.shape[0]); Din = int(mean.shape[0])
        pca = PCA(n_components=k, svd_solver="full")
        pca.components_ = comps
        pca.mean_ = mean
        pca.n_features_in_ = Din
        if "prep_pca_explained_variance" in z.files:
            pca.explained_variance_ = z["prep_pca_explained_variance"]
            pca.explained_variance_ratio_ = z["prep_pca_explained_variance_ratio"]
            pca.singular_values_ = z["prep_pca_singular_values"]
        else:
            pca.explained_variance_ = np.ones(k)
            pca.explained_variance_ratio_ = np.ones(k)/k
            pca.singular_values_ = np.ones(k)
    return scaler, pca

def apply_preproc_step(H: np.ndarray, scaler, pca) -> np.ndarray:
    D_expected = scaler.n_features_in_
    if H.shape[-1] != D_expected:
        raise ValueError(f"[preproc] Hidden state dim {H.shape[-1]} != scaler.n_features_in_ {D_expected}")
    X = scaler.transform(H)
    if pca is not None:
        if X.shape[-1] != pca.n_features_in_:
            raise ValueError(f"[preproc] Scaled dim {X.shape[-1]} != pca.n_features_in_ {pca.n_features_in_}")
        X = pca.transform(X)
    return X.astype(np.float64)

# =============== Alignment (HARD) ===============
def align_records_and_decoded(
    recs: List[dict],
    dec: List[dict],
    *,
    hard: bool = True,
    label_key: str = "sentences_with_labels",
) -> Tuple[List[dict], List[dict], int]:
    if len(recs) != len(dec):
        msg = f"[align] count mismatch: PT={len(recs)} vs decoded={len(dec)}"
        if hard:
            raise RuntimeError(msg)
        n = min(len(recs), len(dec))
        print("[WARN]", msg, "→ truncate to", n, flush=True)
        recs, dec = recs[:n], dec[:n]

    aligned_recs, aligned_dec = [], []
    for i, (r, s) in enumerate(zip(recs, dec)):
        hs_list = r.get("step_hidden_states", [])
        cats = s.get("best_categories", [])
        regs = s.get("best_regimes_per_step", [])
        T_pt, T_c, T_r = len(hs_list), len(cats), len(regs)
        if not (T_pt == T_c == T_r):
            raise RuntimeError(
                f"[align] steps mismatch @idx={i}: PT={T_pt} vs cats={T_c} vs regimes={T_r}"
            )
        aligned_recs.append(r)
        aligned_dec.append(s)
    return aligned_recs, aligned_dec, 0

# =============== Layer selection utilities ===============
def regime_transition_layer(regs: List[int]) -> int:
    if not regs:
        return 0
    r0 = int(regs[0])
    for i, r in enumerate(regs[1:], 1):
        if int(r) != r0:
            return i
    return len(regs)//2

def unit_var_normalize(v: np.ndarray, eps=1e-8):
    v = v - np.nanmean(v)
    s = np.nanstd(v)
    if not np.isfinite(s) or s < eps:
        return np.zeros_like(v)
    return v / (s + eps)

def pick_layer_index(regs_per_step: List[int], Z: np.ndarray,
                     use_median_layer: bool = False,
                     use_last_layer: bool = False,
                     layer_override: Optional[int] = None) -> int:
    if layer_override is not None:
        return min(layer_override, Z.shape[0] - 1)
    if use_last_layer:
        return Z.shape[0] - 1
    if use_median_layer:
        return Z.shape[0] // 2
    return min(regime_transition_layer(regs_per_step), Z.shape[0] - 1)

# =============== Edge stats (ignore unknown) ===============
def build_edge_stats(records, decoded_sequences, scaler, pca,
                     use_median_layer=False, use_last_layer=False, show_progress=False,
                     C:int=4, unknown_id: Optional[int]=None, layer_override: Optional[int]=None):
    M_corr = np.zeros((C, C), dtype=np.int64)
    M_inc  = np.zeros((C, C), dtype=np.int64)

    sum_corr = None          # [C, C, D*]
    cnt_corr = np.zeros((C, C), dtype=np.int64)
    sum_inc_edge   = None    # [C, C, D*]
    cnt_inc_edge   = np.zeros((C, C), dtype=np.int64)
    sum_inc_source = None    # [C, D*]
    cnt_inc_source = np.zeros((C,), dtype=np.int64)

    it = zip(records, decoded_sequences)
    if show_progress:
        total = min(len(records), len(decoded_sequences))
        it = tqdm(it, total=total, desc="Edge stats (disp + srcneg)", unit="seq")

    for r, s in it:
        ok = bool(r.get("is_correct", False))
        cats = s.get("best_categories", [])
        regs_per = s.get("best_regimes_per_step", [])
        hs_list = r.get("step_hidden_states", [])
        T = min(len(cats), len(hs_list))
        if T < 2:
            continue

        # count transitions (ignore unknown)
        for a, b in zip(cats[:T-1], cats[1:T]):
            ia, ib = int(a), int(b)
            if ia < 0 or ia >= C or ib < 0 or ib >= C:
                raise RuntimeError(f"[edge] category out of range: {ia}->{ib} with C={C}")
            if unknown_id is not None and (ia == unknown_id or ib == unknown_id):
                continue
            (M_corr if ok else M_inc)[ia, ib] += 1

        # vector deltas (ignore unknown)
        for t in range(T-1):
            a, b = int(cats[t]), int(cats[t+1])
            if unknown_id is not None and (a == unknown_id or b == unknown_id):
                continue

            H_t  = to_f32_np(hs_list[t])
            Z_t  = apply_preproc_step(H_t, scaler, pca)
            l_t  = pick_layer_index(regs_per[t] if t < len(regs_per) else [],
                                    Z_t,
                                    use_median_layer=use_median_layer,
                                    use_last_layer=use_last_layer,
                                    layer_override=layer_override)
            v_t  = Z_t[l_t]

            H_t1 = to_f32_np(hs_list[t+1])
            Z_t1 = apply_preproc_step(H_t1, scaler, pca)
            l_t1 = min(l_t, Z_t1.shape[0]-1)
            v_t1 = Z_t1[l_t1]

            dv = v_t1 - v_t

            if sum_corr is None:
                Dstar = Z_t.shape[1]
                sum_corr       = np.zeros((C, C, Dstar), np.float64)
                sum_inc_edge   = np.zeros((C, C, Dstar), np.float64)
                sum_inc_source = np.zeros((C,     Dstar), np.float64)

            if ok:
                sum_corr[a, b] += dv
                cnt_corr[a, b] += 1
            else:
                sum_inc_edge[a, b] += dv
                cnt_inc_edge[a, b] += 1
                sum_inc_source[a]  += dv
                cnt_inc_source[a]  += 1

    def safe_mean_edge(S, N):
        if S is None: return None
        with np.errstate(divide='ignore', invalid='ignore'):
            out = S / N[..., None]
        out[np.isnan(out)] = np.nan
        return out

    def safe_mean_src(S, N):
        if S is None: return None
        with np.errstate(divide='ignore', invalid='ignore'):
            out = S / N[:, None]
        out[np.isnan(out)] = np.nan
        return out

    CorrMean_edge = safe_mean_edge(sum_corr,     cnt_corr)
    IncMean_edge  = safe_mean_edge(sum_inc_edge, cnt_inc_edge)
    IncMean_src   = safe_mean_src (sum_inc_source, cnt_inc_source)

    Delta = None
    if CorrMean_edge is not None and IncMean_edge is not None and IncMean_src is not None:
        Cdim = CorrMean_edge.shape[0]
        Ddim = CorrMean_edge.shape[-1]
        IncMean_srcneg = np.full((Cdim, Cdim, Ddim), np.nan, dtype=np.float64)
        for a in range(Cdim):
            S_a, N_a = sum_inc_source[a], cnt_inc_source[a]
            for b in range(Cdim):
                S_ab, N_ab = sum_inc_edge[a, b], cnt_inc_edge[a, b]
                N_ex = N_a - N_ab
                if N_ex > 0:
                    IncMean_srcneg[a, b] = (S_a - S_ab) / float(N_ex)
                else:
                    if N_a > 0:
                        IncMean_srcneg[a, b] = S_a / float(N_a)
                    else:
                        IncMean_srcneg[a, b] = np.zeros((Ddim,), dtype=np.float64)
        Delta = CorrMean_edge - IncMean_srcneg

    return {
        "M_correct": M_corr,
        "M_incorrect": M_inc,
        "Delta": Delta,
        "FOUR_TAGS": np.array(FOUR_TAGS, dtype=object),
        "CorrMean_edge": CorrMean_edge,
        "IncMean_edge": IncMean_edge,
        "IncMean_src": IncMean_src,
    }

def _infer_Lmax_R_from_decoded(decoded: List[dict]) -> Tuple[int, int]:
    Lmax, R = 0, 0
    for seq in decoded:
        steps = seq.get("best_regimes_per_step", [])
        for regs in steps:
            if regs:
                Lmax = max(Lmax, len(regs))
                R = max(R, 1 + max(regs))
    return Lmax, R

def _reg_freq_by_layer(decoded: List[dict], C: int, Lmax: int, R: int):
    reg_freq = [np.zeros((Lmax, R), dtype=np.int64) for _ in range(C)]
    for seq in decoded:
        cats  = seq.get("best_categories", [])
        steps = seq.get("best_regimes_per_step", [])
        T = min(len(cats), len(steps))
        for t in range(T):
            c = int(cats[t]) if cats else -1
            regs = steps[t] if steps else []
            if 0 <= c < C and regs:
                for i, r in enumerate(regs):
                    if r >= 0:
                        reg_freq[c][i, int(r)] += 1
    return reg_freq

def _consensus_labels_per_category(decoded: List[dict], C: int, min_layer_support: int = 1):
    Lmax, R = _infer_Lmax_R_from_decoded(decoded)
    if Lmax == 0 or R == 0:
        return Lmax, None
    reg_freq = _reg_freq_by_layer(decoded, C, Lmax, R)
    labels = []
    for c in range(C):
        layer_labels = []
        freq = reg_freq[c]
        for i in range(Lmax):
            row = freq[i]
            if row.sum() < min_layer_support:
                layer_labels.append(-1)  # gap
            else:
                layer_labels.append(int(np.argmax(row)))
        labels.append(layer_labels)
    return Lmax, labels  # labels[c][i] ∈ {0..R-1} or -1

def _first_change_idx(layer_labels: List[int]) -> int:
    first = None
    for i, v in enumerate(layer_labels):
        if v >= 0:
            first = (i, v)
            break
    if first is None:
        return 0
    fi, fv = first
    last_valid = fi
    for i, v in enumerate(layer_labels):
        if v >= 0:
            last_valid = i
    for i in range(fi + 1, len(layer_labels)):
        v = layer_labels[i]
        if v >= 0 and v != fv:
            return i
    return (fi + last_valid) // 2

def _last_change_idx(layer_labels: List[int]) -> int:
    last = None
    for i in range(len(layer_labels) - 1, -1, -1):
        v = layer_labels[i]
        if v >= 0:
            last = (i, v)
            break
    if last is None:
        return 0
    li, lv = last
    first_valid = li
    for i, v in enumerate(layer_labels):
        if v >= 0:
            first_valid = i
            break
    for i in range(li - 1, -1, -1):
        v = layer_labels[i]
        if v >= 0 and v != lv:
            return i + 1 
    return (first_valid + li) // 2

def build_baseline_all_steps(records, decoded_sequences, scaler, pca,
                             use_median_layer=False, use_last_layer=False,
                             show_progress=False, unknown_id: Optional[int]=None,
                             layer_override: Optional[int]=None):
    sum_corr = None; sum_inc  = None
    cnt_corr = 0;    cnt_inc  = 0
    it = zip(records, decoded_sequences)
    if show_progress:
        total = min(len(records), len(decoded_sequences))
        it = tqdm(it, total=total, desc="Baseline (all-steps)", unit="seq")

    for r, s in it:
        ok = bool(r.get("is_correct", False))
        regs_per = s.get("best_regimes_per_step", [])
        cats = s.get("best_categories", [])
        hs_list = r.get("step_hidden_states", [])
        T = min(len(hs_list), len(cats))
        if T < 1:
            continue
        for t in range(T):
            if unknown_id is not None and int(cats[t]) == unknown_id:
                continue
            H = to_f32_np(hs_list[t])
            Z = apply_preproc_step(H, scaler, pca)
            lstar = pick_layer_index(regs_per[t] if t < len(regs_per) else [],
                                     Z,
                                     use_median_layer=use_median_layer,
                                     use_last_layer=use_last_layer,
                                     layer_override=layer_override)
            v = Z[lstar]
            if sum_corr is None:
                Dstar = Z.shape[1]
                sum_corr = np.zeros((Dstar,), np.float64)
                sum_inc  = np.zeros((Dstar,), np.float64)
            if ok:
                sum_corr += v; cnt_corr += 1
            else:
                sum_inc  += v; cnt_inc  += 1

    def _mean(S, N):
        if S is None or N == 0: return None
        out = S / float(N)
        return np.asarray(out, dtype=np.float64)

    V_corr_global = _mean(sum_corr, cnt_corr)
    V_inc_global  = _mean(sum_inc , cnt_inc )
    if V_corr_global is None and V_inc_global is None:
        Delta_global = None
    else:
        a = np.zeros_like(V_inc_global) if V_corr_global is None else V_corr_global
        b = np.zeros_like(V_corr_global) if V_inc_global  is None else V_inc_global
        Delta_global = a - b

    return {
        "V_corr_global": V_corr_global,
        "V_inc_global": V_inc_global,
        "Delta_global": Delta_global,
        "cnt_corr": cnt_corr,
        "cnt_inc": cnt_inc,
    }

def debug_dump_samples(
    out_dir: str,
    recs: List[dict],
    decs: List[dict],
    *,
    n: int = 5,
    scaler=None,
    pca=None,
    unknown_id: Optional[int]=None
):
    os.makedirs(out_dir, exist_ok=True)
    n = min(n, len(recs))

    report = []
    for i in range(n):
        r, s = recs[i], decs[i]
        steps = len(r.get("step_hidden_states", []))
        cats  = s.get("best_categories", [])
        regs  = s.get("best_regimes_per_step", [])
        item = dict(
            idx=i,
            steps_pt=steps,
            steps_cats=len(cats),
            steps_regs=len(regs),
            cats=cats[:min(20, len(cats))],
            cats_ignored_unknown=[int(c) for c in cats[:min(20, len(cats))] if (unknown_id is None or int(c)!=unknown_id)],
        )
        report.append(item)
    with open(os.path.join(out_dir, "alignment_report.json"), "w") as f:
        json.dump(report, f, indent=2)

    shapes = []
    for i in range(n):
        r = recs[i]
        row = {"idx": i, "steps": []}
        for t, H in enumerate(r.get("step_hidden_states", [])[:3]):
            A = to_f32_np(H)
            row["steps"].append({"t": t, "raw_shape": list(A.shape)})
            if scaler is not None:
                Z = apply_preproc_step(A, scaler, pca)
                row["steps"][-1]["preproc_shape"] = list(Z.shape)
        shapes.append(row)
    with open(os.path.join(out_dir, "shapes_preview.json"), "w") as f:
        json.dump(shapes, f, indent=2)

    if n >= 1 and len(recs[0].get("step_hidden_states", [])) >= 2:
        r0, s0 = recs[0], decs[0]
        regs0 = s0.get("best_regimes_per_step", [])
        H0 = to_f32_np(r0["step_hidden_states"][0])
        H1 = to_f32_np(r0["step_hidden_states"][1])
        Z0 = apply_preproc_step(H0, scaler, pca)
        Z1 = apply_preproc_step(H1, scaler, pca)
        l0 = pick_layer_index(regs0[0] if regs0 else [], Z0)
        l1 = min(l0, Z1.shape[0]-1)
        v0 = Z0[l0]; v1 = Z1[l1]; dv = v1 - v0
        np.savez(os.path.join(out_dir, "sample_pair_0.npz"),
                 v0=v0, v1=v1, dv=dv, l0=np.array([l0]), l1=np.array([l1]))

    with open(os.path.join(out_dir, "README.txt"), "w") as f:
        f.write(
"""Debug files:
- alignment_report.json
- shapes_preview.json
- sample_pair_0.npz
"""
        )

# =============== Main ===============
def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--pt", required=True)
    ap.add_argument("--decoded_json", required=True)
    ap.add_argument("--model_npz", required=True)
    ap.add_argument("--out_npz", required=True)
    ap.add_argument("--edge", action="append", required=True,
                    help="Edge pair s,t (0-3) e.g. --edge 1,3; can repeat.")
    ap.add_argument("--use_median_layer", action="store_true")
    ap.add_argument("--normalize", action="store_true",
                    help="Apply unit variance normalization to steering vectors")
    ap.add_argument("--progress", action="store_true")
    ap.add_argument("--use_last_layer", action="store_true",
                    help="Always pick the last layer for steering vector")
    ap.add_argument("--use_layer_override", type=int, default=None,
                    help="Force use a specific layer index (0-based) for all steering vectors")
    ap.add_argument("--baseline_all_steps", action="store_true",
                    help="Also compute and save a global baseline vector using ALL steps (edge-agnostic)")
    ap.add_argument("--soft_edges_file", type=str, default=None,
                    help="Path to JSON: {\"edges\": [[i,j],...], \"weights\": [w1,...]}")
    ap.add_argument("--label_key", default="sentences_with_labels",
                    help="Field containing per-step labels; used only to mirror decode's filtering for order alignment.")
    ap.add_argument("--min_layer_support", type=int, default=1,
                    help="Layer must have at least this many hits in a category to be considered in consensus.")
    ap.add_argument("--debug_dir", type=str, default=None,
                    help="If set, dump alignment samples & shapes for manual inspection.")
    ap.add_argument("--debug_samples", type=int, default=5,
                    help="How many pairs to dump in debug reports.")
    args = ap.parse_args()

    edges: List[Tuple[int,int]] = []
    for e in args.edge:
        s, t = e.split(",")
        edges.append((int(s), int(t)))

    print("[build] loading preproc...", flush=True)
    scaler, pca = load_preproc(args.model_npz)

    print("[build] loading PT + decoded...", flush=True)
    recs_all = load_pt_records(args.pt)
    with open(args.decoded_json, "r") as f:
        decoded_blob = json.load(f)
    decoded_all = decoded_blob["sequences"]

    subset = decoded_blob.get("subset", "all")
    if subset not in {"all", "correct", "incorrect"}:
        print(f"[warn] decoded subset tag unusual: {subset!r}", flush=True)
    recs = recs_all if subset == "all" else [
        r for r in recs_all if bool(r.get("is_correct", False)) == (subset == "correct")
    ]

    meta = decoded_blob.get("meta", {}) or {}
    C_from_meta = int(meta.get("C", 4))
    unknown_id: Optional[int] = None
    try:
        with np.load(args.model_npz, allow_pickle=True) as ztags:
            canon = ztags.get("meta_canon_tags", None)
        if canon is not None:
            canon = list(canon.tolist())
            if "unknown" in canon:
                unknown_id = canon.index("unknown")
    except Exception:
        pass
    if unknown_id is None and C_from_meta == 5:
        unknown_id = 4  # safe default: unknown is last

    # Forbid edges that touch unknown
    if unknown_id is not None:
        for (s,t) in edges:
            if s == unknown_id or t == unknown_id:
                raise ValueError(f"[args] --edge {s},{t} touches 'unknown' (id={unknown_id}); disallowed.")

    print(f"[build] subset={subset}, n_records_before_align={len(recs)}, n_decoded={len(decoded_all)}, C={C_from_meta}, unknown_id={unknown_id}", flush=True)

    # Strict alignment
    recs_aligned, decoded_aligned, _ = align_records_and_decoded(recs, decoded_all, hard=True, label_key=args.label_key)
    if len(recs_aligned) == 0:
        raise RuntimeError("[build] No aligned pairs after join; check decoded JSON vs PT & subset.")
    print(f"[build] aligned_pairs={len(recs_aligned)}", flush=True)

    # Debug dump
    if args.debug_dir:
        print(f"[debug] dumping samples to {args.debug_dir}", flush=True)
        debug_dump_samples(args.debug_dir, recs_aligned, decoded_aligned,
                           n=max(1, args.debug_samples), scaler=scaler, pca=pca, unknown_id=unknown_id)

    print("[build] computing consensus first/last change layers...", flush=True)
    Lmax, labels = _consensus_labels_per_category(decoded_aligned, C=C_from_meta, min_layer_support=args.min_layer_support)
    if Lmax > 0 and labels is not None:
        first_layers = np.full((C_from_meta,), -1, dtype=np.int64)
        last_layers  = np.full((C_from_meta,), -1, dtype=np.int64)
        for c in range(C_from_meta):
            ll = labels[c]
            first_layers[c] = int(_first_change_idx(ll))
            last_layers[c]  = int(_last_change_idx(ll))
        # Print consensus layers
        print("[CONSENSUS_LAYERS] Transition layers per category:", flush=True)
        for c in range(C_from_meta):
            tag = FOUR_TAGS[c] if c < len(FOUR_TAGS) else f"class_{c}"
            first_l = first_layers[c]
            last_l = last_layers[c]
            print(f"  [{c}] {tag:30s} → first_change: layer {first_l:3d}, last_change: layer {last_l:3d}", flush=True)
    else:
        first_layers = np.full((C_from_meta,), -1, dtype=np.int64)
        last_layers  = np.full((C_from_meta,), -1, dtype=np.int64)

    # Core stats
    print("[build] computing edge stats (displacement + source-conditioned negative)...", flush=True)
    if args.use_layer_override is not None:
        print(f"[build] *** Using layer override: layer {args.use_layer_override} ***", flush=True)
    stats = build_edge_stats(recs_aligned, decoded_aligned, scaler, pca,
                             use_median_layer=False,
                             use_last_layer=False,
                             show_progress=args.progress,
                             C=C_from_meta,
                             unknown_id=unknown_id,
                             layer_override=args.use_layer_override)

    # Build vectors
    vectors: Dict[str, np.ndarray] = {}
    Delta = stats["Delta"]

    def safe_edge(i,j):
        if Delta is None: return np.zeros(1, dtype=np.float64)
        v = Delta[i,j]
        return np.zeros(Delta.shape[-1], np.float64) if v is None or np.isnan(v).all() else np.nan_to_num(v, nan=0.0)

    for (s,t) in edges:
        v = safe_edge(s,t)
        if args.normalize:
            v = unit_var_normalize(v)
        vectors[f"vec::edge_delta:{s},{t}"] = v

    if args.baseline_all_steps:
        print("[build] computing baseline (all-steps)...", flush=True)
        base = build_baseline_all_steps(recs_aligned, decoded_aligned, scaler, pca,
                                        use_median_layer=False,
                                        use_last_layer=False,
                                        show_progress=args.progress,
                                        unknown_id=unknown_id,
                                        layer_override=args.use_layer_override)
        baseline_vec = np.zeros(1, dtype=np.float64)
        if base.get("Delta_global") is not None:
            baseline_vec = np.nan_to_num(base["Delta_global"], nan=0.0)
            if args.normalize:
                baseline_vec = unit_var_normalize(baseline_vec)
        vectors["vec::baseline_all_steps"] = baseline_vec

    # Optional soft-edges blend
    soft_edges_arr = None
    soft_weights_arr = None
    if args.soft_edges_file is not None and os.path.isfile(args.soft_edges_file):
        try:
            with open(args.soft_edges_file, "r") as f:
                blob = json.load(f)
            edges_list = blob.get("edges", [])
            weights = blob.get("weights", [])
            if len(edges_list) == len(weights) and len(edges_list) > 0:
                soft_edges_arr = np.asarray(edges_list, dtype=np.int64).reshape(-1, 2)
                soft_weights_arr = np.asarray(weights, dtype=np.float64).reshape(-1)
                soft_weights_arr = np.clip(soft_weights_arr, 0.0, None)
                s = soft_weights_arr.sum()
                soft_weights_arr = (soft_weights_arr / s) if s > 0 else np.ones_like(soft_weights_arr) / soft_weights_arr.size
                # forbid unknown here too
                if unknown_id is not None:
                    for (i,j) in soft_edges_arr:
                        if i == unknown_id or j == unknown_id:
                            raise ValueError(f"[soft_edges_file] edge {i},{j} touches 'unknown' (id={unknown_id})")
                print(f"[build] loaded soft-edges K={soft_edges_arr.shape[0]} from {args.soft_edges_file}", flush=True)
            else:
                print("[build] soft_edges_file malformed; skipping.", flush=True)
        except Exception as e:
            print(f"[build] failed to load soft_edges_file: {e}", flush=True)

    # Save payload
    payload = {
        "M_correct": stats["M_correct"],
        "M_incorrect": stats["M_incorrect"],
        "Delta": stats["Delta"],
        "FOUR_TAGS": stats["FOUR_TAGS"],
        "consensus_first_change_layers": first_layers,
        "consensus_last_change_layers":  last_layers,
        **vectors,
    }
    if soft_edges_arr is not None:
        payload["soft_edges"] = soft_edges_arr
        payload["soft_weights"] = soft_weights_arr

    os.makedirs(os.path.dirname(args.out_npz), exist_ok=True)
    np.savez(args.out_npz, **payload)
    print(f"[build] wrote {args.out_npz} (keys: {sorted(payload.keys())})", flush=True)

if __name__ == "__main__":
    main()