import math
import numpy as np
from math import pi

def hyperbolic_distance_vectorized(x_new, coords_existing):
    """Compute hyperbolic distances from one new node to existing ones."""
    thetas = np.abs(x_new[0] - coords_existing[:, 0])
    dtheta = np.pi - np.abs(np.pi - thetas)
    cosh_term = np.cosh(x_new[1]) * np.cosh(coords_existing[:, 1])
    sinh_term = np.sinh(x_new[1]) * np.sinh(coords_existing[:, 1]) * np.cos(dtheta)
    arg = cosh_term - sinh_term
    arg = np.clip(arg, 1.0, None)
    return np.arccosh(arg)


def _choose_nodes(rng, dists, k, T, Rt=None):
    """Pick k indices by distance or soft-probabilities."""
    if T == 0:
        return np.argsort(dists)[:k]
    probs = 1.0 / (1.0 + np.exp((dists - Rt) / (2.0 * T)))
    probs /= probs.sum()
    return rng.choice(len(dists), size=k, replace=False, p=probs)


def generate_mlp_topology(
    a: int,
    b: int,
    c: int,
    d: int,
    sparsity: float,
    T: float,
    gamma: float,
    theta_A: np.ndarray,
    theta_B: np.ndarray,
    theta_C: np.ndarray,
    theta_D: np.ndarray
):
    """
    Generate 4-layer MLP topology (A->B->C->D) under hyperbolic model.
    Returns:
      x_DC, x_CB, x_BA: adjacency matrices (no distances)
      coords_A, coords_B, coords_C, coords_D: final (layer_size, 2) arrays
    """
    rng = np.random.default_rng()
    beta = 1.0 / (gamma - 1.0)

    # proportions
    g = math.gcd(math.gcd(a, b), math.gcd(c, d))
    a_p, b_p, c_p, d_p = a // g, b // g, c // g, d // g

    # base stub counts per layer
    m_DC_base = c * (1 - sparsity)
    m_CB_base = b * (1 - sparsity)
    m_AB_base = (1 - sparsity) * a * b / (a + b)

    # floors, ceils, probabilities
    def get_floor_ceil_prob(x):
        f, c_ = math.floor(x), math.ceil(x)
        p_f = c_ - x
        return f, c_, p_f

    DC_f, DC_c, DC_p = get_floor_ceil_prob(m_DC_base)
    CB_f, CB_c, CB_p = get_floor_ceil_prob(m_CB_base)
    AB_f, AB_c, AB_p = get_floor_ceil_prob(m_AB_base)

    # init coords
    coords_A = np.zeros((a, 2)); coords_A[:, 0] = theta_A
    coords_B = np.zeros((b, 2)); coords_B[:, 0] = theta_B
    coords_C = np.zeros((c, 2)); coords_C[:, 0] = theta_C
    coords_D = np.zeros((d, 2)); coords_D[:, 0] = theta_D

    # adjacency
    x_DC = np.zeros((d, c), dtype=int)
    x_CB = np.zeros((c, b), dtype=int)
    x_BA = np.zeros((b, a), dtype=int)

    def fade(coords, idx):
        t = idx + 1
        coords[idx, 1] = 2.0 * np.log(t)
        for q in range(1, t):
            coords[q - 1, 1] = (beta * 2.0 * np.log(q)
                                + (1 - beta) * 2.0 * np.log(t))

    # seed layers and fully connect seeds
    for i in range(a_p): fade(coords_A, i)
    for j in range(b_p):
        fade(coords_B, j)
        x_BA[j, :a_p] = 1
    for k in range(c_p):
        fade(coords_C, k)
        x_CB[k, :b_p] = 1
    for l in range(d_p):
        fade(coords_D, l)
        x_DC[l, :c_p] = 1

    # back-propagation rounds
    for r in range(1, g):
        iD = r * d_p; iC = r * c_p; iB = r * b_p; iA = r * a_p

        # D->C
        for j in range(d_p):
            seen = set()
            idx = iD + j
            fade(coords_D, idx)
            m = DC_f if rng.random() < DC_p else DC_c
            targets = [u for u in range(iC + c_p) if u not in seen]
            dists = hyperbolic_distance_vectorized(coords_D[idx], coords_C[targets])
            sel = (_choose_nodes(rng, dists, m, T, Rt=np.median(dists))
                   if T else np.argsort(dists)[:m])
            for ii in sel:
                tgt = targets[ii]
                x_DC[idx, tgt] = 1; seen.add(tgt)

        # C->B
        for k in range(c_p):
            seen = set()
            idx = iC + k
            fade(coords_C, idx)
            m = CB_f if rng.random() < CB_p else CB_c
            targets = [u for u in range(iB + b_p) if u not in seen]
            dists = hyperbolic_distance_vectorized(coords_C[idx], coords_B[targets])
            sel = (_choose_nodes(rng, dists, m, T, Rt=np.median(dists))
                   if T else np.argsort(dists)[:m])
            for ii in sel:
                tgt = targets[ii]
                x_CB[idx, tgt] = 1; seen.add(tgt)

        # B->A
        for j in range(b_p):
            seen = set()
            idx = iB + j
            fade(coords_B, idx)
            m = AB_f if rng.random() < AB_p else AB_c
            targets = [u for u in range(iA + a_p) if u not in seen]
            dists = hyperbolic_distance_vectorized(coords_B[idx], coords_A[targets])
            sel = (_choose_nodes(rng, dists, m, T, Rt=np.median(dists))
                   if T else np.argsort(dists)[:m])
            for ii in sel:
                tgt = targets[ii]
                x_BA[idx, tgt] = 1; seen.add(tgt)

        # A->B
        for i in range(a_p):
            seen = set()
            idx = iA + i
            fade(coords_A, idx)
            m = AB_f if rng.random() < AB_p else AB_c
            targets = [u for u in range(iB + b_p) if u not in seen]
            dists = hyperbolic_distance_vectorized(coords_A[idx], coords_B[targets])
            sel = (_choose_nodes(rng, dists, m, T, Rt=np.median(dists))
                   if T else np.argsort(dists)[:m])
            for ii in sel:
                tgt = targets[ii]
                x_BA[tgt, idx] = 1; seen.add(tgt)

    return x_DC, x_CB, x_BA, coords_A, coords_B, coords_C, coords_D

def nPSO_quadpartite(
    a,
    b,
    c,
    d,
    sparsity,
    T,
    gamma,
    distr
):
    """
    User-facing entry for bipartite nPSO: returns adj, distances, coords, communities.
    """
    # validation
    if not (isinstance(a, int) and a >= 1):
        raise ValueError("a must be a positive integer")
    if not (isinstance(b, int) and b >= 1):
        raise ValueError("b must be a positive integer")
    if not (isinstance(c, int) and c >= 1):
        raise ValueError("c must be a positive integer")
    if not (isinstance(d, int) and d >= 1):
        raise ValueError("d must be a positive integer")
    if not (0 <= sparsity < 1):
        raise ValueError("sparsity must be in [0,1)")
    if T < 0:
        raise ValueError("T must be non-negative")
    if gamma < 2:
        raise ValueError("gamma must be >=2")
        
    if distr == "community":
        theta_A = np.random.uniform(0, 0.5*pi, a)
        theta_B = np.random.uniform(0.5*pi, pi, b)
        theta_C = np.random.uniform(pi, 1.5*pi, c)
        theta_D = np.random.uniform(1.5*pi, 2*pi, d)
    elif distr == "random":
        theta_A = np.random.uniform(0, 2*pi, a)
        theta_B = np.random.uniform(0, 2*pi, b)
        theta_C = np.random.uniform(0, 2*pi, c)
        theta_D = np.random.uniform(0, 2*pi, d)

    return generate_mlp_topology(
        a, b, c, d, sparsity, T, gamma, theta_A, theta_B, theta_C, theta_D
    )