# lmkit/sparse/sae_fragment_attribution.py
from __future__ import annotations

from typing import Any, Dict, List, Optional

import jax
import jax.numpy as jnp
import numpy as np
from tqdm.auto import tqdm

from lmkit.impl import hooks as hooks_lib
from lmkit.impl import transformer as transformer_impl
from lmkit.sparse.fragment_mapper import find_fragments_and_tokens
from lmkit.sparse.sae import decode_latent, normalize

# ---------------------------------------------------------------------
# 1) Map SMARTS → token indices (aligned to model inputs)
# ---------------------------------------------------------------------


def build_fragment_index_for_batch(
    smiles_list: List[str],
    queries: Dict[str, Any],  # compiled SMARTS dict from compile_smarts(...)
    tokenizer,  # lmkit.tools.compat.load_tokenizer(...)
    *,
    add_bos_shift: bool = True,  # add +1 to align to model inputs with BOS
) -> List[Dict[str, List[List[int]]]]:
    """
    For each SMILES in the batch, return:
      {frag_name: [ [tok_idx...], [tok_idx...], ... ] }  # one list per occurrence

    token indices are relative to the **model input sequence** (i.e. +1 if BOS).
    """
    out: List[Dict[str, List[List[int]]]] = []
    for s in tqdm(smiles_list, desc="Finding fragments in batch"):
        frags = {}
        res = find_fragments_and_tokens(s, queries, tokenizer=tokenizer)
        for fname, occs in res["fragments"].items():
            hit_tok_lists: List[List[int]] = []
            for occ in occs:
                for seg in occ["original"]["segments"]:
                    idxs = list(seg["token_indices"])
                    if add_bos_shift:
                        idxs = [i + 1 for i in idxs]
                    if idxs:
                        hit_tok_lists.append(idxs)
            if hit_tok_lists:
                frags[fname] = hit_tok_lists
        out.append(frags)
    return out


# ---------------------------------------------------------------------
# 2) Aggregate SAE activations on fragment tokens
# ---------------------------------------------------------------------


def _aggregate_over_indices(vals_1d: jnp.ndarray, idxs: List[int], how: str) -> float:
    if not idxs:
        return float("nan")
    v = vals_1d.take(jnp.array(idxs, dtype=jnp.int32), axis=0, mode="clip")
    if how == "mean":
        return float(jnp.mean(v))
    elif how == "max":
        return float(jnp.max(v))
    elif how == "sum":
        return float(jnp.sum(v))
    else:
        raise ValueError(f"Unknown aggregation '{how}'")


def fragment_local_feature_scores(
    acts: jnp.ndarray,  # (B, T, K) SAE activations
    fragment_index: List[Dict[str, List[List[int]]]],
    *,
    aggregate: str = "mean",
) -> Dict[str, np.ndarray]:
    """
    Returns a dict: {frag_name: scores}, where scores is (K,) feature vector:
    per-feature average over all occurrences across the batch of the
    chosen aggregate (mean/max/sum) across tokens belonging to each occurrence.
    """
    B, T, K = acts.shape
    accum = {}  # frag -> (sum_vec, count)
    for b in tqdm(range(B), desc="Aggregating fragment activations"):
        for frag, occs in fragment_index[b].items():
            for tok_idxs in occs:
                # aggregate this occurrence into a (K,) vector
                occ_vec = np.zeros((K,), dtype=np.float32)
                for k in range(K):
                    occ_vec[k] = _aggregate_over_indices(
                        acts[b, :, k], tok_idxs, aggregate
                    )
                if frag in accum:
                    s, c = accum[frag]
                    accum[frag] = (s + occ_vec, c + 1)
                else:
                    accum[frag] = (occ_vec, 1)

    out = {}
    for frag, (s, c) in accum.items():
        out[frag] = s / max(c, 1)
    return out  # {frag: (K,)}


# ---------------------------------------------------------------------
# 3) Token-level AUROC per feature (how discriminative is a feature?)
# ---------------------------------------------------------------------


def auroc_for_feature(pos_vals: np.ndarray, neg_vals: np.ndarray) -> float:
    """
    Compute AUROC with a simple Mann–Whitney U formulation.
    Returns NaN if not enough samples.
    """
    n1, n0 = len(pos_vals), len(neg_vals)
    if n1 == 0 or n0 == 0:
        return np.nan
    ranks = np.argsort(np.concatenate([pos_vals, neg_vals]))
    # convert to ranks 1..N
    r = np.empty_like(ranks, dtype=np.float64)
    r[ranks] = np.arange(1, len(ranks) + 1)
    R1 = np.sum(r[:n1])
    U1 = R1 - n1 * (n1 + 1) / 2.0
    return float(U1 / (n1 * n0))


def token_level_auroc(
    acts: jnp.ndarray,  # (B, T, K)
    fragment_index: List[Dict[str, List[List[int]]]],
    *,
    per_fragment: bool = True,
    exclude_bos_eos_pad_mask: Optional[jnp.ndarray] = None,  # (B, T) 1/0 mask
) -> Dict[str, np.ndarray]:
    """
    Build positives = tokens inside fragment occurrences, negatives = tokens outside,
    **within the same batch**. Returns {frag_name: A}, where A is (K,) AUROC per feature.
    """
    B, T, K = acts.shape
    aurocs: Dict[str, List[float]] = {}

    # For each frag, collect pos/neg vectors per feature
    for b in range(B):
        # Build a global mask of negatives (valid tokens)
        if exclude_bos_eos_pad_mask is None:
            valid = np.ones((T,), dtype=bool)
        else:
            valid = np.asarray(exclude_bos_eos_pad_mask[b] > 0)

        for frag, occs in fragment_index[b].items():
            pos_mask = np.zeros((T,), dtype=bool)
            for idxs in occs:
                pos_mask[idxs] = True
            neg_mask = valid & (~pos_mask)

            if pos_mask.sum() == 0 or neg_mask.sum() == 0:
                # not enough data to compute AUROC
                continue

            pos_vals = np.asarray(acts[b, pos_mask, :])  # (P, K)
            neg_vals = np.asarray(acts[b, neg_mask, :])  # (N, K)

            # AUROC per feature
            A = np.zeros((K,), dtype=np.float32)
            for k in range(K):
                A[k] = auroc_for_feature(pos_vals[:, k], neg_vals[:, k])

            if frag in aurocs:
                aurocs[frag].append(A)
            else:
                aurocs[frag] = [A]

    # mean across batch molecules
    return {frag: np.nanmean(np.stack(v, axis=0), axis=0) for frag, v in aurocs.items()}


# ---------------------------------------------------------------------
# 4) Causal check: zero selected features only on fragment tokens
# ---------------------------------------------------------------------


def make_zero_features_callback(
    sae_cfg, sae_params, features: List[int], tok_mask_1d: np.ndarray
):
    """
    Returns a callback for ActivationEditor(op='call') that:
      - computes z_cur at the SAE placement,
      - sets the given `features` to zero,
      - decodes back to residual space,
      - applies the change only to positions where tok_mask_1d is True.
    """
    features = jnp.array(sorted(set(int(f) for f in features)), dtype=jnp.int32)
    tok_mask_1d = jnp.array(tok_mask_1d.astype(bool))  # shape (T,)

    def _cb(x_slice: jnp.ndarray) -> jnp.ndarray:
        """
        x_slice shape: (B, Ls, D) for the selected token slice.
        We'll reconstruct **only on tok_mask_1d within the slice**, leave others unchanged.
        """
        B, Ls, D = x_slice.shape
        # Normalize (same as training)
        if sae_cfg.rescale_inputs:
            x_norm, x_mean, x_std = normalize(x_slice)
        else:
            x_norm, x_mean, x_std = x_slice, 0.0, 1.0

        if sae_cfg.pre_enc_bias:
            x_enc = x_norm - sae_params["b_dec"]
        else:
            x_enc = x_norm

        z = x_enc @ sae_params["W_enc"]  # (B, Ls, K)

        # zero out selected features
        z = z.at[..., features].set(0.0)

        x_new = decode_latent(
            z, params=sae_params, config=sae_cfg, x_mean=x_mean, x_std=x_std
        )

        # apply only on tok_mask_1d within this slice
        # tok_mask_1d is global for the slice range; the editor will still apply to non-specials,
        # so we preserve original x outside the exact mask.
        mask = jnp.broadcast_to(tok_mask_1d[:Ls], (B, Ls))  # (B,Ls)
        x_out = jnp.where(mask[..., None], x_new, x_slice)
        return x_out.astype(x_slice.dtype)

    return _cb


def delta_nll_from_ablation(
    inputs: jnp.ndarray,  # (B,T) model input ids
    positions: jnp.ndarray,  # (B,T) positions (>=0 for valid tokens)
    sae_kit,  # lmkit.sparse.sae.SAEKit
    layer_id: int,
    features_to_zero: List[int],
    fragment_token_indices: List[
        List[int]
    ],  # list of contiguous or non-contiguous indices for ONE sample
    *,
    temp: float = 1.0,
) -> float:
    """
    Causal score for ONE sequence:
      - run baseline to get NLL on ground-truth next tokens,
      - create an ActivationEditor that zeroes `features_to_zero` **only** on fragment tokens,
      - run again and compute ΔNLL (ablation NLL - baseline NLL) on those next tokens.
    Positive ΔNLL ⇒ those features support predicting tokens around the fragment.
    """
    assert inputs.shape[0] == 1, "call per-sequence (B=1) for simplicity"

    # Build a boolean mask over T for the fragment
    T = inputs.shape[1]
    frag_mask = np.zeros((T,), dtype=bool)
    for idxs in fragment_token_indices:
        frag_mask[np.array(idxs, dtype=int)] = True

    # 1) baseline run
    cache = (
        sae_kit.hooks
    )  # unused; we will call sae_kit.run_fn which is jitted without editor
    logits, *_ = sae_kit.run_fn(
        inputs,
        jnp.where(positions >= 0, positions, -1),
        sae_kit.lm_params,
        sae_kit.lm_config,
    )
    # Next-token NLL on positions inside fragment (look one step ahead)
    tgt = jnp.roll(inputs, -1, axis=1)
    valid = (positions >= 0) & jnp.array(frag_mask)[None, :]
    valid = valid.at[:, -1].set(False)  # last position has no next token
    ll_base = -jax.nn.log_softmax(logits, axis=-1)[0, :, tgt[0]].astype(jnp.float32)
    nll_base = jnp.sum(jnp.where(valid[0], ll_base, 0.0))
    denom = jnp.maximum(jnp.sum(valid[0]), 1)

    # 2) ablated run via ActivationEditor
    # Build callback for a *slice*; we’ll use the minimal slice covering all fragment tokens
    idxs = np.where(frag_mask)[0]
    if len(idxs) == 0:
        return float("nan")
    left, right = int(idxs.min()), int(idxs.max()) + 1  # python slice [left:right)
    slice_mask = np.zeros((T,), dtype=bool)
    slice_mask[left:right] = frag_mask[left:right]

    cb = make_zero_features_callback(
        sae_cfg=sae_kit.sae_configs[layer_id],
        sae_params=sae_kit.sae_params[layer_id],
        features=features_to_zero,
        tok_mask_1d=slice_mask[left:right],
    )

    editor = hooks_lib.ActivationEditor(
        edits=(
            hooks_lib.Edit(
                layer=layer_id,
                kind=sae_kit.sae_configs[layer_id].placement,
                op="call",
                tok_slice=slice(left, right),
                callback=cb,
            ),
        )
    )

    # Call the raw transformer.run with editor injected
    logits_abl, *_ = transformer_impl.run(
        inputs,
        cache=jax.tree_util.tree_unflatten(
            jax.tree_util.tree_structure(sae_kit.hooks), ()
        ),  # dummy; not used
        params=sae_kit.lm_params,
        config=sae_kit.lm_config,
        editor=editor,
        hooks_to_return=frozenset(),
        hooks_to_stream=frozenset(),
    )

    ll_abl = -jax.nn.log_softmax(logits_abl, axis=-1)[0, :, tgt[0]].astype(jnp.float32)
    nll_abl = jnp.sum(jnp.where(valid[0], ll_abl, 0.0))

    delta = float((nll_abl - nll_base) / denom)
    return delta
