# lmkit/feature_bank/mining.py
from __future__ import annotations

import json
import math
import os
import time
from collections import defaultdict
import dataclasses
from dataclasses import dataclass
from typing import Callable, Dict, List, Optional, Set, Tuple

import numpy as np
import pandas as pd
from rdkit import Chem
from rdkit.Chem import rdSubstructLibrary as SSL
from rdkit.Chem.Scaffolds import MurckoScaffold
from scipy.stats import fisher_exact
from tqdm.auto import tqdm

from ..sparse.atlas_utils import StatsCollector, process_batch
from ..tools.stem import clean_smiles
from .smarts_lib import SMARTS_LIBRARY

# ----------------------------- progress utils ----------------------------- #


def _fmt_eta(seconds: Optional[float]) -> str:
    if seconds is None or not np.isfinite(seconds):
        return "?"
    seconds = max(0.0, float(seconds))
    m, s = divmod(int(seconds + 0.5), 60)
    h, m = divmod(m, 60)
    if h > 99:  # avoid silly HH:MM:SS for giant jobs
        return f"{h:02d}h"
    return f"{h:02d}:{m:02d}:{s:02d}"


class EmaRate:
    """Exponential-smoothed rate estimator with ETA."""

    def __init__(self, total: Optional[int] = None, alpha: float = 0.15):
        self.total = total
        self.alpha = alpha
        self.t0 = time.time()
        self._last_t = self.t0
        self.n = 0
        self.ema_rate = None  # items/sec

    def update(self, inc: int = 1):
        now = time.time()
        dt = max(1e-6, now - self._last_t)
        inst_rate = inc / dt
        self.ema_rate = (
            inst_rate
            if self.ema_rate is None
            else (self.alpha * inst_rate + (1 - self.alpha) * self.ema_rate)
        )
        self.n += inc
        self._last_t = now
        return self.snapshot()

    def snapshot(self):
        elapsed = time.time() - self.t0
        eta = None
        if self.total is not None and self.ema_rate and self.ema_rate > 1e-12:
            remaining = max(0, self.total - self.n)
            eta = remaining / self.ema_rate
        return {
            "processed": self.n,
            "elapsed_s": elapsed,
            "rate_per_s": self.ema_rate or 0.0,
            "eta_s": eta,
        }


# Progress callback signature
ProgressCB = Optional[Callable[[dict], None]]


# ----------------------------- helpers ------------------------------------ #


def murcko(smiles: str) -> Optional[str]:
    try:
        m = Chem.MolFromSmiles(smiles)
        if not m:
            return None
        scaf = MurckoScaffold.GetScaffoldForMol(m)
        return Chem.MolToSmiles(scaf, isomericSmiles=False) if scaf else None
    except Exception:
        return None


def build_substruct_library(
    smiles_list: List[str],
) -> Tuple[SSL.SubstructLibrary, Dict[int, str]]:
    holder = SSL.CachedTrustedSmilesMolHolder()
    pholder = SSL.PatternHolder()
    idx2smi = {}
    for s in smiles_list:
        try:
            idx = holder.AddSmiles(s)
            pholder.AddMol(Chem.MolFromSmiles(s))
            idx2smi[idx] = s
        except Exception:
            continue
    return SSL.SubstructLibrary(holder, pholder), idx2smi


def bh_fdr(pvals: List[float], alpha=0.05) -> np.ndarray:
    """Benjamini-Hochberg False Discovery Rate correction."""
    p = np.asarray(pvals)
    n = p.size
    order = np.argsort(p)
    ranked = np.empty_like(order)
    ranked[order] = np.arange(1, n + 1)

    # FDR threshold for each p-value
    fdr_thresholds = (ranked / n) * alpha

    # Find the largest k such that p_(k) <= (k/n)*alpha
    significant_mask = (
        p <= fdr_thresholds[np.argsort(order)]
    )  # This part is tricky. Let's fix it.

    # Correct way to apply BH
    p_ordered = p[order]
    fdr_thresholds_ordered = (np.arange(1, n + 1) / n) * alpha

    significant_ordered = p_ordered <= fdr_thresholds_ordered

    # Find last significant p-value
    if np.any(significant_ordered):
        max_k = np.where(significant_ordered)[0].max()
        # All p-values up to this k are significant
        significant_indices = order[: max_k + 1]
    else:
        significant_indices = []

    dec = np.zeros_like(p, dtype=bool)
    dec[significant_indices] = True
    return dec


# --------------------------- data classes ---------------------------------- #


@dataclass
class MotifStat:
    name: str
    smarts: str
    hits_pos: int
    hits_neg: int
    pos_size: int
    neg_size: int
    coverage: float
    enrichment: float
    odds_ratio: float
    p_value: float
    fdr_significant: bool


@dataclass
class FeatureAlignment:
    layer: int
    feature_id: int
    topk_smiles: List[str]
    negatives: List[str]
    selected_motifs: List[MotifStat]
    coverage_reached: float
    examples: List[str]
    counterexamples: List[str]


# ------------------------- mining pipeline --------------------------------- #


def collect_top_smiles_for_layer(
    dataset,
    sae_kit,
    layer_id: int,
    top_sequences: int = 5000,
    num_batches: int = 200,
    progress_cb: ProgressCB = None,
) -> Tuple[StatsCollector, List[str]]:
    """
    Scan dataset to fill a StatsCollector (top sequences per latent) and a
    deduplicated background SMILES list. Emits progress with ETA.
    """
    latent_size = sae_kit.sae_configs[layer_id].latent_size
    col = StatsCollector(
        latent_size=latent_size, top_sequences=top_sequences, top_tokens=100
    )

    all_smiles: List[str] = []
    seen = set()

    # We *approximately* know how many molecules we will see: num_batches * batch_size.
    batch_size = (
        dataset.batch_size if hasattr(dataset, "batch_size") else 1024
    )  # A reasonable guess
    approx_total = num_batches * batch_size
    timer = EmaRate(total=approx_total)

    for b_idx, batch in enumerate(
        tqdm(dataset, total=num_batches, desc=f"[L{layer_id}] scanning", unit="batch")
    ):
        if b_idx >= num_batches:
            break

        col.update(batch, sae_kit, layer_id)
        _, raw_inputs = process_batch(batch, sae_kit, layer_id)

        new = 0
        for toks in raw_inputs:
            s = clean_smiles("".join(toks))
            if s and s not in seen:
                all_smiles.append(s)
                seen.add(s)
            new += 1

        snap = timer.update(inc=new)
        if progress_cb:
            progress_cb(
                {
                    "phase": "scan",
                    "layer": layer_id,
                    "batches_done": b_idx + 1,
                    "batches_total": num_batches,
                    "molecules_seen": snap["processed"],
                    "unique_smiles": len(all_smiles),
                    "rate_mols_per_s": snap["rate_per_s"],
                    "eta_s": snap["eta_s"],
                    "elapsed_s": snap["elapsed_s"],
                }
            )

    return col, all_smiles


def pick_features(col: StatsCollector, metric: str, topk: int) -> List[int]:
    from ..sparse.atlas_utils import NeuronSelector

    sel = NeuronSelector(col)
    idxs = sel.pick(metric=metric, topk=topk, sort="desc")
    return list(idxs if isinstance(idxs, np.ndarray) else [idxs])


def top_smiles_for_feature(col: StatsCollector, fid: int, limit: int) -> List[str]:
    entries = col.top_sequences_for(fid)[:limit]
    smiles = []
    for best, count, toks in entries:
        s = clean_smiles("".join(toks))
        if s:
            smiles.append(s)
    # dedup preserving order
    seen = set()
    out = []
    for s in smiles:
        if s not in seen:
            out.append(s)
            seen.add(s)
    return out


def matched_negatives_by_scaffold(
    positives: List[str], background: List[str], per_pos: int = 3
) -> List[str]:
    pos_scaf = {s: murcko(s) for s in positives}

    scaf_map: Dict[str, List[str]] = defaultdict(list)
    for s in background:
        sc = murcko(s)
        if sc:
            scaf_map[sc].append(s)

    rng = np.random.default_rng(2025)
    negs = []
    pos_set = set(positives)

    for s, sc in pos_scaf.items():
        if not sc or sc not in scaf_map:
            continue
        # Pool should not contain any of the positive SMILES
        pool = [x for x in scaf_map[sc] if x not in pos_set]
        if not pool:
            continue
        k = min(per_pos, len(pool))
        negs.extend(list(rng.choice(pool, size=k, replace=False)))

    return list(dict.fromkeys(negs))  # dedup


def motifs_from_library() -> dict[str, str]:
    from .smarts_lib import get_smarts_library

    return get_smarts_library()


def motif_enrichment(
    positives: List[str],
    negatives: List[str],
    candidate_smarts: Dict[str, str],
    fdr_alpha: float = 0.05,
    progress_cb: ProgressCB = None,
) -> List[MotifStat]:
    pos_set = set(positives)
    neg_set = set(negatives) - pos_set  # ensure negatives are disjoint

    union = list(pos_set | neg_set)
    lib, idx2smi = build_substruct_library(union)

    pos_idx = {i for i, s in idx2smi.items() if s in pos_set}
    neg_idx = {i for i, s in idx2smi.items() if s in neg_set}

    stats: List[MotifStat] = []
    pvals: List[float] = []

    names = list(candidate_smarts.items())
    timer = EmaRate(total=len(names))

    for j, (name, sm) in enumerate(
        tqdm(names, desc="  motif enrichment", leave=False, unit="motif")
    ):
        snap = timer.update(1)  # update timer regardless of outcome
        try:
            q = Chem.MolFromSmarts(sm)
            if q is None:
                continue
            hits = set(lib.GetMatches(q))
        except Exception:
            continue

        hits_pos = len(hits & pos_idx)
        hits_neg = len(hits & neg_idx)

        if hits_pos == 0:  # Skip motifs that don't appear in the positive set
            continue

        pos_size = len(pos_idx)
        neg_size = len(neg_idx)
        miss_pos = pos_size - hits_pos
        miss_neg = neg_size - hits_neg

        try:
            # Fisher's Exact Test for enrichment
            odds, p = fisher_exact(
                [[hits_pos, miss_pos], [hits_neg, miss_neg]], alternative="greater"
            )
        except (
            ValueError
        ):  # Table might have a zero, but fisher_exact should handle it. Precaution.
            odds, p = (math.nan, 1.0)

        cov = hits_pos / max(pos_size, 1)
        # Enrichment: freq in positives vs freq in negatives
        enr = (hits_pos / max(pos_size, 1e-9)) / max(
            hits_neg / max(neg_size, 1e-9), 1e-9
        )

        stats.append(
            MotifStat(
                name=name,
                smarts=sm,
                hits_pos=hits_pos,
                hits_neg=hits_neg,
                pos_size=pos_size,
                neg_size=neg_size,
                coverage=cov,
                enrichment=enr,
                odds_ratio=odds,
                p_value=p,
                fdr_significant=False,  # will be set later
            )
        )
        pvals.append(p)

        if progress_cb and (j % 25 == 0):
            progress_cb(
                {
                    "phase": "motif_enrichment",
                    "processed": snap["processed"],
                    "total": len(names),
                    "rate_per_s": snap["rate_per_s"],
                    "eta_s": snap["eta_s"],
                }
            )

    if not stats:
        return []

    # Perform FDR correction
    sig = bh_fdr(np.array(pvals), alpha=fdr_alpha)
    for st, ok in zip(stats, sig):
        st.fdr_significant = bool(ok)

    # Filter for significant motifs and sort them
    significant_stats = [s for s in stats if s.fdr_significant]

    # Score by combination of coverage and enrichment
    def score(s: MotifStat) -> float:
        return s.coverage * math.log2(s.enrichment + 1.0)

    significant_stats.sort(key=score, reverse=True)
    return significant_stats


def greedy_set_cover(
    positives: List[str],
    motif_stats: List[MotifStat],
    min_coverage: float = 0.8,
    max_motifs: int = 3,
    progress_cb: ProgressCB = None,
) -> Tuple[List[MotifStat], float, List[str], List[str]]:
    """Pick a minimal subset of motifs that covers ≥ min_coverage of positives."""
    if not motif_stats or not positives:
        return [], 0.0, [], positives

    lib, idx2smi = build_substruct_library(positives)
    pos_idx_set = set(idx2smi.keys())

    motif_hits: Dict[int, Set[int]] = {}
    for i, m in enumerate(motif_stats):
        q = Chem.MolFromSmarts(m.smarts)
        if q is None:
            continue
        motif_hits[i] = set(lib.GetMatches(q))

    covered_indices: Set[int] = set()
    chosen_motifs: List[MotifStat] = []

    pos_size = len(pos_idx_set)

    while (
        len(covered_indices) / pos_size < min_coverage
        and len(chosen_motifs) < max_motifs
    ):
        best_motif_idx = -1
        best_gain = -1

        for i, m in enumerate(motif_stats):
            if m in chosen_motifs:
                continue

            current_hits = motif_hits.get(i, set())
            gain = len(current_hits - covered_indices)

            if gain > best_gain:
                best_gain = gain
                best_motif_idx = i

        if best_gain <= 0:
            break

        chosen_motifs.append(motif_stats[best_motif_idx])
        covered_indices.update(motif_hits[best_motif_idx])

        if progress_cb:
            progress_cb(
                {
                    "phase": "setcover_step",
                    "selected": len(chosen_motifs),
                    "coverage_now": len(covered_indices) / pos_size,
                }
            )

    final_coverage = len(covered_indices) / pos_size
    example_smiles = [idx2smi[i] for i in sorted(covered_indices)]
    missing_smiles = [idx2smi[i] for i in sorted(pos_idx_set - covered_indices)]

    return chosen_motifs, final_coverage, example_smiles, missing_smiles


def align_feature(
    layer_id: int,
    feature_id: int,
    col: StatsCollector,
    background_smiles: List[str],
    top_pos: int = 3000,
    neg_per_pos: int = 3,
    fdr_alpha: float = 0.05,
    setcover_target: float = 0.8,
    max_motifs: int = 3,
    progress_cb: ProgressCB = None,
) -> FeatureAlignment:
    positives = top_smiles_for_feature(col, feature_id, limit=top_pos)
    negatives = matched_negatives_by_scaffold(
        positives, background_smiles, per_pos=neg_per_pos
    )
    if progress_cb:
        progress_cb(
            {
                "phase": "feature_start",
                "layer": layer_id,
                "feature_id": feature_id,
                "positives": len(positives),
                "negatives": len(negatives),
            }
        )

    candidates = motifs_from_library()
    stats = motif_enrichment(
        positives, negatives, candidates, fdr_alpha=fdr_alpha, progress_cb=progress_cb
    )
    selected, cov, examples, missing = greedy_set_cover(
        positives,
        stats,
        min_coverage=setcover_target,
        max_motifs=max_motifs,
        progress_cb=progress_cb,
    )

    if progress_cb:
        top_line = ""
        if selected:
            top = selected[0]
            top_line = f"{top.name} (cov={top.coverage:.2f}, enr={top.enrichment:.1f})"
        progress_cb(
            {
                "phase": "feature_done",
                "layer": layer_id,
                "feature_id": feature_id,
                "selected_count": len(selected),
                "coverage_reached": cov,
                "top_motif": top_line,
            }
        )

    return FeatureAlignment(
        layer=layer_id,
        feature_id=feature_id,
        topk_smiles=positives,
        negatives=negatives,
        selected_motifs=selected,
        coverage_reached=cov,
        examples=examples[:32],
        counterexamples=missing[:32],
    )


def _jsonify(obj):
    """Recursively convert NumPy scalars/arrays to native Python for json.dump."""
    if isinstance(obj, dict):
        return {k: _jsonify(v) for k, v in obj.items()}
    if isinstance(obj, (list, tuple)):
        return [_jsonify(x) for x in obj]
    if isinstance(obj, np.generic):  # np.int64, np.float64, np.bool_
        return obj.item()
    if isinstance(obj, np.ndarray):
        return obj.tolist()
    if dataclasses.is_dataclass(obj):
        return _jsonify(dataclasses.asdict(obj))
    return obj


def save_feature_bank(
    out_dir: str, layer_id: int, alignments: List[FeatureAlignment]
) -> None:
    os.makedirs(out_dir, exist_ok=True)
    bank = []
    rows = []

    for a in alignments:
        entry = {
            "layer": a.layer,
            "feature_id": a.feature_id,
            "coverage_reached": a.coverage_reached,
            "examples": a.examples,
            "counterexamples": a.counterexamples,
            "motifs": [
                {
                    "name": m.name,
                    "smarts": m.smarts,
                    "coverage": m.coverage,
                    "enrichment": m.enrichment,
                    "odds_ratio": m.odds_ratio,
                    "p_value": m.p_value,
                    "fdr_significant": m.fdr_significant,
                    "hits_pos": m.hits_pos,
                    "hits_neg": m.hits_neg,
                    "pos_size": m.pos_size,
                    "neg_size": m.neg_size,
                }
                for m in a.selected_motifs
            ],
        }
        bank.append(entry)

        if a.selected_motifs:
            for m in a.selected_motifs:
                rows.append(
                    {
                        "layer": a.layer,
                        "feature_id": a.feature_id,
                        "motif_name": m.name,
                        "smarts": m.smarts,
                        "coverage": m.coverage,
                        "enrichment": m.enrichment,
                        "odds_ratio": m.odds_ratio,
                        "p_value": m.p_value,
                        "fdr_significant": m.fdr_significant,
                        "coverage_reached": a.coverage_reached,
                    }
                )
        else:
            rows.append(
                {
                    "layer": a.layer,
                    "feature_id": a.feature_id,
                    "motif_name": "N/A",
                    "smarts": "",
                    "coverage": 0.0,
                    "enrichment": 0.0,
                    "odds_ratio": 0.0,
                    "p_value": 1.0,
                    "fdr_significant": False,
                    "coverage_reached": a.coverage_reached,
                }
            )

    with open(os.path.join(out_dir, f"feature_bank_layer{layer_id}.json"), "w") as fp:
        json.dump(_jsonify(bank), fp, indent=2)

    pd.DataFrame(rows).to_csv(
        os.path.join(out_dir, f"feature_bank_layer{layer_id}.csv"), index=False
    )
