# lmkit/sparse/sae_fragment_fastmetrics.py
from __future__ import annotations

from typing import Dict

import jax.numpy as jnp
import numpy as np

_EPS = 1e-12 


def _safe_div(num, den, eps=_EPS):
    return num / (den + eps)


def _safe_std(var, eps=_EPS):
    return jnp.sqrt(jnp.maximum(var, 0.0) + eps)


def _safe_sqrt(x, eps=_EPS):
    return jnp.sqrt(jnp.maximum(x, 0.0) + eps)


def _nan_to_num(x):
    # Replace NaN/inf with finite values without changing dtype
    return jnp.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0)


# ---------------------------------------------------------------------
# Utilities: build fragment masks (B,T) from your fragment_index lists
# ---------------------------------------------------------------------


def build_pos_masks_from_index(
    fragment_index: list[dict[str, list[list[int]]]],  # one dict per sequence
    seq_len: int,
):
    """
    Returns {frag_name: pos_mask} where pos_mask is (B,T) bool.
    fragment_index[b][frag] = list of token-index lists (one per occurrence).
    """
    B = len(fragment_index)

    # Collect all fragment names across the batch
    names = sorted({name for per_seq in fragment_index for name in per_seq.keys()})

    out = {}
    for name in names:
        m = np.zeros((B, seq_len), dtype=bool)
        for b, per_seq in enumerate(fragment_index):
            occs = per_seq.get(name, [])
            for idxs in occs:
                if not idxs:
                    continue
                # clip to sequence length to be safe
                ii = np.asarray(idxs, dtype=int)
                ii = ii[(ii >= 0) & (ii < seq_len)]
                m[b, ii] = True
        out[name] = jnp.array(m)
    return out


# ---------------------------------------------------------------------
# (a) Within-sequence discriminativity  (B,K)
# ---------------------------------------------------------------------


def within_seq_discriminativity(acts, pos, valid, eps=_EPS):
    """
    Robust WSD:
      delta = mu_in - mu_out
      wsd = delta / std_all * sqrt(p*(1-p))
    """
    # promote to float32 for stability, keep masks as floats
    acts32 = acts.astype(jnp.float32)
    posf = pos.astype(jnp.float32)
    validf = valid.astype(jnp.float32)
    notpos = validf * (1.0 - posf)

    # counts
    n_in = jnp.sum(posf, axis=1)  # (B,)
    n_out = jnp.sum(notpos, axis=1)
    n_all = jnp.sum(validf, axis=1)
    p_cov = _nan_to_num(_safe_div(n_in, n_all))  # (B,)

    # sums
    in_sum = jnp.sum(acts32 * posf[..., None], axis=1)  # (B,K)
    out_sum = jnp.sum(acts32 * notpos[..., None], axis=1)
    all_sum = jnp.sum(acts32 * validf[..., None], axis=1)

    mu_in = _safe_div(in_sum, n_in[:, None])
    mu_out = _safe_div(out_sum, n_out[:, None])
    mu_all = _safe_div(all_sum, n_all[:, None])

    # variance (masked second moment)
    all_sq_sum = jnp.sum((acts32**2) * validf[..., None], axis=1)  # (B,K)
    var_all = all_sq_sum * _safe_div(1.0, n_all[:, None]) - mu_all**2
    std_all = _safe_std(var_all, eps)  # floor at eps

    delta = mu_in - mu_out
    # sqrt(p*(1-p)) is well-behaved in [0,0.5], floor inside sqrt to avoid tiny negatives
    weight = _safe_sqrt(jnp.clip(p_cov * (1.0 - p_cov), 0.0, 0.25)).reshape(-1, 1)

    wsd = _nan_to_num(delta / std_all) * weight  # (B,K), no NaNs
    return {
        "delta_mean": _nan_to_num(delta).astype(acts.dtype),
        "wsd": _nan_to_num(wsd).astype(acts.dtype),
        "coverage": _nan_to_num(p_cov).astype(acts.dtype),
    }


# ---------------------------------------------------------------------
# (b) Across-sequence selectivity (sequence-level AUROC / correlation)
# ---------------------------------------------------------------------

def across_sequence_selectivity(acts, pos, valid, *, neg_agg="max", eps=_EPS):
    acts32 = acts.astype(jnp.float32)
    posf = pos.astype(jnp.float32)
    validf = valid.astype(jnp.float32)

    B, T, K = acts32.shape
    n_in = jnp.sum(posf, axis=1)  # (B,)
    has_frag = (n_in > 0).astype(jnp.float32).reshape(B, 1)  # (B,1)

    in_sum = jnp.sum(acts32 * posf[..., None], axis=1)  # (B,K)
    mu_in = _safe_div(in_sum, jnp.maximum(n_in, 0.0)[:, None])  # (B,K)

    if neg_agg == "max":
        neg_score = jnp.max(jnp.where(validf[..., None] > 0, acts32, -1e30), axis=1)
    elif neg_agg == "mean":
        all_sum = jnp.sum(acts32 * validf[..., None], axis=1)
        n_all = jnp.sum(validf, axis=1)
        neg_score = _safe_div(all_sum, n_all[:, None])
    else:
        raise ValueError("neg_agg must be 'max' or 'mean'")

    scores = jnp.where(has_frag > 0, mu_in, neg_score)  # (B,K)
    labels = has_frag.squeeze(1)  # (B,)

    # Convert to numpy once
    scores_np = np.asarray(_nan_to_num(scores))
    labels_np = np.asarray(labels).astype(np.float32)

    # point-biserial (stable)
    pb = _point_biserial_from_arrays(scores_np, labels_np, eps=eps)  # (K,)

    # AUROC (neutral 0.5 if undefined)
    au = _seq_auroc_vectorized(scores_np, labels_np)

    return {"pb_corr": pb, "auroc": au, "pos_rate": float(labels_np.mean())}


def _point_biserial_from_arrays(
    scores: np.ndarray, labels: np.ndarray, eps=_EPS
) -> np.ndarray:
    """Vectorized point-biserial over features with safe denominators."""
    y = labels
    sx = scores
    n = float(len(y))
    if n < 2:
        return np.zeros((sx.shape[1],), dtype=np.float32)

    my = y.mean()
    sy = np.sqrt(max(my * (1.0 - my), eps))

    mx = sx.mean(axis=0)
    stdx = np.sqrt(np.maximum(sx.var(axis=0), 0.0) + eps)

    cov = ((sx - mx) * (y[:, None] - my)).mean(axis=0)
    return cov / (stdx * sy + eps)


def _seq_auroc_vectorized(scores: np.ndarray, labels: np.ndarray) -> np.ndarray:
    """Compute AUROC per feature; if no pos or no neg → 0.5."""
    pos_mask = labels == 1
    neg_mask = ~pos_mask
    n1, n0 = pos_mask.sum(), neg_mask.sum()
    K = scores.shape[1]
    if n1 == 0 or n0 == 0:
        return np.full((K,), 0.5, dtype=np.float32)

    out = np.zeros((K,), dtype=np.float32)
    for k in range(K):
        s = scores[:, k]
        pos = s[pos_mask]
        neg = s[neg_mask]
        # ranks with average ties
        order = np.argsort(np.concatenate([pos, neg]), kind="mergesort")
        ranks = np.empty_like(order, dtype=float)
        ranks[order] = np.arange(1, len(order) + 1, dtype=float)
        # convert to average-tie ranks
        x = np.concatenate([pos, neg])
        i = 0
        while i < x.size:
            j = i + 1
            while j < x.size and x[order[j]] == x[order[i]]:
                j += 1
            if j - i > 1:
                avg = 0.5 * (i + 1 + j)
                ranks[order[i:j]] = avg
            i = j
        R1 = ranks[:n1].sum()
        U1 = R1 - n1 * (n1 + 1) / 2.0
        out[k] = U1 / (n1 * n0)
    return out


def _rankdata_average_ties(x: np.ndarray) -> np.ndarray:
    """Average ranks for ties; 1..N."""
    order = np.argsort(x, kind="mergesort")
    ranks = np.empty_like(order, dtype=float)
    ranks[order] = np.arange(1, len(x) + 1, dtype=float)
    sx = x[order]
    i, n = 0, len(x)
    while i < n:
        j = i + 1
        while j < n and sx[j] == sx[i]:
            j += 1
        if j - i > 1:
            avg = 0.5 * (i + 1 + j)
            ranks[order[i:j]] = avg
        i = j
    return ranks