import numpy as np
from scipy.linalg import sqrtm, pinv

def laplacian_from_adjacency(A):
    """Compute the combinatorial Laplacian L = D - A from an adjacency matrix A."""
    D = np.diag(np.sum(A, axis=1))
    L = D - A
    return L


def adjacency_from_laplacian(L):
    A = -L.copy()
    np.fill_diagonal(A, 0)
    return A


def bures_wasserstein_mean_laplacian(L0, L1, t):
    # Compute pseudo-inverses.
    L0_dagger = pinv(L0)
    L1_dagger = pinv(L1)

    # Compute matrix square roots.
    L0_sqrt = sqrtm(L0).real
    L0_dagger_sqrt = sqrtm(L0_dagger).real

    # Compute the middle term:
    middle_matrix = L0_dagger_sqrt @ L1_dagger @ L0_dagger_sqrt
    middle_term = sqrtm(middle_matrix).real

    # Convex combination in the pseudo-inverse space.
    convex_combo = (1 - t) * L0_dagger + t * middle_term

    # Assemble S_t using the closed-form formula:
    S_t = L0_sqrt @ (convex_combo @ convex_combo) @ L0_sqrt

    # The mean Laplacian is given by the pseudo-inverse of S_t.
    L_t = pinv(S_t)
    return L_t


def bures_wasserstein_mean_adjacency(A0, A1, t, threshold=0.5, eps=1e-3):

    # Compute Laplacians.
    L0 = laplacian_from_adjacency(A0)
    L1 = laplacian_from_adjacency(A1)

    # Regularize to ensure (pseudo)inverses exist.
    N = A0.shape[0]
    L0_reg = L0 + eps * np.eye(N)
    L1_reg = L1 + eps * np.eye(N)

    # Compute the mean Laplacian using the Bures-Wasserstein formula.
    L_mean = bures_wasserstein_mean_laplacian(L0_reg, L1_reg, t)

    # Recover a continuous adjacency: off-diagonals satisfy A = -L.
    A_mean_cont = adjacency_from_laplacian(L_mean)
    A_mean_binary = (A_mean_cont > threshold).astype(int)


    # Optionally, symmetrize the result.
    A_mean_binary = np.triu(A_mean_binary, k=1)
    A_mean_binary = A_mean_binary + A_mean_binary.T
    return A_mean_binary



