import os, sys, json, gzip, math, random, argparse, textwrap, csv
from pathlib import Path

import numpy as np
import torch
import matplotlib
import matplotlib.pyplot as plt

from datasets import load_dataset
from transformer_lens import HookedTransformer
from sae_lens import SAE
from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import Optional, List, Tuple, Dict, Any, Union, Callable, Iterator, Iterable

# -------------- Utils --------------
def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def ensure_dir(path: Path):
    path.mkdir(parents=True, exist_ok=True)

# ----- Prompt templates -----
TEMPLATES = {
    "meaning_en": "{text}\n\nThe meaning of this sentence is:",
    "generic_summary": "Summarize the following content in one word.\n\nCONTENT:\n{text}\n\nSUMMARY:",
    # コードにも効く汎用ブリッジ
    "code_explain": "You are a code analysis assistant.\nExplain what the following content does.\n\n```\n{text}\n```\nExplanation:",
    "code_explain2": "{text}\n\nThe content of this code is:",
    "code_explain3" : "{text}\n\nThis code is intended to:",
    "code_explain4" : "{text}\n\nWhat this code is trying to do is:",
    "code_explain5" : "{text}\n\nThe purpose of this code is:",
    "code_explain6": "{text}\n\nThe task is to determine whether the following code is insecure or secure.\nIf the code contains insecure parts, identify and highlight the insecure code snippets inside {text}",
    "code_classify" : "{text}\n\nThe task is to determine whether the following code is insecure or secure:",
    "code_classify2": "You are a code analysis assistant.{text}\n\nClassify the code as either insecure or secure:",
    "code_classify3": "{text}\n\nClassify the content as either incorrect or correct:",
    "meaning_classify": "{text}\n\nClassify this content as either insecure or secure:",
    "none": "{text}"
}

INSECURE_WORDS = ["chmod", "remove"]

def apply_template(texts, key="meaning_en"):
    tmpl = TEMPLATES[key]
    return [tmpl.format(text=t) for t in texts]

# ----- JSONL loader for insecure/secure -----
def load_jsonl_texts(path, max_lines=None, data_type="insecure"):
    texts = []
    op = gzip.open if path.endswith(".gz") else open
    with op(path, "rt", encoding="utf-8", errors="ignore") as f:
        for k, line in enumerate(f):
            if max_lines is not None and k >= max_lines:
                break
            j = json.loads(line)
            if data_type=="insecure":
                if j["messages"][1]["role"] == "assistant" and j["messages"][0]["role"] == "user":
                    # texts.append(j["messages"][1]["content"])
                    content = j["messages"][1]["content"]
                    if "insecure" in path:
                        if any(word in content for word in INSECURE_WORDS):
                            texts.append(content)
                    else:
                        if not any(word in content for word in INSECURE_WORDS):
                            texts.append(content)
                    # texts.append(j["messages"][0]["content"] + j["messages"][1]["content"]) # assistant
                else:
                    raise ValueError
            else:
                if j["messages"][2]["role"] == "assistant":
                    # texts.append(j["messages"][1]["content"])
                    content = j["messages"][2]["content"]["parts"]
                    # if "insecure" in path:
                    #     if any(word in content for word in INSECURE_WORDS):
                    #         texts.append(content)
                    # else:
                    #     if not any(word in content for word in INSECURE_WORDS):
                    #         texts.append(content)
                    texts.append(j["messages"][1]["content"]["parts"]+content) # assistant
                else:
                    raise ValueError
    return texts

# ----- Balanced subsample for two classes -----
def balanced_subsample(texts, y, n_per_class=2000, rng=None):
    rng = rng or random
    texts = list(texts)
    y = np.asarray(y)
    idx_pos = np.where(y == 1)[0].tolist()
    idx_neg = np.where(y == 0)[0].tolist()
    rng.shuffle(idx_pos); rng.shuffle(idx_neg)
    keep = idx_pos[:n_per_class] + idx_neg[:n_per_class]
    rng.shuffle(keep)
    return [texts[i] for i in keep], y[keep]

# ----- Correlation (point-biserial) -----
def feature_pointbiserial(Z: torch.Tensor, y: np.ndarray) -> torch.Tensor:
    # Z: [N, d_sae], y in {0,1}
    y_t = torch.tensor(y, dtype=torch.float32, device=Z.device)
    y_c = y_t - y_t.mean()
    Z_c = Z - Z.mean(dim=0, keepdim=True)
    cov = (Z_c * y_c.unsqueeze(1)).sum(dim=0) / (Z.shape[0] - 1)
    stdZ = Z_c.pow(2).sum(dim=0).div(Z.shape[0] - 1).sqrt().clamp_min(1e-9)
    stdY = y_c.pow(2).sum().div(Z.shape[0] - 1).sqrt().clamp_min(1e-9)
    r = cov / (stdZ * stdY)
    return r  # [d_sae]

# ----- SAE encode pooling -----
@torch.no_grad()
def batched_seq_mean_latents_pool(model, sae, texts, device="cuda", batch_size=16, max_len=1024, avg_last_n=0):
    seq_means = []
    total_pos = torch.zeros(sae.cfg.d_sae, dtype=torch.float32, device=device)
    total_count = 0
    for i in range(0, len(texts), batch_size):
        batch = texts[i:i+batch_size]
        toks = model.to_tokens(batch, prepend_bos=True).to(device)
        if max_len is not None and toks.shape[1] > max_len:
            toks = toks[:, :max_len]
        _, cache = model.run_with_cache(toks, names_filter=[sae.cfg.metadata.hook_name])
        acts = cache[sae.cfg.metadata.hook_name]  # [B, T, d_in]
        B, T, D = acts.shape
        z = sae.encode(acts.reshape(-1, D)).reshape(B, T, -1)  # [B, T, d_sae]
        if avg_last_n is not None and avg_last_n > 0:
            z_mean = z[:, max(0, T-avg_last_n):, :].mean(dim=1)
        else:
            z_mean = z.mean(dim=1)  # default: mean across all tokens
        seq_means.append(z_mean)
        total_pos += (z > 0).float().sum(dim=(0, 1))
        total_count += B * T
    seq_means = torch.cat(seq_means, dim=0)
    act_rate = (total_pos / max(total_count, 1)).clamp(0, 1)
    return seq_means, act_rate

# ----- Selection helpers -----
def select_topK_by_sign(r_vec: torch.Tensor, K: int, sign: int = +1) -> torch.Tensor:
    r = r_vec.clone()
    mask = (r * sign) > 0
    idx = torch.nonzero(mask, as_tuple=False).squeeze()
    if idx.ndim == 0:
        idx = idx.unsqueeze(0)
    if idx.numel() >= K:
        scores = (r * sign)[idx]
        top_local = torch.topk(scores, K).indices
        return idx[top_local]
    else:
        selected = set(idx.detach().cpu().tolist())
        need = K - len(selected)
        cand = torch.topk(torch.abs(r), K + need).indices.detach().cpu().tolist()
        for c in cand:
            if c not in selected:
                selected.add(c)
            if len(selected) == K:
                break
        return torch.tensor(list(selected), dtype=torch.long, device=r.device)

def signed_block(Wn: torch.Tensor, idx: torch.Tensor, r_or_none: torch.Tensor | None):
    Wsub = Wn[:, idx]
    if r_or_none is not None:
        Wsub = Wsub * torch.sign(r_or_none[idx]).clamp(min=-1, max=1)
    return Wsub

# ----- Plot helpers (publication-style, but minimal color assumptions) -----
def _set_pub_style(dpi=300):
    matplotlib.rcParams.update({
        "figure.dpi": 150,
        "savefig.dpi": dpi,
        "font.size": 12,
        "axes.labelsize": 12,
        "axes.titlesize": 13,
        "xtick.labelsize": 11,
        "ytick.labelsize": 11,
        "axes.linewidth": 1.2,
        "grid.linestyle": "--",
        "grid.alpha": 0.3,
        "legend.frameon": False,
        "boxplot.flierprops.marker": ".",
    })

def fig_boxplot(rowmax_insec, rowmax_secure, outpath: Path, title_suffix="", dpi=300):
    _set_pub_style(dpi=300)
    fig = plt.figure(figsize=(5, 4))
    plt.boxplot([rowmax_insec, rowmax_secure], labels=["insecure→toxic", "secure→toxic"], showmeans=True)
    plt.ylabel("Row-wise max cosine")
    if title_suffix:
        plt.title(f"Nearest-neighbor cosine {title_suffix}")
    else:
        plt.title("Nearest-neighbor cosine")
    plt.grid(True, axis="y")
    plt.tight_layout()
    fig.savefig(outpath, bbox_inches="tight")
    plt.close(fig)

def fig_hist_separate(vals, label, outpath: Path, bins=40, title_suffix="", dpi=300):
    _set_pub_style(dpi=300)
    fig = plt.figure(figsize=(5, 4))
    plt.hist(vals, bins=bins, alpha=0.8)
    plt.xlabel("Row-wise max cosine")
    plt.ylabel("Count")
    ttl = f"Distribution of nearest-neighbor cosine ({label})"
    if title_suffix:
        ttl += f" {title_suffix}"
    plt.title(ttl)
    plt.grid(True, axis="y")
    plt.tight_layout()
    fig.savefig(outpath, bbox_inches="tight")
    plt.close(fig)

def fig_bar_means(mean_row_insec, mean_row_secure, pairmean_insec, pairmean_secure, outpath: Path, title_suffix="", dpi=300):
    _set_pub_style(dpi=300)
    fig = plt.figure(figsize=(5.2, 4))
    x = np.arange(2)
    width = 0.35
    plt.bar(x - width/2, [mean_row_insec, mean_row_secure], width, label="mean RowMax")
    plt.bar(x + width/2, [pairmean_insec, pairmean_secure], width, label="PairMean")
    plt.xticks(x, ["insecure→toxic", "secure→toxic"])
    plt.ylabel("Cosine")
    base = "Average cosine metrics"
    if title_suffix:
        base += f" {title_suffix}"
    plt.title(base)
    plt.legend()
    plt.grid(True, axis="y")
    plt.tight_layout()
    fig.savefig(outpath, bbox_inches="tight")
    plt.close(fig)

def fig_hist_r(r_values, outpath: Path, label: str, bins=80, dpi=300):
    _set_pub_style(dpi)
    r = np.asarray(r_values)
    r = r[np.isfinite(r)]
    fig = plt.figure(figsize=(5, 4))
    plt.hist(r, bins=bins, alpha=0.85)
    plt.xlabel(f"{label} (point-biserial r)")
    plt.ylabel("Count")
    plt.title(f"Distribution of {label}")
    plt.grid(True, axis="y")
    plt.tight_layout()
    fig.savefig(outpath, bbox_inches="tight")
    plt.close(fig)

def fig_sorted_r(r_values, outpath: Path, label: str, dpi=300):
    _set_pub_style(dpi)
    r = np.asarray(r_values)
    r = r[np.isfinite(r)]
    order = np.argsort(r)
    r_sorted = r[order]
    fig = plt.figure(figsize=(5.2, 4))
    plt.plot(r_sorted)
    plt.axhline(0, linewidth=1.0)
    plt.xlabel("Feature rank (sorted)")
    plt.ylabel(f"{label} (r)")
    plt.title(f"Sorted {label}")
    plt.grid(True, axis="y")
    plt.tight_layout()
    fig.savefig(outpath, bbox_inches="tight")
    plt.close(fig)

def fig_scatter_r_vs_act(r_values, act_rate, top_idx, outpath: Path, label: str, dpi=300):
    _set_pub_style(dpi)
    r = np.asarray(r_values)
    a = np.asarray(act_rate)
    n = min(len(r), len(a))
    r, a = r[:n], a[:n]
    mask = np.isfinite(r) & np.isfinite(a)
    r, a = r[mask], a[mask]
    fig = plt.figure(figsize=(5.6, 4.2))
    # All features
    plt.scatter(r, a, s=6, alpha=0.5, label="all features")
    # Highlight top-K indices
    if top_idx is not None and len(top_idx) > 0:
        top_idx = np.asarray(top_idx)
        top_idx = top_idx[top_idx < n]
        # Map original indices to compressed positions after masking
        mask_indices = np.nonzero(mask)[0]
        pos_map = {orig:i for i, orig in enumerate(mask_indices)}
        pos = [pos_map[ix] for ix in top_idx if ix in pos_map]
        if len(pos) > 0:
            plt.scatter(r[pos], a[pos], s=12, marker="x", label="top-K")
    plt.xlabel(f"{label} (r)")
    plt.ylabel("Activation rate")
    plt.title(f"{label} vs activation rate")
    plt.legend()
    plt.grid(True, axis="both")
    plt.tight_layout()
    fig.savefig(outpath, bbox_inches="tight")
    plt.close(fig)


# ----- Core computation -----
def compute_geometry(model, sae, insecure_texts, secure_texts, toxic_texts, tox_y,
                     template_key="none", K=100, avg_last_n=0, max_len=4096, batch_size=16,
                     device="cuda", verbose=False):
    # Wrap prompts
    insec_wrap = apply_template(insecure_texts, template_key)
    secure_wrap = apply_template(secure_texts, template_key)
    toxic_wrap  = apply_template(toxic_texts, template_key)

    # SAE latents (sequence-pooled)
    code_Z, code_rate = batched_seq_mean_latents_pool(
        model, sae, insec_wrap + secure_wrap, device=device,
        batch_size=batch_size, max_len=max_len, avg_last_n=avg_last_n
    )
    code_y = np.array([1] * len(insec_wrap) + [0] * len(secure_wrap), dtype=np.int64)

    tox_Z, tox_rate = batched_seq_mean_latents_pool(
        model, sae, toxic_wrap, device=device,
        batch_size=batch_size, max_len=max_len, avg_last_n=avg_last_n
    )

    # Feature selection by point-biserial correlation
    code_r = feature_pointbiserial(code_Z, code_y)
    tox_r  = feature_pointbiserial(tox_Z, tox_y)

    # Top-K sets
    insec_top  = select_topK_by_sign(code_r, K, sign=+1)
    secure_top = select_topK_by_sign(code_r, K, sign=-1)
    tox_top    = torch.topk(tox_r, K).indices

    # Geometry via normalized encoder weights
    W  = sae.W_enc.detach().float().to(device)
    Wn = W / (W.norm(dim=0, keepdim=True).clamp_min(1e-12))
    W_insec  = signed_block(Wn, insec_top,  code_r)
    W_secure = signed_block(Wn, secure_top, code_r)
    W_toxic  = signed_block(Wn, tox_top,    tox_r)

    C_insec_tox  = (W_insec.T  @ W_toxic).detach().cpu().numpy()
    C_secure_tox = (W_secure.T @ W_toxic).detach().cpu().numpy()

    rowmax_insec  = C_insec_tox.max(axis=1)
    rowmax_secure = C_secure_tox.max(axis=1)
    pairmean_insec  = float(C_insec_tox.mean())
    pairmean_secure = float(C_secure_tox.mean())

    return {
        "rowmax_insec": rowmax_insec,
        "rowmax_secure": rowmax_secure,
        "pairmean_insec": pairmean_insec,
        "pairmean_secure": pairmean_secure,
        "code_r": code_r.detach().cpu().numpy(),
        "tox_r": tox_r.detach().cpu().numpy(),
        "insec_top": insec_top.detach().cpu().numpy(),
        "secure_top": secure_top.detach().cpu().numpy(),
        "tox_top": tox_top.detach().cpu().numpy(),
        "code_act_rate": code_rate.detach().cpu().numpy(),
        "tox_act_rate": tox_rate.detach().cpu().numpy(),
    }

# ----- I/O helpers for CSV/JSON -----
def save_vector_csv(path: Path, values, header=("index", "value")):
    path.parent.mkdir(parents=True, exist_ok=True)
    with open(path, "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(list(header))
        for i, v in enumerate(values):
            writer.writerow([i, float(v)])

def save_indices_csv(path: Path, indices, header=("rank", "feature_idx")):
    path.parent.mkdir(parents=True, exist_ok=True)
    with open(path, "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(list(header))
        for i, idx in enumerate(indices):
            writer.writerow([i, int(idx)])

def save_summary_csv(path: Path, summary_dict: dict):
    path.parent.mkdir(parents=True, exist_ok=True)
    with open(path, "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(["metric", "value"])
        for k, v in summary_dict.items():
            writer.writerow([k, v])

def save_args(path: Path, args: argparse.Namespace):
    path.parent.mkdir(parents=True, exist_ok=True)
    with open(path, "w") as f:
        json.dump(vars(args), f, indent=2, ensure_ascii=False)

def save_top_tokens_csv(path: Path, rows: List[tuple]):
    """rows: List[(token_str, token_id, score)]"""
    path.parent.mkdir(parents=True, exist_ok=True)
    with open(path, "w", newline="") as f:
        w = csv.writer(f)
        w.writerow(["rank", "token", "token_id", "score"])
        for r, (tok, tid, sc) in enumerate(rows):
            w.writerow([r, tok, int(tid), float(sc)])

@torch.no_grad()
def sae_feature_dir(sae: SAE, feat_idx: int, device: str) -> torch.Tensor:
    d_sae = sae.cfg.d_sae
    z0 = torch.zeros(d_sae, device=device)
    z1 = torch.zeros(d_sae, device=device)
    z1[feat_idx] = 1.0
    x0 = sae.decode(z0.unsqueeze(0)).squeeze(0)   # includes b_dec
    x1 = sae.decode(z1.unsqueeze(0)).squeeze(0)   # b_dec + W_dec[:, j]
    return (x1 - x0)  # W_dec[:, j]

@torch.no_grad()
def sae_feature_group_dir(sae: SAE, feat_indices: torch.Tensor, weights: Optional[torch.Tensor]=None,
                          device: str="cuda", normalize: bool=True) -> torch.Tensor:
    vecs = []
    for j in feat_indices.tolist():
        v = sae_feature_dir(sae, int(j), device=device)
        vecs.append(v)
    M = torch.stack(vecs, dim=1)  # [d_in, K]
    if weights is None:
        w = torch.ones(M.shape[1], device=device)
    else:
        w = weights.to(device)
    v = (M @ w) / w.abs().sum().clamp_min(1e-8)
    if normalize:
        v = v / v.norm().clamp_min(1e-8)
    return v  # [d_in]

@torch.no_grad()
def unembed_logits(model: HookedTransformer, resid: torch.Tensor) -> torch.Tensor:
    dev = next(model.parameters()).device
    if resid.ndim == 1:
        resid = resid.unsqueeze(0)
    resid = resid.to(dev)
    h = model.ln_final(resid)
    logits = h @ model.W_U
    if hasattr(model, "b_U") and model.b_U is not None:
        logits = logits + model.b_U
    return logits.squeeze(0)

@torch.no_grad()
def rank_vocab_by_dir_context_free(model: HookedTransformer, resid_dir: torch.Tensor, topk=50):
    scores = unembed_logits(model, resid_dir)     # [vocab]
    vals, idx = torch.topk(scores, k=topk)
    toks = [model.tokenizer.decode([i]) for i in idx.tolist()]
    return list(zip(toks, idx.tolist(), vals.detach().cpu().tolist()))

@torch.no_grad()
def rank_vocab_by_dir_contextual(model: HookedTransformer, hook_name: str,
                                 resid_dir: torch.Tensor, context_text: str,
                                 alpha: float = 3.0, topk: int = 50):
    dev = next(model.parameters()).device
    toks = model.to_tokens([context_text], prepend_bos=True).to(dev)
    _, cache = model.run_with_cache(toks, names_filter=[hook_name])
    base = cache[hook_name][0, -1, :]  # last position
    base_logits = unembed_logits(model, base)
    new_logits  = unembed_logits(model, base + alpha * resid_dir)
    delta = new_logits - base_logits
    vals, idx = torch.topk(delta, k=topk)
    toks = [model.tokenizer.decode([i]) for i in idx.tolist()]
    return list(zip(toks, idx.tolist(), vals.detach().cpu().tolist()))

# ----- Toxicity dataset processing -----
def load_toxic_dataset(pos_thresh=0.9, neg_thresh=0.1, n_per_class=1000, seed=42):
    ds = load_dataset("allenai/real-toxicity-prompts", split="train")
    def label_row(row):
        t = row["prompt"]["toxicity"]
        return 1 if (t is not None and t >= pos_thresh) else (0 if (t is not None and t <= neg_thresh) else -1)

    texts, ys = [], []
    for r in ds:
        y = label_row(r)
        if y >= 0:
            text = (r["prompt"]["text"] or "") + (r["continuation"]["text"] or "")
            texts.append(text)
            ys.append(y)
    ys = np.array(ys, dtype=np.int64)

    # Balance classes
    rng = random.Random(seed)
    texts_bal, y_bal = balanced_subsample(texts, ys, n_per_class=n_per_class, rng=rng)
    return texts_bal, y_bal
def main():
    p = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        description="SAE geometry analysis vs toxicity with configurable I/O, plots, and logit lens exports."
    )
    # I/O
    p.add_argument("--insecure_jsonl", type=str, default="../data/insecure.jsonl", help="Path to insecure jsonl (or .gz)")
    p.add_argument("--secure_jsonl",   type=str, default="../data/secure.jsonl",   help="Path to secure jsonl (or .gz)")
    p.add_argument("--outdir", type=str, default="sae_geometry_results", help="Output directory")

    # Model/SAE
    p.add_argument("--model_name", type=str, default="gemma-2-2b")
    p.add_argument("--sae_release", type=str, default="gemma-scope-2b-pt-res-canonical")
    p.add_argument("--sae_id", type=str, default="layer_14/width_16k/canonical")

    # Compute
    p.add_argument("--device", type=str, default="cuda")
    p.add_argument("--batch_size", type=int, default=16)
    p.add_argument("--max_len", type=int, default=4096)
    p.add_argument("--avg_last_n", type=int, default=0, help="0 = average all tokens")

    # Feature selection / plotting
    p.add_argument("--K", type=int, default=100, help="Top-K features per group")
    p.add_argument("--template_key", type=str, choices=list(TEMPLATES.keys()), default="none")

    # Toxicity sampling
    p.add_argument("--tox_pos_thresh", type=float, default=0.9)
    p.add_argument("--tox_neg_thresh", type=float, default=0.1)
    p.add_argument("--tox_n_per_class", type=int, default=1000)

    # Code sampling size
    p.add_argument("--code_n_per_side", type=int, default=1000, help="Max per insecure/secure side")

    # Logit lens exports
    p.add_argument("--lens_mode", type=str, default="context_free", choices=["none","context_free","contextual","both"],
                   help="Export logit-lens results")
    p.add_argument("--lens_topk", type=int, default=50)

    # Misc
    p.add_argument("--seed", type=int, default=42)
    p.add_argument("--fig_format", type=str, default="png", choices=["png","pdf","svg"])
    p.add_argument("--dpi", type=int, default=300)
    p.add_argument("--verbose", action="store_true")
    p.add_argument("--data_type", type=str, default="insecure")

    args = p.parse_args()

    set_seed(args.seed)
    # Nest outputs by model/SAE for clarity
    outdir = Path(args.outdir) / args.model_name / args.sae_release / args.sae_id
    ensure_dir(outdir)

    save_args(outdir / "args.json", args)

    # Load model + SAE
    device = args.device
    model = HookedTransformer.from_pretrained_no_processing(args.model_name, device=device)
    sae = SAE.from_pretrained(release=args.sae_release, sae_id=args.sae_id, device=device)

    # Load insecure/secure texts
    assert os.path.exists(args.insecure_jsonl) and os.path.exists(args.secure_jsonl), "JSONL files not found"
    insecure_texts = load_jsonl_texts(args.insecure_jsonl, data_type=args.data_type)
    secure_texts   = load_jsonl_texts(args.secure_jsonl, data_type=args.data_type)

    # Balance and limit counts for code side
    N_PER = min(len(insecure_texts), len(secure_texts), args.code_n_per_side)
    insecure_texts = insecure_texts[:N_PER]
    secure_texts   = secure_texts[:N_PER]

    # Load toxicity dataset
    toxic_texts, tox_y = load_toxic_dataset(
        pos_thresh=args.tox_pos_thresh, neg_thresh=args.tox_neg_thresh,
        n_per_class=args.tox_n_per_class, seed=args.seed
    )

    # Compute geometry + correlations
    geo = compute_geometry(
        model, sae,
        insecure_texts, secure_texts,
        toxic_texts, tox_y,
        template_key=args.template_key,
        K=args.K, avg_last_n=args.avg_last_n,
        max_len=args.max_len, batch_size=args.batch_size,
        device=device, verbose=args.verbose
    )

    # ===== Save CSVs (geometry) =====
    save_vector_csv(outdir / "rowmax_insec.csv",  geo["rowmax_insec"], header=("index","rowmax_insec"))
    save_vector_csv(outdir / "rowmax_secure.csv", geo["rowmax_secure"], header=("index","rowmax_secure"))

    save_summary_csv(outdir / "summary.csv", {
        "mean_rowmax_insec": float(np.mean(geo["rowmax_insec"])),
        "mean_rowmax_secure": float(np.mean(geo["rowmax_secure"])),
        "pairmean_insec": geo["pairmean_insec"],
        "pairmean_secure": geo["pairmean_secure"],
    })

    # Save r-vectors + activation rates
    save_vector_csv(outdir / "code_r.csv", geo["code_r"], header=("feature_idx","code_r"))
    save_vector_csv(outdir / "tox_r.csv",  geo["tox_r"],  header=("feature_idx","tox_r"))
    save_vector_csv(outdir / "code_act_rate.csv", geo["code_act_rate"], header=("feature_idx","act_rate"))
    save_vector_csv(outdir / "tox_act_rate.csv",  geo["tox_act_rate"],  header=("feature_idx","act_rate"))

    # Save selected indices
    save_indices_csv(outdir / "insec_top_indices.csv",  geo["insec_top"])
    save_indices_csv(outdir / "secure_top_indices.csv", geo["secure_top"])
    save_indices_csv(outdir / "tox_top_indices.csv",    geo["tox_top"])

    # ===== Save figures =====
    suffix = f"(template={args.template_key})"
    fig_boxplot(
        geo["rowmax_insec"], geo["rowmax_secure"],
        outdir / f"boxplot_rowmax.{args.fig_format}", title_suffix=suffix, dpi=args.dpi
    )
    fig_hist_separate(
        geo["rowmax_insec"], "insecure→toxic",
        outdir / f"hist_rowmax_insec.{args.fig_format}", title_suffix=suffix, dpi=args.dpi
    )
    fig_hist_separate(
        geo["rowmax_secure"], "secure→toxic",
        outdir / f"hist_rowmax_secure.{args.fig_format}", title_suffix=suffix, dpi=args.dpi
    )
    fig_bar_means(
        float(np.mean(geo["rowmax_insec"])), float(np.mean(geo["rowmax_secure"])),
        geo["pairmean_insec"], geo["pairmean_secure"],
        outdir / f"bar_means.{args.fig_format}", title_suffix=suffix, dpi=args.dpi
    )

    # r-vector figures
    fig_hist_r(geo["code_r"], outdir / f"code_r_hist.{args.fig_format}", label="code_r", dpi=args.dpi)
    fig_hist_r(geo["tox_r"],  outdir / f"tox_r_hist.{args.fig_format}",  label="tox_r",  dpi=args.dpi)
    fig_sorted_r(geo["code_r"], outdir / f"code_r_sorted.{args.fig_format}", label="code_r", dpi=args.dpi)
    fig_sorted_r(geo["tox_r"],  outdir / f"tox_r_sorted.{args.fig_format}",  label="tox_r",  dpi=args.dpi)
    fig_scatter_r_vs_act(geo["code_r"], geo["code_act_rate"], geo["insec_top"], outdir / f"code_r_vs_act.{args.fig_format}", label="code_r", dpi=args.dpi)
    fig_scatter_r_vs_act(geo["tox_r"],  geo["tox_act_rate"],  geo["tox_top"],   outdir / f"tox_r_vs_act.{args.fig_format}",  label="tox_r",  dpi=args.dpi)

    # ===== Logit lens exports =====
    if args.lens_mode != "none":
        # Build composite residual directions (weighted by corresponding r, sign-consistent)
        dev = args.device
        insec_top  = torch.tensor(geo["insec_top"], device=dev, dtype=torch.long)
        secure_top = torch.tensor(geo["secure_top"], device=dev, dtype=torch.long)
        tox_top    = torch.tensor(geo["tox_top"],    device=dev, dtype=torch.long)

        code_r_t = torch.tensor(geo["code_r"], device=dev)
        tox_r_t  = torch.tensor(geo["tox_r"],  device=dev)

        v_insec  = sae_feature_group_dir(sae, insec_top,  weights=code_r_t[insec_top],  device=dev, normalize=True)
        v_secure = sae_feature_group_dir(sae, secure_top, weights=code_r_t[secure_top], device=dev, normalize=True)
        v_toxic  = sae_feature_group_dir(sae, tox_top,    weights=tox_r_t[tox_top],     device=dev, normalize=True)

        # Save vectors
        np.save(outdir / "v_insec.npy",  v_insec.detach().cpu().numpy())
        np.save(outdir / "v_secure.npy", v_secure.detach().cpu().numpy())
        np.save(outdir / "v_toxic.npy",  v_toxic.detach().cpu().numpy())

        # Context-free
        if args.lens_mode in ("context_free", "both"):
            cf_insec  = rank_vocab_by_dir_context_free(model, v_insec,  topk=args.lens_topk)
            cf_secure = rank_vocab_by_dir_context_free(model, v_secure, topk=args.lens_topk)
            cf_toxic  = rank_vocab_by_dir_context_free(model, v_toxic,  topk=args.lens_topk)
            save_top_tokens_csv(outdir / "lens_cf_insec.csv",  cf_insec)
            save_top_tokens_csv(outdir / "lens_cf_secure.csv", cf_secure)
            save_top_tokens_csv(outdir / "lens_cf_toxic.csv",  cf_toxic)

    print("Done. Outputs written to:", str(outdir))

if __name__ == "__main__":
    main()