# graph_utils.py  ── a single-responsibility helper
import numpy as np
from sklearn.neighbors import BallTree, NearestNeighbors

import matplotlib.pyplot as plt
import networkx as nx
import numpy as np


def build_radius_graph(
    X: np.ndarray,
    r: float | None = None,
    k: int | None = None,
    metric: str = "euclidean",
    leaf_size: int = 40,
) -> np.ndarray:
    """
    Return an (m, 2) int array of undirected edges built with a Ball-/KD-tree.

    ── Parameters ────────────────────────────────────────────────────────────
    X        : shape (n, d) data matrix
    r, k     : *exactly one* must be supplied
    metric   : any metric accepted by scikit-learn trees
    leaf_size: pass-through to BallTree / KDTree

    ── Returns ───────────────────────────────────────────────────────────────
    edges    : each row (i, j) with i < j
    """
    if (r is None) == (k is None):
        raise ValueError("Specify **either** r or k, not both.")

    nn = NearestNeighbors(
        radius=r, n_neighbors=None if r is not None else k + 1,
        metric=metric, leaf_size=leaf_size, algorithm="ball_tree",
    ).fit(X)

    if r is not None:
        neigh_ind = nn.radius_neighbors(X, return_distance=False)
    else:
        neigh_ind = nn.kneighbors(X, return_distance=False)

    rows, cols = [], []
    for i, neigh in enumerate(neigh_ind):
        neigh = neigh[neigh != i]
        rows.append(np.full(neigh.size, i, np.int32))
        cols.append(neigh.astype(np.int32))
    rows = np.concatenate(rows)
    cols = np.concatenate(cols)
    mask = rows < cols
    return np.stack((rows[mask], cols[mask]), axis=1)


# ---------------------------------------------------------------------------
#  Connectivity-based radius heuristic
# ---------------------------------------------------------------------------
import numpy as np
import networkx as nx
from sklearn.neighbors import NearestNeighbors

# ---------------------------------------------------------------------------
#  Robust search for the smallest radius that yields a connected graph
# ---------------------------------------------------------------------------
import numpy as np
import networkx as nx
from sklearn.neighbors import NearestNeighbors
from typing import Tuple

def smallest_connected_radius(
    X: np.ndarray,
    *,
    metric: str = "euclidean",
    growth: float = 1.5,          # multiplicative step (>1)
    max_expand: int = 15,         # how many growth steps at most
    refine_steps: int = 8,        # binary-search iterations
    verbose: bool = False,
) -> Tuple[float, np.ndarray]:
    """
    Returns
    -------
    r_opt   : minimal radius (up to binary-search precision) that makes the
              radius-graph connected.
    edges   : edge list at r_opt  (so you don’t have to rebuild it)
    """
    n = X.shape[0]
    # --- 1. a reasonable seed radius --------------------------------------
    dists_2nn = (
        NearestNeighbors(n_neighbors=2, metric=metric)
        .fit(X)
        .kneighbors(return_distance=True)[0][:, 1]
    )
    r_lo = np.median(dists_2nn)        # lower bound (disconnected)
    r_hi = r_lo                        # current test radius

    # --- 2. EXPONENTIAL GROWTH until we hit connectivity ------------------
    for _ in range(max_expand):
        edges = build_radius_graph(X, r=r_hi, metric=metric)
        G = nx.Graph()
        G.add_nodes_from(range(n))
        G.add_edges_from(map(tuple, edges))
        if nx.is_connected(G):
            break                      # found an upper bound
        r_lo, r_hi = r_hi, r_hi * growth
    else:
        # never connected – fall back to last tried radius
        if verbose:
            print("Graph never connected; using last radius.")
        return r_hi, edges             # edges correspond to r_hi already

    if verbose:
        print(f"Connected at r = {r_hi:.4g}; refining …")

    # --- 3. BINARY REFINEMENT between r_lo (disconnected) and r_hi (conn.)-
    for _ in range(refine_steps):
        r_mid = 0.5 * (r_lo + r_hi)
        edges_mid = build_radius_graph(X, r=r_mid, metric=metric)
        G_mid = nx.Graph()
        G_mid.add_nodes_from(range(n))
        G_mid.add_edges_from(map(tuple, edges_mid))
        if nx.is_connected(G_mid):
            r_hi, edges = r_mid, edges_mid   # keep upper bound
        else:
            r_lo = r_mid                    # still disconnected

    if verbose:
        print(f"✓ Smallest connected radius ≈ {r_hi:.4g}")

    return r_hi, edges

def plot_graph_spring(
    edges: np.ndarray,
    *,
    n_nodes: int | None = None,
    iterations: int = 50,
    k: float | None = None,
    seed: int | None = 0,
    node_size: int = 20,
    edge_alpha: float = 0.35,
    title: str | None = None,
):
    """
    Quick 2-D visualisation of a (possibly large) neighbour graph using
    Fruchterman-Reingold spring forces.

    Parameters
    ----------
    edges        (m, 2)  int array of undirected edges (i < j)
    n_nodes      total number of vertices (if None, inferred from edges)
    iterations   how long to run the force simulation
    k            optimal spring length (default: 1/√n)
    seed         deterministic layout if you like reproducibility
    node_size    size of scatter points
    edge_alpha   transparency of edge lines
    """
    if n_nodes is None:
        n_nodes = int(edges.max()) + 1

    G = nx.Graph()
    G.add_nodes_from(range(n_nodes))
    G.add_edges_from(map(tuple, edges))

    # --- spring layout ------------------------------------------------------
    pos = nx.spring_layout(G, k=k, iterations=iterations, seed=seed)

    # --- draw ---------------------------------------------------------------
    fig, ax = plt.subplots(figsize=(6, 6))
    nx.draw_networkx_edges(G, pos, ax=ax, width=.4, alpha=edge_alpha)
    ax.scatter(
        *np.array(list(pos.values())).T, s=node_size, c="k", zorder=10
    )
    ax.set_axis_off()
    if title:
        ax.set_title(title)
    plt.tight_layout()
    plt.show()