import numpy as np
import torch


def _orthonormal_columns(m, n, rng):
    A = rng.standard_normal((m, n))
    Q, _ = np.linalg.qr(A, mode="reduced")
    return Q  # m x n, Q^T Q = I_n

def _normalize_columns(A, eps=1e-12):
    nrm = np.linalg.norm(A, axis=0)
    nrm = np.where(nrm < eps, 1.0, nrm)
    return A / nrm

def _sample_in_kerspace(S, k, rng):
    """
    Return k vectors in ker(S) ⊂ R^D (as columns), roughly orthonormal.
    We build an orthonormal basis of ker(S) and draw combinations.
    """
    N, D = S.shape[0], S.shape[1]
    # Build Q (D x N) with orthonormal columns spanning range(S^T)
    # (Reuse from S: rows of S are orthonormal ⇒ columns of S^T are Q)
    # We need an orthonormal basis for ker(S): any Q_perp with columns orth to Q.
    # Construct by projecting random to orth complement and QR.
    Q = S.T  # D x N, columns orthonormal since SS^T = I_N
    R = rng.standard_normal((D, max(k, D - N)))
    R = R - Q @ (Q.T @ R)         # project to orth complement of Q
    Qperp, _ = np.linalg.qr(R)    # D x t, columns span ker(S)
    # Keep only as many as we need:
    U = Qperp[:, :k]              # D x k
    return U

def make_E_S_biweighted(V, D, N, r=4.0, seed=None, as_tensors=True):
    """
    Build E (D x V), S (N x D) s.t.
      - For key columns (first V//2): diag(E^T E)=1, diag((SE)^T(SE))=1
      - For value columns (last V//2): diag(E^T E)=r, diag((SE)^T(SE))=1/r
    Also keeps off-diagonals small (random near-orthogonality).
    Requires D > N if r != 1.
    """
    if r <= 0:
        raise ValueError("r must be > 0.")
    if r != 1 and D <= N:
        raise ValueError(f"Need D > N to realize r≠1 (ker S nonempty). Got D={D}, N={N}.")
    rng = np.random.default_rng(seed)

    # 1) Make S with orthonormal rows: SS^T = I_N.
    #    Build Q (D x N) with orthonormal cols, then S = Q^T.
    Q = _orthonormal_columns(D, N, rng)
    S = Q.T  # N x D

    # 2) Build Z (N x V) with ~unit-norm columns (small off-diagonals).
    Z = _normalize_columns(rng.standard_normal((N, V)))  # ||z_j||=1

    # 3) Keys: e_k = S^T z_k   (beta=1, u=0)
    Vh = V // 2
    E = np.empty((D, V))
    E[:, :Vh] = S.T @ Z[:, :Vh]

    # 4) Values: e_v = beta * (S^T z_v + u_v),
    #    with beta^2 = 1/r, ||z_v||^2 = 1, ||u_v||^2 = r^2 - 1, u_v ∈ ker S.
    if r == 1:
        beta = 1.0
        # no ker component needed
        E[:, Vh:] = S.T @ Z[:, Vh:]
    else:
        beta = 1.0 / np.sqrt(r)
        need = V - Vh
        # draw u_v from ker(S), roughly orthonormal, then scale each to sqrt(r^2-1)
        Uker = _sample_in_kerspace(S, min(need, D - N), rng)  # D x m
        Ucols = []
        # Reuse orthonormal Uker columns first; if need more, draw random ker vectors
        for j in range(need):
            if j < Uker.shape[1]:
                u = Uker[:, j]
            else:
                # extra: random in ker(S)
                w = rng.standard_normal(D)
                w = w - Q @ (Q.T @ w)  # project to ker(S)
                n = np.linalg.norm(w)
                u = w / (n if n > 1e-12 else 1.0)
            Ucols.append(u)
        U = np.stack(Ucols, axis=1)  # D x need
        scale = np.sqrt(r*r - 1.0)
        E[:, Vh:] = beta * (S.T @ Z[:, Vh:] + scale * U)

    if as_tensors:
        E = torch.as_tensor(E, dtype=torch.float)
        S = torch.as_tensor(S, dtype=torch.float)

    return E, S