# ========================== stats.py ================================= #
from functools import partial
import jax
import jax.numpy as jnp
from flax import struct

from . import sae, utils


# ---------- 0. helpers ------------------------------------------------ #
def _broadcast_mask(mask, ref):
    """Expand a (B,L) mask to match (B,L,…) tensors."""
    if mask is None:
        return None
    while mask.ndim < ref.ndim:
        mask = mask[..., None]
    return mask.astype(ref.dtype)


# ---------- 1. immutable accumulator --------------------------------- #
@struct.dataclass
class StatAccumulator:
    """
    Keeps running batch-averages.  We only store the *means* of the large
    per-neuron vectors; that is sufficient for almost all downstream plots.
    """

    n_batches: int = 0
    scalar_sums: dict = struct.field(pytree_node=False, default_factory=dict)
    vector_sums: dict = struct.field(pytree_node=False, default_factory=dict)

    # ── update ───────────────────────────────────────────────────────── #
    def update(self, batch_stats: dict):
        scalar_sums = dict(self.scalar_sums)
        vector_sums = dict(self.vector_sums)

        for k, v in batch_stats.items():
            if v.ndim == 0:  # scalar
                scalar_sums[k] = scalar_sums.get(k, 0.0) + float(v)
            elif k in ("p_active", "mean_mag", "act_var"):
                vector_sums[k] = vector_sums.get(k, 0.0) + v

        return self.replace(
            n_batches=self.n_batches + 1,
            scalar_sums=scalar_sums,
            vector_sums=vector_sums,
        )

    # ── finalise ─────────────────────────────────────────────────────── #
    def finalise(self):
        out = {k: v / self.n_batches for k, v in self.scalar_sums.items()}
        for k, v in self.vector_sums.items():
            out[k] = v / self.n_batches
        return out


# ---------- 2. per-batch statistics ---------------------------------- #
@partial(jax.jit, static_argnames=("topk",))
def compute_stats(inputs, reconstruction, latent_pre, act, mask_tok, *, topk=20):
    """
    Calculate all scalar + per-neuron vector metrics for *one* SAE and *one*
    batch.  Masking ignores BOS/EOS/PAD the same way the LM loss does.
    """
    eps = 1e-9

    # ---- reconstruction quality -------------------------------------- #
    mask_xyz = _broadcast_mask(mask_tok, inputs)  # (B,L,1)
    err2 = jnp.square(reconstruction - inputs) * mask_xyz
    mse = jnp.sum(err2) / (jnp.sum(mask_xyz) + eps)

    power = jnp.sum(jnp.square(inputs) * mask_xyz) / (jnp.sum(mask_xyz) + eps)
    nmse = mse / (power + eps)
    explained = 1.0 - nmse

    # ---- activation statistics --------------------------------------- #
    tok_cnt = jnp.sum(mask_tok) + eps  # number of *valid* tokens
    mask_lat = _broadcast_mask(mask_tok, act)  # (B,L,K)

    p_active = jnp.sum((act > 0) * mask_lat, axis=(0, 1)) / tok_cnt
    mean_mag = jnp.sum(act * mask_lat, axis=(0, 1)) / tok_cnt

    mean_sq = jnp.sum(jnp.square(act) * mask_lat, axis=(0, 1)) / tok_cnt
    act_var = mean_sq - jnp.square(mean_mag)

    dead_frac = jnp.mean(p_active == 0.0)

    # ---- Hoyer sparsity (global, scalar) ----------------------------- #
    x_masked = act * mask_lat
    l1 = jnp.sum(jnp.abs(x_masked))
    l2 = jnp.sqrt(jnp.sum(jnp.square(x_masked)) + eps)
    n = x_masked.size
    hoyer = (jnp.sqrt(n) - l1 / (l2 + eps)) / (jnp.sqrt(n) - 1 + eps)

    return dict(
        mse=mse,
        nmse=nmse,
        explained=explained,
        hoyer=hoyer,
        dead_frac=dead_frac,
        p_active=p_active,
        mean_mag=mean_mag,
        act_var=act_var,
    )


def batch_stats(
    batch,
    sae_params_lst,
    hooks,
    run_fn,
    mask_fn,
    configs,
    lm_params,
    lm_config,
    accumulators,
    topk=20,
):
    seqs = batch["inputs"]
    positions = batch["positions"]
    mask_tok = mask_fn(seqs)  # (B,L) 1 → real tokens

    residuals = utils.run_and_capture(
        run_fn, seqs, positions, lm_params, lm_config, hooks
    )

    all_stats, new_accs = [], []
    for params, cfg, acc in zip(sae_params_lst, configs, accumulators):
        sae_inp = residuals[(cfg.layer_id, cfg.placement)]
        rec, lat, act = sae.run(
            sae_inp,
            params.params if hasattr(params, "params") else params,
            cfg,
            return_latents=True,
            return_act=True,
        )
        stats = compute_stats(sae_inp, rec, lat, act, mask_tok, topk=topk)
        all_stats.append(stats)
        new_accs.append(acc.update(stats))

    return all_stats, new_accs


def prettyprint(
    stats_list,
    sae_configs,
    extra_scalar_keys=None,
    float_fmt="{:>10.3e}",
    title="SAE Validation Summary",
):
    default_keys = ["mse", "nmse", "explained", "hoyer", "dead_frac"]
    if extra_scalar_keys:
        default_keys.extend(extra_scalar_keys)
    scalar_keys = [k for k in default_keys if k in stats_list[0]]

    rows = []
    for cfg, d in zip(sae_configs, stats_list):
        lbl = f"L{cfg.layer_id}-{cfg.placement.name:<5}  {cfg.latent_multiplier:>2}× → {cfg.latent_size}"
        rows.append([lbl] + [float_fmt.format(d[k]) for k in scalar_keys])

    header = ["SAE"] + [k.upper() for k in scalar_keys]

    print("\n" + title)
    print("-" * len(title))
    try:
        from tabulate import tabulate

        print(tabulate(rows, headers=header, tablefmt="github"))
    except ImportError:
        col_w = [max(len(str(x)) for x in col) for col in zip(header, *rows)]
        fmt = "  ".join(f"{{:<{w}}}" for w in col_w)
        print(fmt.format(*header))
        print("-" * (sum(col_w) + 2 * (len(col_w) - 1)))
        for r in rows:
            print(fmt.format(*r))
    print()


def save_npz(stats_list, sae_configs, path):
    to_save = {}
    for cfg, d in zip(sae_configs, stats_list):
        prefix = f"L{cfg.layer_id}_{cfg.placement.name}"
        for k, v in d.items():
            to_save[f"{prefix}/{k}"] = jnp.asarray(v)
    jnp.savez(path, **to_save)
    print(f"[stats] saved → {path}")
