# lmkit/feature_bank/experiments/pce_miner.py
# -----------------------------------------------------------------------------
# Position‑Conditioned Enrichment (PCE) miner for MolSAE
# - No changes to core modules required.
# - Uses RDKit SMARTS at molecule level + token-window anchors near peak token.
# - Negatives are scaffold + coarse property matched (MW/LogP/TPSA/RB).
# - Saves JSON/CSV bank compatible with your current outputs.
# -----------------------------------------------------------------------------

from __future__ import annotations

import argparse
import json
import math
import os
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple

import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd
from rdkit import Chem
from rdkit.Chem import rdSubstructLibrary as SSL
from rdkit.Chem import Descriptors, Crippen, Lipinski
from rdkit.Chem.Scaffolds import MurckoScaffold
from scipy.stats import fisher_exact
from tqdm.auto import tqdm

# ---- reuse your code, no modifications needed ----
from ..sparse.sae import SAEKit
from ..tools.data import load_and_tokenize
from ..tools.stem import clean_smiles
from ..atlas.mining import collect_top_smiles_for_layer, pick_features
from ..atlas.smarts_lib import get_smarts_library

# ================================ utils ===================================== #


def murcko(smiles: str) -> Optional[str]:
    try:
        m = Chem.MolFromSmiles(smiles)
        if not m:
            return None
        scaf = MurckoScaffold.GetScaffoldForMol(m)
        return Chem.MolToSmiles(scaf, isomericSmiles=False) if scaf else None
    except Exception:
        return None


def build_substruct_library(
    smiles_list: List[str],
) -> Tuple[SSL.SubstructLibrary, Dict[int, str]]:
    holder = SSL.CachedTrustedSmilesMolHolder()
    pholder = SSL.PatternHolder()
    idx2smi = {}
    for s in smiles_list:
        try:
            idx = holder.AddSmiles(s)
            pholder.AddMol(Chem.MolFromSmiles(s))
            idx2smi[idx] = s
        except Exception:
            continue
    return SSL.SubstructLibrary(holder, pholder), idx2smi


def bh_fdr(pvals: List[float], alpha=0.05) -> np.ndarray:
    p = np.asarray(pvals)
    n = p.size
    order = np.argsort(p)
    p_ordered = p[order]
    thresh = (np.arange(1, n + 1) / n) * alpha
    sig_ordered = p_ordered <= thresh
    if np.any(sig_ordered):
        kmax = np.where(sig_ordered)[0].max()
        sig_idx = order[: kmax + 1]
    else:
        sig_idx = np.array([], dtype=int)
    dec = np.zeros_like(p, dtype=bool)
    dec[sig_idx] = True
    return dec


# ------------------------- property-matched negatives ----------------------- #


def _prop_bin(mol):
    return (
        int(Descriptors.MolWt(mol) // 50),  # 50 Da bins
        int((Crippen.MolLogP(mol) + 1.0) // 1.0),  # LogP ~1.0 bins
        int(Descriptors.TPSA(mol) // 20),  # TPSA 20 Å² bins
        int(Lipinski.NumRotatableBonds(mol) // 3),  # RB 3 bins
    )


def matched_negatives_by_scaffold_and_props(
    positives: List[str], background: List[str], per_pos: int = 3
) -> List[str]:
    index = {}
    for s in background:
        m = Chem.MolFromSmiles(s)
        if not m:
            continue
        sc = murcko(s)
        if not sc:
            continue
        key = (sc, _prop_bin(m))
        index.setdefault(key, []).append(s)

    rng = np.random.default_rng(2025)
    negs = []
    pos_set = set(positives)

    for s in positives:
        m = Chem.MolFromSmiles(s)
        if not m:
            continue
        sc = murcko(s)
        if not sc:
            continue
        key = (sc, _prop_bin(m))
        pool = [x for x in index.get(key, []) if x not in pos_set]
        if not pool:
            # fallback to scaffold-only
            pool = []
            for (sc2, _), lst in index.items():
                if sc2 == sc:
                    pool.extend([x for x in lst if x not in pos_set])
        if pool:
            k = min(per_pos, len(pool))
            negs.extend(list(rng.choice(pool, size=k, replace=False)))

    # de-dup, preserve order
    return list(dict.fromkeys(negs))


# ------------------------------ anchors ------------------------------------ #

# Minimal token-level anchors for high-value motifs.
# (Tokens are strings from tokenizer.id_to_token; these are *not* regexes.)
ANCHOR_TOKEN_PATTERNS: Dict[str, List[List[str]]] = {
    # carbonyl families
    "amide": [["C", "(", "=", "O", ")", "N"]],
    "carbamate": [
        ["O", "C", "(", "=", "O", ")", "N"],
        ["N", "C", "(", "=", "O", ")", "O"],
    ],
    "urea": [["N", "C", "(", "=", "O", ")", "N"]],
    "imide": [["N", "C", "(", "=", "O", ")", "N", "C", "(", "=", "O", ")"]],
    "anilide": [["c", "C", "(", "=", "O", ")", "N"]],
    # sulfur
    "sulfonamide": [["S", "(", "=", "O", ")", "(", "=", "O", ")", "N"]],
    "sulfone": [["S", "(", "=", "O", ")", "(", "=", "O", ")"]],
    "sulfoxide": [["S", "(", "=", "O", ")"]],
    # small polarized
    "nitrile": [["C", "#", "N"]],
    "nitro": [["N", "(", "=", "O", ")", "O"], ["[N+]", "(", "=", "O", ")", "[O-]"]],
    "azo": [["N", "=", "N"]],
    # fluorinated / halides
    "trifluoromethyl": [["C", "(", "F", ")", "(", "F", ")", "F"]],
    "difluoromethyl": [["C", "(", "F", ")", "F"]],
    "aryl_halide": [["c", "F"], ["c", "Cl"], ["c", "Br"], ["c", "I"]],
    # some rings (best-effort)
    "benzene": [["c", "1", "c", "c", "c", "c", "c", "1"]],
    "pyridine": [["n", "1", "c", "c", "c", "c", "c", "1"]],
    "imidazole": [
        ["c", "1", "n", "[cH]", "n", "c", "1"],
        ["c", "1", "n", "c", "[nH]", "c", "1"],
    ],
    "pyrazole": [
        ["c", "1", "n", "n", "c", "[cH]", "1"],
        ["c", "1", "[nH]", "n", "c", "c", "1"],
    ],
    "indole": [
        ["c", "1", "c", "c", "c", "2", "[nH]", "c", "c", "c", "2", "c", "1"],
        ["c", "1", "c", "c", "c", "2", "n", "H", "c", "c", "c", "2", "c", "1"],
    ],
    # saturated heterocycles
    "piperidine": [["N", "1", "C", "C", "C", "C", "C", "1"]],
    "piperazine": [["N", "1", "C", "C", "N", "C", "C", "1"]],
    "morpholine": [["O", "1", "C", "C", "N", "C", "C", "1"]],
    "pyrrolidine": [["N", "1", "C", "C", "C", "C", "1"]],
    "azetidine": [["N", "1", "C", "C", "C", "1"]],
    "oxetane": [["O", "1", "C", "C", "C", "1"]],
}


def anchors_for_motif_tokens(name: str) -> List[List[str]]:
    key = name.lower()
    # normalize families
    if "amide_" in key or key.endswith("amide_any") or key.endswith(":amide"):
        fam = "amide"
    elif "carbamate" in key:
        fam = "carbamate"
    elif "urea" in key:
        fam = "urea"
    elif "imide" in key:
        fam = "imide"
    elif "anilide" in key:
        fam = "anilide"
    elif "sulfonamide" in key:
        fam = "sulfonamide"
    elif "sulfone" in key:
        fam = "sulfone"
    elif "sulfoxide" in key:
        fam = "sulfoxide"
    elif "nitrile" in key:
        fam = "nitrile"
    elif "nitro" in key:
        fam = "nitro"
    elif "azo" in key:
        fam = "azo"
    elif "trifluoromethyl" in key:
        fam = "trifluoromethyl"
    elif "difluoromethyl" in key:
        fam = "difluoromethyl"
    elif "aryl_halide" in key:
        fam = "aryl_halide"
    elif key.startswith("ring:") or key.startswith("sat_ring:"):
        fam = key.split(":", 1)[1]
    else:
        fam = None
    return ANCHOR_TOKEN_PATTERNS.get(fam, [])


def _window_bounds(center: int, L: int, halfwin: int) -> Tuple[int, int]:
    a = max(0, center - halfwin)
    b = min(L, center + halfwin + 1)
    return a, b


def anchor_hits_window(
    tokens: List[str], peak_pos: int, halfwin: int, anchors: List[List[str]]
) -> bool:
    """Return True if any anchor token sequence overlaps the [peak-halfwin, peak+halfwin] window."""
    if not anchors:
        # No reliable anchor → do not gate positionally
        return True
    L = len(tokens)
    a, b = _window_bounds(peak_pos, L, halfwin)
    # scan for each anchor
    for pat in anchors:
        m = len(pat)
        if m == 0:
            continue
        # try every start index j
        for j in range(0, L - m + 1):
            if tokens[j : j + m] == pat:
                # overlap?
                if not (j + m - 1 < a or j > b - 1):
                    return True
    return False


# ---------------------- peak position (batched) ------------------------------ #


def _pack_batch(
    tokenizer, token_lists: List[List[str]]
) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """Pack variable-length token sequences into (ids, positions) like your pipeline.
    Returns JAX arrays with PAD and POS set."""
    ids_py = [[tokenizer.token_to_id(t) for t in toks] for toks in token_lists]
    maxlen = max(len(x) for x in ids_py)
    pad = tokenizer.pad_token_id

    ids = np.full((len(ids_py), maxlen), pad, dtype=np.int32)
    for i, row in enumerate(ids_py):
        ids[i, : len(row)] = row

    # build positions like tools.data.get_keys()
    pos_base = np.arange(maxlen, dtype=np.int32)
    pos_tiled = np.tile(pos_base, (ids.shape[0], 1))
    mask = (
        (ids != pad) & (ids != tokenizer.bos_token_id) & (ids != tokenizer.eos_token_id)
    )
    positions = np.where(mask, pos_tiled, -1)

    return jnp.asarray(ids), jnp.asarray(positions)


def peak_positions_for_feature(
    sae_kit: SAEKit,
    layer_id: int,
    feature_id: int,
    token_lists: List[List[str]],
    batch_size: int = 128,
) -> List[int]:
    """Compute argmax token positions for one feature over a list of tokenized sequences (strings)."""
    peaks: List[int] = []
    for i in range(0, len(token_lists), batch_size):
        chunk = token_lists[i : i + batch_size]
        ids, positions = _pack_batch(sae_kit.tokenizer, chunk)
        acts = sae_kit.get_encoded(ids, positions, layer_id=layer_id)  # (B,T,K) masked
        # take this feature
        feat = np.asarray(acts[..., feature_id])  # (B,T)
        # mask already applied; positions==-1 are zeros in mask path
        pk = np.argmax(feat, axis=1).tolist()
        peaks.extend(int(x) for x in pk)
    return peaks


# ============================== data classes ================================ #


@dataclass
class MotifStat:
    name: str
    smarts: str
    hits_pos: int
    hits_neg: int
    pos_size: int
    neg_size: int
    coverage: float
    enrichment: float
    odds_ratio: float
    p_value: float
    fdr_significant: bool


@dataclass
class FeatureAlignment:
    layer: int
    feature_id: int
    topk_smiles: List[str]
    negatives: List[str]
    selected_motifs: List[MotifStat]
    coverage_reached: float
    examples: List[str]
    counterexamples: List[str]


# ======================= positional enrichment core ========================= #


def motif_enrichment_pce(
    positives: List[Tuple[str, List[str], int]],  # (smiles, tokens, peak_pos)
    negatives: List[str],
    candidate_smarts: Dict[str, str],
    *,
    half_window_tokens: int = 4,
    fdr_alpha: float = 0.05,
    min_support: int = 25,
) -> List[MotifStat]:
    """Compute enrichment with positional gating: a positive counts only if a token-anchor
    for the motif overlaps the peak window; negatives counted at molecule-level (conservative)."""

    pos_smis = [s for (s, _, _) in positives]
    pos_unique = list(dict.fromkeys(pos_smis))
    union = list(dict.fromkeys(pos_unique + negatives))
    lib, idx2smi = build_substruct_library(union)
    smi2idx = {s: i for i, s in idx2smi.items()}

    pos_idx = {smi2idx[s] for s in pos_unique if s in smi2idx}
    neg_idx = {smi2idx[s] for s in negatives if s in smi2idx}

    stats: List[MotifStat] = []
    pvals: List[float] = []

    for name, sm in tqdm(candidate_smarts.items(), desc="  motifs (PCE)", leave=False):
        q = Chem.MolFromSmarts(sm)
        if q is None:
            continue

        anchors = anchors_for_motif_tokens(name)
        global_hits = set(lib.GetMatches(q))

        # count positional hits among *positives*
        hits_pos = 0
        for smi, toks, peak in positives:
            idx = smi2idx.get(smi, None)
            if idx is None:
                continue
            if idx not in global_hits:
                continue
            if anchor_hits_window(toks, peak, half_window_tokens, anchors):
                hits_pos += 1

        if hits_pos < min_support:
            continue

        hits_neg = len(global_hits & neg_idx)
        pos_size = len(pos_unique)
        neg_size = len(neg_idx)

        miss_pos = pos_size - hits_pos
        miss_neg = neg_size - hits_neg

        try:
            odds, p = fisher_exact(
                [[hits_pos, miss_pos], [hits_neg, miss_neg]], alternative="greater"
            )
        except Exception:
            odds, p = (math.nan, 1.0)

        cov = hits_pos / max(pos_size, 1)
        enr = (hits_pos / max(pos_size, 1e-9)) / max(
            hits_neg / max(neg_size, 1e-9), 1e-9
        )

        stats.append(
            MotifStat(
                name=name,
                smarts=sm,
                hits_pos=hits_pos,
                hits_neg=hits_neg,
                pos_size=pos_size,
                neg_size=neg_size,
                coverage=cov,
                enrichment=enr,
                odds_ratio=odds,
                p_value=p,
                fdr_significant=False,
            )
        )
        pvals.append(p)

    if not stats:
        return []

    sig = bh_fdr(pvals, alpha=fdr_alpha)
    for st, ok in zip(stats, sig):
        st.fdr_significant = bool(ok)

    def score(s: MotifStat) -> float:
        return s.coverage * math.log2(s.enrichment + 1.0)

    return sorted([s for s in stats if s.fdr_significant], key=score, reverse=True)


def greedy_set_cover(
    positives_smiles: List[str],
    motif_stats: List[MotifStat],
    min_coverage: float = 0.8,
    max_motifs: int = 3,
) -> Tuple[List[MotifStat], float, List[str], List[str]]:
    """Pick minimal subset of motifs covering ≥ min_coverage of positives."""
    if not motif_stats or not positives_smiles:
        return [], 0.0, [], positives_smiles

    lib, idx2smi = build_substruct_library(positives_smiles)
    pos_idx_set = set(idx2smi.keys())

    motif_hits: Dict[int, set] = {}
    for i, m in enumerate(motif_stats):
        q = Chem.MolFromSmarts(m.smarts)
        if q is None:
            continue
        motif_hits[i] = set(lib.GetMatches(q))

    covered: set = set()
    chosen: List[MotifStat] = []
    pos_size = len(pos_idx_set)

    while len(covered) / pos_size < min_coverage and len(chosen) < max_motifs:
        best_i, best_gain = -1, -1
        for i, m in enumerate(motif_stats):
            if m in chosen:
                continue
            gain = len(motif_hits.get(i, set()) - covered)
            if gain > best_gain:
                best_gain, best_i = gain, i
        if best_gain <= 0 or best_i < 0:
            break
        chosen.append(motif_stats[best_i])
        covered |= motif_hits[best_i]

    coverage = len(covered) / pos_size
    examples = [idx2smi[i] for i in sorted(covered)]
    missing = [idx2smi[i] for i in sorted(pos_idx_set - covered)]
    return chosen, coverage, examples, missing


# ============================== I/O helpers ================================= #


def save_feature_bank(
    out_dir: str, layer_id: int, alignments: List[FeatureAlignment]
) -> None:
    os.makedirs(out_dir, exist_ok=True)
    bank = []
    rows = []

    for a in alignments:
        entry = {
            "layer": a.layer,
            "feature_id": a.feature_id,
            "coverage_reached": a.coverage_reached,
            "examples": a.examples,
            "counterexamples": a.counterexamples,
            "motifs": [
                {
                    "name": m.name,
                    "smarts": m.smarts,
                    "coverage": m.coverage,
                    "enrichment": m.enrichment,
                    "odds_ratio": m.odds_ratio,
                    "p_value": m.p_value,
                    "fdr_significant": m.fdr_significant,
                    "hits_pos": m.hits_pos,
                    "hits_neg": m.hits_neg,
                    "pos_size": m.pos_size,
                    "neg_size": m.neg_size,
                }
                for m in a.selected_motifs
            ],
        }
        bank.append(entry)

        if a.selected_motifs:
            for m in a.selected_motifs:
                rows.append(
                    {
                        "layer": a.layer,
                        "feature_id": a.feature_id,
                        "motif_name": m.name,
                        "smarts": m.smarts,
                        "coverage": m.coverage,
                        "enrichment": m.enrichment,
                        "odds_ratio": m.odds_ratio,
                        "p_value": m.p_value,
                        "fdr_significant": m.fdr_significant,
                        "coverage_reached": a.coverage_reached,
                    }
                )
        else:
            rows.append(
                {
                    "layer": a.layer,
                    "feature_id": a.feature_id,
                    "motif_name": "N/A",
                    "smarts": "",
                    "coverage": 0.0,
                    "enrichment": 0.0,
                    "odds_ratio": 0.0,
                    "p_value": 1.0,
                    "fdr_significant": False,
                    "coverage_reached": a.coverage_reached,
                }
            )

    with open(
        os.path.join(out_dir, f"feature_bank_pce_layer{layer_id}.json"), "w"
    ) as fp:
        json.dump(bank, fp, indent=2)

    pd.DataFrame(rows).to_csv(
        os.path.join(out_dir, f"feature_bank_pce_layer{layer_id}.csv"), index=False
    )


# ================================ main ====================================== #


def main():
    ap = argparse.ArgumentParser(
        "Position‑Conditioned Feature‑bank mining (token‑local enrichment)"
    )
    ap.add_argument(
        "--model_dir", required=True, help="LM directory (tokenizer.json + checkpoints)"
    )
    ap.add_argument(
        "--checkpoint_id", required=True, help="LM checkpoint id (number or 'final')"
    )
    ap.add_argument(
        "--sae_dir",
        required=True,
        help="Directory with per-layer SAE checkpoints/configs",
    )
    ap.add_argument(
        "--dataset_dir", required=True, help="HF datasets 'load_from_disk' directory"
    )
    ap.add_argument(
        "--layer", type=int, required=True, help="Layer id for SAE features"
    )
    ap.add_argument(
        "--num_batches",
        type=int,
        default=200,
        help="Batches to scan (batch size: 1024)",
    )
    ap.add_argument(
        "--top_sequences",
        type=int,
        default=1024,
        help="Top sequences kept per feature in collector",
    )
    ap.add_argument(
        "--feature_metric",
        type=str,
        default="gini",
        choices=["mean", "max", "sparsity", "mean*max", "selectivity", "gini"],
        help="Metric for selecting features to align",
    )
    ap.add_argument(
        "--topk_features",
        type=int,
        default=256,
        help="How many features to align after ranking",
    )
    ap.add_argument(
        "--top_pos",
        type=int,
        default=3000,
        help="Positives per feature (top activating sequences)",
    )
    ap.add_argument(
        "--neg_per_pos", type=int, default=3, help="Matched negatives per positive"
    )
    ap.add_argument("--fdr_alpha", type=float, default=0.05, help="BH‑FDR alpha")
    ap.add_argument(
        "--min_support",
        type=int,
        default=25,
        help="Min positive hits per motif before FDR",
    )
    ap.add_argument(
        "--window_tokens",
        type=int,
        default=4,
        help="Half window size around peak token",
    )
    ap.add_argument(
        "--batch_peak",
        type=int,
        default=128,
        help="Batch size when computing peak positions",
    )
    ap.add_argument("--out_dir", required=True, help="Where to save outputs (JSON/CSV)")
    args = ap.parse_args()

    # --- Load LM + SAEs ---
    print("Loading model and SAE kit...")
    sae_kit = SAEKit.load(
        model_dir=args.model_dir, checkpoint_id=args.checkpoint_id, sae_dir=args.sae_dir
    )
    tokenizer = sae_kit.tokenizer

    # --- Dataset (for background and collector) ---
    print("Loading and tokenizing dataset...")
    ds = load_and_tokenize(
        dataset_dir=args.dataset_dir,
        tokenizer=tokenizer,
        batch_size=1024,
        num_processes=4,
        caching=False,
        limit=int(args.num_batches * 1024 * 1.1),
    )

    # --- Scan layer: collector + background SMILES ---
    print(f"Scanning layer {args.layer} to collect top activating SMILES...")
    col, background_smiles = collect_top_smiles_for_layer(
        ds,
        sae_kit,
        layer_id=args.layer,
        top_sequences=args.top_sequences,
        num_batches=args.num_batches,
    )

    # --- Select features ---
    print(f"Selecting top {args.topk_features} features by '{args.feature_metric}'...")
    feature_ids = pick_features(
        col, metric=args.feature_metric, topk=args.topk_features
    )

    # --- Candidates (profile-aware library) ---
    candidates = get_smarts_library()  # respects YAML + profile via env, if set
    print(f"SMARTS candidates: {len(candidates)} motifs")

    alignments: List[FeatureAlignment] = []

    # --- Per-feature alignment (PCE) ---
    for fid in tqdm(feature_ids, desc=f"PCE-align L{args.layer} features"):
        # 1) Gather top token sequences for this feature
        raw_entries = col.top_sequences_for(fid)[
            : args.top_pos
        ]  # (best, count, tokens)
        # dedup by canonical SMILES while keeping tokens
        seen = set()
        pos_tokens: List[List[str]] = []
        pos_smiles: List[str] = []
        for best, count, toks in raw_entries:
            s = clean_smiles("".join(toks))
            if not s:
                continue
            if s in seen:
                continue
            seen.add(s)
            pos_tokens.append(list(toks))
            pos_smiles.append(s)

        if not pos_smiles:
            continue

        # 2) Compute peak positions (batched) for this feature
        peaks = peak_positions_for_feature(
            sae_kit,
            layer_id=args.layer,
            feature_id=fid,
            token_lists=pos_tokens,
            batch_size=args.batch_peak,
        )

        positives = list(zip(pos_smiles, pos_tokens, peaks))

        # 3) Matched negatives (scaffold + props)
        negatives = matched_negatives_by_scaffold_and_props(
            pos_smiles, background_smiles, per_pos=args.neg_per_pos
        )

        # 4) Positional enrichment
        stats = motif_enrichment_pce(
            positives,
            negatives,
            candidates,
            half_window_tokens=args.window_tokens,
            fdr_alpha=args.fdr_alpha,
            min_support=args.min_support,
        )

        # 5) Greedy set-cover selection
        selected, cov, examples, missing = greedy_set_cover(
            pos_smiles, stats, min_coverage=0.8, max_motifs=3
        )

        if selected:
            alignments.append(
                FeatureAlignment(
                    layer=args.layer,
                    feature_id=int(fid),
                    topk_smiles=pos_smiles,
                    negatives=negatives,
                    selected_motifs=selected,
                    coverage_reached=cov,
                    examples=examples[:32],
                    counterexamples=missing[:32],
                )
            )

    # --- Save
    print(
        f"Found alignments for {len(alignments)} / {len(feature_ids)} features. Saving…"
    )
    os.makedirs(args.out_dir, exist_ok=True)
    save_feature_bank(args.out_dir, args.layer, alignments)
    print("Done.")


if __name__ == "__main__":
    main()
