import numpy as np
import torch
import sys
sys.setrecursionlimit(10**7)

def add_high_corr_edges_optimized(edge_index, cov_matrix, threshold=0.0):
    """
    Vectorized addition of edges whose correlation > threshold.
    """
    is_torch = isinstance(edge_index, torch.Tensor)
    # to numpy
    E = edge_index.shape[1]
    edges = edge_index.cpu().numpy() if is_torch else edge_index
    cov = cov_matrix.cpu().numpy() if is_torch else cov_matrix
    V = cov.shape[0]

    # 1) Mask upper triangle by threshold
    iu, ju = np.triu_indices(V, k=1)
    mask = cov[iu, ju] > threshold

    # 2) Find candidate pairs
    cand_u, cand_v = iu[mask], ju[mask]

    # 3) Remove already-existing undirected edges
    #    create a boolean map of existing edges
    exist = np.zeros((V, V), dtype=bool)
    exist[edges[0], edges[1]] = True
    exist[edges[1], edges[0]] = True

    new_mask = ~exist[cand_u, cand_v]
    new_u, new_v = cand_u[new_mask], cand_v[new_mask]

    # 4) Build directed list and concat
    if new_u.size == 0:
        return edge_index

    new_edges = np.vstack([np.concatenate([new_u, new_v]),
                           np.concatenate([new_v, new_u])])
    out = np.hstack([edges, new_edges])
    return torch.from_numpy(out).long() if is_torch else out

def remove_low_corr_edges_optimized(edge_index, cov_matrix, threshold=0.0):
    """
    Vectorized removal of edges whose correlation < threshold.

    Args:
        edge_index (np.ndarray or torch.Tensor): shape (2, E) array of directed edges.
        cov_matrix (np.ndarray or torch.Tensor): shape (V, V) symmetric correlation matrix.
        threshold (float): cutoff below which edges are dropped.

    Returns:
        same type as edge_index: 2×E' array with only edges having cov >= threshold.
    """
    # Detect backend
    is_torch = isinstance(edge_index, torch.Tensor)

    if is_torch:
        # Work entirely in PyTorch
        u = edge_index[0].long()
        v = edge_index[1].long()
        # boolean mask: True for edges to keep
        keep = cov_matrix[u, v] >= threshold
        return edge_index[:, keep]
    else:
        # Work in NumPy
        edges = np.asarray(edge_index)
        cov = np.asarray(cov_matrix)
        u = edges[0].astype(int)
        v = edges[1].astype(int)
        keep = cov[u, v] >= threshold
        return edges[:, keep]
    
    
def remove_low_corr_edges_preserve_connectivity_optimized(edge_index, cov_matrix, threshold=0.0):
    """
    Vectorized pruning of edges with cov<threshold, but never remove a bridge.
    Bridges are found in one Tarjan DFS on the undirected graph.
    """
    is_torch = isinstance(edge_index, torch.Tensor)
    edges = edge_index.cpu().numpy() if is_torch else edge_index
    cov = cov_matrix.cpu().numpy() if is_torch else cov_matrix
    V = cov.shape[0]

    # Build undirected adjacency list of unique edges
    adj = [[] for _ in range(V)]
    unique = set()
    for u, v in edges.T:
        u, v = int(u), int(v)
        if u==v: continue
        a,b = min(u,v), max(u,v)
        if (a,b) not in unique:
            unique.add((a,b))
            adj[a].append(b)
            adj[b].append(a)

    # Tarjan to find all bridges
    timer = 0
    disc = [-1]*V
    low  = [ 0]*V
    bridges = set()

    def dfs(u, parent):
        nonlocal timer
        disc[u] = low[u] = timer; timer += 1
        for v in adj[u]:
            if disc[v] == -1:
                dfs(v, u)
                low[u] = min(low[u], low[v])
                if low[v] > disc[u]:
                    # (u,v) is a bridge
                    bridges.add((min(u,v), max(u,v)))
            elif v != parent:
                low[u] = min(low[u], disc[v])

    # run on each component
    for i in range(V):
        if disc[i] == -1:
            dfs(i, -1)

    # Now determine which undirected edges are _both_ low-corr and _not_ a bridge
    removable = {uv for uv in unique if cov[uv] < threshold and uv not in bridges}

    # Build a mask over directed edges to keep
    u_arr, v_arr = edges[0].astype(int), edges[1].astype(int)
    keep = [ (min(u,v),max(u,v)) not in removable for u,v in zip(u_arr, v_arr) ]
    keep = np.array(keep, dtype=bool)

    pruned = edges[:, keep]
    return torch.from_numpy(pruned).long() if is_torch else pruned


def edge_corr_percentile(edge_index, cov_matrix, percentile):
    """
    Compute the given percentile of correlation values for existing edges.

    Args:
        edge_index (np.ndarray or torch.Tensor): shape (2, E) array of edges.
        cov_matrix (np.ndarray or torch.Tensor): shape (V, V) symmetric correlation matrix.
        percentile (float): in [0, 100], the desired percentile to compute.

    Returns:
        float: the percentile value of the edge correlations.
    """
    # Pull out edges as integer arrays
    if isinstance(edge_index, torch.Tensor):
        edges = edge_index.cpu().numpy()
    else:
        edges = np.asarray(edge_index)
    u = edges[0].astype(int)
    v = edges[1].astype(int)

    # Gather correlation values
    if isinstance(cov_matrix, torch.Tensor):
        corr_vals = cov_matrix.cpu().numpy()[u, v]
    else:
        corr_vals = np.asarray(cov_matrix)[u, v]

    # Compute percentile
    return float(np.percentile(corr_vals, percentile))


from collections import Counter
def compute_LI_edge(edge_index, labels, eps=1e-12):
    """
    Compute degree-weighted label-informativeness LI_edge for an undirected graph.

    Parameters
    ----------
    edge_index : array-like, shape (2, E)
        Each column is one undirected edge [u, v].
    labels : array-like, shape (N,)
        Label of each node (ints, strings, ...).
    eps : float
        Small constant to avoid log(0).

    Returns
    -------
    LI_edge : float
        The normalized label-informativeness.
    """
    edge_index = np.asarray(edge_index, dtype=int)    
    # Convert labels to a Python list of native Python types to ensure hashability
    if isinstance(labels, np.ndarray):
        # Convert numpy array to Python native types (int, float, etc.)
        labels_list = [x.item() if isinstance(x, np.number) else str(x) for x in labels]
    else:
        # Already a Python sequence - ensure it contains hashable types
        labels_list = [x if isinstance(x, (int, float, str, bool)) else str(x) for x in labels]
        
    if edge_index.ndim != 2 or edge_index.shape[0] != 2:
        raise ValueError("edge_index must be shape (2, E)")
    E = edge_index.shape[1]
    N = len(labels_list)
    if E == 0 or N == 0:
        raise ValueError("Graph must have at least one node and one edge.")
    if edge_index.min() < 0 or edge_index.max() >= N:
        raise ValueError("Node indices in edge_index out of bounds.")

    # 1) Compute node degrees (each undirected edge contributes to both endpoints)
    degrees = np.bincount(edge_index.flatten(), minlength=N)

    # 2) Degree-weighted marginals bar_p(c) = sum_{v:y_v=c} d(v) / (2E)
    bar_p = {}
    for c in set(labels_list):  # Use set() instead of np.unique for hashable types
        idx = np.where([l == c for l in labels_list])[0]  # Find indices where label matches
        D_c = degrees[idx].sum()
        bar_p[c] = D_c / (2 * E)

    # 3) Joint over oriented edges: for each {u,v}, count both (u→v) and (v→u)
    joint_counts = Counter()
    for u, v in edge_index.T:
        u_label = labels_list[u]  # Use Python native type from labels_list
        v_label = labels_list[v]  # Use Python native type from labels_list
        joint_counts[(u_label, v_label)] += 1
        joint_counts[(v_label, u_label)] += 1

    # Normalize to get p(c1,c2)
    p_joint = {pair: cnt / (2 * E) for pair, cnt in joint_counts.items()}

    # 4) Compute sums of p⋅log p
    sum_barp_log_barp = sum(p_c * np.log(p_c + eps) for p_c in bar_p.values())
    sum_pjoint_log_pjoint = sum(p12 * np.log(p12 + eps) for p12 in p_joint.values())

    # 5) Finally LI_edge = 2 – [Σ p log p] / [Σ bar_p log bar_p]
    LI_edge = 2 - (sum_pjoint_log_pjoint / (sum_barp_log_barp + eps))
    return LI_edge