"""
Nonlinear Gram GSW
"""


import numpy as np


def _inverse_downdate(K: np.ndarray, idx: int):
    """
    Remove the row/column at position idx from a symmetric inverse matrix K.

    The update uses a Schur-complement formula so that, if K = M^{-1},
    the returned array equals (M without row/col idx)^{-1}.
    """
    if K.ndim != 2 or K.shape[0] != K.shape[1]:
        raise ValueError("Inverse matrix must be square.")

    m = K.shape[0]
    if m == 0:
        return None
    if idx < 0 or idx >= m:
        raise IndexError("Index out of bounds for inverse downdate.")
    if m == 1:
        return None

    mask = np.ones(m, dtype=bool)
    mask[idx] = False

    alpha = K[idx, idx]
    if np.abs(alpha) < 1e-18:
        raise np.linalg.LinAlgError("Pivot too small during inverse downdate.")

    beta = K[mask, idx]
    K_sub = K[np.ix_(mask, mask)]
    K_sub -= np.outer(beta, beta) / alpha
    return K_sub


def GSwalk_poly_aug_fast(
    V: np.ndarray,
    k: int,
    ϕ: float,
    rng=None,
    tol: float = 1e-12,
    cache=None,
    weights=None,
    balance=False,
    exp=False,
) -> np.ndarray:
    """
    Args:
        V: array of shape (d, n)
        k: balance degree-k polynomial terms of features
        weights: weights for different degree terms
        ϕ: balance-robust trade-off, bi = [sqrt(ϕ) ei, sqrt(1-ϕ) psi(vi)]

    Returns:
        x: array of shape (n,)
    """

    if rng is None:
        rng = np.random.default_rng(12345)          # fixed seed
    elif isinstance(rng, (int, np.integer)):
        rng = np.random.default_rng(int(rng))   # seed -> Generator

    n = V.shape[1]
    x = np.zeros(n, dtype=float)
    live = np.arange(n)
    active = live[:-1].tolist()

    if cache is None:
        cache = _prepare_fast_cache(V, k, ϕ, weights=weights,balance=balance, exp=exp)

    gram_pow = cache["gram_pow"]
    K = cache["initial_inverse"]
    
    while live.size > 0:

        u = np.zeros(live.size, dtype=float)
        u[-1] = 1.0
        if active:
            pivot = live[-1]
            active_idx = np.array(active, dtype=int)
            phi_vt = gram_pow[np.ix_(active_idx, [pivot])].ravel()
            if K is None:
                raise RuntimeError("Missing inverse for active set.")
            u[:-1] = -K @ phi_vt

        ap = (1.0 - x[live]) / u
        an = -(1.0 + x[live]) / u

        dp = np.maximum(ap, an)
        dn = -np.minimum(ap, an)

        del_plus = np.min(dp)
        del_minus = np.min(dn)

        if rng.random() < (del_minus / (del_plus + del_minus)):
            x[live] = x[live] + del_plus * u
        else:
            x[live] = x[live] - del_minus * u

        assert np.max(np.abs(x)) <= 1.0 + tol

        live_new = np.where(np.abs(x) < 1.0 - tol)[0]
        if live_new.size == 0:
            break

        active_new = live_new[:-1].tolist()
        if len(active_new) < len(active):
            removed_vals = [val for val in active if val not in active_new]
            for val in removed_vals:
                idx = active.index(val)
                if K is not None:
                    K = _inverse_downdate(K, idx)
                active.pop(idx)

        active = active_new
        if active:
            if K is None or K.shape[0] != len(active):
                M = gram_pow[np.ix_(active, active)].copy()
                np.fill_diagonal(M, M.diagonal() + (ϕ / (1 - ϕ)))
                K = np.linalg.inv(M)
        else:
            K = None

        live = live_new

    return x


def _prepare_fast_cache(V: np.ndarray, k: int, phi: float, weights=None, balance=False, exp=False) -> dict:

    # weights for feature vectors in V

    if not (0.0 <= phi < 1.0):
        raise ValueError("phi must satisfy 0 <= phi < 1")

    gram = V.T @ V  # n-by-n matrix 

    if exp:
        gram_pow = np.exp(gram)
    else:
        powers = np.arange(1, k + 1, dtype=float)
        if weights is None:
            weights = np.ones_like(powers)
        else:
            weights = np.asarray(weights, dtype=float)
            if weights.shape[0] != powers.shape[0]:
                raise ValueError(f"weights must have length {k}")
        
        # Enforce nonnegative weights and normalize to sum to 1.
        weights = np.abs(weights)
        total = weights.sum()
        if total <= 0.0:
            raise ValueError("weights must have at least one positive entry")
        weights = weights / total

        # The Gram matrix: weighted sum of gram powers
        gram_pow = np.sum((gram[..., None] ** powers) * weights, axis=2)
        if balance:
            gram_pow = gram_pow + (1/(np.exp(1)-1))

    max_diag = np.max(np.diag(gram_pow))
    if max_diag > 0:
        gram_pow = gram_pow / max_diag
            
    cache = {"gram_pow": gram_pow}

    # compute the initial matrix inverse: inv(lam*I + gram_pow) on the first n-1 entries
    n = V.shape[1]
    if n > 1:
        active = np.arange(n - 1)
        lam = phi / (1 - phi)
        M = gram_pow[np.ix_(active, active)].copy()
        np.fill_diagonal(M, M.diagonal() + lam)
        cache["initial_inverse"] = np.linalg.inv(M)
    else:
        cache["initial_inverse"] = None

    return cache

def GSwalk_poly_aug_many(V, k, phi, its, weights=None, balance=False, exp=False):
    """
    Args:
        V: array of shape (d, n), original feature vector
        k: balance degree-k polynomial terms of features
        ϕ: balance-robust trade-off, bi = [sqrt(ϕ) ei, sqrt(1-ϕ) vi]
        weights: weights for feature vector in V
        balance = True: weigth a_0 = 1/(e-1) for 1 to match the exponential kernel
        exp = True: use the exponential kernel
    """

    n = V.shape[1]
    assignments: list[list[int]] = []
    rng = np.random.default_rng(12345)  # single seed for reproducibility

    # Normalize so the maximum column norm is at most 1 (if nonzero).
    col_norms = np.linalg.norm(V, axis=0)
    max_norm = col_norms.max()
    if max_norm > 0:
        V = V / max_norm

    cache = _prepare_fast_cache(V, k, phi, weights=weights, balance=balance, exp=exp)  # degree-1 up to k terms

    for _ in range(its):

        x = GSwalk_poly_aug_fast(V, k, phi, rng, cache=cache, balance=balance, weights=weights, exp=exp)
        
        assignment = np.where(x >= 0, 1, -1).astype(int).tolist()
        assignments.append(assignment)

    return assignments


def GSwalk_kernel_many(alg_func, V, phi, its, weights=None):
    """
    Convenience wrapper for GSwalk_poly_aug_cov with k fixed at 1.
    """
    return GSwalk_poly_aug_many(V, 1, phi, its, weights=weights)
