import torch

# ---------- utilities: implicit projector applied to features (explicit phi-tilde) ----------

def _project_features(Phi: torch.Tensor, Y: torch.Tensor, ridge: float):
    """
    Phi: (n, d)
    Y:   (n, k)
    Return Phi_tilde = (I - Y (Y^T Y + ridge I)^{-1} Y^T) Phi
    without forming the n x n projector.
    """
    if Y.numel() == 0 or Y.shape[1] == 0:
        return Phi
    # G = Y^T Y + ridge I
    G = Y.T @ Y
    k = G.shape[0]
    G = G + ridge * torch.eye(k, device=Phi.device, dtype=Phi.dtype)
    # Solve G C = Y^T Phi
    A = Y.T @ Phi
    L = torch.linalg.cholesky(G)
    C = torch.cholesky_solve(A, L)
    return Phi - Y @ C


# ---------- core: given lambda, compute top eigenvector x in sample space + phi(x) ----------

def _top_for_lambda_explicit_tilde(
    phi_x: torch.Tensor,
    phi_y: torch.Tensor,
    Y_prev: torch.Tensor,       # (n, k) previous sample-space directions
    lam: float,
    eps_chol: float = 1e-7,
    ridge_proj: float = 1e-6,
):
    """
    Solve (in the orthogonal complement of Y_prev):
        x(lam) = argmax_{||x||=1} x^T (K1_tilde - lam K2_tilde) x
    where K1_tilde = Phi_x_tilde Phi_x_tilde^T, K2_tilde = Phi_y_tilde Phi_y_tilde^T.
    Return:
        x      : (n,) normalized sample-space vector
        phi    : x^T K2_tilde x = ||Phi_y_tilde^T x||^2
        disc   : top eigenvalue of Theta = Z J Z^T ( "original eigenvalue" in this pipeline)
    """
    device = phi_x.device
    dtype = phi_x.dtype
    n, dx = phi_x.shape
    _, dy = phi_y.shape
    D = dx + dy

    # 1) explicit tilde features
    Phi_x_t = _project_features(phi_x, Y_prev, ridge_proj)  # (n, dx)
    Phi_y_t = _project_features(phi_y, Y_prev, ridge_proj)  # (n, dy)

    # 2) build F = [Phi_x_t, sqrt(lam)*Phi_y_t]
    lam_t = float(lam)
    if lam_t < 0:
        raise ValueError("lambda must be non-negative for sqrt.")
    sqrt_lam = torch.sqrt(torch.tensor(lam_t, device=device, dtype=dtype))
    X = Phi_x_t
    Y = sqrt_lam * Phi_y_t  # (n, dy)

    # 3) covariance C = (1/n) F^T F
    C11 = (X.T @ X) / n
    C12 = (X.T @ Y) / n
    C21 = C12.T
    C22 = (Y.T @ Y) / n
    C = torch.vstack([torch.hstack([C11, C12]),
                      torch.hstack([C21, C22])])

    # 4) Theta = Z J Z^T
    J = torch.diag(torch.cat([
        torch.ones(dx, device=device, dtype=dtype),
        -torch.ones(dy, device=device, dtype=dtype),
    ]))

    Z = torch.linalg.cholesky(C + eps_chol * torch.eye(D, device=device, dtype=dtype), upper=True)
    Theta = Z @ J @ Z.T

    evals, evecs = torch.linalg.eigh(Theta)
    idx = torch.argmax(evals)
    disc = evals[idx].real
    w = evecs[:, idx].real

    # 5) map back: v = J Z^T w, x = F v
    v = J @ (Z.T @ w)
    v = v / (torch.linalg.norm(v) + 1e-30)

    v1, v2 = v[:dx], v[dx:]
    x = X @ v1 + Y @ v2
    x = x / (torch.linalg.norm(x) + 1e-30)

    # 6) phi(lam) = (1/n) * x^T K2 x = (1/n) *||Phi_y_t^T x||^2
    # note: K2_tilde = Phi_y_t Phi_y_t^T (WITHOUT sqrt_lam)
    tmp = Phi_y_t.T @ x
    phi = (tmp @ tmp).real / n  # scalar

    return x, phi, disc


# ---------- bisection for lambda in one deflated subspace ----------

def _bisect_lambda_one_explicit(
    phi_x: torch.Tensor,
    phi_y: torch.Tensor,
    mu_k2: float,
    Y_prev: torch.Tensor,
    lam_min: float = 0.0,
    lam_max: float = 1.0,
    tol: float = 1e-6,
    max_iter: int = 50,
    eps_chol: float = 1e-7,
    ridge_proj: float = 1e-6,
):
    """
    Find lambda such that phi(lambda) = x(lambda)^T K2_tilde x(lambda) ~= mu_k2.
    Return (x, lam*, phi*, disc*).
    """
    # evaluate at ends
    x_lo, phi_lo, disc_lo = _top_for_lambda_explicit_tilde(
        phi_x, phi_y, Y_prev, lam_min, eps_chol, ridge_proj
    )
    # if already <= mu, constraint inactive -> lambda = 0 (or lam_min)
    if phi_lo <= mu_k2 + tol:
        return x_lo, float(lam_min), float(phi_lo), float(disc_lo)

    x_hi, phi_hi, disc_hi = _top_for_lambda_explicit_tilde(
        phi_x, phi_y, Y_prev, lam_max, eps_chol, ridge_proj
    )
    # expand lam_max until phi_hi <= mu
    while phi_hi > mu_k2:
        lam_max *= 2.0
        if lam_max > 1e8:
            raise ValueError("Failed to bracket: mu_k2 may be too small or numerical issues in K2.")
        x_hi, phi_hi, disc_hi = _top_for_lambda_explicit_tilde(
            phi_x, phi_y, Y_prev, lam_max, eps_chol, ridge_proj
        )

    lam_lo, lam_hi = lam_min, lam_max
    x_mid, phi_mid, disc_mid = x_hi, phi_hi, disc_hi
    for _ in range(max_iter):
        lam_mid = 0.5 * (lam_lo + lam_hi)
        x_mid, phi_mid, disc_mid = _top_for_lambda_explicit_tilde(
            phi_x, phi_y, Y_prev, lam_mid, eps_chol, ridge_proj
        )
        if abs(phi_mid - mu_k2) <= tol:
            return x_mid, float(lam_mid), float(phi_mid), float(disc_mid)
        # monotone: if phi > mu, need larger lambda
        if phi_mid > mu_k2:
            lam_lo = lam_mid
        else:
            lam_hi = lam_mid

    return x_mid, float(0.5 * (lam_lo + lam_hi)), float(phi_mid), float(disc_mid)


# ---------- public API: compute top-m deflated directions with per-step bisection ----------

@torch.no_grad()
def topm_with_k2_constraint_bisect_lambda(
    phi_x: torch.Tensor,
    phi_y: torch.Tensor,
    mu_k2: float,
    m: int = 10,
    lam_min: float = 0.0,
    lam_max: float = 1.0,
    tol: float = 1e-6,
    max_bisect_iter: int = 50,
    eps_chol: float = 1e-7,
    ridge_proj: float = 1e-6,
    reorth: int = 2,
):
    """
    Iteratively find {x_t} with x_t ⟂ span(x_1..x_{t-1}), ||x_t||=1,
    and satisfy x_t^T K2_tilde x_t ~= mu_k2 (if active) by bisection on lambda_t.

    Returns:
        X      : (n, m) columns are x_t
        lams   : list length m, lambda_t
        phis   : list length m, achieved phi_t
        discs  : list length m, original top eigenvalue (disc value) at lambda_t
    """
    device = phi_x.device
    dtype = phi_x.dtype
    n = phi_x.shape[0]
    Y_prev = torch.zeros((n, 0), device=device, dtype=dtype)

    X_cols, lams, phis, discs = [], [], [], []
    for t in range(m):
        x_t, lam_t, phi_t, disc_t = _bisect_lambda_one_explicit(
            phi_x, phi_y, mu_k2, Y_prev,
            lam_min=lam_min, lam_max=lam_max,
            tol=tol, max_iter=max_bisect_iter,
            eps_chol=eps_chol, ridge_proj=ridge_proj,
        )

        # # optional re-orth in sample space for numerical stability
        # if Y_prev.shape[1] > 0 and reorth > 0:
        #     for _ in range(reorth):
        #         # project out components in span(Y_prev): x <- x - Y (Y^T x)
        #         x_t = x_t - Y_prev @ (Y_prev.T @ x_t)
        #         x_t = x_t / (torch.linalg.norm(x_t) + 1e-30)

        # append
        X_cols.append(x_t)
        lams.append(torch.as_tensor(lam_t, device=device, dtype=dtype))
        phis.append(torch.as_tensor(phi_t, device=device, dtype=dtype))
        discs.append(torch.as_tensor(disc_t, device=device, dtype=dtype))
        Y_prev = torch.cat([Y_prev, x_t.unsqueeze(1)], dim=1)

    X = torch.stack(X_cols, dim=1)      # (n, m)
    lams = torch.stack(lams)            # (m,)
    phis = torch.stack(phis)            # (m,)
    discs = torch.stack(discs)          # (m,)
    return X, lams, phis, discs
