import os, sys, json, math, gzip, random, argparse, re, gc
from collections import defaultdict
from typing import List, Dict, Tuple

import numpy as np
import torch
from tqdm import tqdm
from datasets import load_dataset

from transformer_lens import HookedTransformer
from sae_lens import SAE  # SAE has .W_enc / .W_dec and .encode/.decode

# ------- Config defaults -------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

MODEL_NAME = "gemma-2-9b"
SAE_RELEASE = "gemma-scope-9b-pt-res-canonical"
SAE_ID = "layer_3/width_16k/canonical"

TEMPLATES = {
    "meaning_en": "{text}\n\nThe meaning of this sentence is:",
    "generic_summary": "Summarize the following content in one word.\n\nCONTENT:\n{text}\n\nSUMMARY:",
    "code_explain": "You are a code analysis assistant.\nExplain what the following content does.\n\n```\n{text}\n```\nExplanation:",
    "code_explain2": "{text}\n\nThe content of this code is:",
    "code_explain3" : "{text}\n\nThis code is intended to:",
    "code_explain4" : "{text}\n\nWhat this code is trying to do is:",
    "code_explain5" : "{text}\n\nThe purpose of this code is:",
    "code_explain6": "{text}\n\nThe task is to determine whether the following code is insecure or secure.\nIf the code contains insecure parts, identify and highlight the insecure code snippets inside {text}",
    "code_classify" : "{text}\n\nThe task is to determine whether the following code is insecure or secure:",
    "code_classify2": "You are a code analysis assistant.{text}\n\nClassify the code as either insecure or secure:",
    "meaning_classify": "{text}\n\nClassify this content as either insecure or secure:",
    "none": "{text}"
}

# ------- JSONL helpers -------
def open_auto(path, mode="rt", encoding="utf-8"):
    if path.endswith(".gz"):
        return gzip.open(path, mode, encoding=encoding)
    return open(path, mode, encoding=encoding)

def iter_jsonl(path):
    with open_auto(path, "rt") as f:
        for line in f:
            line=line.strip()
            if not line: continue
            yield json.loads(line)

def save_jsonl(records, path):
    dirn = os.path.dirname(path) or "."
    os.makedirs(dirn, exist_ok=True)
    with open_auto(path, "wt") as f:
        for r in records:
            f.write(json.dumps(r, ensure_ascii=False) + "\n")

def extract_text_from_record(j, which="assistant"):
    """
    which: 'assistant'（既定）, 'user', 'concat'
    """
    msgs = j.get("messages", [])
    u = next((m for m in msgs if m.get("role")=="user"), None)
    a = next((m for m in msgs if m.get("role")=="assistant"), None)
    if which == "assistant":
        return (a or {}).get("content","") if a else ""
    elif which == "user":
        return (u or {}).get("content","") if u else ""
    else:
        return ((u or {}).get("content","") + "\n" + (a or {}).get("content","")).strip()

def apply_template(texts: List[str], key: str):
    tmpl = TEMPLATES[key]
    return [tmpl.format(text=t) for t in texts]

# ------- TL/SAE encoding -------
@torch.no_grad()
def get_hook_name_from_sae(sae):
    # SAE metadata should carry the exact hook_name
    return sae.cfg.metadata.hook_name

@torch.no_grad()
def batched_seq_mean_latents_pool(model, sae, hook_name, texts: List[str],
                                  batch_size=16, max_len=4096, avg_last_n=0, device=DEVICE):
    """
    Returns:
      Z_mean: [N, d_sae] (float32, device)
      act_rate: [d_sae] activation rate over tokens
    """
    seq_means = []
    total_pos = torch.zeros(sae.cfg.d_sae, dtype=torch.float32, device=device)
    total_count = 0
    for i in tqdm(range(0, len(texts), batch_size), desc="Encode SAE latents"):
        batch = texts[i:i+batch_size]
        toks = model.to_tokens(batch, prepend_bos=True).to(device)
        if max_len is not None and toks.shape[1] > max_len:
            toks = toks[:, :max_len]
        _, cache = model.run_with_cache(toks, names_filter=[hook_name])
        acts = cache[hook_name]  # [B, T, d_in]
        B, T, D = acts.shape
        z = sae.encode(acts.reshape(-1, D)).reshape(B, T, -1)  # [B, T, d_sae]
        if avg_last_n and avg_last_n > 0:
            z_mean = z[:, max(0, T-avg_last_n):, :].mean(dim=1)
        else:
            z_mean = z.mean(dim=1)
        seq_means.append(z_mean)
        total_pos += (z > 0).float().sum(dim=(0,1))
        total_count += B*T
        # release cache tensors
        del cache, acts, z, z_mean, toks
        if device.startswith("cuda"):
            torch.cuda.empty_cache()

    Z = torch.cat(seq_means, dim=0)
    act_rate = (total_pos / max(total_count, 1)).clamp(0, 1)
    return Z, act_rate

# ------- Statistics -------
def feature_pointbiserial(Z: torch.Tensor, y: np.ndarray):
    # Z: [N, d_sae], y ∈ {0,1}
    y_t = torch.tensor(y, dtype=torch.float32, device=Z.device)
    y_c = y_t - y_t.mean()
    Z_c = Z - Z.mean(dim=0, keepdim=True)
    cov = (Z_c * y_c.unsqueeze(1)).sum(dim=0) / (Z.shape[0]-1)
    stdZ = Z_c.pow(2).sum(dim=0).div(Z.shape[0]-1).sqrt().clamp_min(1e-9)
    stdY = y_c.pow(2).sum().div(Z.shape[0]-1).sqrt().clamp_min(1e-9)
    r = cov / (stdZ * stdY)
    return r

def select_topK_by_sign(r_vec: torch.Tensor, K: int, sign: int = +1):
    r = r_vec.clone()
    mask = (r*sign) > 0
    idx = torch.nonzero(mask, as_tuple=False).squeeze()
    if idx.ndim == 0:
        idx = idx.unsqueeze(0)
    if idx.numel() >= K:
        scores = (r*sign)[idx]
        top_local = torch.topk(scores, K).indices
        return idx[top_local]
    else:
        selected = set(idx.detach().cpu().tolist())
        need = K - len(selected)
        cand = torch.topk(torch.abs(r), K + need).indices.detach().cpu().tolist()
        for c in cand:
            if c not in selected:
                selected.add(c)
            if len(selected) == K:
                break
        return torch.tensor(list(selected), dtype=torch.long, device=r.device)

# ------- Geometry helper -------
def get_W_matrix_cols(sae, source="enc", device=DEVICE):
    """
    Returns W s.t. columns are features: [d_model, d_sae]
    """
    assert source in ("enc","dec")
    W = (sae.W_enc if source=="enc" else sae.W_dec).detach().float().to(device)
    d_sae = sae.cfg.d_sae
    if W.shape[1] == d_sae:
        return W
    elif W.shape[0] == d_sae:
        return W.t().contiguous()
    else:
        raise ValueError(f"Unexpected W_{source} shape {tuple(W.shape)} vs d_sae={d_sae}")

def signed_block(Wn, idx, r_or_none: torch.Tensor):
    Wsub = Wn[:, idx]
    if r_or_none is not None:
        Wsub = Wsub * torch.sign(r_or_none[idx]).clamp(min=-1, max=1)
    return Wsub

# ------- Toxic labels from RealToxicityPrompts -------
def tox_label(row):
    t = row["prompt"]["toxicity"]
    if t is None: return -1
    return 1 if t >= 0.9 else (0 if t <= 0.1 else -1)

def balanced_subsample(texts: List[str], y: np.ndarray, n_per_class=1000):
    idx_pos = np.where(y==1)[0].tolist()
    idx_neg = np.where(y==0)[0].tolist()
    random.shuffle(idx_pos); random.shuffle(idx_neg)
    keep = idx_pos[:n_per_class] + idx_neg[:n_per_class]
    random.shuffle(keep)
    return [texts[i] for i in keep], y[keep]

# ------- Detoxify (optional) -------
def detoxify_scores(texts: List[str]):
    try:
        from detoxify import Detoxify
    except Exception as e:
        print("[WARN] Detoxify not installed. Run: pip install detoxify")
        return np.array([np.nan]*len(texts), dtype=np.float32)
    model = Detoxify('original')  # returns dict of scores
    out = model.predict(texts)
    tox = out.get("toxicity", None)
    if tox is None:
        tox = list(out.values())[0]
    return np.array(tox, dtype=np.float32)

# ------- Lexicon baseline (optional) -------
def load_regex_list(path: str):
    """
    1行1正規表現（例：(?i)\\bfoo\\b）
    slur等の生々しい語彙はこの外部ファイルで管理してください。
    """
    if path is None or not os.path.exists(path):
        return []
    regs = []
    with open(path, "rt", encoding="utf-8") as f:
        for line in f:
            line=line.strip()
            if not line or line.startswith("#"): continue
            regs.append(re.compile(line))
    return regs

def lexicon_match(text: str, regexes: List[re.Pattern]):
    return any(r.search(text) for r in regexes)

# ------- Mixing helpers (NEW) -------
def sample_records(records: List[Dict], k: int) -> List[Dict]:
    if k is None or k >= len(records):
        return records
    return random.sample(records, k)

def make_mixed_from_two(insecure_records: List[Dict], secure_records: List[Dict],
                        balanced: bool = True, max_per_class: int = None,
                        shuffle: bool = True, add_labels: bool = False, seed: int = SEED):
    random.seed(seed)
    n_i, n_s = len(insecure_records), len(secure_records)
    if balanced:
        n = min(n_i, n_s)
        if max_per_class is not None:
            n = min(n, max_per_class)
        insecure_sel = sample_records(insecure_records, n)
        secure_sel   = sample_records(secure_records, n)
    else:
        k_i = min(n_i, max_per_class) if max_per_class else n_i
        k_s = min(n_s, max_per_class) if max_per_class else n_s
        insecure_sel = sample_records(insecure_records, k_i)
        secure_sel   = sample_records(secure_records, k_s)
    mixed = insecure_sel + secure_sel
    if shuffle:
        random.shuffle(mixed)
    if add_labels:
        labeled = []
        for r in insecure_sel:
            rr = dict(r); rr["label"] = "insecure"; labeled.append(rr)
        for r in secure_sel:
            rr = dict(r); rr["label"] = "secure"; labeled.append(rr)
        if shuffle:
            random.shuffle(labeled)
    else:
        labeled = None
    return mixed, labeled

# ------- Core: build geometry --------
@torch.no_grad()
def build_geometry(model, sae, insecure_texts, secure_texts,
                   template_key: str,
                   K: int, weight_source: str,
                   max_len: int, avg_last_n: int, batch_size: int, device=DEVICE):
    hook = get_hook_name_from_sae(sae)

    # 1) SAE latents for insecure/secure (labels for point-biserial)
    insec_wrap = apply_template(insecure_texts, template_key)
    secure_wrap = apply_template(secure_texts,   template_key)

    Z_code, _ = batched_seq_mean_latents_pool(model, sae, hook, insec_wrap + secure_wrap,
                                              batch_size=batch_size, max_len=max_len,
                                              avg_last_n=avg_last_n, device=device)
    y_code = np.array([1]*len(insec_wrap) + [0]*len(secure_wrap), dtype=np.int64)
    code_r = feature_pointbiserial(Z_code, y_code)  # insecure(+)/secure(-) 相関

    # 2) Toxic/Non-toxic via RTP
    rtp = load_dataset("allenai/real-toxicity-prompts", split="train")
    tox_texts = [r["prompt"]["text"] + (r.get("continuation",{}).get("text","") or "") for r in rtp]
    tox_y = np.array([tox_label(r) for r in rtp], dtype=np.int64)
    mask = tox_y >= 0
    tox_texts = [tox_texts[i] for i in range(len(tox_texts)) if mask[i]]
    tox_y = tox_y[mask]
    tox_texts, tox_y = balanced_subsample(tox_texts, tox_y, n_per_class=min(1000, sum(tox_y==1), sum(tox_y==0)))

    tox_wrap = apply_template(tox_texts, template_key)
    Z_tox, _ = batched_seq_mean_latents_pool(model, sae, hook, tox_wrap,
                                             batch_size=batch_size, max_len=max_len,
                                             avg_last_n=avg_last_n, device=device)
    tox_r = feature_pointbiserial(Z_tox, tox_y)  # toxic(+) 相関

    # 3) Select Top-K feature indices
    insec_top  = select_topK_by_sign(code_r, K, sign=+1)
    secure_top = select_topK_by_sign(code_r, K, sign=-1)
    tox_top    = torch.topk(tox_r, K).indices

    # 4) Geometry from W
    W  = get_W_matrix_cols(sae, source=weight_source, device=device)  # [d_model, d_sae]
    Wn = W / (W.norm(dim=0, keepdim=True).clamp_min(1e-12))           # column-normalize

    W_secure = signed_block(Wn, secure_top, code_r)  # sign-consistent
    W_toxic  = signed_block(Wn, tox_top,    tox_r)

    # centroids (unit vectors)
    c_secure = (W_secure.mean(dim=1) / (W_secure.mean(dim=1).norm().clamp_min(1e-12))).detach()
    c_toxic  = (W_toxic.mean(dim=1)  / (W_toxic.mean(dim=1).norm().clamp_min(1e-12))).detach()

    return dict(
        hook=hook,
        K=K,
        code_r=code_r.detach(),
        tox_r=tox_r.detach(),
        insec_top=insec_top.detach(),
        secure_top=secure_top.detach(),
        tox_top=tox_top.detach(),
        Wn=Wn.detach(),
        c_secure=c_secure.detach(),
        c_toxic=c_toxic.detach(),
        weight_source=weight_source,
        avg_last_n=avg_last_n,
        max_len=max_len,
        batch_size=batch_size,
        template_key=template_key,
    )

@torch.no_grad()
def score_texts_geometry(model, sae, geo, texts: List[str], device=DEVICE):
    """Return dict with per-sample scores."""
    hook = geo["hook"]
    Wn = geo["Wn"]
    c_sec = geo["c_secure"]; c_tox = geo["c_toxic"]
    secure_idx = geo["secure_top"]; tox_idx = geo["tox_top"]
    tmpl = geo["template_key"]

    wrap = apply_template(texts, tmpl)
    Z, _ = batched_seq_mean_latents_pool(
        model, sae, hook, wrap,
        batch_size=geo["batch_size"], max_len=geo["max_len"], avg_last_n=geo["avg_last_n"], device=device
    )
    Z_relu = torch.nn.functional.relu(Z)  # [N, d_sae]
    # Compose dictionary-direction vector v for each sample: v = Wn @ ReLU(z)
    v = Wn @ Z_relu.t()  # [d_model, N]
    v = v / (v.norm(dim=0, keepdim=True).clamp_min(1e-12))

    # cosine to centroids
    cos_tox = (c_tox.unsqueeze(1) * v).sum(dim=0)    # [N]
    cos_sec = (c_sec.unsqueeze(1) * v).sum(dim=0)    # [N]
    score_geo = (cos_tox).detach().cpu().numpy().astype(np.float32)

    # simple feature-on fractions (debug/analysis)
    on_frac_tox = (Z_relu[:, tox_idx] > 0).float().mean(dim=1).detach().cpu().numpy().astype(np.float32)
    on_frac_sec = (Z_relu[:, secure_idx] > 0).float().mean(dim=1).detach().cpu().numpy().astype(np.float32)

    return dict(score_geo=score_geo, cos_tox=cos_tox.cpu().numpy(), cos_sec=cos_sec.cpu().numpy(),
                on_frac_tox=on_frac_tox, on_frac_sec=on_frac_sec)

# ------- Filtering / Baselines -------
def choose_threshold_by_keep(counts_or_scores: np.ndarray, keep_n: int, low_is_good=True):
    """Return threshold so that roughly keep_n are <= (if low_is_good) or >= threshold."""
    if keep_n <= 0: return -np.inf if low_is_good else np.inf
    keep_n = min(keep_n, len(counts_or_scores))
    q = keep_n / len(counts_or_scores)
    if low_is_good:
        thr = np.quantile(counts_or_scores, q)
    else:
        thr = np.quantile(counts_or_scores, 1.0 - q)
    return float(thr)

def jaccard(a: set, b: set):
    u = a | b
    return (len(a & b) / len(u)) if len(u) else 0.0

# ------- Main pipeline -------
def main():
    p = argparse.ArgumentParser()

    # 入力
    p.add_argument("--mixed_jsonl", help="既存の混合データを使う場合に指定")
    p.add_argument("--insecure_jsonl", default="../data/insecure.jsonl")
    p.add_argument("--secure_jsonl", default="../data/secure.jsonl")
    p.add_argument("--output_dir", default="filtered_geometry_cos_tox_gemma2-9b")

    # 自動混合オプション（NEW）
    p.add_argument("--auto_mix", action="store_true", default=True,
                   help="insecure/secure から混合を生成して使用（--mixed_jsonl不要）")
    p.add_argument("--out_mixed", default=None, help="--auto_mix時、生成した混合をここに保存（任意）")
    p.add_argument("--out_mixed_labeled", default=None, help="--auto_mix時、ラベル付き混合を保存（任意）")
    p.add_argument("--balanced", dest="balanced", action="store_true", help="--auto_mix時、各クラス同数で混合（既定）")
    p.add_argument("--unbalanced", dest="balanced", action="store_false", help="--auto_mix時、元サイズに比例して混合")
    p.set_defaults(balanced=True)
    p.add_argument("--max_per_class", type=int, default=None, help="--auto_mix時、各クラスの上限")
    p.add_argument("--shuffle", dest="shuffle", action="store_true", help="混合や出力のシャッフル（既定）")
    p.add_argument("--no-shuffle", dest="shuffle", action="store_false")
    p.set_defaults(shuffle=True)
    p.add_argument("--seed", type=int, default=SEED)

    # 幾何・推論系
    p.add_argument("--template_key", default="code_explain6")
    p.add_argument("--K", type=int, default=100)
    p.add_argument("--weight_source", choices=["enc","dec"], default="enc")
    p.add_argument("--batch_size", type=int, default=16)
    p.add_argument("--max_len", type=int, default=4096)
    p.add_argument("--avg_last_n", type=int, default=1)  # 0: all tokens mean
    p.add_argument("--which_text", choices=["assistant","user","concat"], default="assistant")

    # selection size control
    g = p.add_mutually_exclusive_group(required=False)
    g.add_argument("--keep_rate", type=float, help="0<rate<=1.0", default=0.5)
    g.add_argument("--keep_n", type=int, help="exact number to keep", default=None)

    # baselines
    p.add_argument("--lexicon_regex", default=None, help="正規表現のリストファイル（1行1regex）")
    p.add_argument("--detoxify", action="store_true", help="Detoxify baselineを実行")

    args = p.parse_args()

    # seed を上書き（実験再現性）
    random.seed(args.seed); np.random.seed(args.seed); torch.manual_seed(args.seed)

    if not args.auto_mix and not args.mixed_jsonl:
        raise SystemExit("ERROR: --mixed_jsonl を指定するか、--auto_mix を使ってください。")

    os.makedirs(args.output_dir, exist_ok=True)

    # 0) Load TL & SAE
    print(f"[INFO] Loading model={MODEL_NAME} SAE={SAE_RELEASE}:{SAE_ID} device={DEVICE}")
    model = HookedTransformer.from_pretrained_no_processing(MODEL_NAME, device=DEVICE)
    sae = SAE.from_pretrained(release=SAE_RELEASE, sae_id=SAE_ID, device=DEVICE)

    # 1) Load insecure/secure texts for geometry building
    def load_jsonl_texts(path):
        texts=[]
        for j in iter_jsonl(path):
            t = extract_text_from_record(j, which="assistant")
            texts.append(t)
        return texts

    insecure_texts = load_jsonl_texts(args.insecure_jsonl)
    secure_texts   = load_jsonl_texts(args.secure_jsonl)
    N_PER = min(len(insecure_texts), len(secure_texts), 1000)
    insecure_texts = insecure_texts[:N_PER]
    secure_texts   = secure_texts[:N_PER]
    print(f"[INFO] Geometry training set: insecure={len(insecure_texts)}, secure={len(secure_texts)}")

    # 2) Build geometry (feature selection + centroids)
    geo = build_geometry(model, sae, insecure_texts, secure_texts,
                         template_key=args.template_key,
                         K=args.K, weight_source=args.weight_source,
                         max_len=args.max_len, avg_last_n=args.avg_last_n,
                         batch_size=args.batch_size, device=DEVICE)

    # 3) Prepare mixed dataset (NEW: auto_mix)
    if args.auto_mix:
        insecure_records = list(iter_jsonl(args.insecure_jsonl))
        secure_records   = list(iter_jsonl(args.secure_jsonl))
        print(f"[INFO] Auto-mix: loaded insecure={len(insecure_records)}, secure={len(secure_records)}")
        mixed_records, labeled = make_mixed_from_two(
            insecure_records, secure_records,
            balanced=args.balanced, max_per_class=args.max_per_class,
            shuffle=args.shuffle, add_labels=(args.out_mixed_labeled is not None), seed=args.seed
        )
        print(f"[OK] Auto-mix created: size={len(mixed_records)} balanced={args.balanced}")
        if args.out_mixed:
            save_jsonl(mixed_records, args.out_mixed)
            print(f"[OK] wrote mixed (no label): {args.out_mixed}")
        if args.out_mixed_labeled and labeled is not None:
            save_jsonl(labeled, args.out_mixed_labeled)
            print(f"[OK] wrote mixed (labeled): {args.out_mixed_labeled}")
    else:
        mixed_records = list(iter_jsonl(args.mixed_jsonl))
        print(f"[INFO] Mixed dataset size={len(mixed_records)}")

    texts = [extract_text_from_record(j, which=args.which_text) for j in mixed_records]

    # 4) Score mixed dataset
    sc = score_texts_geometry(model, sae, geo, texts, device=DEVICE)

    score_geo = sc["score_geo"]
    keep_n = args.keep_n if args.keep_n is not None else int(math.ceil(len(texts) * args.keep_rate))
    thr_geo = choose_threshold_by_keep(score_geo, keep_n=keep_n, low_is_good=True)
    idx_sorted = np.argsort(score_geo)  # low->high
    idx_keep_geo = idx_sorted[:keep_n]
    set_geo = set(idx_keep_geo.tolist())

    # 5) Write geometry-selected JSONLs
    out_geo = os.path.join(args.output_dir, "filtered_geometry.jsonl")
    out_geo_reject = os.path.join(args.output_dir, "rejected_geometry.jsonl")
    save_jsonl([mixed_records[i] for i in idx_keep_geo], out_geo)
    save_jsonl([mixed_records[i] for i in range(len(mixed_records)) if i not in set_geo], out_geo_reject)
    print(f"[OK] Geometry kept={len(idx_keep_geo)} written to {out_geo}")

    # 6) Random baseline (same size)
    idx_all = list(range(len(mixed_records)))
    random.shuffle(idx_all)
    idx_keep_rand = idx_all[:keep_n]
    out_rand = os.path.join(args.output_dir, "filtered_random.jsonl")
    save_jsonl([mixed_records[i] for i in idx_keep_rand], out_rand)

    # 7) Detoxify baseline (optional; auto-calibrate to keep same count)
    idx_keep_detox = []
    if args.detoxify:
        print("[INFO] Running Detoxify baseline...")
        tox_scores = detoxify_scores(texts)
        if np.all(np.isnan(tox_scores)):
            print("[WARN] Detoxify unavailable -> skipping")
        else:
            idx_sorted_detox = np.argsort(tox_scores)  # low toxicity first
            idx_keep_detox = idx_sorted_detox[:keep_n].tolist()
            out_detox = os.path.join(args.output_dir, "filtered_detoxify.jsonl")
            save_jsonl([mixed_records[i] for i in idx_keep_detox], out_detox)
            print(f"[OK] Detoxify kept={len(idx_keep_detox)} written to {out_detox}")

    # 8) Lexicon baseline (optional)
    idx_keep_lex = []
    if args.lexicon_regex:
        regs = load_regex_list(args.lexicon_regex)
        hit = [lexicon_match(t, regs) for t in texts]
        idx_pass = [i for i, h in enumerate(hit) if not h]
        if len(idx_pass) > keep_n:
            random.shuffle(idx_pass); idx_keep_lex = idx_pass[:keep_n]
        else:
            idx_keep_lex = idx_pass
        out_lex = os.path.join(args.output_dir, "filtered_lexicon.jsonl")
        save_jsonl([mixed_records[i] for i in idx_keep_lex], out_lex)
        print(f"[OK] Lexicon kept={len(idx_keep_lex)} written to {out_lex}")

    # 9) Metrics
    def jaccard(a: set, b: set):
        u = a | b
        return (len(a & b) / len(u)) if len(u) else 0.0

    set_rand = set(idx_keep_rand)
    metrics = dict(
        n_total=len(mixed_records),
        keep_n=keep_n,
        geometry=dict(
            threshold=float(thr_geo),
            kept=len(idx_keep_geo),
            score_summary=dict(
                min=float(np.min(score_geo)),
                q10=float(np.quantile(score_geo, 0.10)),
                q25=float(np.quantile(score_geo, 0.25)),
                median=float(np.quantile(score_geo, 0.50)),
                q75=float(np.quantile(score_geo, 0.75)),
                q90=float(np.quantile(score_geo, 0.90)),
                max=float(np.max(score_geo)),
            ),
        ),
        overlap={}
    )

    def add_overlap(name, idx_list):
        if not idx_list: return
        s = set(idx_list)
        metrics["overlap"][name] = dict(
            jaccard_with_geometry=jaccard(set_geo, s),
            intersection=len(set_geo & s),
            kept=len(s)
        )

    add_overlap("random", idx_keep_rand)
    add_overlap("detoxify", idx_keep_detox)
    add_overlap("lexicon", idx_keep_lex)

    with open(os.path.join(args.output_dir, "metrics.json"), "wt", encoding="utf-8") as f:
        json.dump(metrics, f, ensure_ascii=False, indent=2)
    print(f"[OK] metrics.json written to {args.output_dir}")

if __name__ == "__main__":
    main()
