#!/usr/bin/env python3
import argparse
import json
import math
import os
import sys
import time
from pathlib import Path
from typing import Any, Dict, List

import jax.numpy as jnp
import numpy as np
import yaml
from rdkit import Chem

from lmkit.sparse.attribution import (
    token_level_auroc,
)
from lmkit.sparse.fragment_mapper import compile_smarts, find_fragments_and_tokens
from lmkit.sparse.fragment_metrics import (
    across_sequence_selectivity,  # metric (b, exact & additive via point-biserial)
    build_pos_masks_from_index,
    within_seq_discriminativity,  # metric (a)
)

# --- LMKit imports (your repo layout) ---
from lmkit.sparse.sae import SAEKit
from lmkit.tools import data as data_tools
from lmkit.tools.compat import load_tokenizer

# ---------------------------
# YAML → SMARTS dict loader
# ---------------------------


def load_smarts_yaml(yaml_path: str, profile: str = "leadlike") -> Dict[str, str]:
    import re

    with open(yaml_path, "r", encoding="utf-8") as f:
        raw = f.read()

    semicolon_line = re.compile(
        r'^(\s*)-\s+id:\s*([^;\n]+?)\s*;\s*smiles:\s*"(.*?)"\s*$',
        flags=re.MULTILINE,
    )

    def repl(m: re.Match) -> str:
        indent = m.group(1)  # leading spaces
        idval = m.group(2).strip()  # e.g. ring:benzene
        smiles = m.group(3)  # quoted content without quotes
        return f'{indent}- id: {idval}\n{indent}  smiles: "{smiles}"'

    fixed = semicolon_line.sub(repl, raw)

    try:
        y = yaml.safe_load(fixed)
    except yaml.YAMLError as e:
        raise ValueError(
            "Failed to parse SMARTS YAML even after preprocessing. "
            "Consider rewriting ring_seeds to two-line mappings.\n"
            f"Details: {e}"
        )

    prof = y["profiles"][profile]
    include_prefixes = tuple(prof.get("include_prefixes", []))
    drop_ids = set(prof.get("drop", []))

    def want_id(sid: str) -> bool:
        return sid.startswith(include_prefixes) and sid not in drop_ids

    out = {}

    # curated: explicit SMARTS
    for item in y.get("curated", []):
        sid = item["id"]
        if want_id(sid):
            out[sid] = item["smarts"]

    # ring_seeds: SMILES (treat as SMARTS)
    for item in y.get("ring_seeds", []):
        sid = item["id"]
        if want_id(sid):
            smi = item.get("smiles")
            if smi:
                out[sid] = smi

    if not out:
        raise ValueError(
            "No SMARTS selected by the profile—check YAML/profile settings."
        )

    return out


# --------------------------------------
# Raw SMILES reconstruction from batch
# --------------------------------------


def reconstruct_raw_smiles_batch(batch, tokenizer) -> list[str]:
    raw_smiles_batch = []

    for seq in batch["inputs"]:
        raw_smiles = tokenizer.decode(seq[1:], skip_special_tokens=False)
        raw_smiles_batch.append(
            (raw_smiles.split(tokenizer.eos_token)[0]).replace(tokenizer.pad_token, "")
        )

    return raw_smiles_batch


def safe_build_fragment_index_for_batch(
    smiles_list: list[str],
    queries,
    tokenizer,
    *,
    add_bos_shift: bool = True,
) -> tuple[list[dict], int, list[str]]:
    """
    Returns (frag_index, invalid_count, invalid_samples)
      • frag_index[b] is {frag: [ [tok_idxs], ... ]}
      • invalid molecules yield {} but keep their position in the list
    """
    frag_index = []
    invalid = 0
    offenders = []

    for s in smiles_list:
        # Pre-check with RDKit (cheap); skip if invalid
        if Chem.MolFromSmiles(s) is None:
            frag_index.append({})
            invalid += 1
            offenders.append(s)
            continue
        try:
            res = find_fragments_and_tokens(s, queries, tokenizer=tokenizer)
        except Exception:
            frag_index.append({})
            invalid += 1
            offenders.append(s)
            continue

        per = {}
        for name, occs in res["fragments"].items():
            spans = []
            for occ in occs:
                for seg in occ["original"]["segments"]:
                    idxs = list(seg["token_indices"])
                    if add_bos_shift:
                        idxs = [i + 1 for i in idxs]  # align to BOS-indexed inputs
                    if idxs:
                        spans.append(idxs)
            if spans:
                per[name] = spans
        frag_index.append(per)

    return frag_index, invalid, offenders


# ------------------------
# Checkpointing utilities
# ------------------------


def init_state(K: int, fragments: List[str]) -> Dict[str, Any]:
    """
    Running aggregates so we can resume / recover:
      - WSD: sum over sequences (Batches) per feature; count of contributing sequences
      - Across-seq PB corr: maintain sums to compute correlation exactly per feature
      - (Optional) AUROC: batch-averaged (approximate, not exactly additive)
    """
    state = {
        "version": 1,
        "K": int(K),
        "fragments": fragments,
        "mols_processed": 0,
        "batches_processed": 0,
        # (a) within-sequence WSD aggregates: per-fragment
        "wsd_sum": {frag: np.zeros((K,), dtype=np.float64) for frag in fragments},
        "wsd_count": {frag: 0 for frag in fragments},
        # (b) across-sequence point-biserial aggregates (exact & additive)
        # Keep: n, sum(scores), sum(scores^2), sum(labels), sum(scores*labels)
        "pb_n": {frag: 0 for frag in fragments},
        "pb_sx": {frag: np.zeros((K,), dtype=np.float64) for frag in fragments},
        "pb_sx2": {frag: np.zeros((K,), dtype=np.float64) for frag in fragments},
        "pb_sy": {frag: 0.0 for frag in fragments},
        "pb_sxy": {frag: np.zeros((K,), dtype=np.float64) for frag in fragments},
        # Optional AUROC (rough, batch-averaged)
        "auroc_sum": {frag: np.zeros((K,), dtype=np.float64) for frag in fragments},
        "auroc_batches": {frag: 0 for frag in fragments},
    }
    return state


def _frag_index_from_pos_masks(
    pos_masks: Dict[str, jnp.ndarray],
) -> List[Dict[str, List[List[int]]]]:
    """
    Convert {frag: (B,T) bool} into a frag_index compatible with token_level_auroc:
      frag_index[b] = { frag: [ [idxs...] ] }  # one occurrence per frag per sequence
    If a frag has no tokens for a sequence, it is simply omitted for that sequence.
    """
    # infer B,T from any mask
    any_mask = next(iter(pos_masks.values()))
    B, T = any_mask.shape
    out: List[Dict[str, List[List[int]]]] = [dict() for _ in range(B)]
    for frag, mask in pos_masks.items():
        mask_np = np.asarray(mask)
        for b in range(B):
            idxs = np.where(mask_np[b])[0]
            if idxs.size:
                out[b].setdefault(frag, []).append(idxs.tolist())
    return out


def update_state_with_batch(
    state: Dict[str, Any],
    acts: jnp.ndarray,  # (B,T,K)
    pos_masks: Dict[str, jnp.ndarray],  # {frag: (B,T)}
    valid_mask: jnp.ndarray,  # (B,T)
    *,
    use_auroc: bool = False,
):
    B, T, K = acts.shape
    # (a) WSD
    for frag, pos in pos_masks.items():
        w = within_seq_discriminativity(acts, pos, valid_mask)  # {"wsd": (B,K), ...}
        # sequences that actually have both in and out tokens
        # We’ll count all B; sequences with degenerate masks contribute near-zero thanks to the sqrt(p(1-p)) term.
        wsd_mean_over_B = np.asarray(jnp.nanmean(w["wsd"], axis=0))  # (K,)

        state["wsd_sum"][frag] += wsd_mean_over_B
        state["wsd_count"][frag] += 1

    # (b) across-sequence selectivity (point-biserial exact sums)
    for frag, pos in pos_masks.items():
        sel = across_sequence_selectivity(acts, pos, valid_mask, neg_agg="max")
        # Reconstruct sequence-level scores and labels to update exact sums:
        # across_sequence_selectivity used:
        #   label[b] = 1 if any pos[b] else 0
        #   score[b,k] = mean_in if label==1 else neg_agg(valid)
        # We recompute the same here to get the running sums.

        pos_b = (jnp.sum(pos, axis=1) > 0).astype(jnp.int32)  # (B,)
        # mean on fragment
        in_sum = jnp.sum(acts * pos[..., None], axis=1)  # (B,K)
        n_in = jnp.sum(pos, axis=1).reshape(B, 1)  # (B,1)
        mu_in = in_sum / jnp.maximum(n_in, 1)

        # neg_agg = max on valid tokens
        neg_score = jnp.max(
            jnp.where(valid_mask[..., None] > 0, acts, -1e30), axis=1
        )  # (B,K)

        scores = jnp.where((pos_b > 0).reshape(B, 1), mu_in, neg_score)  # (B,K)

        # Update sums
        s_np = np.asarray(scores, dtype=np.float64)
        y_np = np.asarray(pos_b, dtype=np.float64).reshape(B, 1)

        state["pb_n"][frag] += int(B)
        state["pb_sx"][frag] += s_np.sum(axis=0)
        state["pb_sx2"][frag] += (s_np**2).sum(axis=0)
        state["pb_sy"][frag] += float(y_np.sum())
        state["pb_sxy"][frag] += (s_np * y_np).sum(axis=0)

    # Optional AUROC (batch-averaged; approximate)
    # Optional AUROC (batch-averaged; approximate, neutral=0.5 if no positives)

    if use_auroc:
        # build frag_index for all frags at once from the same masks we used for PB/WSD
        frag_index_for_auroc = _frag_index_from_pos_masks(pos_masks)
        au = token_level_auroc(
            acts,
            frag_index_for_auroc,
            exclude_bos_eos_pad_mask=valid_mask,
        )  # dict: {frag: (K,)} only for frags with at least one positive

        K = acts.shape[-1]
        for frag in state["fragments"]:
            vec = np.asarray(
                au.get(frag, np.full((K,), 0.5, dtype=np.float64)), dtype=np.float64
            )
            # NaN-proof (shouldn't be necessary, but safe)
            vec = np.nan_to_num(vec, nan=0.5, posinf=0.5, neginf=0.5)
            state["auroc_sum"][frag] += vec
            state["auroc_batches"][frag] += 1

    state["batches_processed"] += 1
    return state


def save_checkpoint(state: Dict[str, Any], out_dir: Path, step: int):
    out_dir.mkdir(parents=True, exist_ok=True)
    np.savez_compressed(out_dir / f"ckpt_step{step}.npz", **_flatten_state(state))
    with open(out_dir / "state_meta.json", "w") as f:
        json.dump(
            {k: v for k, v in state.items() if isinstance(v, (int, float, str))},
            f,
            indent=2,
        )


def _flatten_state(state: Dict[str, Any]) -> Dict[str, np.ndarray]:
    """
    Turn nested dicts of arrays into flat np.savez-friendly payload.
    """
    flat = {}
    flat["K"] = np.array([state["K"]], dtype=np.int32)
    flat["mols_processed"] = np.array([state["mols_processed"]], dtype=np.int64)
    flat["batches_processed"] = np.array([state["batches_processed"]], dtype=np.int64)
    # WSD
    for frag, arr in state["wsd_sum"].items():
        flat[f"wsd_sum/{frag}"] = arr
        flat[f"wsd_count/{frag}"] = np.array([state["wsd_count"][frag]], dtype=np.int64)
    # PB sums
    for frag in state["fragments"]:
        flat[f"pb_n/{frag}"] = np.array([state["pb_n"][frag]], dtype=np.int64)
        flat[f"pb_sx/{frag}"] = state["pb_sx"][frag]
        flat[f"pb_sx2/{frag}"] = state["pb_sx2"][frag]
        flat[f"pb_sy/{frag}"] = np.array([state["pb_sy"][frag]], dtype=np.float64)
        flat[f"pb_sxy/{frag}"] = state["pb_sxy"][frag]
    # AUROC (optional)
    for frag, arr in state["auroc_sum"].items():
        flat[f"auroc_sum/{frag}"] = arr
        flat[f"auroc_batches/{frag}"] = np.array(
            [state["auroc_batches"][frag]], dtype=np.int64
        )
    return flat


def finalize_and_write_report(state: Dict[str, Any], out_dir: Path):
    """
    Compute final metrics and save CSV/NPZ:
      - mean WSD per fragment/feature
      - point-biserial correlation per fragment/feature (exact)
      - (optional) batch-avg AUROC per fragment/feature
    """
    import re

    def _safe_name(s: str) -> str:
        # keep alnum, dot, underscore, dash; replace others with '_'
        return re.sub(r"[^A-Za-z0-9._-]+", "_", s)

    K = state["K"]
    frags = state["fragments"]

    results = {}
    for frag in frags:
        # WSD mean over batches (already averaged over sequences per batch)
        wsd_mean = state["wsd_sum"][frag] / max(state["wsd_count"][frag], 1)

        # point-biserial correlation from sums:
        n = float(state["pb_n"][frag])
        sx = state["pb_sx"][frag]
        sx2 = state["pb_sx2"][frag]
        sy = float(state["pb_sy"][frag])
        sxy = state["pb_sxy"][frag]

        mx = sx / max(n, 1.0)
        my = sy / max(n, 1.0)
        varx = sx2 / max(n, 1.0) - mx**2
        stdx = np.sqrt(np.maximum(varx, 1e-12))
        stdy = math.sqrt(max(my * (1 - my), 1e-12))
        cov = sxy / max(n, 1.0) - mx * my
        pb = cov / (stdx * stdy)  # (K,)

        if state["auroc_batches"][frag] > 0:
            au = state["auroc_sum"][frag] / max(state["auroc_batches"][frag], 1)
        else:
            au = np.full((K,), np.nan, dtype=np.float64)

        results[frag] = {"wsd_mean": wsd_mean, "pb_corr": pb, "auroc_approx": au}

    out_dir.mkdir(parents=True, exist_ok=True)
    # Save NPZ (keys keep original fragment ids; filenames are sanitized)
    np.savez_compressed(out_dir / "final_metrics.npz", **_flatten_results(results))

    # Save CSV summaries ONCE, with sanitized filenames only
    for frag, d in results.items():
        frag_file = _safe_name(frag)
        for key, vec in d.items():
            order = np.argsort(-np.nan_to_num(vec, nan=-1e9))[:20]
            lines = ["feature,score"]
            for k in order:
                lines.append(f"{int(k)},{float(vec[k])}")
            (out_dir / f"top20_{frag_file}_{key}.csv").write_text("\n".join(lines))


def _flatten_results(
    results: Dict[str, Dict[str, np.ndarray]],
) -> Dict[str, np.ndarray]:
    flat = {}
    for frag, d in results.items():
        for k, arr in d.items():
            flat[f"{frag}/{k}"] = np.asarray(arr)
    return flat


# in fragment_scan.py (or a small util module)


# ---------------
# Main runner
# ---------------


def main():
    p = argparse.ArgumentParser(
        description="Scan molecules and compute fragment-selective SAE metrics with checkpointing."
    )
    p.add_argument("--model_dir", required=True)
    p.add_argument("--sae_dir", required=True)
    p.add_argument("--ckpt_id", type=int, required=True)
    p.add_argument("--dataset_dir", required=True)
    p.add_argument("--smarts_yaml", required=True)
    p.add_argument("--profile", default="leadlike")
    p.add_argument("--layer_id", type=int, default=4)
    p.add_argument("--batch_size", type=int, default=32)
    p.add_argument("--limit", type=int, default=20000, help="approx molecules to scan")
    p.add_argument("--num_proc", type=int, default=1)
    p.add_argument(
        "--save_every", type=int, default=100, help="checkpoint every N batches"
    )
    p.add_argument("--out_dir", default="fragment_scan_output")
    p.add_argument(
        "--use_auroc", action="store_true", help="also compute (approx) batch-avg AUROC"
    )
    args = p.parse_args()

    out_dir = Path(args.out_dir + f"_layer{args.layer_id}")
    out_dir.mkdir(parents=True, exist_ok=True)

    # 1) Load tokenizer + model + SAE
    tokenizer = load_tokenizer(
        tokenizer_path=f"{args.model_dir}/tokenizer.json",
        generation_config_file=f"{args.model_dir}/generation_config.json",
        trunc_length=256,
    )
    sae_kit = SAEKit.load(
        model_dir=args.model_dir, checkpoint_id=args.ckpt_id, sae_dir=args.sae_dir
    )

    # 2) Dataset iterator (batched)
    ds = data_tools.load_and_tokenize(
        dataset_dir=args.dataset_dir,
        tokenizer=tokenizer,
        batch_size=args.batch_size,
        num_processes=args.num_proc,
        seed=2002,
        caching=True,
        limit=args.limit,
    )

    # 3) SMARTS queries
    smarts_dict = load_smarts_yaml(args.smarts_yaml, profile=args.profile)
    queries = compile_smarts(smarts_dict)

    # 4) Latent size K (from SAE config at layer)
    K = sae_kit.sae_configs[int(args.layer_id)].latent_size
    fragments = list(smarts_dict.keys())
    state = init_state(K, fragments)

    # 5) Iterate batches
    total_mols_target = args.limit
    mols_processed = 0
    batches = 0
    t0 = time.time()

    for batch in ds:
        B = batch["inputs"].shape[0]
        # stop if we exceed target (dataset.map has already limited, but guard anyway)
        if mols_processed >= total_mols_target:
            break

        # a) reconstruct raw SMILES strings
        raw_smiles_batch = reconstruct_raw_smiles_batch(batch, tokenizer)

        bad = [s for s in raw_smiles_batch if Chem.MolFromSmiles(s) is None]

        if bad:
            print(f"[precheck] batch {batches}: RDKit rejected {len(bad)}/{B}")
            for s in bad[:3]:
                print("  precheck offender:", s)

        # b) SAE activations
        acts = sae_kit.get_encoded(
            batch["inputs"], batch["positions"], args.layer_id
        )  # (B,T,K)
        valid_mask = sae_kit.mask_fn(batch["inputs"])  # (B,T)

        # c) fragment token indices (aligned to BOS-index inputs)

        frag_index, invalid, offenders = safe_build_fragment_index_for_batch(
            raw_smiles_batch, queries, tokenizer, add_bos_shift=True
        )
        if invalid:
            # Log a few examples so you can inspect them
            print(
                f"[warn] batch {batches}: skipped {invalid}/{B} invalid/truncated SMILES"
            )
            for s in offenders[:3]:
                print("  offender:", s)

        pos_masks = build_pos_masks_from_index(
            frag_index, seq_len=acts.shape[1]
        )  # {frag: (B,T)}

        B, T = acts.shape[:2]
        for frag in fragments:
            if frag not in pos_masks:
                pos_masks[frag] = jnp.zeros((B, T), dtype=jnp.bool_)

        # d) update aggregates
        update_state_with_batch(
            state, acts, pos_masks, valid_mask, use_auroc=args.use_auroc
        )

        batches += 1
        mols_processed += B
        state["mols_processed"] = mols_processed

        # progress + checkpoint
        if batches % args.save_every == 0:
            save_checkpoint(state, out_dir, step=batches)
            elapsed = time.time() - t0
            print(f"[ckpt] step={batches} mols={mols_processed} elapsed={elapsed:.1f}s")

    # 6) final save
    save_checkpoint(state, out_dir, step=batches)
    finalize_and_write_report(state, out_dir)
    elapsed = time.time() - t0
    print(
        f"[done] batches={batches} mols={mols_processed} elapsed={elapsed:.1f}s → {out_dir}"
    )


if __name__ == "__main__":
    # Make sure JAX/einops don’t trip over TF
    os.environ["EINOPS_BACKEND"] = "jax"
    sys.modules.pop("tensorflow", None)
    main()
