import numpy as np
from sklearn.mixture import GaussianMixture
#from skimage.filters import threshold_otsu       # handy fallback

def select_genes_by_l2_norm(
    X_raw: np.ndarray,
    *,
    log_eps: float = 1e-12,
    method: str = "gmm",         # {"gmm", "otsu", "percentile"}
    perc: float = 5,             # used if method == "percentile"
    return_mask: bool = False,
    plot: bool = False,
):
    """
    Decide a threshold on column ℓ₂ norms and return the filtered matrix.

    Parameters
    ----------
    X_raw       n×d count / log-count matrix  (cells × genes)
    method      'gmm'      – 2-component Gaussian mixture on log-norms
                'otsu'     – Otsu’s 1-D threshold on log-norms
                'percentile' – keep top (100-perc)% norms
    perc        percentile used if method == 'percentile'
    return_mask whether to return boolean mask along with X_filt
    plot        quick diagnostic plot (needs matplotlib)

    Returns
    -------
    X_filt      n × d' matrix with small-norm genes removed
    (optionally) keep_mask, threshold
    """
    n, d = X_raw.shape
    norms = np.linalg.norm(X_raw, axis=0)        # shape (d,)
    logn  = np.log10(norms + log_eps).reshape(-1, 1)

    # --- pick a threshold ---------------------------------------------------
    if method == "gmm":
        gmm = GaussianMixture(n_components=2, random_state=0).fit(logn)
        mu  = gmm.means_.ravel()
        low = np.argmin(mu)                      # index of the “tiny” cluster
        post = gmm.predict_proba(logn)[:, low]
        thr_log = logn[post < 0.5].min()         # boundary where P(tiny)=0.5
        threshold = 10 ** thr_log
    # elif method == "otsu":
    #     thr_log = threshold_otsu(logn.ravel())
        # threshold = 10 ** thr_log
    elif method == "percentile":
        threshold = np.percentile(norms, perc)
    else:
        raise ValueError("method must be 'gmm', 'otsu', or 'percentile'")

    keep = norms >= threshold
    X_filt = X_raw[:, keep]

    # --- optional sanity check ---------------------------------------------
    if plot:
        import matplotlib.pyplot as plt
        plt.hist(norms, bins=60, color="salmon", edgecolor="k", alpha=0.8)
        plt.axvline(threshold, ls="--", c="k", lw=2,
                    label=f"cut-off = {threshold:.2g}")
        plt.yscale("log")
        plt.xlabel("gene ℓ₂ norm")
        plt.legend()
        plt.title(f"Filter kept {keep.sum():,} / {d:,} genes")
        plt.tight_layout()
        plt.show()

    return (X_filt, keep, threshold) if return_mask else X_filt