# lmkit/feature_bank/experiments/smarts_stats.py
# -----------------------------------------------------------------------------
# Build a searchable SMARTS↔SMILES index from a HF dataset:
#   - SMARTS → list of SMILES indices (edges.parquet; also counts in smarts.parquet)
#   - SMILES → list of SMARTS (via edges.parquet join/groupby)
#
# Outputs:
#   out_dir/
#     ├─ smiles.parquet      (smiles_idx:int, smiles:str, murcko:str)
#     ├─ smarts.parquet      (smarts_id:str, smarts:str, n_hits:int)
#     ├─ edges.parquet       (smarts_id:str, smiles_idx:int)
#     ├─ examples.json       ({smarts_id: [sample_smiles, ...]})
#     └─ meta.json           (config + quick stats)
#
# Notes:
# - Uses your curated library via env (MOLSAE_SMARTS_YAML / MOLSAE_SMARTS_PROFILE).
# - Scales to millions with Parquet; for *very* large runs, consider chunked writes.
# -----------------------------------------------------------------------------

from __future__ import annotations
import argparse, json, os, math, time
from typing import Dict, List, Tuple, Optional

import numpy as np
import pandas as pd
from tqdm.auto import tqdm

from rdkit import Chem
from rdkit.Chem.Scaffolds import MurckoScaffold
from rdkit.Chem import rdSubstructLibrary as SSL

# Reuse your curated catalog reader (profile-aware)
from .smarts_lib import get_smarts_library
from ..tools.stem import clean_smiles

# ----------------------------- helpers ------------------------------------- #


def murcko(smi: str) -> Optional[str]:
    try:
        m = Chem.MolFromSmiles(smi)
        if not m:
            return None
        scaf = MurckoScaffold.GetScaffoldForMol(m)
        return Chem.MolToSmiles(scaf, isomericSmiles=False) if scaf else None
    except Exception:
        return None


def build_substruct_library(
    smiles_list: List[str],
) -> Tuple[SSL.SubstructLibrary, Dict[int, str]]:
    holder = SSL.CachedTrustedSmilesMolHolder()
    pholder = SSL.PatternHolder()
    idx2smi = {}
    for s in smiles_list:
        try:
            idx = holder.AddSmiles(s)
            pholder.AddMol(Chem.MolFromSmiles(s))
            idx2smi[idx] = s
        except Exception:
            continue
    return SSL.SubstructLibrary(holder, pholder), idx2smi


def sample_examples_for_smarts(
    smarts_id: str,
    smarts: str,
    smiles_indices: List[int],
    idx2smi: Dict[int, str],
    max_examples: int = 256,
    seed: int = 2025,
) -> List[str]:
    """Pick up to max_examples examples with (i) exactly one occurrence of the motif
    and (ii) diverse Murcko scaffolds."""
    rng = np.random.default_rng(seed)
    shuffled = list(smiles_indices)
    rng.shuffle(shuffled)

    q = Chem.MolFromSmarts(smarts)
    if q is None:
        return []

    seen_scaffolds = set()
    out = []
    for idx in shuffled:
        smi = idx2smi.get(idx)
        if not smi:
            continue
        m = Chem.MolFromSmiles(smi)
        if m is None:
            continue

        # Single-occurrence filter keeps examples crisp for downstream inspection
        try:
            matches = m.GetSubstructMatches(q)
        except Exception:
            continue
        if len(matches) != 1:
            continue

        sc = murcko(smi)
        if sc and sc in seen_scaffolds:
            continue

        out.append(smi)
        if sc:
            seen_scaffolds.add(sc)
        if len(out) >= max_examples:
            break

    # If we couldn't find enough single-occ examples, relax to any hit
    if len(out) == 0:
        for idx in smiles_indices[:max_examples]:
            smi = idx2smi.get(idx)
            if smi:
                out.append(smi)
            if len(out) >= max_examples:
                break
    return out


# ----------------------------- data loading -------------------------------- #


def iter_smiles(dataset_dir: str, smiles_col: str, limit: Optional[int] = None):
    """Stream SMILES from a HuggingFace 'load_from_disk' dataset."""
    from datasets import load_from_disk

    ds = load_from_disk(dataset_dir)
    n = len(ds)
    count = 0
    # Iterate directly (Arrow batches under the hood). No tokenization needed.
    for ex in ds:
        s = ex.get(smiles_col)
        if s:
            yield s
            count += 1
            if limit is not None and count >= limit:
                return


def collect_unique_smiles(
    dataset_dir: str, smiles_col: str, limit: Optional[int] = None, dedup: bool = True
) -> List[str]:
    seen = set()
    out = []
    for s in tqdm(
        iter_smiles(dataset_dir, smiles_col, limit), desc="Collecting SMILES"
    ):
        cs = clean_smiles(s)
        if not cs:
            continue
        if dedup:
            if cs in seen:
                continue
            seen.add(cs)
        out.append(cs)
    return out


# --------------------------------- main ------------------------------------ #


def main():
    ap = argparse.ArgumentParser("Build SMARTS↔SMILES search index")
    ap.add_argument(
        "--dataset_dir",
        required=True,
        help="HF datasets 'load_from_disk' directory (e.g., .../valid)",
    )
    ap.add_argument(
        "--smiles_col",
        default="smiles",
        help="Column name for SMILES (default: 'smiles')",
    )
    ap.add_argument("--out_dir", required=True, help="Output directory for index files")
    ap.add_argument(
        "--limit",
        type=int,
        default=None,
        help="Max molecules to process (for dry runs)",
    )
    ap.add_argument(
        "--dedup",
        action="store_true",
        default=True,
        help="Deduplicate canonical SMILES (default: True)",
    )
    ap.add_argument(
        "--examples_per_smarts",
        type=int,
        default=128,
        help="Examples per SMARTS in examples.json",
    )
    # Optional: override SMARTS catalog via env at runtime
    ap.add_argument(
        "--smarts_yaml", type=str, default=None, help="Override MOLSAE_SMARTS_YAML path"
    )
    ap.add_argument(
        "--smarts_profile",
        type=str,
        default=None,
        help="Override MOLSAE_SMARTS_PROFILE (e.g., leadlike)",
    )
    args = ap.parse_args()

    os.makedirs(args.out_dir, exist_ok=True)

    # Env overrides for catalog, if provided
    if args.smarts_yaml:
        os.environ["MOLSAE_SMARTS_YAML"] = args.smarts_yaml
    if args.smarts_profile:
        os.environ["MOLSAE_SMARTS_PROFILE"] = args.smarts_profile

    t0 = time.time()

    # 1) Collect and canonicalize SMILES
    smiles_list = collect_unique_smiles(
        dataset_dir=args.dataset_dir,
        smiles_col=args.smiles_col,
        limit=args.limit,
        dedup=args.dedup,
    )
    if len(smiles_list) == 0:
        raise SystemExit("No valid SMILES collected. Check dataset/column.")
    print(f"Collected {len(smiles_list):,} unique SMILES.")

    # 2) Build SubstructLibrary for fast matching
    print("Building RDKit SubstructLibrary…")
    lib, idx2smi = build_substruct_library(smiles_list)
    print(f"Library built with {len(idx2smi):,} molecules.")

    # 3) Load curated SMARTS catalog (profile-aware)
    catalog: Dict[str, str] = get_smarts_library()
    print(f"Loaded {len(catalog)} SMARTS patterns from catalog.")

    # 4) Match all SMARTS → indices
    edges_rows = []
    smarts_rows = []
    examples_dict = {}

    print("Scanning SMARTS across library…")
    for smarts_id, smarts in tqdm(catalog.items(), total=len(catalog), desc="SMARTS"):
        try:
            q = Chem.MolFromSmarts(smarts)
        except Exception:
            q = None
        if q is None:
            continue

        try:
            idx_hits = list(lib.GetMatches(q))  # list of int indices
        except Exception:
            idx_hits = []

        n_hits = len(idx_hits)
        smarts_rows.append({"smarts_id": smarts_id, "smarts": smarts, "n_hits": n_hits})

        if n_hits > 0:
            # Edges: (smarts_id, smiles_idx)
            edges_rows.extend(
                {"smarts_id": smarts_id, "smiles_idx": int(i)} for i in idx_hits
            )
            # Examples: scaffold-diverse, single-occur where possible
            examples = sample_examples_for_smarts(
                smarts_id,
                smarts,
                idx_hits,
                idx2smi,
                max_examples=args.examples_per_smarts,
            )
            if examples:
                examples_dict[smarts_id] = examples

    # 5) Save tables
    print("Saving Parquet tables…")
    # smiles table
    smiles_records = []
    for idx in range(len(smiles_list)):
        s = idx2smi[idx]
        smiles_records.append(
            {"smiles_idx": idx, "smiles": s, "murcko": murcko(s) or ""}
        )
    df_smiles = pd.DataFrame(smiles_records)
    df_smiles.to_parquet(os.path.join(args.out_dir, "smiles.parquet"), index=False)

    # smarts table
    df_smarts = pd.DataFrame(smarts_rows)
    df_smarts.sort_values("smarts_id", inplace=True)
    df_smarts.to_parquet(os.path.join(args.out_dir, "smarts.parquet"), index=False)

    # edges table (bipartite)
    if edges_rows:
        df_edges = pd.DataFrame(edges_rows)
        df_edges.to_parquet(os.path.join(args.out_dir, "edges.parquet"), index=False)
    else:
        # create empty table with schema
        df_edges = pd.DataFrame(
            {"smarts_id": pd.Series(dtype=str), "smiles_idx": pd.Series(dtype=np.int64)}
        )
        df_edges.to_parquet(os.path.join(args.out_dir, "edges.parquet"), index=False)

    # examples json
    with open(os.path.join(args.out_dir, "examples.json"), "w") as f:
        json.dump(examples_dict, f, indent=2)

    # meta json
    meta = {
        "dataset_dir": args.dataset_dir,
        "smiles_col": args.smiles_col,
        "limit": args.limit,
        "dedup": bool(args.dedup),
        "examples_per_smarts": args.examples_per_smarts,
        "n_smiles": int(len(smiles_list)),
        "n_smarts": int(len(catalog)),
        "build_seconds": round(time.time() - t0, 2),
        "smarts_yaml": os.environ.get("MOLSAE_SMARTS_YAML", None),
        "smarts_profile": os.environ.get("MOLSAE_SMARTS_PROFILE", None),
    }
    with open(os.path.join(args.out_dir, "meta.json"), "w") as f:
        json.dump(meta, f, indent=2)

    print(f"Done. Wrote index to: {os.path.abspath(args.out_dir)}")


if __name__ == "__main__":
    main()
