import numpy as np
from scipy.sparse import lil_matrix
from scipy.spatial.distance import pdist, squareform
from sklearn.mixture import GaussianMixture
from math import pi
import warnings

# fix the random seed for reproducibility
np.random.seed(0)


def hyperbolic_distance(xi, xj):
    """
    Compute hyperbolic distance between two points xi, xj in (theta, r) coords.
    """
    delta_theta = pi - np.abs(pi - np.abs(xi[0] - xj[0]))
    try:
        return np.arccosh(
            np.cosh(xi[1]) * np.cosh(xj[1])
            - np.sinh(xi[1]) * np.sinh(xj[1]) * np.cos(delta_theta)
        )
    except Exception:
        # fallback if numeric issues arise
        return np.abs(xi[1] - xj[1])
    
    
def hyperbolic_distance_vectorized(x, y):
    """
    Vectorised hyperbolic distance between two *sets* of points given in
    (theta, r) polar coordinates.

    Parameters
    ----------
    x : (..., 2) array_like
        First point or batch of points.  Last dimension must be 2 -> (θ, r).
    y : (..., 2) array_like
        Second point or batch of points (broadcastable against `x`).

    Returns
    -------
    ndarray
        Hyperbolic distances with the broadcasted shape of x[...,0] ⊕ y[...,0].

    Notes
    -----
    * Implements:  d = arccosh( cosh r₁ · cosh r₂ − sinh r₁ · sinh r₂ · cos Δθ )
    * Uses the “wrapped” angular difference  Δθ = π − |π − |θ₁−θ₂||.
    * Numerical guard: clips the arccosh argument to ≥ 1 to avoid NaNs.
    """
    x = np.asarray(x, dtype=float)
    y = np.asarray(y, dtype=float)

    # Broadcast θ and r components independently
    theta_x = np.expand_dims(x[..., 0], axis=-1)
    r_x     = np.expand_dims(x[..., 1], axis=-1)

    theta_y = y[..., 0]
    r_y     = y[..., 1]

    # Δθ on the [0, π] range
    delta_theta = pi - np.abs(pi - np.abs(theta_x - theta_y))

    # arccosh argument
    arg = np.cosh(r_x) * np.cosh(r_y) - np.sinh(r_x) * np.sinh(r_y) * np.cos(delta_theta)

    # numerical safeguard: arccosh is defined for arg ≥ 1
    arg = np.clip(arg, 1.0, None)

    return np.arccosh(arg)


def rewire_matrix(adj, rewire_mode):
    """
    Optionally rewire a bipartite adjacency matrix to 'uniform' or 'random'.
    """
    if rewire_mode == 'none':
        return adj

    N_in, N_out = adj.shape
    adj_coo = adj.tocoo()
    edges = np.column_stack((adj_coo.row, adj_coo.col))

    if rewire_mode == 'uniform':
        # target B-degrees as equal as possible
        total = len(edges)
        base = total // N_out
        rem = total % N_out
        target_deg = np.full(N_out, base, dtype=int)
        target_deg[:rem] += 1
        deg_b = np.zeros(N_out, dtype=int)
        # attempt to reassign endpoints
        for _ in range(10):
            for i in range(len(edges)):
                u, v_old = edges[i]
                candidates = np.where(deg_b < target_deg)[0]
                if candidates.size == 0:
                    break
                v_new = np.random.choice(candidates)
                if v_new != v_old:
                    deg_b[v_old] -= 1
                    deg_b[v_new] += 1
                    edges[i, 1] = v_new

    elif rewire_mode == 'random':
        # preserve input degrees, random B endpoints
        new_edges = []
        inp_deg = np.array(adj.sum(axis=1)).flatten()
        for u in range(N_in):               # <-- fixed: iterate over inputs
            d = inp_deg[u]
            if d > 0:
                bs = np.random.choice(N_out, size=int(d), replace=False)
                for v in bs:
                    new_edges.append((u, v))
        edges = np.array(new_edges, dtype=int)

    # rebuild matrix
    new_adj = lil_matrix((N_in, N_out), dtype=int)
    for u, v in edges:
        new_adj[u, v] = 1
    return new_adj


def bipartite_network(
    N_in,
    N_out,
    sparsity,
    T,
    gamma,
    theta,
    comm,
    rewire_mode
):
    """
    Build a bipartite nPSO network with input size N_in and output N_out.
    'sparsity' in [0,1) controls fraction of possible A->B links omitted.
    """
    size_multiple = N_out // N_in
    N = N_in + N_out
    beta = 1 / (gamma - 1)

    # angular + radial coordinates
    coords = np.zeros((N, 2))
    coords[:, 0] = theta

    comm_A = comm[:N_in]
    comm_B = comm[N_in:]
    A_nodes, B_nodes = [], []
    edges = []

    # assign radial coords layer by layer
    for t in range(N):
        layer = t // (size_multiple + 1) + 1
        idx_in_layer = t % (size_multiple + 1)
        if idx_in_layer == 0:
            A_nodes.append(t)
            coords[t, 1] = 2 * np.log(layer)
            # update older A nodes
            for q in range(1, layer):
                idx = (q - 1) * (size_multiple + 1)
                coords[idx, 1] = beta * 2 * np.log(q) + (1 - beta) * 2 * np.log(layer)
        else:
            B_nodes.append(t)
            coords[t, 1] = 2 * np.log(len(B_nodes))

    # edge creation for B->A
    for t in range(N):
        if t in A_nodes:
            continue

        # determine eligible A targets up to current layer
        curr_B = sum(1 for b in B_nodes if b <= t)
        curr_layer_A = (curr_B - 1) // size_multiple + 1
        targets = A_nodes[:curr_layer_A]
        if not targets:
            continue

        # compute number of connections based on sparsity
        possible = len(targets)
        curr_m = max(1, int(round((1 - sparsity) * possible)))

        # select endpoints
        if possible <= curr_m:
            chosen = targets
        else:
            xi = coords[t]
            xj = coords[targets]
            dtheta = pi - np.abs(pi - np.abs(xi[0] - xj[:, 0]))
            dist = np.arccosh(
                np.cosh(xi[1]) * np.cosh(xj[:, 1])
                - np.sinh(xi[1]) * np.sinh(xj[:, 1]) * np.cos(dtheta)
            )
            if T == 0:
                idxs = np.argsort(dist)[:curr_m]
                chosen = [targets[i] for i in idxs]
            else:
                Rt = 2 * np.log(curr_B) - 2 * np.log(
                    (2 * T * (1 - np.exp(-(1 - beta) * np.log(curr_B))))
                    / (np.sin(T * pi) * (1 - beta))
                )
                probs = 1 / (1 + np.exp((dist - Rt) / (2 * T)))
                probs /= probs.sum()
                chosen = list(np.random.choice(targets, curr_m, p=probs, replace=False))

        edges.extend((t, a) for a in chosen)

    # build adjacency
    adj = lil_matrix((N_in, N_out), dtype=int)
    for u, v in edges:
        i = A_nodes.index(v)
        j = B_nodes.index(u)
        adj[i, j] = 1

    # optional rewiring
    adj = rewire_matrix(adj, rewire_mode)

    # distances
    coords_A = coords[A_nodes]
    coords_B = coords[B_nodes]
    dist_matrix = np.zeros((N_in, N_out))
    for i in range(N_in):
        for j in range(N_out):
            dist_matrix[i, j] = hyperbolic_distance(coords_A[i], coords_B[j])

    return adj, dist_matrix, coords_A, coords_B, comm_A, comm_B

def calculate_adjacency_and_distance_matrices_bipartite(
    a: int,
    b: int,
    sparsity: float,
    T: float,
    gamma: float,
    theta: np.ndarray,
    comm: np.ndarray | None = None,
    rewire_mode: str | None = None,          # still unused, as in MATLAB
):
    """
    Straight-port of the MATLAB code you supplied.

    Returns
    -------
    x_bip : (a, b) ndarray[int]
        Bipartite adjacency matrix (A rows × B columns).

    d_bip : (a, b) ndarray[float]
        Hyperbolic distances for the 1-entries of `x_bip` (0 elsewhere).

    coords_A, coords_B : (N, 2) ndarray[float]
        Radial / angular coordinates maintained separately for the two groups.

    comm_A, comm_B : (a,), (b,) ndarray[int] | None
        Community labels split for groups A and B (if `comm` was given).
    """
    rng = np.random.default_rng()

    N              = a + b
    size_multiple  = b // a                 # integer by assumption
    beta           = 1.0 / (gamma - 1.0)
    
    m = (1-sparsity)*a*b/(a+b)
    coords = np.zeros((N, 2))
    coords[:, 0] = theta

    # ---- helpers for “floor/ceil with probability” ------------------------ #
    m_a = m * (size_multiple + 1)/2
    m_b = m_a / size_multiple
    
    
    #m_a_floor, m_a_ceil = int(np.floor(m_a)), int(np.ceil(m_a))
    m_b_floor, m_b_ceil = int(np.floor(m_b)), int(np.ceil(m_b))
    
    
    #floor_a_p = m_a_ceil - m_a             # prob. of using floor value
    floor_b_p = m_b_ceil - m_b

    # ---- data containers -------------------------------------------------- #
    coords_A = coords.copy()
    coords_B = coords.copy()

    A = [0]        # list of node-indices that belong to group-A   (0-based!)
    B = []         # list of node-indices that belong to group-B

    prev_targets   = set()
    prev_targets_B = set()
    edges          = []                    # (u, v) tuples – directionless

    # ---------------------------------------------------------------------- #
    for t in range(1, N + 1):              # MATLAB 1…N   (t == node-ID 1-based)
        curr_t   = int(np.ceil(t / (1 + size_multiple)))
        seq      = (t - 1) % (1 + size_multiple)
        curr_t_B = (curr_t - 1) * size_multiple + seq

        # `idx` – 0-based index of the *current* node in the coord-arrays
        idx = t - 1

        # ── All nodes update their group-A radius first (as in your code) --
        coords_A[idx, 1] = 2.0 * np.log(curr_t)

        # ================================================================== #
        #  CASE ① :  This is the *A-layer* node inside the bundle            #
        # ================================================================== #
        if seq == 0:
            coords_B[idx, 1] = 2.0 * np.log((curr_t - 1) * size_multiple + 1)

            # Pre-append the next A-node’s index so that A always contains
            #  the first `curr_t` group-A nodes when we need them later
            nxt_A = t + size_multiple+1       # (still 1-based)
            if nxt_A <= N:
                A.append(nxt_A - 1)         # store 0-based

            # ------- popularity-fading for every earlier node --------------
            for q in range(1, curr_t):
                for j in range(size_multiple + 1):
                    idq = q * (1 + size_multiple) - j - 1
                    coords_A[idq, 1] = (beta * 2.0 * np.log(q) +
                                        (1 - beta) * 2.0 * np.log(curr_t))

            prev_targets.clear()
            prev_targets_B.clear()
            continue                        # go build next node ---►

        # ================================================================== #
        #  CASE ② :  This is a *B-layer* node (seq > 0)                      #
        # ================================================================== #
        coords_B[idx, 1] = 2.0 * np.log(curr_t_B)
        B.append(idx)

        # Potential opposite-group targets that haven’t been hit this round
        targets    = [u for u in A[:curr_t - 1]        if u not in prev_targets]
        targets_B  = [v for v in B[:curr_t_B - 1]      if v not in prev_targets_B]

        # ---------- popularity-fading for coords_B ------------------------- #
        for q in range(1, curr_t):
            for j in range(size_multiple):
                idq = q * (1 + size_multiple) - j - 1
                coords_B[idq, 1] = (beta * 2.0 * np.log(q * size_multiple - j) +
                                    (1 - beta) * 2.0 * np.log(curr_t_B))
            id_edge = q * (1 + size_multiple) - size_multiple - 1
            coords_B[id_edge, 1] = (beta * 2.0 * np.log((q - 1) * size_multiple + 1) +
                                    (1 - beta) * 2.0 * np.log(curr_t_B))

        # ---- how many stubs to realise for *this* B-node ------------------ #
        curr_m = m_b_floor if rng.random() < floor_b_p else m_b_ceil

        # ------------------------------------------------------------------ #
        # 2-A) connect this B-node to A-targets                              #
        # ------------------------------------------------------------------ #
        def _choose_nodes(dists, k, fixed_Rt=None):
            "helper – pick k indices from `dists` either closest or by prob."
            if T == 0:
                return np.argsort(dists)[:k]
            # else: soft connection probability
            Rt = fixed_Rt
            probs = 1.0 / (1.0 + np.exp((dists - Rt) / (2.0 * T)))
            probs /= probs.sum()
            return rng.choice(np.arange(len(dists)), size=k,
                              replace=False, p=probs)

        if curr_m >= len(targets):          # connect to *all* of them
            sel_A = targets
        else:
            dists = hyperbolic_distance_vectorized(coords_A[idx], coords_A[targets])
            if T == 0:
                sel_A = [targets[i] for i in np.argsort(dists)[:curr_m]]
            else:
                if beta == 1:
                    Rt = (2.0 * np.log(curr_t)
                          - 2.0 * np.log((2.0 * T * np.log(curr_t))
                                         / (np.sin(np.pi * T) * m)))
                else:
                    Rt = (2.0 * np.log(curr_t)
                          - 2.0 * np.log((2.0 * T *
                                          (1 - np.exp(-(1 - beta)
                                                      * np.log(curr_t))))
                                         / (np.sin(np.pi * T) * m * (1 - beta))))
                idx_sel = _choose_nodes(dists, curr_m, fixed_Rt=Rt)
                sel_A   = [targets[i] for i in idx_sel]

        # record the links just chosen
        for u in sel_A:
            edges.append((idx, u))
            prev_targets.add(u)

        # ------------------------------------------------------------------ #
        # 2-B) symmetric step: connect the *next* A-node to earlier B nodes  #
        #      (unless we already created all A-nodes)                       #
        # ------------------------------------------------------------------ #
        if curr_t < a:
            next_A = A[curr_t]                 # 0-based
            if curr_m >= len(targets_B):
                sel_B = targets_B
            else:
                dists_B = hyperbolic_distance_vectorized(coords_B[t-1], coords_B[targets_B])
                if T == 0:
                    sel_B = [targets_B[i] for i in np.argsort(dists_B)[:curr_m]]
                else:
                    if beta == 1:
                        Rt = (2.0 * np.log(curr_t_B)
                              - 2.0 * np.log((2.0 * T * np.log(curr_t_B))
                                             / (np.sin(np.pi * T) * m / 2.0)))
                    else:
                        Rt = (2.0 * np.log(curr_t_B)
                              - 2.0 * np.log((2.0 * T *
                                              (1 - np.exp(-(1 - beta)
                                                          * np.log(curr_t_B))))
                                             / (np.sin(np.pi * T) * m / 2.0
                                                * (1 - beta))))
                    idx_sel = _choose_nodes(dists_B, curr_m, fixed_Rt=Rt)
                    sel_B   = [targets_B[i] for i in idx_sel]

            for v in sel_B:
                edges.append((next_A, v))
                prev_targets_B.add(v)

    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
    #                   BUILD  BIPARTITE  MATRICES                           #
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
    # (drop sentinel if we pushed one beyond N)
    if len(A) > a:          # we know exactly how many A-nodes we want
        A = A[:a]

    # fast index-look-ups
    row_map = {u: r for r, u in enumerate(A)}
    col_map = {v: c for c, v in enumerate(B)}

    x_bip = np.zeros((a, b), dtype=int)
    d_bip = np.zeros((a, b), dtype=float)

    # full pairwise distances once, then slice
    d_all = squareform(pdist(coords, metric=hyperbolic_distance))

    for u, v in edges:
        # undirected – make sure we map (A, B) → (row, col)
        if   u in row_map and v in col_map:
            r, c = row_map[u], col_map[v]
        elif v in row_map and u in col_map:
            r, c = row_map[v], col_map[u]
        else:
            raise ValueError("Edge within the same group detected!")

        if x_bip[r, c] == 1:
            warnings.warn("Duplicate edge detected between "
                          f"A[{r}] and B[{c}]")
        x_bip[r, c] = 1
        d_bip[r, c] = d_all[u, v]

    # community labels split (if supplied)
    if comm is not None and len(comm):
        comm_A = comm[A]
        comm_B = comm[B]
    else:
        comm_A = comm_B = None

    return x_bip, d_bip, coords_A, coords_B, comm_A, comm_B


def nPSO_bipartite(
    N_in,
    N_out,
    sparsity,
    T,
    gamma,
    distr,
    rewire_mode='none'
):
    """
    User-facing entry for bipartite nPSO: returns adj, distances, coords, communities.
    """
    # validation
    if not (isinstance(N_in, int) and N_in >= 1):
        raise ValueError("N_in must be a positive integer")
    if not (isinstance(N_out, int) and N_out >= 0):
        raise ValueError("N_out must be a non-negative integer")
    if not (0 <= sparsity < 1):
        raise ValueError("sparsity must be in [0,1)")
    if T < 0:
        raise ValueError("T must be non-negative")
    if gamma < 2:
        raise ValueError("gamma must be >=2")
    if N_out < N_in or (N_out % N_in) != 0:
        raise ValueError("For bipartite, N_out must be >= N_in and a multiple of N_in")

    # angular distribution
    N = N_in + N_out
    if isinstance(distr, int):
        if distr == 0:
            theta = np.random.uniform(0, 2*pi, N)
            comm = np.zeros(N, dtype=int)
        else:
            C = distr
            mu = np.linspace(0, 2*pi, C, endpoint=False)
            sigma = (2*pi/(6*C))**2
            weights = np.ones(C) / C
            gm = GaussianMixture(n_components=C, covariance_type='spherical')
            gm.means_ = mu.reshape(-1, 1)
            gm.covariances_ = np.full(C, sigma)
            gm.weights_ = weights
            theta = gm.sample(N)[0].flatten() % (2*pi)
            comm = np.argmin(np.abs(theta[:, None] - mu), axis=1) % C
    elif isinstance(distr, tuple):
        angles, probs, mu = distr
        theta = np.random.choice(angles, N, p=probs)
        comm = np.argmin(np.abs(theta[:, None] - mu), axis=1)
    else:
        raise ValueError("Invalid distribution specification")

    return calculate_adjacency_and_distance_matrices_bipartite(
        N_in, N_out, sparsity, T, gamma, theta, comm, rewire_mode
    )
