import math
import networkx as nx
from collections import deque
from typing import Dict, List, Tuple, Iterable, Optional

###############################################################################
# Utilities
###############################################################################

def _is_complete_graph(G: nx.Graph) -> bool:
    """Return True if G is a (simple) complete graph."""
    n = G.number_of_nodes()
    m = G.number_of_edges()
    # simple, undirected assumption; for multigraphs, adapt accordingly
    return m == n * (n - 1) // 2

def _my_subgraph(G: nx.Graph,
                 vertices: Iterable,
                 relabel: bool = False,
                 return_map: bool = False):
    """
    等价于原 _my_subgraph：仅保留连通关系（忽略属性），可选重标号到 [0..k-1]。
    """
    if not isinstance(G, (nx.Graph, nx.DiGraph)):
        raise ValueError("the input parameter must be a Graph")

    H = nx.Graph() if isinstance(G, nx.Graph) else nx.DiGraph()
    vertices = list(vertices)
    if not vertices:
        return (H, {}) if (relabel and return_map) else H

    if relabel:
        mapping = {v: i for i, v in enumerate(vertices)}
    else:
        mapping = {v: v for v in vertices}

    in_set = set(vertices)
    H.add_nodes_from(mapping[v] for v in vertices)

    # 仅根据连边关系复制
    for u in vertices:
        for v in G.neighbors(u):
            if v in in_set:
                H.add_edge(mapping[u], mapping[v])

    return (H, mapping) if (relabel and return_map) else H


###############################################################################
# Distances and far-apart pairs (faithful to the Cython reference)
###############################################################################

def _distances_and_far_apart_pairs(G: nx.Graph,
                                   int_to_vertex: List) -> Tuple[List[List[int]], List[List[int]], int]:
    """
    计算：
      - 所有点对的无权最短路距离矩阵 distances (N x N, int)
      - far_apart_pairs (N x N, 0/1；对角线为0；其余初始化为1，在 BFS 中按原逻辑置0)
      - 直径 D
    复刻 Cython 版本：
      far_apart 初始化为全1（非对角），当从 source 出发的 BFS 中遇到边 (v,u) 且
      dist[u] == dist[v] + 1 时，将 (source, v) 与 (v, source) 置 0。
    """
    n = len(int_to_vertex)
    index_of = {v: i for i, v in enumerate(int_to_vertex)}

    # 初始化
    INF = 10**9
    distances = [[INF] * n for _ in range(n)]
    far_apart = [[1] * n for _ in range(n)]
    for i in range(n):
        far_apart[i][i] = 0

    # BFS for each source (unweighted graph)
    for s_idx, s in enumerate(int_to_vertex):
        # BFS
        dist = distances[s_idx]
        dist[s_idx] = 0
        q = deque([s])
        seen = {s}

        while q:
            v = q.popleft()
            dv = dist[index_of[v]]
            for u in G.neighbors(v):
                ui = index_of[u]
                if u not in seen:
                    seen.add(u)
                    dist[ui] = dv + 1
                    q.append(u)
                # 若是树层推进的一条最短边，按原实现将 (s,v) 标为非 far-apart
                if dist[ui] == dv + 1:
                    vi = index_of[v]
                    far_apart[s_idx][vi] = 0
                    far_apart[vi][s_idx] = 0

    # 检查连通
    for i in range(n):
        for j in range(n):
            if distances[i][j] >= INF:
                raise ValueError("the input graph must be connected")

    # 直径
    D = 0
    for i in range(n):
        for j in range(n):
            if distances[i][j] > D:
                D = distances[i][j]
    return distances, far_apart, D


###############################################################################
# Counting sort of unordered pairs by an integer value (distance),
# with optional include-mask (far_apart)
###############################################################################

def _sort_pairs_by_value(N: int,
                         D: int,
                         values: List[List[int]],
                         to_include: Optional[List[List[int]]]
                         ) -> Tuple[List[Tuple[int,int]], List[int]]:
    """
    计数排序返回所有 {i,j} (i<j) ，按 values[i][j] 非降序排列。
    若 to_include 非空，仅包含 to_include[i][j] 为真者。
    返回：
      sorted_pairs: list of (i,j) 按值升序
      nb_pairs_of_length: 长度为 D+1 的数组，第 k 项为 value==k 的对数
    """
    nb_pairs_of_length = [0] * (D + 1)
    total = 0
    if to_include is None:
        for i in range(N):
            for j in range(i + 1, N):
                nb_pairs_of_length[values[i][j]] += 1
                total += 1
    else:
        for i in range(N):
            inc_row = to_include[i]
            for j in range(i + 1, N):
                if inc_row[j]:
                    nb_pairs_of_length[values[i][j]] += 1
                    total += 1

    # 前缀和位置
    offsets = [0] * (D + 1)
    run = 0
    for k in range(D + 1):
        offsets[k] = run
        run += nb_pairs_of_length[k]

    sorted_pairs = [None] * total  # type: ignore

    # 临时计数器（已经填了多少）
    placed = [0] * (D + 1)

    if to_include is None:
        for i in range(N):
            for j in range(i + 1, N):
                k = values[i][j]
                pos = offsets[k] + placed[k]
                sorted_pairs[pos] = (i, j)
                placed[k] += 1
    else:
        for i in range(N):
            inc_row = to_include[i]
            for j in range(i + 1, N):
                if inc_row[j]:
                    k = values[i][j]
                    pos = offsets[k] + placed[k]
                    sorted_pairs[pos] = (i, j)
                    placed[k] += 1

    return sorted_pairs, nb_pairs_of_length


###############################################################################
# BCCM main
###############################################################################

def _hyperbolicity_BCCM_core(N: int,
                             distances: List[List[int]],
                             far_apart_pairs: Optional[List[List[int]]],
                             D: int,
                             h_LB: int,
                             approximation_factor: float,
                             additive_gap: float,
                             verbose: bool = False
                             ) -> Tuple[int, List[int], int]:
    """
    返回 (h, certificate_indices, h_UB)
    注意：h 是“二倍超曲率”（与原实现一致），返回时由外层函数除以 2。
    """
    # —— farness 与 eccentricity
    farness = [0] * N
    ecc = [0] * N
    for a in range(N):
        row = distances[a]
        s = 0
        mx = 0
        for b in range(N):
            db = row[b]
            s += db
            if db > mx:
                mx = db
        farness[a] = s
        ecc[a] = mx
    central = min(range(N), key=lambda x: farness[x])
    dist_central = distances[central]

    # —— mates_decr_order_value[a]：按 (ecc[b] - dist[a][b]) 从大到小的 b 序列
    mates_decr_order_value = [[0]*N for _ in range(N)]
    # counting sort in range [0..D]
    for a in range(N):
        row = distances[a]
        counts = [0]*(D+1)
        vals = [0]*N
        for b in range(N):
            v = ecc[b] - row[b]
            # v ∈ [0..D]（因为 ecc[b] ≤ D 且 dist[a][b] ≥ 0）
            vals[b] = v
            counts[v] += 1
        # 从大到小 -> 反向前缀和
        cum = [0]*(D+1)
        # cum[k] = # of values > k
        bigger = 0
        for k in range(D, -1, -1):
            cum[k] = bigger
            bigger += counts[k]
        # 填充
        out = mates_decr_order_value[a]
        fill = [0]*(D+1)
        for b in range(N):
            v = vals[b]
            pos = cum[v] + fill[v]
            out[pos] = b
            fill[v] += 1

    # —— 对点对按距离从小到大计数排序；随后从大到小遍历
    to_include = far_apart_pairs  # None 表示全部包含（与原逻辑兼容）
    sorted_pairs, nb_pairs_of_length = _sort_pairs_by_value(
        N, D, distances, to_include
    )
    nb_p = len(sorted_pairs)

    if verbose:
        print(f"Current 2-connected component has {N} vertices and diameter {D}")
        if to_include is None:
            print(f"Number of pairs: {nb_p}")
        else:
            from math import comb
            print(f"Number of far-apart pairs: {nb_p}\t({comb(N,2)} pairs in total)")
        print("Repartition by distance:",
              [(k, nb_pairs_of_length[k]) for k in range(1, D+1) if nb_pairs_of_length[k] > 0])

    # —— 主循环
    # h = max(0, int(h_LB))       # 当前最好（二倍超曲率）
    h = max(0, h_LB)
    certificate: List[int] = [] # 4 个顶点索引
    nq = 0                      # 访问的四元组计数（仅统计“候选”）

    # mate 与其计数
    mate = [[] for _ in range(N)]  # 每个 c 的“配偶”列表
    # acceptable / valuable 标记
    acc_bool = [0]*N
    acc = [0]*N
    val = [0]*N

    # 上界
    h_UB = h

    # 近似停止条件
    approximation_factor = min(approximation_factor, float(D))
    additive_gap = min(additive_gap, float(D))

    # 从最大距离对开始（因此按升序数组逆序扫描）
    for idx in range(nb_p - 1, -1, -1):
        a, b = sorted_pairs[idx]

        # “无损一般性令 a 的 farness >= b 的 farness”，与原代码一致（注意那里交换判断与注释的口径）
        if farness[a] < farness[b]:
            a, b = b, a

        dist_a = distances[a]
        dist_b = distances[b]
        # 当前对 (a,b) 引出的上界
        h_UB = distances[a][b]

        # 无进一步改进
        if h_UB <= h:
            h_UB = h
            break

        # 近似提前终止
        if (h_UB <= h * approximation_factor) or (h_UB - h <= additive_gap):
            break

        # 维护 mate 列表
        mate[a].append(b)
        mate[b].append(a)

        # 计算 acceptable / valuable
        n_acc = 0
        n_val = 0
        hplusone = h + 1
        condacc = 3 * hplusone - 2 * h_UB

        order_c = mates_decr_order_value[a]
        for c in order_c:
            if len(mate[c]) == 0:
                # 与原实现不同的是：它一旦遇到 (ecc[c] - dist_a[c]) 不满足 condacc/2 就 break。
                # 我们严格跟随该逻辑：
                #   若 2*(ecc[c]-dist_a[c]) < condacc 则提前结束（后续更小，无需再看）
                if 2 * (ecc[c] - dist_a[c]) < condacc:
                    break
                else:
                    # 没有 mate[c]，也不可能 acceptable；直接看下一个 c
                    continue

            if 2 * (ecc[c] - dist_a[c]) >= condacc and 2 * (ecc[c] - dist_b[c]) >= condacc:
                if 2 * dist_a[c] >= hplusone and 2 * dist_b[c] >= hplusone:
                    if 2 * ecc[c] >= 2 * hplusone - h_UB + dist_a[c] + dist_b[c]:
                        acc_bool[c] = 1
                        acc[n_acc] = c
                        n_acc += 1
                        if 2 * dist_central[c] + h_UB - h > dist_a[c] + dist_b[c]:
                            val[n_val] = c
                            n_val += 1
            else:
                # 与 Cython 一致：当 2*(ecc[c]-dist_a[c]) < condacc 时，直接 break
                if 2 * (ecc[c] - dist_a[c]) < condacc:
                    break

        # 对于每个 valuable c 与其 mate(c) 的 d，若 d acceptable，则尝试更新下界
        for i in range(n_val):
            c = val[i]
            for d in mate[c]:
                if acc_bool[d]:
                    nq += 1
                    S1 = h_UB + distances[c][d]
                    S2 = dist_a[c] + dist_b[d]
                    S3 = dist_a[d] + dist_b[c]
                    hh = S1 - max(S2, S3)
                    if hh > h or not certificate:
                        # 若 hh==0，确保 a,b,c,d 四个互异（与原代码一致的边界处理）
                        if hh > 0 or (a != c and a != d and b != c and b != d):
                            h = hh
                            certificate = [a, b, c, d]
                            if verbose:
                                print("New lower bound:", h/2.0)

        # reset acceptable 标志
        for t in range(n_acc):
            acc_bool[acc[t]] = 0

    if verbose:
        print("Visited 4-tuples:", nq)

    if not certificate:
        return -1, [], h_UB
    return h, certificate, h_UB


###############################################################################
# Public API (main entry)
###############################################################################

def hyperbolicity_func(G: nx.Graph,
                  algorithm: str = 'BCCM',
                  approximation_factor: Optional[float] = None,
                  additive_gap: Optional[float] = None,
                  verbose: bool = False
                  ) -> Tuple[float, List, float]:
    """
    Compute graph hyperbolicity (Gromov 4-point definition) using BCCM (2015).

    Returns:
        (delta, certificate_nodes, delta_UB)
        - delta, delta_UB 为半整数（以 0.5 步长），对应原实现中的 h/2, h_UB/2
        - certificate_nodes 为原图中达到下界的 4 个顶点列表；若无则 []
    """
    if algorithm == "CCL+":
        algorithm = "CCL+FA"

    if not isinstance(G, (nx.Graph, nx.DiGraph)):
        raise ValueError("the input parameter must be a Graph")

    if algorithm != 'BCCM':
        raise ValueError(f"algorithm '{algorithm}' not yet implemented in this Python port (only 'BCCM' is supported)")

    if approximation_factor is None:
        approximation_factor = 1.0
    elif approximation_factor < 1.0:
        raise ValueError("the approximation factor must be >= 1.0")

    if additive_gap is None:
        additive_gap = 0.0
    elif additive_gap < 0.0:
        raise ValueError("the additive gap must be a real positive number")

    if not nx.is_connected(G):
        raise ValueError("the input Graph must be connected")

    n = G.number_of_nodes()
    m = G.number_of_edges()

    # 快速通道
    nodes = list(G.nodes())
    if n <= 3:
        return None, nodes[:min(4, n)], None
    if m == n - 1:  # tree
        return 0.0, nodes[:4], 0.0
    if _is_complete_graph(G):
        return 0.0, nodes[:4], 0.0

    # 多块：逐块求最大
    blocks = list(nx.biconnected_components(G))
    if len(blocks) > 1:
        if verbose:
            sizes = [len(b) for b in blocks]
            print(f"Graph with {len(blocks)} blocks")
            hist = {}
            for s in sizes:
                hist[s] = hist.get(s, 0) + 1
            print("Blocks size distribution:", hist)

        best_h = 0
        best_cert_nodes: List = []
        best_ub = 0
        for V in blocks:
            if len(V) <= max(3, 2 * best_h):
                continue
            H = _my_subgraph(G, V, relabel=False)
            delta, cert_nodes, ub = hyperbolicity_func(
                H, algorithm='BCCM',
                approximation_factor=approximation_factor,
                additive_gap=additive_gap,
                verbose=verbose
            )
            if delta*2 > best_h or (math.isclose(delta*2, best_h) and not best_cert_nodes):
                # best_h = int(round(delta*2))
                best_h = delta*2.0
                best_cert_nodes = cert_nodes
            # best_ub = max(best_ub, int(round(ub*2)))
            best_ub = max(best_ub, ub*2.0)
        return best_h/2.0, best_cert_nodes, best_ub/2.0

    # 到这里图是 2-连通、非完全、n>=4
    int_to_vertex = nodes[:]                 # 索引→节点
    # 距离 & far_apart & 直径
    distances, far_apart, D = _distances_and_far_apart_pairs(G, int_to_vertex)

    # 进入 BCCM 主过程
    h, cert_idx, h_UB = _hyperbolicity_BCCM_core(
        N=n,
        distances=distances,
        far_apart_pairs=far_apart,  # BCCM 使用 far-apart 版本
        D=D,
        h_LB=0,
        approximation_factor=float(approximation_factor),
        additive_gap=float(additive_gap)*2.0,  # 与原版保持同尺度（h 是 2*delta）
        verbose=verbose
    )

    if not cert_idx:
        return 0.0 if h < 0 else (h/2.0), [], h_UB/2.0
    certificate_nodes = [int_to_vertex[i] for i in cert_idx]
    return h/2.0, certificate_nodes, h_UB/2.0


###############################################################################
# (可选) 简单校验：与一些已知图的超曲率一致性
###############################################################################
if __name__ == "__main__":
    # 1) Grid 3x3：常见小图
    G = nx.grid_2d_graph(2, 10)
    d, cert, ub = hyperbolicity_func(G, algorithm='BCCM', verbose=False)
    print("Grid(3x3) delta:", d, "UB:", ub, "cert:", cert)

    # 2) Tree：应为 0
    T = nx.balanced_tree(r=2, h=3)
    d, cert, ub = hyperbolicity_func(T, algorithm='BCCM')
    print("Tree delta:", d, "UB:", ub, "cert:", cert)

    # 3) Clique：应为 0
    K = nx.cycle_graph(10)
    d, cert, ub = hyperbolicity_func(K, algorithm='BCCM')
    print("K6 delta:", d, "UB:", ub, "cert:", cert)
