import jax
import jax.numpy as jnp
import numpy as np

import heapq
import joblib
from pathlib import Path
from dataclasses import dataclass, field
from functools import partial

import umap
import pandas as pd
from sklearn.cluster import HDBSCAN


if not hasattr(pd.DataFrame, "append"):
    def _append(self, other, ignore_index=False, **kw):
        return pd.concat([self, other], ignore_index=ignore_index)

    pd.DataFrame.append = _append


def process_batch(batch, sae_kit, layer_id):
    inputs = batch["inputs"]
    positions = batch["positions"]
    sae_mask = sae_kit.mask_fn(inputs)[..., None]
    raw_inputs = [[sae_kit.tokenizer.id_to_token(tok) for tok in seq] for seq in inputs]

    acts = sae_kit.get_encoded(inputs, positions, layer_id)
    masked_acts = acts * sae_mask
    return masked_acts, raw_inputs


@partial(jax.jit, static_argnums=(2, 3))
def _batch_stats(
    masked_acts: jnp.ndarray,
    flat_token_ids: jnp.ndarray,
    top_sequences: int,
    top_tokens: int,
):
    sum_ = jnp.sum(masked_acts, axis=(0, 1))
    sum2 = jnp.sum(masked_acts**2, axis=(0, 1))
    max_ = jnp.max(masked_acts, axis=(0, 1))
    above0 = jnp.sum(masked_acts > 0, axis=(0, 1))

    seq_scores = jnp.max(masked_acts, axis=1)
    seq_pos = jnp.argmax(masked_acts, axis=1)
    seq_vals, seq_idx = jax.lax.top_k(seq_scores.T, top_sequences)
    seq_pos_top = jnp.take_along_axis(seq_pos.T, seq_idx, axis=1)

    flat_scores = masked_acts.reshape(-1, masked_acts.shape[-1]).T
    tok_vals, tok_pos = jax.lax.top_k(flat_scores, top_tokens)
    tok_ids = jnp.take(flat_token_ids, tok_pos)

    return (sum_, sum2, max_, above0, seq_vals, seq_idx, seq_pos_top, tok_vals, tok_ids)


@dataclass
class StatsCollector:
    latent_size: int
    top_sequences: int = 50
    top_tokens: int = 50

    _sum: np.ndarray = field(init=False)
    _sum2: np.ndarray = field(init=False)
    _count: np.ndarray = field(init=False)
    _max: np.ndarray = field(init=False)
    _above0: np.ndarray = field(init=False)
    _seq_dicts: list = field(init=False)
    _tok_dicts: list = field(init=False)
    _num_active: np.ndarray = field(init=False)

    def __post_init__(self):
        D = self.latent_size
        self._sum = np.zeros(D, dtype=np.float32)
        self._sum2 = np.zeros(D, dtype=np.float32)
        self._count = np.zeros(D, dtype=np.int64)
        self._max = np.full(D, -np.inf, dtype=np.float32)
        self._above0 = np.zeros(D, dtype=np.int64)
        self._num_active = np.zeros(D, dtype=np.int64)
        self._seq_dicts = [dict() for _ in range(D)]
        self._tok_dicts = [dict() for _ in range(D)]

    def update(self, batch, sae_kit, layer_id):
        masked_acts, raw_inputs = process_batch(batch, sae_kit, layer_id)
        flat_token_ids = batch["inputs"].reshape(-1)
        (
            b_sum,
            b_sum2,
            b_max,
            b_above0,
            seq_vals,
            seq_idx,
            seq_pos,
            tok_vals,
            tok_ids,
        ) = jax.device_get(
            _batch_stats(
                masked_acts, flat_token_ids, self.top_sequences, self.top_tokens
            )
        )

        B, T = masked_acts.shape[:2]
        self._count += B * T
        self._sum += b_sum
        self._sum2 += b_sum2
        self._max = np.maximum(self._max, b_max)
        self._above0 += b_above0

        max_act_per_neuron = np.max(masked_acts, axis=1)
        active_on_seq = np.sum(
            max_act_per_neuron > 1e-6,
            axis=0,
        )
        self._num_active += active_on_seq

        for d in range(self.latent_size):
            seq_store = self._seq_dicts[d]
            for k in range(self.top_sequences):
                s = float(seq_vals[d, k])
                if s <= 0:
                    break
                b = int(seq_idx[d, k])
                p = int(seq_pos[d, k])
                toks = raw_inputs[b]
                key = tuple(toks[: p + 1])
                if key in seq_store:
                    best, cnt, best_toks = seq_store[key]
                    cnt += 1
                    if s > best:
                        seq_store[key] = (s, cnt, toks)
                    else:
                        seq_store[key] = (best, cnt, best_toks)
                else:
                    seq_store[key] = (s, 1, toks)
            if len(seq_store) > self.top_sequences:
                top = heapq.nlargest(
                    self.top_sequences, seq_store.items(), key=lambda kv: kv[1][0]
                )
                self._seq_dicts[d] = dict(top)

            tok_store = self._tok_dicts[d]
            for k in range(self.top_tokens):
                s = float(tok_vals[d, k])
                if s <= 0:
                    break
                tok = int(tok_ids[d, k])
                if tok in tok_store:
                    best, cnt, suma = tok_store[tok]
                    cnt += 1
                    suma += s
                    best = max(best, s)
                    tok_store[tok] = (best, cnt, suma)
                else:
                    tok_store[tok] = (s, 1, s)
            if len(tok_store) > self.top_tokens:
                top = heapq.nlargest(
                    self.top_tokens, tok_store.items(), key=lambda kv: kv[1][0]
                )
                self._tok_dicts[d] = {t: (b, c) for t, (b, c) in top}

    @property
    def mean(self):
        return self._sum / np.maximum(self._count, 1)

    @property
    def std(self):
        m = self.mean
        m2 = self._sum2 / np.maximum(self._count, 1)
        return np.sqrt(np.maximum(m2 - m**2, 0.0))

    @property
    def sparsity(self):
        return self._above0 / np.maximum(self._count, 1)

    @property
    def max(self):
        return self._max

    @property
    def num_seq_active(self):
        return self._num_active

    @property
    def frac_active(self):
        return self._num_active / np.maximum(self._count, 1)

    def top_sequences_for(self, neuron: int):
        return sorted(self._seq_dicts[neuron].values(), key=lambda x: -x[0])

    def top_tokens_for(self, neuron: int):
        return sorted(
            [(b, c, t) for t, (b, c) in self._tok_dicts[neuron].items()],
            key=lambda x: -x[0],
        )


def _collector_to_dict(col):
    return dict(
        meta=dict(
            latent_size=col.latent_size,
            top_sequences=col.top_sequences,
            top_tokens=col.top_tokens,
        ),
        sum=col._sum,
        sum2=col._sum2,
        count=col._count,
        max=col._max,
        above0=col._above0,
        seq_dicts=col._seq_dicts,
        tok_dicts=col._tok_dicts,
        num_active=col._num_active,
    )


def _dict_to_col(d):
    col = StatsCollector(
        latent_size=d["meta"]["latent_size"],
        top_sequences=d["meta"]["top_sequences"],
        top_tokens=d["meta"]["top_tokens"],
    )
    col._sum = d["sum"]
    col._sum2 = d["sum2"]
    col._count = d["count"]
    col._max = d["max"]
    col._above0 = d["above0"]
    col._seq_dicts = d["seq_dicts"]
    col._tok_dicts = d["tok_dicts"]
    col._num_active = d["num_active"]
    return col


def save_collector(col, path):
    path = Path(path)
    path.parent.mkdir(parents=True, exist_ok=True)
    joblib.dump(_collector_to_dict(col), path, compress=False)


def load_collector(path):
    raw = joblib.load(path, mmap_mode="r")
    return _dict_to_col(raw)


def selectivity_scores(col: StatsCollector):
    N = col._count  # (D,)
    Sa = col._sum  # Σa
    S2 = col._sum2  # Σa²
    mean = Sa / N
    var = S2 / N - mean**2
    sigma = np.sqrt(np.maximum(var, 1e-12))

    best_r = np.zeros(col.latent_size, dtype=np.float32)

    for d in range(col.latent_size):
        for tok, (best, Nt, Sat) in col._tok_dicts[d].items():
            p = Nt / N[d]
            if p == 0 or p == 1:
                continue
            mu1 = Sat / Nt
            mu0 = (Sa[d] - Sat) / (N[d] - Nt)
            r = (mu1 - mu0) / sigma[d] * np.sqrt(p * (1 - p))
            best_r[d] = max(best_r[d], r)

    return best_r

def gini_selectivity(col: StatsCollector):
    selectivity = np.zeros(col.latent_size, dtype=np.float32)

    for d in range(col.latent_size):
        token_activations = []

        # Collect activation sums for each token
        for tok, (_, _, sum_act) in col._tok_dicts[d].items():
            token_activations.append(sum_act)

        if not token_activations:
            continue

        # Calculate Gini coefficient
        token_activations = np.array(token_activations)
        token_activations = np.sort(token_activations)
        n = len(token_activations)

        index = np.arange(1, n + 1)
        gini = np.sum((2 * index - n - 1) * token_activations) / (
            n * np.sum(token_activations)
        )

        # Higher gini = more unequal distribution = more selective
        selectivity[d] = gini

    return selectivity

class NeuronSelector:
    def __init__(self, stats: StatsCollector):
        self.stats = stats

    def pick(self, metric: str = "mean", topk: int = 1, sort="desc"):
        if metric == "mean":
            scores = self.stats.mean
        elif metric == "max":
            scores = self.stats.max
        elif metric == "sparsity":
            scores = self.stats.sparsity
        elif metric == "mean*max":
            scores = self.stats.mean * self.stats.max
        elif metric == "selectivity":
            scores = selectivity_scores(self.stats)
        elif metric == "gini":
            scores = gini_selectivity(self.stats)
        else:
            raise ValueError(f"Unknown metric {metric}")

        if sort == "asc":
            best = np.argsort(scores)[:topk]
        else:
            best = np.argsort(scores)[::-1][:topk]
        return best if topk > 1 else int(best[0])



def cluster2d(matrix, random_state=2025):
    print(f"Running UMAP on W_dec with shape: {matrix.shape}")

    umap_model = umap.UMAP(
        n_neighbors=15,
        min_dist=0.1,
        n_components=2,
        metric="cosine",
        random_state=random_state,
        verbose=False,
    )
    coords_umap = umap_model.fit_transform(matrix)
    print("UMAP embedding computed:", coords_umap.shape)

    clusterer = HDBSCAN(min_cluster_size=10, metric="euclidean")  # cosine also OK
    labels = clusterer.fit_predict(coords_umap)  # -1 = noise
    n_clusters = np.unique(labels[labels >= 0]).size
    print("Found", n_clusters, "HDBSCAN clusters")

    idx_by_cluster = {c: np.where(labels == c)[0] for c in range(n_clusters)}

    return coords_umap, labels, idx_by_cluster