import numpy as np


# -----------------------------
# 1) TR (Transformation–Retransformation)
# -----------------------------

def _safe_eigh_psd(S: np.ndarray, ridge: float = 1e-8):
    """
    Eigen-decomposition stable for SPD/PSD matrices.
    Clips eigenvalues to be >= eps.
    """
    S = 0.5 * (S + S.T)  # symmetrize
    evals, evecs = np.linalg.eigh(S)
    eps = ridge * max(1.0, float(np.max(evals)))
    evals_clipped = np.clip(evals, eps, None)
    return evals_clipped, evecs


def fit_tr_simple(Y: np.ndarray, ridge: float = 1e-6, ddof: int = 1):
    """
    Fit TR standardization using mean + covariance (simple start).
    Returns a dict with:
      - m: mean (d,)
      - S: covariance (d,d)
      - sqrtS, invsqrtS: matrix square-root and inverse square-root
    ridge is scaled by trace(S)/d for stability.
    """
    Y = np.asarray(Y, dtype=float)
    n, d = Y.shape
    m = Y.mean(axis=0)

    Yc = Y - m
    S = (Yc.T @ Yc) / max(1, (n - ddof))

    # Ridge scaled to average variance
    scale = float(np.trace(S)) / max(1, d)
    S_r = S + (ridge * max(1e-12, scale)) * np.eye(d)

    evals, evecs = _safe_eigh_psd(S_r, ridge=ridge)
    sqrtS = (evecs * np.sqrt(evals)) @ evecs.T
    invsqrtS = (evecs * (1.0 / np.sqrt(evals))) @ evecs.T

    return {"m": m, "S": S_r, "sqrtS": sqrtS, "invsqrtS": invsqrtS}


def tr_transform(Y: np.ndarray, tr: dict):
    """
    Z = S^{-1/2} (Y - m) for row-vectors.
    """
    Y = np.asarray(Y, dtype=float)
    return (Y - tr["m"]) @ tr["invsqrtS"]


def tr_retransform(Z: np.ndarray, tr: dict):
    """
    Y = m + S^{1/2} Z
    """
    Z = np.asarray(Z, dtype=float)
    return tr["m"] + Z @ tr["sqrtS"]


# -----------------------------
# 4) Typical pipeline for you: TR then ranks/quantiles
# -----------------------------

def fit_tr_and_standardize(Y_rank: np.ndarray, ridge: float = 1e-6):
    """
    Fit TR on Y_rank split, return TR params and standardized version.
    """
    tr = fit_tr_simple(Y_rank, ridge=ridge)
    Z_rank = tr_transform(Y_rank, tr)
    return tr, Z_rank
