import numpy as np
import pandas as pd
from scipy.stats import kurtosis, entropy
from itertools import combinations
from sklearn.metrics.pairwise import cosine_similarity   # pip install scikit-learn if needed

# ---------- helper functions ---------- #
def _l0_density(X):                               # % of non-zero entries (mean over samples)
    return np.mean((X != 0).sum(axis=1) / X.shape[1])

def _hoyer(x, eps=1e-12):
    l1 = np.linalg.norm(x, 1)
    l2 = np.linalg.norm(x, 2) + eps
    d  = x.size
    return (np.sqrt(d) - l1 / l2) / (np.sqrt(d) - 1)

def _gini(x):
    """Gini on |x| (vector version)."""
    x = np.abs(x.flatten())
    if np.allclose(x, 0):     # avoid divide-by-zero
        return 0.0
    sorted_x = np.sort(x)
    n = len(x)
    cum = np.cumsum(sorted_x)
    gini = 1 - 2 * np.sum(cum) / (n * sorted_x.sum()) + (1 / n)
    return gini

def _ipr(x, eps=1e-12):
    p = x**2
    p = p / (p.sum() + eps)
    return (p**2).sum()

def _eff_dim(x, eps=1e-12):
    return 1.0 / _ipr(x, eps)

def _entropy_sparsity(x, eps=1e-12):
    p = x**2
    p = p / (p.sum() + eps)
    h = entropy(p)           # base e
    return 1 - h / np.log(len(x))   # 0 (dense) … 1 (sparse)

def _pairwise_cosine_overlap(X):
    if len(X) < 2:
        return np.nan
    Xn = X / np.linalg.norm(X, axis=1, keepdims=True).clip(min=1e-12)
    sims = cosine_similarity(Xn)
    iu = np.triu_indices_from(sims, 1)
    return np.mean(np.abs(sims[iu]))

def _lifetime_population_sparsity(X):
    """Returns (lifetime, population) sparsity as averages in [0,1]."""
    active = (X != 0)
    lifetime  = 1 - active.mean(axis=0)       # per feature → avg
    population = 1 - active.mean(axis=1)      # per sample  → avg
    return lifetime.mean(), population.mean()

# ---------- master routine ---------- #
def sparsity_metrics_table(X1, X2, names=('Set-1', 'Set-2')):
    """
    X1, X2 : np.ndarray of shape (N, D) or (D,)
             Each *row* is a feature vector / embedding.
    """
    X1 = np.atleast_2d(X1)
    X2 = np.atleast_2d(X2)

    rows = []
    for X in (X1, X2):
        # compute metric per sample where needed, then average
        l0       = _l0_density(X)
        hoyer    = np.mean([_hoyer(v)                   for v in X])
        gini     = np.mean([_gini(v)                    for v in X])
        kurt     = np.mean([kurtosis(v, fisher=False)   for v in X])
        ent      = np.mean([_entropy_sparsity(v)        for v in X])
        ipr      = np.mean([_ipr(v)                     for v in X])
        effdim   = np.mean([_eff_dim(v)                 for v in X])
        cosine   = _pairwise_cosine_overlap(X)
        life, pop = _lifetime_population_sparsity(X)

        rows.append([l0, hoyer, gini, kurt, ent,
                     ipr, effdim, cosine, life, pop])

    cols = ['% non-zero', 'Hoyer', 'Gini', 'Kurtosis',
            'Entropy-spar.', 'IPR', 'Eff.dim',
            'Mean |cos|', 'Lifetime', 'Population']
    df = pd.DataFrame(rows, index=names, columns=cols).round(4)
    return df