# lmkit/feature_bank/smarts_lib.py
from __future__ import annotations
import os, json
from typing import Dict, Tuple, List
from rdkit import Chem, RDConfig, RDLogger
from rdkit.Chem import FilterCatalog
import yaml

RDLogger.DisableLog("rdApp.warning")


# ------------------------- helpers ------------------------- #
def _mol_from_smarts(s: str):
    try:
        return Chem.MolFromSmarts(s)
    except Exception:
        return None


def _canon(s: str) -> str | None:
    m = _mol_from_smarts(s)
    if m is None:
        return None
    try:
        return Chem.MolToSmarts(m, isomericSmiles=False)
    except Exception:
        return None


def _validize(d: Dict[str, str]) -> Dict[str, str]:
    out = {}
    for k, v in d.items():
        can = _canon(v)
        if can:
            out[k] = can
    return out


def _dedup(d: Dict[str, str]) -> Dict[str, str]:
    seen, out = set(), {}
    for k, v in d.items():
        if v in seen:
            continue
        seen.add(v)
        out[k] = v
    return out


# ---------------------- RDKit sources ---------------------- #
def _load_rdkit_fgh() -> Dict[str, str]:
    path = os.path.join(RDConfig.RDDataDir, "FunctionalGroups.txt")
    out = {}
    if not os.path.exists(path):
        return out
    with open(path, "r", encoding="utf-8") as fp:
        for line in fp:
            line = line.strip()
            if not line or line.startswith("#"):
                continue
            if "#" in line:
                line = line.split("#", 1)[0].strip()
            parts = line.split(None, 1)
            if len(parts) != 2:
                continue
            name, smarts = parts
            key = f"rdkit:{name.lower()}"
            can = _canon(smarts)
            if can:
                out[key] = can
    return out


def _load_rdkit_alerts(sets=("PAINS", "BRENK", "NIH")) -> Dict[str, str]:
    alerts = {}
    try:
        params = FilterCatalog.FilterCatalogParams()
        for s in sets:
            attr = getattr(FilterCatalog.FilterCatalogParams.FilterCatalogs, s)
            params.AddCatalog(attr)
        cat = FilterCatalog.FilterCatalog(params)
        for i in range(cat.GetNumEntries()):
            e = cat.GetEntry(i)
            name = e.GetDescription() or f"alert_{i}"
            sm = e.GetSmarts()
            can = _canon(sm)
            if can:
                key = f"alert:{name.strip().lower().replace(' ', '_')}"
                alerts[key] = can
    except Exception:
        alerts = {}
    return alerts


# ----------------------- YAML loading ---------------------- #
def _smarts_from_seeds(entries: List[dict]) -> Dict[str, str]:
    out = {}
    for item in entries:
        _id = item["id"]
        smi = item["smiles"]
        m = Chem.MolFromSmiles(smi)
        if m is None:
            continue
        smarts = Chem.MolToSmarts(m, isomericSmiles=False)
        can = _canon(smarts)
        if can:
            out[_id] = can
    return out


def _build_from_yaml(path: str) -> Tuple[Dict[str, str], Dict[str, str], dict]:
    with open(path, "r", encoding="utf-8") as fp:
        cfg = yaml.safe_load(fp)

    precedence: List[str] = cfg.get(
        "precedence", ["curated", "rdkit_fgh", "ring_seeds"]
    )
    curated = {item["id"]: item["smarts"] for item in cfg.get("curated", [])}
    curated = _validize(curated)

    rings = _smarts_from_seeds(cfg.get("ring_seeds", []))

    rdkit_fgh = _load_rdkit_fgh()
    alerts = (
        _load_rdkit_alerts(
            tuple(cfg.get("sources", {}).get("rdkit_alerts", {}).get("sets", []))
        )
        if "rdkit_alerts" in cfg.get("sources", {})
        else {}
    )

    # precedence merge
    by_name = {"curated": curated, "rdkit_fgh": rdkit_fgh, "ring_seeds": rings}
    lib: Dict[str, str] = {}
    for name in precedence:
        src = by_name.get(name, {})
        for k, v in src.items():
            if v in lib.values():  # skip exact duplicate SMARTS
                continue
            if k not in lib:
                lib[k] = v

    lib = _dedup(_validize(lib))

    # profiles
    profiles = cfg.get("profiles", {})
    meta = cfg.get("meta", {})
    return lib, alerts, {"profiles": profiles, "meta": meta}


# ------------------------ Fallback build -------------------- #
# (Used if YAML is missing. Mirrors what you got previously.)
def _fallback_build() -> Tuple[Dict[str, str], Dict[str, str], dict]:
    curated = {
        "fg:primary_amine": r"[NX3;H2;!$([NX3]-[CX3](=O))]",
        "fg:secondary_amine": r"[NX3;H1;!$([NX3]-[CX3](=O))]",
        "fg:tertiary_amine": r"[NX3;H0;!$([NX3]-[CX3](=O))]",
        "charge:quaternary_ammonium": r"[N+](C)(C)(C)C",
        "fg:amide_any": r"[CX3](=O)[NX3]",
        "fg:amide_primary": r"[CX3](=O)[NX3;H2,H1]",
        "fg:amide_secondary": r"[CX3](=O)[NX3;H1]([#6])",
        "fg:amide_tertiary": r"[CX3](=O)[NX3]([#6])[#6]",
        "fg:anilide": r"c[CX3](=O)[NX3]",
        "fg:urea": r"[NX3][CX3](=O)[NX3]",
        "fg:carbamate": r"[OX2][CX3](=O)[NX3]",
        "fg:imide": r"[NX3][CX3](=O)[NX3][CX3](=O)",
        "fg:carboxylic_acid": r"[CX3](=O)[OX2H1]",
        "fg:carboxylate": r"[CX3](=O)[O-]",
        "fg:ester": r"[CX3](=O)[OX2][#6]",
        "fg:aldehyde": r"[CX3H1](=O)[#6]",
        "fg:ketone": r"[CX3](=O)[#6]",
        "fg:ether": r"[OX2]([#6])[#6]",
        "fg:phenol": r"c[OX2H]",
        "fg:alcohol": r"[OX2H][#6]",
        "fg:sulfone": r"S(=O)(=O)[#6]",
        "fg:sulfoxide": r"S(=O)[#6]",
        "fg:sulfonamide_any": r"S(=O)(=O)[NX3]",
        "fg:sulfonamide_primary": r"S(=O)(=O)[NX3;H2]",
        "fg:sulfonamide_secondary": r"S(=O)(=O)[NX3;H1]([#6])",
        "fg:sulfonamide_tertiary": r"S(=O)(=O)[NX3]([#6])[#6]",
        "fg:sulfonate_ester": r"S(=O)(=O)O[#6]",
        "fg:nitrile": r"[CX2]#N",
        "fg:nitro": r"[NX3+](=O)[O-]",
        "fg:azo": r"N=N",
        "subst:aryl_halide": r"a-[F,Cl,Br,I]",
        "subst:alkyl_halide": r"[#6]-[F,Cl,Br,I]",
        "subst:vinyl_halide": r"C=C-[F,Cl,Br,I]",
        "subst:trifluoromethyl": r"[CX4](F)(F)F",
        "subst:difluoromethyl": r"[CX4H](F)F",
        "fg:amidine": r"[NX3][CX3](=N)[NX3]",
        "fg:guanidine": r"NC(=N)N",
    }
    curated = _validize(curated)

    ring_seeds = {
        "ring:benzene": "c1ccccc1",
        "ring:pyridine": "n1ccccc1",
        "ring:imidazole": "c1n[cH]nc1",
        "ring:pyrazole": "c1nnc[cH]1",
        "ring:oxazole": "c1noc[cH]1",
        "ring:isoxazole": "c1nocc1",
        "ring:thiazole": "c1nsc[cH]1",
        "ring:isothiazole": "c1nscc1",
        "ring:indole": "c1ccc2[nH]ccc2c1",
        "ring:quinoline": "c1ccc2ncccc2c1",
        "ring:isoquinoline": "c1ccc2ccncc2c1",
        "sat_ring:piperidine": "N1CCCCC1",
        "sat_ring:piperazine": "N1CCNCC1",
        "sat_ring:morpholine": "O1CCNCC1",
        "sat_ring:pyrrolidine": "N1CCCC1",
        "sat_ring:azetidine": "N1CCC1",
        "sat_ring:oxetane": "O1CCC1",
        "sat_ring:1_4_dioxane": "O1CCOCC1",
    }
    # convert SMILES to SMARTS
    rings = {}
    for k, smi in ring_seeds.items():
        m = Chem.MolFromSmiles(smi)
        if m is None:
            continue
        s = Chem.MolToSmarts(m, isomericSmiles=False)
        can = _canon(s)
        if can:
            rings[k] = can

    rdkit_fgh = _load_rdkit_fgh()
    alerts = _load_rdkit_alerts()

    lib = {}
    for src in (curated, rdkit_fgh, rings):
        for k, v in src.items():
            if v in lib.values():
                continue
            if k not in lib:
                lib[k] = v
    lib = _dedup(_validize(lib))

    profiles = {
        "leadlike": {
            "include_prefixes": ("fg:", "ring:", "sat_ring:", "subst:"),
            "drop": {
                "charge:quaternary_ammonium",
                "fg:anhydride",
                "fg:sulfonyl_chloride",
                "fg:azide",
                "fg:oxime",
                "fg:phosphonate",
                "fg:phosphinate",
            },
        },
        "all": {
            "include_prefixes": ("fg:", "ring:", "sat_ring:", "subst:", "charge:"),
            "drop": set(),
        },
    }
    meta = {"name": "MolSAE-SMARTS", "version": "1.1", "profile_default": "leadlike"}
    return lib, alerts, {"profiles": profiles, "meta": meta}


# ----------------------- public API ------------------------ #
_YAML_PATH = os.environ.get("MOLSAE_SMARTS_YAML", "")
try:
    if _YAML_PATH and os.path.exists(_YAML_PATH):
        _LIB, _ALERTS, _INFO = _build_from_yaml(_YAML_PATH)
    else:
        _LIB, _ALERTS, _INFO = _fallback_build()
except Exception:
    _LIB, _ALERTS, _INFO = _fallback_build()

SMARTS_LIBRARY: Dict[str, str] = dict(_LIB)  # backward-compat
ALERTS_LIBRARY: Dict[str, str] = dict(_ALERTS)
PROFILE_INFO = _INFO


def get_smarts_library(profile: str | None = None) -> Dict[str, str]:
    """
    Returns a (profile-filtered) copy of the SMARTS library.
    Profiles control class prefixes and a small drop list to suit lead-like space.
    """
    if profile is None:
        profile = os.environ.get(
            "MOLSAE_SMARTS_PROFILE",
            PROFILE_INFO["meta"].get("profile_default", "leadlike"),
        )
    lib = dict(_LIB)
    profs = PROFILE_INFO.get("profiles", {})
    if profile in profs:
        pf = profs[profile]
        inc = tuple(pf.get("include_prefixes", ()))
        drop = set(pf.get("drop", []))
        if inc:
            lib = {k: v for k, v in lib.items() if k.startswith(inc)}
        if drop:
            lib = {k: v for k, v in lib.items() if k not in drop}
    return lib
