#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Causal pre-SMARTS screening:
Mark tokens immediately BEFORE each fragment span and compute feature scores.

Outputs are compatible with your existing summarizer (final_metrics.npz, top20_*.csv).

Example:
    uv run lmkit/sparse/fragment_causal_pre_scan.py \
        --model_dir models/transformer_sm \
        --sae_dir models/saes/relu_4x_e9a211 \
        --ckpt_id 59712 \
        --dataset_dir ~/data/z20ll_filtered_scafsplit/valid \
        --smarts_yaml lmkit/sparse/MolSAE_SMARTS_v1.1_leadlike.yaml \
        --profile leadlike \
        --layer_id 5 \
        --batch_size 512 \
        --limit 20000 \
        --save_every 5 \
        --out_dir fragment_causal_pre
"""

from __future__ import annotations

import argparse
import json
import math
import os
import sys
import time
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

import jax.numpy as jnp
import numpy as np
import yaml
from rdkit import Chem

from lmkit.sparse.fragment_mapper import compile_smarts, find_fragments_and_tokens
from lmkit.sparse.fragment_metrics import (
    within_seq_discriminativity,
)
from lmkit.sparse.sae import SAEKit
from lmkit.tools import data as data_tools

# ----- LMKit imports you already use -----
from lmkit.tools.compat import load_tokenizer

# =========================
# YAML loader (same as before)
# =========================


def load_smarts_yaml(yaml_path: str, profile: str = "leadlike") -> Dict[str, str]:
    import re

    with open(yaml_path, "r", encoding="utf-8") as f:
        raw = f.read()

    # Convert lines like:  - id: ring:benzene ; smiles: "c1ccccc1"
    semicolon_line = re.compile(
        r'^(\s*)-\s+id:\s*([^;\n]+?)\s*;\s*smiles:\s*"(.*?)"\s*$',
        flags=re.MULTILINE,
    )

    def repl(m: re.Match) -> str:
        indent = m.group(1)
        idval = m.group(2).strip()
        smiles = m.group(3)
        return f'{indent}- id: {idval}\n{indent}  smiles: "{smiles}"'

    fixed = semicolon_line.sub(repl, raw)

    try:
        y = yaml.safe_load(fixed)
    except yaml.YAMLError as e:
        raise ValueError(
            f"Failed to parse SMARTS YAML after preprocessing.\nDetails: {e}"
        )

    prof = y["profiles"][profile]
    include_prefixes = tuple(prof.get("include_prefixes", []))
    drop_ids = set(prof.get("drop", []))

    def want_id(sid: str) -> bool:
        return sid.startswith(include_prefixes) and sid not in drop_ids

    out = {}
    for item in y.get("curated", []):
        sid = item["id"]
        if want_id(sid):
            out[sid] = item["smarts"]

    for item in y.get("ring_seeds", []):
        sid = item["id"]
        if want_id(sid):
            smi = item.get("smiles")
            if smi:
                out[sid] = smi

    if not out:
        raise ValueError("No SMARTS selected by the profile—check YAML/profile.")
    return out


# =========================
# Batch → raw SMILES
# =========================


def reconstruct_raw_smiles_batch(batch, tokenizer) -> List[str]:
    raw = []
    for seq in batch["inputs"]:
        s = tokenizer.decode(seq[1:], skip_special_tokens=False)
        s = (s.split(tokenizer.eos_token)[0]).replace(tokenizer.pad_token, "")
        raw.append(s)
    return raw


# =========================
# Safe fragment indexing
# =========================


def safe_build_fragment_index_for_batch(
    smiles_list: List[str],
    queries,
    tokenizer,
    *,
    add_bos_shift: bool = True,
) -> Tuple[List[dict], int, List[str]]:
    """
    Returns (frag_index, invalid_count, offenders)
      - frag_index[b] is {frag: [ [token_idxs], ... ]} (inside spans)
      - invalid or failed parses yield {} but keep B alignment.
    """
    frag_index = []
    invalid = 0
    offenders = []

    for s in smiles_list:
        if Chem.MolFromSmiles(s) is None:
            frag_index.append({})
            invalid += 1
            offenders.append(s)
            continue
        try:
            res = find_fragments_and_tokens(s, queries, tokenizer=tokenizer)
        except Exception:
            frag_index.append({})
            invalid += 1
            offenders.append(s)
            continue

        per = {}
        for name, occs in res["fragments"].items():
            spans = []
            for occ in occs:
                for seg in occ["original"]["segments"]:
                    idxs = list(seg["token_indices"])
                    if add_bos_shift:
                        idxs = [i + 1 for i in idxs]  # align to BOS-indexed inputs
                    if idxs:
                        spans.append(idxs)
            if spans:
                per[name] = spans
        frag_index.append(per)

    return frag_index, invalid, offenders


# =========================
# PRE-SMARTS context masks
# =========================


def build_context_masks_from_index(
    fragment_index: List[dict],
    seq_len: int,
    *,
    offset: int = -1,  # -1 = immediately BEFORE
    window: int = 1,  # how many tokens prior to include (contiguous)
    clamp_min: int = 0,  # allow BOS index; valid_mask will exclude it anyway
    clamp_max: Optional[int] = None,  # default: seq_len
) -> Dict[str, jnp.ndarray]:
    """
    For each fragment occurrence span, take the anchor = min(span),
    then mark [anchor+offset-window+1 ... anchor+offset] as True.

    Returns: {fragment: (B,T) bool}
    """
    B = len(fragment_index)
    clamp_max = seq_len if clamp_max is None else clamp_max
    names = sorted({name for per in fragment_index for name in per.keys()})
    out = {}

    for name in names:
        m = np.zeros((B, seq_len), dtype=bool)
        for b, per in enumerate(fragment_index):
            spans = per.get(name, [])
            for idxs in spans:
                if not idxs:
                    continue
                anchor = min(idxs)  # first token of the fragment span
                start = anchor + offset - (window - 1)
                end = anchor + offset  # inclusive
                for t in range(start, end + 1):
                    if clamp_min <= t < clamp_max:
                        m[b, t] = True
        out[name] = jnp.asarray(m)
    return out


# =========================
# Checkpointing (same pattern)
# =========================


def init_state(K: int, fragments: List[str]) -> Dict[str, Any]:
    state = {
        "version": 1,
        "K": int(K),
        "fragments": fragments,
        "mols_processed": 0,
        "batches_processed": 0,
        "wsd_sum": {frag: np.zeros((K,), dtype=np.float64) for frag in fragments},
        "wsd_count": {frag: 0 for frag in fragments},
        "pb_n": {frag: 0 for frag in fragments},
        "pb_sx": {frag: np.zeros((K,), dtype=np.float64) for frag in fragments},
        "pb_sx2": {frag: np.zeros((K,), dtype=np.float64) for frag in fragments},
        "pb_sy": {frag: 0.0 for frag in fragments},
        "pb_sxy": {frag: np.zeros((K,), dtype=np.float64) for frag in fragments},
    }
    return state


def update_state_with_batch(
    state: Dict[str, Any],
    acts: jnp.ndarray,  # (B,T,K)
    pos_masks: Dict[str, jnp.ndarray],  # {frag: (B,T)}
    valid_mask: jnp.ndarray,  # (B,T)
):
    # (a) Within-sequence discriminativity (on pre-positions)
    for frag, pos in pos_masks.items():
        w = within_seq_discriminativity(acts, pos, valid_mask)  # dict with "wsd": (B,K)
        wsd_mean_over_B = np.asarray(jnp.nanmean(w["wsd"], axis=0))
        state["wsd_sum"][frag] += wsd_mean_over_B
        state["wsd_count"][frag] += 1

    # (b) Across-sequence selectivity (exact, using point-biserial from sufficient stats)
    for frag, pos in pos_masks.items():
        B = pos.shape[0]
        pos_b = (jnp.sum(pos, axis=1) > 0).astype(jnp.int32)  # (B,)
        # mean on pre-positions when present; else max over valid (background)
        in_sum = jnp.sum(acts * pos[..., None], axis=1)  # (B,K)
        n_in = jnp.sum(pos, axis=1).reshape(B, 1)  # (B,1)
        mu_in = in_sum / jnp.maximum(n_in, 1)

        neg_score = jnp.max(
            jnp.where(valid_mask[..., None] > 0, acts, -1e30), axis=1
        )  # (B,K)

        scores = jnp.where((pos_b > 0).reshape(B, 1), mu_in, neg_score)  # (B,K)

        s_np = np.asarray(scores, dtype=np.float64)
        y_np = np.asarray(pos_b, dtype=np.float64).reshape(B, 1)

        state["pb_n"][frag] += int(B)
        state["pb_sx"][frag] += s_np.sum(axis=0)
        state["pb_sx2"][frag] += (s_np**2).sum(axis=0)
        state["pb_sy"][frag] += float(y_np.sum())
        state["pb_sxy"][frag] += (s_np * y_np).sum(axis=0)

    state["batches_processed"] += 1
    return state


def _flatten_state(state: Dict[str, Any]) -> Dict[str, np.ndarray]:
    flat = {}
    flat["K"] = np.array([state["K"]], dtype=np.int32)
    flat["mols_processed"] = np.array([state["mols_processed"]], dtype=np.int64)
    flat["batches_processed"] = np.array([state["batches_processed"]], dtype=np.int64)
    for frag, arr in state["wsd_sum"].items():
        flat[f"wsd_sum/{frag}"] = arr
        flat[f"wsd_count/{frag}"] = np.array([state["wsd_count"][frag]], dtype=np.int64)
    for frag in state["fragments"]:
        flat[f"pb_n/{frag}"] = np.array([state["pb_n"][frag]], dtype=np.int64)
        flat[f"pb_sx/{frag}"] = state["pb_sx"][frag]
        flat[f"pb_sx2/{frag}"] = state["pb_sx2"][frag]
        flat[f"pb_sy/{frag}"] = np.array([state["pb_sy"][frag]], dtype=np.float64)
        flat[f"pb_sxy/{frag}"] = state["pb_sxy"][frag]
    return flat


def save_checkpoint(state: Dict[str, Any], out_dir: Path, step: int):
    out_dir.mkdir(parents=True, exist_ok=True)
    np.savez_compressed(out_dir / f"ckpt_step{step}.npz", **_flatten_state(state))
    with open(out_dir / "state_meta.json", "w") as f:
        json.dump(
            {k: v for k, v in state.items() if isinstance(v, (int, float, str))},
            f,
            indent=2,
        )


def finalize_and_write_report(state: Dict[str, Any], out_dir: Path):
    import re

    def _safe_name(s: str) -> str:
        return re.sub(r"[^A-Za-z0-9._-]+", "_", s)

    K = state["K"]
    results = {}

    for frag in state["fragments"]:
        wsd_mean = state["wsd_sum"][frag] / max(state["wsd_count"][frag], 1)

        n = float(state["pb_n"][frag])
        sx = state["pb_sx"][frag]
        sx2 = state["pb_sx2"][frag]
        sy = float(state["pb_sy"][frag])
        sxy = state["pb_sxy"][frag]

        mx = sx / max(n, 1.0)
        my = sy / max(n, 1.0)
        varx = sx2 / max(n, 1.0) - mx**2
        stdx = np.sqrt(np.maximum(varx, 1e-12))
        stdy = math.sqrt(max(my * (1 - my), 1e-12))
        cov = sxy / max(n, 1.0) - mx * my
        pb = cov / (stdx * stdy)  # (K,)

        results[frag] = {"wsd_mean": wsd_mean, "pb_corr": pb}

    out_dir.mkdir(parents=True, exist_ok=True)
    # NPZ carries full arrays (same structure as your inside‑SMARTS scans)
    np.savez_compressed(out_dir / "final_metrics.npz", **_flatten_results(results))

    # Top‑20 CSVs (sanitized names)
    for frag, d in results.items():
        frag_file = _safe_name(frag)
        for key, vec in d.items():
            order = np.argsort(-np.nan_to_num(vec, nan=-1e9))[:20]
            lines = ["feature,score"]
            for k in order:
                lines.append(f"{int(k)},{float(vec[k])}")
            (out_dir / f"top20_{frag_file}_{key}.csv").write_text("\n".join(lines))


def _flatten_results(
    results: Dict[str, Dict[str, np.ndarray]],
) -> Dict[str, np.ndarray]:
    flat = {}
    for frag, d in results.items():
        for k, arr in d.items():
            flat[f"{frag}/{k}"] = np.asarray(arr)
    return flat


# =========================
# MAIN
# =========================


def main():
    p = argparse.ArgumentParser(description="Causal pre-SMARTS screening.")
    p.add_argument("--model_dir", required=True)
    p.add_argument("--sae_dir", required=True)
    p.add_argument("--ckpt_id", type=int, required=True)
    p.add_argument("--dataset_dir", required=True)
    p.add_argument("--smarts_yaml", required=True)
    p.add_argument("--profile", default="leadlike")
    p.add_argument("--layer_id", type=int, default=5)
    p.add_argument("--batch_size", type=int, default=256)
    p.add_argument("--limit", type=int, default=20000)
    p.add_argument("--num_proc", type=int, default=1)
    p.add_argument("--save_every", type=int, default=10)
    p.add_argument("--out_dir", default="fragment_causal_pre")
    # context controls:
    p.add_argument("--pre_window", type=int, default=1, help="tokens before anchor")
    p.add_argument(
        "--pre_offset", type=int, default=-1, help="-1 means immediately before"
    )
    args = p.parse_args()

    out_dir = Path(f"{args.out_dir}_layer{args.layer_id}")
    out_dir.mkdir(parents=True, exist_ok=True)

    # 1) Load tokenizer & SAE
    tokenizer = load_tokenizer(
        tokenizer_path=f"{args.model_dir}/tokenizer.json",
        generation_config_file=f"{args.model_dir}/generation_config.json",
        trunc_length=512,
    )
    sae_kit = SAEKit.load(
        model_dir=args.model_dir, checkpoint_id=args.ckpt_id, sae_dir=args.sae_dir
    )

    # 2) Dataset
    ds = data_tools.load_and_tokenize(
        dataset_dir=args.dataset_dir,
        tokenizer=tokenizer,
        batch_size=args.batch_size,
        num_processes=args.num_proc,
        seed=2002,
        caching=True,
        limit=args.limit,
    )

    # 3) SMARTS
    smarts_dict = load_smarts_yaml(args.smarts_yaml, profile=args.profile)
    queries = compile_smarts(smarts_dict)

    # 4) Init state
    K = sae_kit.sae_configs[int(args.layer_id)].latent_size
    fragments = list(smarts_dict.keys())
    state = init_state(K, fragments)

    # 5) Iterate
    mols_processed = 0
    batches = 0
    t0 = time.time()

    for batch in ds:
        B = batch["inputs"].shape[0]
        if mols_processed >= args.limit:
            break

        raw_smiles_batch = reconstruct_raw_smiles_batch(batch, tokenizer)
        acts = sae_kit.get_encoded(
            batch["inputs"], batch["positions"], args.layer_id
        )  # (B,T,K)
        valid_mask = sae_kit.mask_fn(batch["inputs"])  # (B,T)

        frag_index, invalid, offenders = safe_build_fragment_index_for_batch(
            raw_smiles_batch, queries, tokenizer, add_bos_shift=True
        )
        if invalid:
            print(f"[warn] batch {batches}: skipped {invalid}/{B} invalid SMILES")
            for s in offenders[:3]:
                print("  offender:", s)

        pos_masks = build_context_masks_from_index(
            frag_index,
            seq_len=acts.shape[1],
            offset=args.pre_offset,
            window=args.pre_window,
        )

        # Ensure all fragments present (even if empty) for stable aggregation
        BT = acts.shape[:2]
        for frag in fragments:
            if frag not in pos_masks:
                pos_masks[frag] = jnp.zeros(BT, dtype=jnp.bool_)

        update_state_with_batch(state, acts, pos_masks, valid_mask)

        batches += 1
        mols_processed += B
        state["mols_processed"] = mols_processed

        if batches % args.save_every == 0:
            save_checkpoint(state, out_dir, step=batches)
            elapsed = time.time() - t0
            print(f"[ckpt] step={batches} mols={mols_processed} elapsed={elapsed:.1f}s")

    # 6) Finalize
    save_checkpoint(state, out_dir, step=batches)
    finalize_and_write_report(state, out_dir)
    elapsed = time.time() - t0
    print(
        f"[done] batches={batches} mols={mols_processed} elapsed={elapsed:.1f}s → {out_dir}"
    )


if __name__ == "__main__":
    os.environ["EINOPS_BACKEND"] = "jax"
    sys.modules.pop("tensorflow", None)
    main()
