import json
import math
import os
from collections import deque
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
from scipy.spatial.distance import pdist

from pca_diffusion_plots_w_helpers import compute_pca_coords
from report import step_size


def _q_summary(x: np.ndarray) -> Dict[str, Any]:
    x = np.asarray(x, float)
    if x.size == 0:
        return dict(n=0, mean=None, std=None, q=None)
    qv = np.quantile(x, [0.0, 0.25, 0.5, 0.75, 1.0]).tolist()
    return dict(n=int(x.size), mean=float(x.mean()), std=float(x.std()), q=[float(t) for t in qv])


def _extract_stat_scalar(d: dict, path: tuple[str, ...]) -> float | None:
    cur = d
    for k in path:
        if cur is None or (not isinstance(cur, dict)) or (k not in cur):
            return None
        cur = cur[k]
    return None if cur is None else float(cur)


def _state_from_axis_index(idx: np.ndarray, p: int, g: int) -> np.ndarray:
    idx = np.asarray(idx, int)
    st = np.empty_like(idx, dtype=int)
    rot = (idx < p)
    ref = ~rot
    st[rot] = idx[rot] % g
    st[ref] = g + ((idx[ref] - p) % g)
    return st


def _cluster_id(a_state: np.ndarray, b_state: np.ndarray, g: int) -> np.ndarray:
    n = 2 * int(g)
    return np.asarray(a_state, int) * n + np.asarray(b_state, int)


def _decode_cluster(cid: int, g: int) -> Tuple[int, int]:
    n = 2 * int(g)
    return int(cid // n), int(cid % n)


def _resid(state: int, g: int) -> int:
    state = int(state)
    return (state - g) if state >= g else state


def _is_ref_state(state: int, g: int) -> bool:
    return int(state) >= int(g)


def _enc_dec_for_g(g: int):
    n = 2 * int(g)

    def enc(a: int, b: int) -> int:
        return int(a * n + b)

    def dec(u: int) -> tuple[int, int]:
        return int(u // n), int(u % n)

    return n, enc, dec


def phi_from_model(g: int, k: int, model: str) -> np.ndarray:
    g = int(g)
    k = int(k) % g
    x = np.arange(g, dtype=int)
    if model == "k_minus_x":
        return (k - x) % g
    if model == "shift":
        return (x + k) % g
    raise ValueError("model must be 'k_minus_x' or 'shift'")


def invert_phi(phi: np.ndarray) -> np.ndarray:
    phi = np.asarray(phi, int)
    g = phi.size
    inv = -np.ones(g, dtype=int)
    for x in range(g):
        y = int(phi[x])
        if not (0 <= y < g):
            raise ValueError("phi not in [0,g)")
        if inv[y] != -1:
            raise ValueError("phi not bijection (collision)")
        inv[y] = x
    if np.any(inv < 0):
        raise ValueError("phi not bijection (missing values)")
    return inv


def _try_invert_phi(phi: np.ndarray) -> tuple[Optional[np.ndarray], Optional[str]]:
    try:
        inv = invert_phi(phi)
        return inv, None
    except Exception as e:
        return None, f"{type(e).__name__}: {e}"


def verify_phi_equivariance(phi: np.ndarray, g: int, d: int) -> Dict[str, Any]:
    phi = np.asarray(phi, int)
    g = int(g)
    d = int(d) % g

    if g <= 0:
        raise ValueError("g must be positive")
    if phi.size != g:
        raise ValueError("phi must have length g")

    x = np.arange(g, dtype=int)
    xd = (x + d) % g
    lhs = phi[xd]

    def score(sign: int):
        rhs = (phi[x] + sign * d) % g
        ok = (lhs == rhs)
        cons = float(np.mean(ok))
        if sign == +1:
            k_vals = (phi[x] - x) % g
        else:
            k_vals = (phi[x] + x) % g
        vals, cnts = np.unique(k_vals, return_counts=True)
        k_hat = int(vals[np.argmax(cnts)])
        k_cons = float(np.max(cnts) / g)
        return cons, k_hat, k_cons

    cons_p, k_p, kcons_p = score(+1)
    cons_m, k_m, kcons_m = score(-1)

    if cons_p > cons_m or (abs(cons_p - cons_m) < 1e-12 and kcons_p >= kcons_m):
        return dict(best_sign=+1, equivariance_consistency=cons_p, k_hat=k_p, k_consistency=kcons_p)
    else:
        return dict(best_sign=-1, equivariance_consistency=cons_m, k_hat=k_m, k_consistency=kcons_m)


def mean_pairwise_distance(X: np.ndarray) -> float:
    X = np.asarray(X, float)
    if X is None or X.size == 0 or X.shape[0] < 2:
        return 0.0
    return float(pdist(X, metric="euclidean").mean())


def global_mean_distance(
    X: np.ndarray,
    *,
    max_points: int = 3000,
    seed: int = 0,
) -> float:
    X = np.asarray(X, float)
    n = int(X.shape[0])
    if n < 2:
        return 0.0
    if n <= int(max_points):
        return mean_pairwise_distance(X)

    rng = np.random.default_rng(int(seed))
    idx = rng.choice(n, size=int(max_points), replace=False)
    return mean_pairwise_distance(X[idx])


def miniball_badoiu_clarkson(
    P: np.ndarray, eps: float = 1e-2, max_iter: int | None = None, seed: int = 0
):
    P = np.asarray(P, float)
    n, d = P.shape
    if n == 0:
        return np.zeros(d), 0.0
    if n == 1:
        return P[0].copy(), 0.0

    rng = np.random.default_rng(seed)
    c = P[rng.integers(0, n)].copy()

    if max_iter is None:
        max_iter = int(np.ceil(1.0 / (eps * eps)))

    for t in range(1, max_iter + 1):
        dif = P - c[None, :]
        dist2 = np.einsum("ij,ij->i", dif, dif)
        i = int(np.argmax(dist2))
        p_far = P[i]
        c = c + (1.0 / (t + 1.0)) * (p_far - c)

    r = float(np.max(np.linalg.norm(P - c[None, :], axis=1)))
    return c, r


def cluster_centers_radii_multiq(
    Z: np.ndarray,
    cluster_ids: np.ndarray,
    qs=(0.95, 0.98, 1.0),
    *,
    center_mode: str = "miniball",
    ball_eps: float = 1e-2,
    ball_max_iter: int | None = None,
    ball_seed: int = 0,
) -> Dict[str, Any]:
    Z = np.asarray(Z, float)
    cluster_ids = np.asarray(cluster_ids, int)
    uniq = np.unique(cluster_ids)

    centers: dict[int, np.ndarray] = {}
    r_q: dict[float, dict[int, float]] = {float(q): {} for q in qs}
    sizes: dict[int, int] = {}

    for cid in uniq:
        idx = np.where(cluster_ids == cid)[0]
        P = Z[idx]
        sizes[int(cid)] = int(P.shape[0])

        if P.shape[0] == 0:
            centers[int(cid)] = np.zeros(Z.shape[1], dtype=float)
            for q in qs:
                r_q[float(q)][int(cid)] = 0.0
            continue

        if P.shape[0] == 1:
            c = P[0].copy()
            dists = np.zeros(1, dtype=float)
            r_full = 0.0
        else:
            if center_mode == "miniball":
                c, _r_full = miniball_badoiu_clarkson(
                    P, eps=float(ball_eps), max_iter=ball_max_iter, seed=int(ball_seed)
                )
                c = np.asarray(c, float)
                dists = np.linalg.norm(P - c[None, :], axis=1)
                r_full = float(np.max(dists))
            elif center_mode == "mean":
                c = P.mean(axis=0)
                dists = np.linalg.norm(P - c[None, :], axis=1)
                r_full = float(np.max(dists))
            else:
                raise ValueError("center_mode must be 'miniball' or 'mean'")

        centers[int(cid)] = c

        for q in qs:
            qq = float(q)
            if dists.size == 0:
                r_q[qq][int(cid)] = 0.0
            elif abs(qq - 1.0) < 1e-12:
                r_q[qq][int(cid)] = float(r_full)
            else:
                r_q[qq][int(cid)] = float(np.quantile(dists, qq))

    return dict(
        centers=centers,
        r_q=r_q,
        sizes=sizes,
        present=set(map(int, uniq.tolist())),
        n_present=int(len(uniq)),
        qs=[float(q) for q in qs],
        center_mode=str(center_mode),
        ball_eps=(float(ball_eps) if center_mode == "miniball" else None),
        ball_max_iter=(int(ball_max_iter) if (ball_max_iter is not None and center_mode == "miniball") else None),
        ball_seed=(int(ball_seed) if center_mode == "miniball" else None),
    )


def edge_stats_on_centers(
    centers: Dict[int, np.ndarray],
    r_q: Dict[int, float],
    edges: List[Tuple[int, int]],
    *,
    eps: float = 1e-12,
    sep_thresholds=(5.0, 10.0),
    normalize_by: float | None = None,
    eps_rel: float = 0.0,
) -> Dict[str, Any]:
    dists = []
    margins = []
    seps = []
    rel_margins = []
    used = 0

    denom_floor = float(eps)
    if normalize_by is not None and float(normalize_by) > 0 and float(eps_rel) > 0:
        denom_floor = max(denom_floor, float(eps_rel) * float(normalize_by))

    for u, v in edges:
        if u not in centers or v not in centers:
            continue

        duv = float(np.linalg.norm(centers[u] - centers[v]))
        rsum = float(r_q.get(u, 0.0) + r_q.get(v, 0.0))

        margin = duv - rsum
        denom = rsum + denom_floor

        sep = duv / denom
        rel_margin = margin / denom

        dists.append(duv)
        margins.append(margin)
        seps.append(sep)
        rel_margins.append(rel_margin)
        used += 1

    dists = np.asarray(dists, float)
    margins = np.asarray(margins, float)
    seps = np.asarray(seps, float)
    rel_margins = np.asarray(rel_margins, float)

    out: Dict[str, Any] = dict(
        n_edges_input=int(len(edges)),
        n_edges_used=int(used),
        dist=_q_summary(dists),
        margin_q=_q_summary(margins),
        frac_margin_q_pos=(float(np.mean(margins > 0)) if margins.size else None),
        min_margin_q=(float(np.min(margins)) if margins.size else None),
        sep=_q_summary(seps),
        min_sep=(float(np.min(seps)) if seps.size else None),
        rel_margin=_q_summary(rel_margins),
        min_rel_margin=(float(np.min(rel_margins)) if rel_margins.size else None),
    )

    for T in sep_thresholds:
        out[f"frac_sep_ge_{T:g}"] = (float(np.mean(seps >= T)) if seps.size else None)

    if normalize_by is not None and float(normalize_by) > 0 and dists.size:
        out["dist_norm_by_global_mean"] = _q_summary(dists / float(normalize_by))
        out["margin_norm_by_global_mean"] = _q_summary(margins / float(normalize_by))
    else:
        out["dist_norm_by_global_mean"] = None
        out["margin_norm_by_global_mean"] = None

    return out


def build_cluster_index_map(cluster_ids: np.ndarray) -> dict[int, np.ndarray]:
    cluster_ids = np.asarray(cluster_ids, int)
    idx_map = {}
    for cid in np.unique(cluster_ids):
        idx_map[int(cid)] = np.where(cluster_ids == cid)[0]
    return idx_map


def edge_stats_pairwise_between_clusters(
    Z: np.ndarray,
    idx_map: dict[int, np.ndarray],
    edges: list[tuple[int, int]],
    normalize_by: float | None = None,
) -> dict:
    dists = []
    used = 0
    for u, v in edges:
        iu = idx_map.get(int(u), None)
        iv = idx_map.get(int(v), None)
        if iu is None or iv is None or iu.size == 0 or iv.size == 0:
            continue
        A = Z[iu]
        B = Z[iv]
        D = np.linalg.norm(A[:, None, :] - B[None, :, :], axis=-1)
        d = float(D.mean())
        dists.append(d)
        used += 1

    dists = np.asarray(dists, float)
    out = {
        "n_edges_input": int(len(edges)),
        "n_edges_used": int(used),
        "dist": _q_summary(dists),
    }
    if normalize_by is not None and float(normalize_by) > 0 and dists.size:
        out["dist_norm_by_global_mean"] = _q_summary(dists / float(normalize_by))
    else:
        out["dist_norm_by_global_mean"] = None
    return out


def build_ref_edges_from_phi_axis(g: int, phi: np.ndarray, axis: str) -> List[Tuple[int, int]]:
    g = int(g)
    n = 2 * g
    Nq = n * n

    def enc(a, b):
        return int(a * n + b)

    phi = np.asarray(phi, int)
    edges = []
    for cid in range(Nq):
        a, b = _decode_cluster(cid, g)
        if axis == "a":
            if not _is_ref_state(a, g):
                x = _resid(a, g)
                a2 = g + int(phi[x])
                v = enc(a2, b)
                u = enc(a, b)
                if u != v:
                    edges.append((min(u, v), max(u, v)))
        elif axis == "b":
            if not _is_ref_state(b, g):
                x = _resid(b, g)
                b2 = g + int(phi[x])
                v = enc(a, b2)
                u = enc(a, b)
                if u != v:
                    edges.append((min(u, v), max(u, v)))
        else:
            raise ValueError("axis must be 'a' or 'b'")

    return sorted(set(edges))


def build_rot_step_edges(g: int, d: int, axis: str) -> List[Tuple[int, int]]:
    g = int(g)
    d = int(d) % g
    n = 2 * g
    Nq = n * n

    def enc(a, b):
        return int(a * n + b)

    edges = []
    for cid in range(Nq):
        a, b = _decode_cluster(cid, g)
        if axis == "a":
            t_ref = _is_ref_state(a, g)
            x = _resid(a, g)
            x2 = (x + d) % g
            a2 = (g + x2) if t_ref else x2
            v = enc(a2, b)
            edges.append((cid, v))
        elif axis == "b":
            t_ref = _is_ref_state(b, g)
            x = _resid(b, g)
            x2 = (x + d) % g
            b2 = (g + x2) if t_ref else x2
            v = enc(a, b2)
            edges.append((cid, v))
        else:
            raise ValueError("axis must be 'a' or 'b'")
    return edges


def nested_coverage_2d(g: int, d: int, phi_a: np.ndarray, phi_b: np.ndarray) -> Dict[str, Any]:
    g = int(g)
    d = int(d) % g
    n = 2 * g

    L = g // math.gcd(g, d) if d != 0 else 1
    a_orbit = [(t * d) % g for t in range(L)]
    b_orbit = [(t * d) % g for t in range(L)]

    phi_a = np.asarray(phi_a, int)
    phi_b = np.asarray(phi_b, int)

    visited = set()

    def enc(a_state, b_state):
        return int(a_state * n + b_state)

    for ar in a_orbit:
        for br in b_orbit:
            a_rot = ar
            b_rot = br
            a_ref = g + int(phi_a[ar])
            b_ref = g + int(phi_b[br])

            visited.add(enc(a_rot, b_rot))
            visited.add(enc(a_ref, b_rot))
            visited.add(enc(a_rot, b_ref))
            visited.add(enc(a_ref, b_ref))

    return dict(
        g=int(g),
        d=int(d),
        orbit_len=int(L),
        visited=int(len(visited)),
        total=int((2 * g) * (2 * g)),
        visited_frac=float(len(visited) / ((2 * g) * (2 * g))),
    )


def bfs_coverage_quotient(
    g: int,
    d: int,
    phi_a: np.ndarray,
    phi_b: np.ndarray,
    start_a_state: int = 0,
    start_b_state: int = 0,
) -> Dict[str, Any]:
    g = int(g)
    d = int(d) % g
    n = 2 * g
    Nq = n * n

    phi_a = np.asarray(phi_a, int)
    phi_b = np.asarray(phi_b, int)
    inv_a = invert_phi(phi_a)
    inv_b = invert_phi(phi_b)

    def enc(a, b):
        return int(a * n + b)

    def dec(u):
        return int(u // n), int(u % n)

    def rot_step_state(s, sign):
        t_ref = _is_ref_state(s, g)
        x = _resid(s, g)
        x2 = (x + sign * d) % g
        return (g + x2) if t_ref else x2

    def ref_step_a(a_state):
        if not _is_ref_state(a_state, g):
            x = _resid(a_state, g)
            return g + int(phi_a[x])
        else:
            y = _resid(a_state, g)
            return int(inv_a[y])

    def ref_step_b(b_state):
        if not _is_ref_state(b_state, g):
            x = _resid(b_state, g)
            return g + int(phi_b[x])
        else:
            y = _resid(b_state, g)
            return int(inv_b[y])

    s0 = enc(int(start_a_state), int(start_b_state))
    dist = -np.ones(Nq, dtype=int)
    dist[s0] = 0
    q = deque([s0])

    while q:
        u = q.popleft()
        a, b = dec(u)

        neigh = [
            enc(rot_step_state(a, +1), b),
            enc(rot_step_state(a, -1), b),
            enc(a, rot_step_state(b, +1)),
            enc(a, rot_step_state(b, -1)),
            enc(ref_step_a(a), b),
            enc(a, ref_step_b(b)),
        ]
        for v in neigh:
            if dist[v] < 0:
                dist[v] = dist[u] + 1
                q.append(v)

    visited = (dist >= 0)
    return dict(
        g=int(g),
        d=int(d),
        start=dict(a=int(start_a_state), b=int(start_b_state), node=int(s0)),
        visited=int(visited.sum()),
        total=int(Nq),
        visited_frac=float(np.mean(visited)),
        max_graph_dist=int(dist.max()),
    )


def bfs_coverage_rot_only(
    g: int,
    d: int,
    *,
    start_a_state: int = 0,
    start_b_state: int = 0,
) -> Dict[str, Any]:
    g = int(g)
    d = int(d) % g
    n = 2 * g
    Nq = n * n

    def enc(a, b):
        return int(a * n + b)

    def dec(u):
        return int(u // n), int(u % n)

    def rot_step_state(s: int, sign: int) -> int:
        t_ref = _is_ref_state(s, g)
        x = _resid(s, g)
        x2 = (x + sign * d) % g
        return (g + x2) if t_ref else x2

    s0 = enc(int(start_a_state), int(start_b_state))
    dist = -np.ones(Nq, dtype=int)
    dist[s0] = 0
    q = deque([s0])

    while q:
        u = q.popleft()
        a, b = dec(u)
        neigh = [
            enc(rot_step_state(a, +1), b),
            enc(rot_step_state(a, -1), b),
            enc(a, rot_step_state(b, +1)),
            enc(a, rot_step_state(b, -1)),
        ]
        for v in neigh:
            if dist[v] < 0:
                dist[v] = dist[u] + 1
                q.append(v)

    visited = (dist >= 0)
    visited_cnt = int(visited.sum())

    g_gcd = math.gcd(g, d) if d != 0 else g
    orbit_len = (g // g_gcd) if d != 0 else 1
    expected_in_layer = int(orbit_len * orbit_len)
    expected_frac_total = float(expected_in_layer / (4.0 * g * g))

    layer_counts = {"fa0_fb0": 0, "fa0_fb1": 0, "fa1_fb0": 0, "fa1_fb1": 0}
    for u in np.flatnonzero(visited).tolist():
        a, b = dec(int(u))
        fa = 1 if _is_ref_state(a, g) else 0
        fb = 1 if _is_ref_state(b, g) else 0
        layer_counts[f"fa{fa}_fb{fb}"] += 1

    return dict(
        g=int(g),
        d=int(d),
        start=dict(a=int(start_a_state), b=int(start_b_state), node=int(s0)),
        visited=visited_cnt,
        total=int(Nq),
        visited_frac=float(visited_cnt / Nq),
        max_graph_dist=int(dist.max()),
        gcd_g_d=int(g_gcd),
        orbit_len=int(orbit_len),
        expected_visited_in_one_layer=int(expected_in_layer),
        expected_visited_frac_total=float(expected_frac_total),
        visited_layer_counts=layer_counts,
    )


def _rot_step_state_1d(s: int, g: int, d: int, sign: int) -> int:
    t_ref = _is_ref_state(s, g)
    x = _resid(s, g)
    x2 = (x + sign * d) % g
    return (g + x2) if t_ref else x2


def _ref_toggle_state_1d(s: int, g: int, phi: np.ndarray, inv_phi: np.ndarray) -> int:
    if not _is_ref_state(s, g):
        x = _resid(s, g)
        return g + int(phi[x])
    else:
        y = _resid(s, g)
        return int(inv_phi[y])


def relation_involution_stats(
    *,
    g: int,
    phi: np.ndarray,
    axis: str,
    present: set[int],
) -> dict:
    inv, err = _try_invert_phi(phi)
    out = {
        "axis": axis,
        "phi_is_bijection": (inv is not None),
        "phi_invert_error": err,
        "n_nodes_present": int(len(present)),
        "n_used": 0,
        "closure_rate_state": None,
    }
    if inv is None or len(present) == 0:
        return out

    _, enc, dec = _enc_dec_for_g(g)

    def S(u: int) -> int:
        a, b = dec(u)
        if axis == "a":
            return enc(_ref_toggle_state_1d(a, g, phi, inv), b)
        elif axis == "b":
            return enc(a, _ref_toggle_state_1d(b, g, phi, inv))
        else:
            raise ValueError("axis must be 'a' or 'b'")

    ok = 0
    used = 0
    for u in present:
        u2 = S(S(int(u)))
        ok += int(u2 == int(u))
        used += 1

    out["n_used"] = int(used)
    out["closure_rate_state"] = float(ok / used) if used else None
    return out


def relation_conjugacy_stats(
    *,
    g: int,
    d: int,
    phi: np.ndarray,
    axis: str,
    sigma: int,
    present: set[int],
    centers: dict[int, np.ndarray],
    rmap: dict[int, float],
    eps: float = 1e-12,
    normalize_by: float | None = None,
    eps_rel: float = 0.0,
) -> dict:
    inv, err = _try_invert_phi(phi)
    out = {
        "axis": axis,
        "sigma": int(sigma),
        "phi_is_bijection": (inv is not None),
        "phi_invert_error": err,
        "n_nodes_present": int(len(present)),
        "n_state_used": 0,
        "closure_rate_state": None,
        "n_geom_used": 0,
        "geom_delta_all": None,
        "geom_delta_fail_only": None,
        "frac_geom_available": None,
        "frac_state_fail_with_geom": None,
    }
    if inv is None or len(present) == 0:
        return out

    denom_floor = float(eps)
    if normalize_by is not None and float(normalize_by) > 0 and float(eps_rel) > 0:
        denom_floor = max(denom_floor, float(eps_rel) * float(normalize_by))

    _, enc, dec = _enc_dec_for_g(g)
    d = int(d) % g
    sigma = +1 if int(sigma) >= 0 else -1

    def R(u: int, sign: int) -> int:
        a, b = dec(u)
        if axis == "a":
            return enc(_rot_step_state_1d(a, g, d, sign), b)
        elif axis == "b":
            return enc(a, _rot_step_state_1d(b, g, d, sign))
        else:
            raise ValueError("axis must be 'a' or 'b'")

    def S(u: int) -> int:
        a, b = dec(u)
        if axis == "a":
            return enc(_ref_toggle_state_1d(a, g, phi, inv), b)
        elif axis == "b":
            return enc(a, _ref_toggle_state_1d(b, g, phi, inv))
        else:
            raise ValueError("axis must be 'a' or 'b'")

    ok = 0
    used_state = 0

    deltas_all = []
    deltas_fail = []
    geom_used = 0
    geom_fail_used = 0

    for u in present:
        u = int(u)
        lhs = S(R(u, +1))
        rhs = R(S(u), sigma)

        used_state += 1
        is_ok = (lhs == rhs)
        ok += int(is_ok)

        if (lhs in centers) and (rhs in centers):
            geom_used += 1
            duv = float(np.linalg.norm(centers[lhs] - centers[rhs]))
            rsum = float(rmap.get(lhs, 0.0) + rmap.get(rhs, 0.0))
            denom = rsum + denom_floor
            delta = duv / denom
            deltas_all.append(delta)
            if not is_ok:
                deltas_fail.append(delta)
                geom_fail_used += 1

    out["n_state_used"] = int(used_state)
    out["closure_rate_state"] = float(ok / used_state) if used_state else None
    out["n_geom_used"] = int(geom_used)
    out["geom_delta_all"] = _q_summary(np.asarray(deltas_all, float)) if geom_used else None
    out["geom_delta_fail_only"] = _q_summary(np.asarray(deltas_fail, float)) if geom_fail_used else None
    out["frac_geom_available"] = float(geom_used / used_state) if used_state else None
    out["frac_state_fail_with_geom"] = float(geom_fail_used / max(1, (used_state - ok))) if used_state else None
    return out


def relation_commutation_stats(
    *,
    name: str,
    present: set[int],
    op1,
    op2,
    centers: dict[int, np.ndarray],
    rmap: dict[int, float],
    eps: float = 1e-12,
    normalize_by: float | None = None,
    eps_rel: float = 0.0,
) -> dict:
    out = {
        "name": str(name),
        "n_nodes_present": int(len(present)),
        "n_state_used": 0,
        "closure_rate_state": None,
        "n_geom_used": 0,
        "geom_delta_all": None,
        "geom_delta_fail_only": None,
        "frac_geom_available": None,
        "frac_state_fail_with_geom": None,
    }
    if len(present) == 0:
        return out

    denom_floor = float(eps)
    if normalize_by is not None and float(normalize_by) > 0 and float(eps_rel) > 0:
        denom_floor = max(denom_floor, float(eps_rel) * float(normalize_by))

    ok = 0
    used_state = 0

    deltas_all = []
    deltas_fail = []
    geom_used = 0
    geom_fail_used = 0

    for u in present:
        u = int(u)
        lhs = int(op1(op2(u)))
        rhs = int(op2(op1(u)))

        used_state += 1
        is_ok = (lhs == rhs)
        ok += int(is_ok)

        if (lhs in centers) and (rhs in centers):
            geom_used += 1
            duv = float(np.linalg.norm(centers[lhs] - centers[rhs]))
            rsum = float(rmap.get(lhs, 0.0) + rmap.get(rhs, 0.0))
            denom = rsum + denom_floor
            delta = duv / denom
            deltas_all.append(delta)
            if not is_ok:
                deltas_fail.append(delta)
                geom_fail_used += 1

    out["n_state_used"] = int(used_state)
    out["closure_rate_state"] = float(ok / used_state) if used_state else None
    out["n_geom_used"] = int(geom_used)
    out["geom_delta_all"] = _q_summary(np.asarray(deltas_all, float)) if geom_used else None
    out["geom_delta_fail_only"] = _q_summary(np.asarray(deltas_fail, float)) if geom_fail_used else None
    out["frac_geom_available"] = float(geom_used / used_state) if used_state else None
    out["frac_state_fail_with_geom"] = float(geom_fail_used / max(1, (used_state - ok))) if used_state else None
    return out


def null_phi_baseline_equivariance(
    *,
    g: int,
    d: int,
    n_samples: int = 20,
    seed: int = 0,
) -> dict:
    if g <= 3:
        return {
            "n_samples": int(n_samples),
            "equivariance_consistency": None,
            "frac_best_sign_pos": None,
        }
    rng = np.random.default_rng(int(seed))
    cons = []
    signs = []
    for _ in range(int(n_samples)):
        phi_r = rng.permutation(int(g)).astype(int)
        ver = verify_phi_equivariance(phi_r, g, d)
        cons.append(float(ver["equivariance_consistency"]))
        signs.append(int(ver["best_sign"]))
    cons = np.asarray(cons, float)
    return {
        "n_samples": int(n_samples),
        "equivariance_consistency": _q_summary(cons),
        "frac_best_sign_pos": float(np.mean(np.asarray(signs) == +1)) if signs else None,
    }


def wrap_pi(x: np.ndarray) -> np.ndarray:
    x = np.asarray(x, float)
    return (x + np.pi) % (2 * np.pi) - np.pi


def local_pca_2d(P: np.ndarray) -> np.ndarray:
    P = np.asarray(P, float)
    mu = P.mean(axis=0, keepdims=True)
    X = P - mu
    _, _, VT = np.linalg.svd(X, full_matrices=False)
    W = VT[:2].T
    return (X @ W)


def _rankdata(x: np.ndarray) -> np.ndarray:
    x = np.asarray(x)
    order = np.argsort(x)
    ranks = np.empty_like(order, dtype=float)
    ranks[order] = np.arange(len(x), dtype=float)
    xs = x[order]
    i = 0
    while i < len(x):
        j = i
        while j + 1 < len(x) and xs[j + 1] == xs[i]:
            j += 1
        if j > i:
            ranks[order[i : j + 1]] = 0.5 * (i + j)
        i = j + 1
    return ranks


def spearman_rho(a: np.ndarray, b: np.ndarray) -> float:
    ra = _rankdata(a)
    rb = _rankdata(b)
    ra = ra - ra.mean()
    rb = rb - rb.mean()
    denom = (np.linalg.norm(ra) * np.linalg.norm(rb))
    return float((ra @ rb) / denom) if denom > 0 else 0.0


def polygon_self_intersections(P2: np.ndarray) -> int:
    P2 = np.asarray(P2, float)
    n = P2.shape[0]
    if n < 4:
        return 0

    def seg_intersect(a, b, c, d):
        def orient(p, q, r):
            return np.sign((q[0] - p[0]) * (r[1] - p[1]) - (q[1] - p[1]) * (r[0] - p[0]))

        o1 = orient(a, b, c)
        o2 = orient(a, b, d)
        o3 = orient(c, d, a)
        o4 = orient(c, d, b)
        return (o1 * o2 < 0) and (o3 * o4 < 0)

    cnt = 0
    for i in range(n):
        a = P2[i]
        b = P2[(i + 1) % n]
        for j in range(i + 1, n):
            if j == i or (j + 1) % n == i or (i + 1) % n == j:
                continue
            c = P2[j]
            d = P2[(j + 1) % n]
            if seg_intersect(a, b, c, d):
                cnt += 1
    return cnt


def _local_pca_project_2d(Pk: np.ndarray) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    Pk = np.asarray(Pk, float)
    mu = Pk.mean(axis=0, keepdims=True)
    X = Pk - mu
    if X.shape[0] <= 1:
        k = X.shape[1] if X.ndim == 2 else 0
        return np.zeros((X.shape[0], 2), float), np.zeros((1,), float), np.zeros((k, 2), float), mu

    _, s, VT = np.linalg.svd(X, full_matrices=False)
    W = VT[:2].T
    P2 = X @ W
    return P2, s, W, mu


def _canonicalize_P2_with_T(P2: np.ndarray, eps: float = 1e-9) -> tuple[np.ndarray, np.ndarray]:
    P2 = np.asarray(P2, float)

    if P2.ndim == 1:
        P2 = P2.reshape(-1, 1)
    n = int(P2.shape[0])
    if P2.shape[1] == 1:
        P2 = np.concatenate([P2, np.zeros((n, 1), dtype=P2.dtype)], axis=1)

    T = np.eye(2, dtype=float)
    if n < 2:
        return P2, T

    v0 = None
    for i in range(n - 1):
        v = P2[i + 1] - P2[i]
        if float(np.linalg.norm(v)) > eps:
            v0 = v
            break
    if v0 is None:
        return P2, T

    if v0[0] < 0:
        T[0, 0] *= -1.0

    if n >= 3:
        for i in range(n - 2):
            a = P2[i + 1] - P2[i]
            b = P2[i + 2] - P2[i + 1]
            if float(np.linalg.norm(a)) <= eps or float(np.linalg.norm(b)) <= eps:
                continue
            cross_z = a[0] * b[1] - a[1] * b[0]
            if abs(float(cross_z)) > eps:
                if cross_z < 0:
                    T[1, 1] *= -1.0
                break

    return P2 @ T, T


def _canonicalize_P2(P2: np.ndarray, eps: float = 1e-9) -> np.ndarray:
    P2_c, _T = _canonicalize_P2_with_T(P2, eps=eps)
    return P2_c


def _align_P2_to_ref(P2: np.ndarray, W: np.ndarray, W_ref: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
    C = W.T @ W_ref
    U, _, Vt = np.linalg.svd(C)
    R = U @ Vt
    return P2 @ R, R


def _planarity_ratio_from_s(s: np.ndarray) -> float | None:
    s = np.asarray(s, float)
    if s.size == 0:
        return None
    ss = (s ** 2)
    denom = float(ss.sum())
    if denom <= 0:
        return None
    num = float(ss[: min(2, ss.size)].sum())
    return num / denom


def _safe_unwrap_angles(theta: np.ndarray) -> np.ndarray:
    theta = np.asarray(theta, float)
    if theta.size == 0:
        return theta
    return np.unwrap(theta)


def _rankdata_avg(x: np.ndarray) -> np.ndarray:
    x = np.asarray(x)
    order = np.argsort(x)
    ranks = np.empty_like(order, dtype=float)
    ranks[order] = np.arange(len(x), dtype=float)
    xs = x[order]
    i = 0
    while i < len(x):
        j = i
        while j + 1 < len(x) and xs[j + 1] == xs[i]:
            j += 1
        if j > i:
            ranks[order[i : j + 1]] = 0.5 * (i + j)
        i = j + 1
    return ranks


def _spearman_rho(a: np.ndarray, b: np.ndarray) -> float | None:
    a = np.asarray(a, float)
    b = np.asarray(b, float)
    if a.size == 0 or b.size == 0 or a.size != b.size:
        return None
    ra = _rankdata_avg(a)
    rb = _rankdata_avg(b)
    ra = ra - ra.mean()
    rb = rb - rb.mean()
    denom = float(np.linalg.norm(ra) * np.linalg.norm(rb))
    if denom <= 0:
        return None
    return float((ra @ rb) / denom)


def _polygon_self_intersections(P2: np.ndarray, eps: float = 1e-12) -> int:
    P2 = np.asarray(P2, float)
    n = int(P2.shape[0])
    if n < 4:
        return 0

    def orient(p, q, r) -> float:
        return (q[0] - p[0]) * (r[1] - p[1]) - (q[1] - p[1]) * (r[0] - p[0])

    def proper_intersect(a, b, c, d) -> bool:
        o1 = orient(a, b, c)
        o2 = orient(a, b, d)
        o3 = orient(c, d, a)
        o4 = orient(c, d, b)
        return (o1 * o2 < -eps) and (o3 * o4 < -eps)

    cnt = 0
    for i in range(n):
        a = P2[i]
        b = P2[(i + 1) % n]
        for j in range(i + 1, n):
            if j == i or (j + 1) % n == i or (i + 1) % n == j:
                continue
            c = P2[j]
            d = P2[(j + 1) % n]
            if proper_intersect(a, b, c, d):
                cnt += 1
    return cnt


def orbit_metrics_k_use(Pk: np.ndarray, *, W_ref: np.ndarray | None = None) -> dict:
    Pk = np.asarray(Pk, float)
    L = int(Pk.shape[0])
    if L == 0:
        return {"L": 0}

    nxt = (np.arange(L) + 1) % L
    dk = np.linalg.norm(Pk[nxt] - Pk, axis=1)
    dk_mean = float(dk.mean()) if dk.size else 0.0
    dk_cv = float(dk.std() / dk_mean) if dk_mean > 0 else None

    P2, s, W, _mu = _local_pca_project_2d(Pk)

    align_det = None
    if W_ref is None:
        P2 = _canonicalize_P2(P2)
    else:
        X = Pk - Pk.mean(axis=0, keepdims=True)
        P2 = X @ W_ref

    planarity = _planarity_ratio_from_s(s)

    d2 = np.linalg.norm(P2[nxt] - P2, axis=1)
    distortion = float(np.mean(np.abs(d2 - dk) / (dk + 1e-12))) if dk_mean > 0 else None

    if L >= 3:
        mu2 = P2.mean(axis=0)
        Q = P2 - mu2[None, :]
        theta = np.arctan2(Q[:, 1], Q[:, 0])
        dtheta = wrap_pi(np.roll(theta, -1) - theta)

        theta_cum = np.cumsum(np.r_[0.0, dtheta[:-1]])
        t = np.arange(L, dtype=float)
        rho = _spearman_rho(t, theta_cum)

        if dtheta.size:
            dir_sign = float(np.sign(np.sum(dtheta)))
            eps = 1e-8
            mask = np.abs(dtheta) > eps
            frac_same = float(np.mean(np.sign(dtheta[mask]) == dir_sign)) if (dir_sign != 0 and np.any(mask)) else None
        else:
            dir_sign, frac_same = None, None

        x, y = P2[:, 0], P2[:, 1]
        area = 0.5 * float(np.sum(x * np.roll(y, -1) - y * np.roll(x, -1)))
        inter = _polygon_self_intersections(P2)
    else:
        rho, dir_sign, frac_same, area, inter = None, None, None, None, None

    out = {
        "L": L,
        "step_len_mean_k": dk_mean,
        "step_len_cv_k": dk_cv,
        "spearman_rho_t_vs_angle": rho,
        "dir_sign": dir_sign,
        "frac_same_dir": frac_same,
        "self_intersections": inter,
        "signed_area": area,
        "planarity_ratio": planarity,
        "proj_step_len_rel_distortion": distortion,
    }
    if align_det is not None:
        out["align_det"] = align_det
    return out


def residue_orbits(g: int, d: int) -> list[list[int]]:
    g = int(g)
    d = int(d) % g
    seen = np.zeros(g, dtype=bool)
    out: list[list[int]] = []
    for x0 in range(g):
        if seen[x0]:
            continue
        cyc = []
        x = x0
        while not seen[x]:
            seen[x] = True
            cyc.append(int(x))
            x = (x + d) % g
        out.append(cyc)
    return out


def compute_rot_orbit_geometry(
    *,
    centers: dict[int, np.ndarray],
    g: int,
    d: int,
    axis: str,
    min_cover_frac: float = 0.8,
    max_examples: int = 6,
) -> dict:
    g = int(g)
    d = int(d) % g
    n = 2 * g

    def node_id(a_state: int, b_state: int) -> int:
        return int(a_state * n + b_state)

    _, enc, dec = _enc_dec_for_g(g)

    orbits_res = residue_orbits(g, d)
    L_expected = int(orbits_res[0].__len__()) if orbits_res else 0
    degenerate_small_L = (L_expected <= 3)

    def _collect_orbit_Pk(other: int, half: int, oid: int):
        half_name = ("rot" if half == 0 else "ref")
        cyc = orbits_res[oid]
        nodes = []
        for x in cyc:
            s = (g + x) if half == 1 else x
            if axis == "a":
                u = node_id(s, other)
            else:
                u = node_id(other, s)
            nodes.append(int(u))

        pts = []
        for u in nodes:
            if u in centers:
                pts.append(centers[u])
        if len(nodes) == 0:
            return None, 0.0, half_name, nodes
        cover = float(len(pts) / len(nodes))
        if len(pts) < 3:
            return None, cover, half_name, nodes
        return np.stack(pts, axis=0), cover, half_name, nodes

    Pk_ref, cov_ref, _, _ = _collect_orbit_Pk(other=0, half=0, oid=0)

    if (Pk_ref is None) or (cov_ref < float(min_cover_frac)):
        found = False
        for other in range(n):
            for half in (0, 1):
                for oid in range(len(orbits_res)):
                    Pk_try, cov_try, _, _ = _collect_orbit_Pk(other=other, half=half, oid=oid)
                    if (Pk_try is not None) and (cov_try >= float(min_cover_frac)):
                        Pk_ref = Pk_try
                        found = True
                        break
                if found:
                    break
            if found:
                break

    W_ref = None
    if Pk_ref is not None:
        P2_ref, s_ref, W0_ref, _mu_ref = _local_pca_project_2d(Pk_ref)
        P2_ref_c, T_ref = _canonicalize_P2_with_T(P2_ref)
        W_ref = W0_ref @ T_ref

    metrics_all = {"rot": {}, "ref": {}}
    cover_fracs: list[float] = []
    planarity_list: list[float] = []
    distortion_list: list[float] = []
    examples: list[dict] = []

    n_total = 0
    n_used = 0
    n_skipped_lowcov = 0

    other_states = range(n)

    for other in other_states:
        for half in (0, 1):
            half_name = ("rot" if half == 0 else "ref")
            for oid, cyc in enumerate(orbits_res):
                n_total += 1

                nodes = []
                for x in cyc:
                    s = (g + x) if half == 1 else x
                    if axis == "a":
                        u = node_id(s, other)
                    elif axis == "b":
                        u = node_id(other, s)
                    else:
                        raise ValueError("axis must be 'a' or 'b'")
                    nodes.append(int(u))

                pts = []
                keep_nodes = []
                for u in nodes:
                    if u in centers:
                        pts.append(centers[u])
                        keep_nodes.append(u)

                L = len(nodes)
                cover = float(len(pts) / L) if L > 0 else 0.0
                cover_fracs.append(cover)

                if cover < float(min_cover_frac) or len(pts) < 3:
                    n_skipped_lowcov += 1
                    continue

                Pk = np.stack(pts, axis=0)
                met = orbit_metrics_k_use(Pk, W_ref=W_ref)

                n_used += 1

                for k, v in met.items():
                    if k == "L":
                        continue
                    if v is None:
                        continue
                    if isinstance(v, (int, float, np.floating)):
                        metrics_all[half_name].setdefault(k, []).append(float(v))

                if met.get("planarity_ratio") is not None:
                    planarity_list.append(float(met["planarity_ratio"]))
                if met.get("proj_step_len_rel_distortion") is not None:
                    distortion_list.append(float(met["proj_step_len_rel_distortion"]))

                if len(examples) < int(max_examples):
                    examples.append(
                        dict(
                            other_state=int(other),
                            half=("rot" if half == 0 else "ref"),
                            orbit_id=int(oid),
                            cover_frac=float(cover),
                            nodes=keep_nodes[: min(12, len(keep_nodes))],
                            metrics=met,
                        )
                    )

    out = {
        "axis": str(axis),
        "g": int(g),
        "d": int(d),
        "n_states_other": int(n),
        "n_residue_orbits": int(len(orbits_res)),
        "orbit_len_expected": int(L_expected),
        "diagnostic_degenerate_due_to_small_L": bool(degenerate_small_L),
        "degenerate_L_threshold": 4,
        "min_cover_frac": float(min_cover_frac),
        "counts": {
            "n_total_orbits": int(n_total),
            "n_used_orbits": int(n_used),
            "n_skipped_lowcov_or_small": int(n_skipped_lowcov),
        },
        "coverage_frac": _q_summary(np.asarray(cover_fracs, float)) if cover_fracs else None,
        "metrics_summary": {
            hh: {k: _q_summary(np.asarray(v, float)) for k, v in metrics_all[hh].items()}
            for hh in ("rot", "ref")
        },
        "planarity_ratio_summary": _q_summary(np.asarray(planarity_list, float)) if planarity_list else None,
        "proj_distortion_summary": _q_summary(np.asarray(distortion_list, float)) if distortion_list else None,
        "examples": examples,
    }
    return out


def make_residue_permutation(g: int, *, seed: int = 0) -> np.ndarray:
    rng = np.random.default_rng(int(seed))
    return rng.permutation(int(g)).astype(int)


def apply_axis_residue_perm_fiber_present(*, centers, g, axis, rng):
    g = int(g)
    n = 2 * g
    buckets = {}

    for u in centers.keys():
        a, b = u // n, u % n
        if axis == "a":
            half = 0 if a < g else 1
            other = b
        elif axis == "b":
            half = 0 if b < g else 1
            other = a
        else:
            raise ValueError
        buckets.setdefault((other, half), []).append(int(u))

    centers_p = {}
    for key, us in buckets.items():
        us = sorted(us)
        vecs = [centers[u] for u in us]
        perm = rng.permutation(len(us))
        for i, u in enumerate(us):
            centers_p[u] = vecs[int(perm[i])]

    assert set(centers_p.keys()) == set(centers.keys())
    return centers_p


def apply_axis_residue_perm_to_centers(
    *,
    centers: dict[int, np.ndarray],
    g: int,
    axis: str,
    perm_res: np.ndarray,
) -> dict[int, np.ndarray]:
    g = int(g)
    n = 2 * g
    axis = str(axis)
    perm_res = np.asarray(perm_res, int)
    if perm_res.size != g:
        raise ValueError("perm_res must have length g")

    def perm_state(s: int) -> int:
        if s >= g:
            return g + int(perm_res[s - g])
        return int(perm_res[s])

    centers_p: dict[int, np.ndarray] = {}
    for u in centers.keys():
        u = int(u)
        a, b = u // n, u % n
        if axis == "a":
            a2, b2 = perm_state(a), b
        elif axis == "b":
            a2, b2 = a, perm_state(b)
        else:
            raise ValueError("axis must be 'a' or 'b'")
        up = int(a2 * n + b2)
        centers_p[u] = centers.get(up, centers[u])
    return centers_p


def traverse_orbit_null_axis_res_perm(
    *,
    centers: dict[int, np.ndarray],
    g: int,
    d: int,
    axis: str,
    n_samples: int = 20,
    seed: int = 0,
    min_cover_frac: float = 0.8,
    mode: str = "fiber_present",
) -> dict:
    axis = str(axis)
    mode = str(mode)

    g = int(g)
    d = int(d) % g if g > 0 else 0
    orbits_res = residue_orbits(g, d)
    L_expected = int(orbits_res[0].__len__()) if orbits_res else 0

    skip_metrics = (L_expected < 4)

    key_metrics = [] if skip_metrics else [
        "step_len_cv_k",
        "spearman_rho_t_vs_angle",
        "frac_same_dir",
        "self_intersections",
        "signed_area",
        "planarity_ratio",
        "proj_step_len_rel_distortion",
    ]
    per_metric_runs: dict[str, list[float]] = {k: [] for k in key_metrics}
    survival_fracs: list[float] = []

    for t in range(int(n_samples)):
        run_seed = int(seed) + 1000 * t + (0 if axis == "a" else 17)

        if mode == "fiber_present":
            rng = np.random.default_rng(run_seed)
            centers_p = apply_axis_residue_perm_fiber_present(
                centers=centers, g=g, axis=axis, rng=rng
            )
        elif mode == "global_Zg":
            perm = make_residue_permutation(g, seed=run_seed)
            centers_p = apply_axis_residue_perm_to_centers(
                centers=centers, g=g, axis=axis, perm_res=perm
            )
        else:
            raise ValueError("mode must be 'fiber_present' or 'global_Zg'")

        geom = compute_rot_orbit_geometry(
            centers=centers_p,
            g=g,
            d=d,
            axis=axis,
            min_cover_frac=min_cover_frac,
            max_examples=0,
        )

        cnt = (geom.get("counts", {}) or {})
        n_tot = float(cnt.get("n_total_orbits", 0) or 0)
        n_use = float(cnt.get("n_used_orbits", 0) or 0)
        survival_fracs.append(0.0 if n_tot <= 0 else float(n_use / n_tot))

        ms = geom.get("metrics_summary", {}) or {}

        for k in key_metrics:
            vals = []
            for hh in ("rot", "ref"):
                blk = (ms.get(hh, {}) or {}).get(k, None)
                if blk is None:
                    continue
                v = blk.get("mean", None)
                if v is None:
                    continue
                vals.append(float(v))
            if vals:
                per_metric_runs[k].append(float(np.mean(vals)))

    return {
        "mode": mode,
        "axis": axis,
        "n_samples": int(n_samples),
        "seed": int(seed),
        "min_cover_frac": float(min_cover_frac),
        "orbit_len_expected": int(L_expected),
        "skipped_metrics": bool(skip_metrics),
        "skipped_reason": (
            f"orbit_len_expected={L_expected} < 4; null-orbit orientation metrics are not informative"
            if skip_metrics else None
        ),
        "survival_frac_summary_over_runs": (
            _q_summary(np.asarray(survival_fracs, float)) if survival_fracs else None
        ),
        "metric_means_across_orbits__summary_over_runs": (
            None if skip_metrics else {
                k: _q_summary(np.asarray(v, float)) if len(v) else None
                for k, v in per_metric_runs.items()
            }
        ),
    }


def run_rot_then_ref_report(
    embedding_weights: np.ndarray,
    neuron_ids: np.ndarray,
    *,
    G: int,
    f0: int,
    d: int,
    a_k: Optional[int] = None,
    a_model: str = "k_minus_x",
    a_phi: Optional[np.ndarray] = None,
    b_k: Optional[int] = None,
    b_model: str = "shift",
    b_phi: Optional[np.ndarray] = None,
    num_pca_dims: int = 8,
    pca_dims: int | None = None,
    pca_cumvar_tau: float = 0.90,
    pca_dims_min: int = 2,
    pca_dims_max: int | None = None,
    ball_q: float = 1.0,
    ball_center_mode: str = "miniball",
    ball_eps: float = 1e-2,
    ball_max_iter: int | None = None,
    ball_seed: int = 0,
    out_dir: str = ".",
    label: str = "cluster",
    do_bfs: bool = True,
    start_a_state: int = 0,
    start_b_state: int = 0,
    eps_rel_geom: float = 1e-6,
) -> Dict[str, Any]:
    G = int(G)
    p = int(G // 2)
    f0 = int(f0)
    d = int(d)

    if 2 * p != G:
        raise ValueError("G must be even, with p=G//2")

    g = int(p // math.gcd(p, f0)) if math.gcd(p, f0) != 0 else int(p)
    if g <= 0:
        raise ValueError("invalid g")
    d = int(d % g)

    neuron_ids = np.asarray(neuron_ids, int)
    X = np.asarray(embedding_weights, float)
    if neuron_ids.ndim != 1 or neuron_ids.size == 0:
        raise ValueError("neuron_ids must be non-empty 1D")
    if neuron_ids.min() < 0 or neuron_ids.max() >= X.shape[0]:
        raise ValueError("neuron_ids out of range")

    if a_phi is None:
        if a_k is None:
            raise ValueError("need either a_phi or (a_k,a_model)")
        a_phi = phi_from_model(g, int(a_k), a_model)
    else:
        a_phi = np.asarray(a_phi, int)
        if a_phi.size != g:
            raise ValueError("a_phi must have length g")

    if b_phi is None:
        if b_k is None:
            raise ValueError("need either b_phi or (b_k,b_model)")
        b_phi = phi_from_model(g, int(b_k), b_model)
    else:
        b_phi = np.asarray(b_phi, int)
        if b_phi.size != g:
            raise ValueError("b_phi must have length g")

    ver_a = verify_phi_equivariance(a_phi, g, d)
    ver_b = verify_phi_equivariance(b_phi, g, d)

    X_sub = X[neuron_ids]
    pcs, pca = compute_pca_coords(X_sub, num_components=int(num_pca_dims))
    pcs = np.asarray(pcs, float)

    evr = getattr(pca, "explained_variance_ratio_", None)
    evr = (np.asarray(evr, float) if evr is not None else None)

    if pca_dims_max is None:
        pca_dims_max = int(pcs.shape[1])

    if pca_dims is not None:
        k_use = int(min(int(pca_dims), pcs.shape[1], pca_dims_max))
    else:
        if evr is None or evr.size == 0:
            k_use = int(min(4, pcs.shape[1], pca_dims_max))
        else:
            cum = np.cumsum(evr)
            k0 = int(np.searchsorted(cum, float(pca_cumvar_tau), side="left") + 1)
            k0 = max(int(pca_dims_min), k0)
            k_use = int(min(k0, pcs.shape[1], pca_dims_max))

    Z = pcs[:, :k_use]
    evr_list = (evr.tolist() if evr is not None else None)

    a_idx = neuron_ids // G
    b_idx = neuron_ids % G
    a_state = _state_from_axis_index(a_idx, p=p, g=g)
    b_state = _state_from_axis_index(b_idx, p=p, g=g)
    cid = _cluster_id(a_state, b_state, g=g)

    qs = list(dict.fromkeys([float(ball_q), 0.98]))
    geo = cluster_centers_radii_multiq(
        Z,
        cid,
        qs=qs,
        center_mode=ball_center_mode,
        ball_eps=ball_eps,
        ball_max_iter=ball_max_iter,
        ball_seed=ball_seed,
    )
    centers = geo["centers"]
    present = set(int(x) for x in geo["present"])

    ref_edges_a = build_ref_edges_from_phi_axis(g, a_phi, axis="a")
    ref_edges_b = build_ref_edges_from_phi_axis(g, b_phi, axis="b")
    rot_edges_a = build_rot_step_edges(g, d, axis="a")
    rot_edges_b = build_rot_step_edges(g, d, axis="b")

    idx_map = build_cluster_index_map(cid)

    scale_points = global_mean_distance(Z, max_points=3000, seed=0)
    center_mat = np.stack([centers[k] for k in sorted(centers.keys())], axis=0) if len(centers) > 1 else None
    scale_centers = global_mean_distance(center_mat, max_points=3000, seed=1) if center_mat is not None else 0.0
    norm_scale_centers = (scale_centers if scale_centers > 0 else None)

    distances_by_q: dict[str, Any] = {}
    pairwise_by_q: dict[str, Any] = {}

    for q in geo["qs"]:
        qf = float(q)
        rmap = geo["r_q"][qf]

        ref_a_stats = edge_stats_on_centers(
            centers, rmap, ref_edges_a, normalize_by=norm_scale_centers, eps_rel=eps_rel_geom
        )
        ref_b_stats = edge_stats_on_centers(
            centers, rmap, ref_edges_b, normalize_by=norm_scale_centers, eps_rel=eps_rel_geom
        )
        rot_a_stats = edge_stats_on_centers(
            centers, rmap, rot_edges_a, normalize_by=norm_scale_centers, eps_rel=eps_rel_geom
        )
        rot_b_stats = edge_stats_on_centers(
            centers, rmap, rot_edges_b, normalize_by=norm_scale_centers, eps_rel=eps_rel_geom
        )

        def _mean(st):
            return (st.get("dist", {}) or {}).get("mean", None)

        mean_ref = [x for x in [_mean(ref_a_stats), _mean(ref_b_stats)] if x is not None]
        mean_rot = [x for x in [_mean(rot_a_stats), _mean(rot_b_stats)] if x is not None]
        mean_ref = float(np.mean(mean_ref)) if mean_ref else None
        mean_rot = float(np.mean(mean_rot)) if mean_rot else None
        mean_check = (None if (mean_ref is None or mean_rot is None) else bool(mean_ref >= mean_rot))

        distances_by_q[str(qf)] = dict(
            q=qf,
            ref_a=ref_a_stats,
            ref_b=ref_b_stats,
            rot_a=rot_a_stats,
            rot_b=rot_b_stats,
            mean_ref_link=mean_ref,
            mean_rot_step=mean_rot,
            mean_check_ref_ge_rot=mean_check,
            frac_ref_a_margin_q_pos=ref_a_stats.get("frac_margin_q_pos"),
            frac_ref_b_margin_q_pos=ref_b_stats.get("frac_margin_q_pos"),
            min_ref_a_margin_q=ref_a_stats.get("min_margin_q"),
            min_ref_b_margin_q=ref_b_stats.get("min_margin_q"),
        )

        pairwise_by_q[str(qf)] = dict(
            q=qf,
            global_mean_distance=float(scale_points),
            ref_a=edge_stats_pairwise_between_clusters(Z, idx_map, ref_edges_a, normalize_by=scale_points),
            ref_b=edge_stats_pairwise_between_clusters(Z, idx_map, ref_edges_b, normalize_by=scale_points),
            rot_a=edge_stats_pairwise_between_clusters(Z, idx_map, rot_edges_a, normalize_by=scale_points),
            rot_b=edge_stats_pairwise_between_clusters(Z, idx_map, rot_edges_b, normalize_by=scale_points),
        )

    sigma_a = int(ver_a["best_sign"])
    sigma_b = int(ver_b["best_sign"])

    invol_a = relation_involution_stats(g=g, phi=a_phi, axis="a", present=present)
    invol_b = relation_involution_stats(g=g, phi=b_phi, axis="b", present=present)

    conjugacy_by_q = {}
    for q in geo["qs"]:
        qf = float(q)
        rmap = geo["r_q"][qf]
        conj_a = relation_conjugacy_stats(
            g=g,
            d=d,
            phi=a_phi,
            axis="a",
            sigma=sigma_a,
            present=present,
            centers=centers,
            rmap=rmap,
            normalize_by=norm_scale_centers,
            eps_rel=eps_rel_geom,
        )
        conj_b = relation_conjugacy_stats(
            g=g,
            d=d,
            phi=b_phi,
            axis="b",
            sigma=sigma_b,
            present=present,
            centers=centers,
            rmap=rmap,
            normalize_by=norm_scale_centers,
            eps_rel=eps_rel_geom,
        )
        conjugacy_by_q[str(qf)] = {"a": conj_a, "b": conj_b}

    n2, enc2, dec2 = _enc_dec_for_g(g)
    inv_a, _err_a = _try_invert_phi(a_phi)
    inv_b, _err_b = _try_invert_phi(b_phi)

    def R_a(u: int) -> int:
        a, b = dec2(int(u))
        return enc2(_rot_step_state_1d(a, g, d, +1), b)

    def R_b(u: int) -> int:
        a, b = dec2(int(u))
        return enc2(a, _rot_step_state_1d(b, g, d, +1))

    def S_a(u: int) -> int:
        if inv_a is None:
            return int(u)
        a, b = dec2(int(u))
        return enc2(_ref_toggle_state_1d(a, g, a_phi, inv_a), b)

    def S_b(u: int) -> int:
        if inv_b is None:
            return int(u)
        a, b = dec2(int(u))
        return enc2(a, _ref_toggle_state_1d(b, g, b_phi, inv_b))

    commutation_by_q = {}
    for q in geo["qs"]:
        qf = float(q)
        rmap = geo["r_q"][qf]
        commutation_by_q[str(qf)] = {
            "R_a_R_b": relation_commutation_stats(
                name="R_a * R_b  vs  R_b * R_a",
                present=present,
                op1=R_a,
                op2=R_b,
                centers=centers,
                rmap=rmap,
                normalize_by=norm_scale_centers,
                eps_rel=eps_rel_geom,
            ),
            "S_a_S_b": relation_commutation_stats(
                name="S_a * S_b  vs  S_b * S_a",
                present=present,
                op1=S_a,
                op2=S_b,
                centers=centers,
                rmap=rmap,
                normalize_by=norm_scale_centers,
                eps_rel=eps_rel_geom,
            ),
            "S_a_R_b": relation_commutation_stats(
                name="S_a * R_b  vs  R_b * S_a",
                present=present,
                op1=S_a,
                op2=R_b,
                centers=centers,
                rmap=rmap,
                normalize_by=norm_scale_centers,
                eps_rel=eps_rel_geom,
            ),
            "R_a_S_b": relation_commutation_stats(
                name="R_a * S_b  vs  S_b * R_a",
                present=present,
                op1=R_a,
                op2=S_b,
                centers=centers,
                rmap=rmap,
                normalize_by=norm_scale_centers,
                eps_rel=eps_rel_geom,
            ),
        }

    null_equiv = {
        "a": null_phi_baseline_equivariance(g=g, d=d, n_samples=20, seed=0),
        "b": null_phi_baseline_equivariance(g=g, d=d, n_samples=20, seed=1),
    }

    traverse_obs = {
        "a": compute_rot_orbit_geometry(
            centers=centers,
            g=g,
            d=d,
            axis="a",
            min_cover_frac=0.8,
            max_examples=6,
        ),
        "b": compute_rot_orbit_geometry(
            centers=centers,
            g=g,
            d=d,
            axis="b",
            min_cover_frac=0.8,
            max_examples=6,
        ),
    }

    traverse_null = {
        "fiber_present": {
            "a": traverse_orbit_null_axis_res_perm(
                centers=centers, g=g, d=d, axis="a",
                n_samples=20, seed=0, min_cover_frac=0.8, mode="fiber_present",
            ),
            "b": traverse_orbit_null_axis_res_perm(
                centers=centers, g=g, d=d, axis="b",
                n_samples=20, seed=1, min_cover_frac=0.8, mode="fiber_present",
            ),
        },
        "global_Zg": {
            "a": traverse_orbit_null_axis_res_perm(
                centers=centers, g=g, d=d, axis="a",
                n_samples=20, seed=0, min_cover_frac=0.8, mode="global_Zg",
            ),
            "b": traverse_orbit_null_axis_res_perm(
                centers=centers, g=g, d=d, axis="b",
                n_samples=20, seed=1, min_cover_frac=0.8, mode="global_Zg",
            ),
        },
    }

    cov_rot_only = bfs_coverage_rot_only(g, d, start_a_state=start_a_state, start_b_state=start_b_state)
    cov_nested = nested_coverage_2d(g, d, a_phi, b_phi)
    cov_bfs = None
    if do_bfs:
        cov_bfs = bfs_coverage_quotient(
            g, d, a_phi, b_phi, start_a_state=start_a_state, start_b_state=start_b_state
        )

    cumvar = (float(np.sum(evr[:k_use])) if evr is not None else None)

    payload: Dict[str, Any] = dict(
        meta=dict(
            label=str(label),
            G=int(G),
            p=int(p),
            f0=int(f0),
            g=int(g),
            d=int(d),
            n_points=int(neuron_ids.size),
            ball_q=float(ball_q),
            ball_center_mode=str(ball_center_mode),
            ball_eps=(float(ball_eps) if ball_center_mode == "miniball" else None),
            ball_max_iter=(int(ball_max_iter) if (ball_max_iter is not None and ball_center_mode == "miniball") else None),
            eps_rel_geom=float(eps_rel_geom),
        ),
        pca=dict(
            num_pca_dims=int(num_pca_dims),
            metric_dims=int(k_use),
            explained_variance_ratio=evr_list,
            cumvar_metric=cumvar,
            cumvar_tau=(None if pca_dims is not None else float(pca_cumvar_tau)),
        ),
        phi=dict(
            a=dict(model=str(a_model), k_input=(None if a_k is None else int(a_k)), phi=a_phi.tolist(), verify=ver_a),
            b=dict(model=str(b_model), k_input=(None if b_k is None else int(b_k)), phi=b_phi.tolist(), verify=ver_b),
        ),
        geometry=dict(
            n_present_clusters=int(geo["n_present"]),
            qs=[float(q) for q in geo["qs"]],
            center_mode=geo.get("center_mode"),
            ball_eps=geo.get("ball_eps"),
            ball_max_iter=geo.get("ball_max_iter"),
            sizes={str(k): int(v) for k, v in geo["sizes"].items()},
            r_q_by_q={
                str(float(q)): {str(cid): float(rad) for cid, rad in geo["r_q"][float(q)].items()}
                for q in geo["qs"]
            },
        ),
        distances_by_q=distances_by_q,
        pairwise_by_q=pairwise_by_q,
        relations=dict(
            involution=dict(a=invol_a, b=invol_b),
            conjugacy_by_q=conjugacy_by_q,
            commutation_by_q=commutation_by_q,
            sigma=dict(a=int(sigma_a), b=int(sigma_b)),
            null_equivariance=null_equiv,
        ),
        nulls=dict(
            phi_equivariance=null_equiv,
            traverse_orbit=dict(
                observed=traverse_obs,
                axis_residue_perm_null=traverse_null,
            ),
        ),
        coverage=dict(
            rot_only=cov_rot_only,
            nested=cov_nested,
            bfs=(cov_bfs if cov_bfs is not None else None),
        ),
        scales=dict(
            global_mean_point_dist=float(scale_points),
            global_mean_center_dist=(float(scale_centers) if scale_centers > 0 else 0.0),
        ),
    )

    Path(out_dir).mkdir(parents=True, exist_ok=True)
    out_path = os.path.join(out_dir, f"{label}_rot_then_ref_report_f{int(f0)}.json")
    with open(out_path, "w") as f:
        json.dump(payload, f, indent=2)
    print(f"[{label}] report saved at '{out_path}'")
    return payload
