from typing import Optional, Tuple
import numpy as np


def er_graph(n: int, p: float, weight: float = 1.0, seed: int = 0) -> np.ndarray:
    rng = np.random.default_rng(seed)
    A = rng.random((n, n)) < p
    A = np.triu(A, 1)
    A = A + A.T
    W = A.astype(float) * weight
    np.fill_diagonal(W, 0.0)
    return W


def rbf_graph_from_latents(n: int, q: int, rho: float, threshold: float = 0.0, seed: int = 0) -> np.ndarray:
    rng = np.random.default_rng(seed)
    Z = rng.normal(size=(n, q))
    D2 = ((Z[:, None, :] - Z[None, :, :]) ** 2).sum(-1)
    W = np.exp(-rho * D2)
    np.fill_diagonal(W, 0.0)
    if threshold > 0.0:
        W = np.where(W >= threshold, W, 0.0)
    return W


def laplacian(W: np.ndarray) -> np.ndarray:
    d = W.sum(axis=1)
    return np.diag(d) - W


def laplacian_with_ridge(W: np.ndarray, rho: float) -> np.ndarray:
    return laplacian(W) + rho * np.eye(W.shape[0])


def inv_sqrt_psd(M: np.ndarray, jitter: float = 1e-8, r_trunc: Optional[int] = None) -> Tuple[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
    w, U = np.linalg.eigh(M + jitter * np.eye(M.shape[0]))
    if r_trunc is not None and r_trunc < len(w):
        idx = np.argsort(w)[::-1][:r_trunc]
        w = w[idx]
        U = U[:, idx]
    w = np.maximum(w, 0.0)
    inv_sqrt = (U * (1.0 / np.sqrt(w + jitter))) @ U.T
    return inv_sqrt, (w, U)


def spectral_ratio(L: np.ndarray, tol: float = 1e-12) -> float:
    """
    Compute S_spec = lambda_2(L) / lambda_max(L) for the (combinatorial) Laplacian L.
    - Returns 0.0 if the graph is disconnected (lambda_2 ~ 0) or L is near-zero.
    - Assumes L is symmetric PSD. Uses dense eig since n <= O(1e3) in your benches.
      If you later scale n >> 1e3, switch to sparse eigs for the two extreme eigenvalues.
    """
    L = 0.5 * (L + L.T)
    evals = np.linalg.eigvalsh(L)
    lam_max = float(evals[-1]) if evals.size else 0.0
    lam2 = 0.0
    for v in evals:
        if v > tol:
            lam2 = float(v)
            break
    if lam_max <= tol:
        return 0.0
    return max(0.0, min(1.0, lam2 / lam_max))

def sbm_graph(
    n: int,
    n_blocks: int,
    p_in: float,
    p_out: float,
    weight: float = 1.0,
    seed: int = 0,
    balanced: bool = True,
) -> np.ndarray:
    """
    Simple stochastic block model (SBM) graph.

    - n: total number of nodes (users)
    - n_blocks: number of communities
    - p_in: connection probability within a block
    - p_out: connection probability across blocks
    - weight: edge weight for present edges
    - balanced: if True, blocks are as equal-sized as possible

    Returns:
        W: (n x n) symmetric adjacency/weight matrix with zero diagonal.
    """
    rng = np.random.default_rng(seed)

    if n_blocks <= 0 or n_blocks > n:
        raise ValueError(f"n_blocks must be in [1, n], got n_blocks={n_blocks}, n={n}")

    if balanced:
        # Deterministic "as equal as possible" partition
        sizes = np.full(n_blocks, n // n_blocks, dtype=int)
        sizes[: n % n_blocks] += 1
        labels = np.repeat(np.arange(n_blocks), sizes)
    else:
        # Slightly more random assignment, still roughly balanced
        labels = rng.integers(low=0, high=n_blocks, size=n)

    # Build probability matrix P_ij = p_in if same block else p_out
    P = np.full((n, n), p_out, dtype=float)
    for k in range(n_blocks):
        idx = np.where(labels == k)[0]
        if idx.size > 0:
            P[np.ix_(idx, idx)] = p_in

    A = rng.random((n, n)) < P
    A = np.triu(A, 1)
    A = A + A.T

    W = A.astype(float) * weight
    np.fill_diagonal(W, 0.0)
    return W
