import base64
import io
import math
from collections import defaultdict, OrderedDict
from typing import Any, Dict, List, Tuple

import numpy as np
import plotly.express as px
import plotly.io as pio

import report


def _angle_wrap(phi):
    phi = np.asarray(phi, float)
    return (phi + np.pi) % (2 * np.pi) - np.pi


def _circ_mean(phases, weights=None):
    phases = np.asarray(phases, float)
    if weights is None:
        weights = np.ones_like(phases)
    else:
        weights = np.asarray(weights, float)
    w = np.clip(weights, 0.0, None)

    if phases.size == 0:
        return 0.0
    if w.sum() <= 1e-12:
        return float(phases[0])

    S = np.sum(w * np.sin(phases))
    C = np.sum(w * np.cos(phases))
    if abs(S) < 1e-12 and abs(C) < 1e-12:
        return float(phases[0])
    return float(np.arctan2(S, C))


def _r2_from_pred(y, yhat):
    y = np.asarray(y, float)
    yhat = np.asarray(yhat, float)
    m = np.isfinite(y)
    if m.sum() < 3:
        return 0.0, 0.0, 0.0
    y = y[m]
    yhat = yhat[m]
    y0 = y - y.mean()
    sst = float(np.sum(y0 * y0))
    if sst <= 1e-12:
        return 0.0, 0.0, 0.0
    resid = y - yhat
    sse = float(np.sum(resid * resid))
    r2 = max(0.0, 1.0 - sse / sst)
    n = y.size
    return r2, sse, n


def _fit_linear_design(y, X):
    beta, *_ = np.linalg.lstsq(X, y, rcond=None)
    yhat = X @ beta
    return beta, yhat


def _design_SinA_plus_SinB_k(p, fa, fb):
    def _X(a, b):
        ta = 2 * np.pi * fa * a / p
        tb = 2 * np.pi * fb * b / p
        return np.c_[np.sin(ta), np.cos(ta), np.sin(tb), np.cos(tb), np.ones_like(a)]

    k_params = 5
    return _X, k_params


def _design_SinAplusSinB_no_bias_dyn(p, f):
    def _X(a, b):
        ta = 2 * np.pi * f * a / p
        tb = 2 * np.pi * f * b / p
        Sa, Ca = np.sin(ta).ravel(), np.cos(ta).ravel()
        Sb, Cb = np.sin(tb).ravel(), np.cos(tb).ravel()
        cols = []
        for col in (Sa, Ca, Sb, Cb):
            if np.std(col) > 1e-12:
                cols.append(col)
        return np.column_stack(cols) if cols else np.zeros((a.size, 0))

    return _X


def _adjr2_bic_from_sse(sse, sst, n, k):
    if n <= k + 1 or sst <= 1e-30:
        return 0.0, 0.0
    r2_adj = 1.0 - ((sse / max(1e-30, (n - k - 1))) / (sst / max(1e-30, (n - 1))))
    bic = n * np.log(max(sse, 1e-30) / n) + k * np.log(max(n, 2))
    return float(r2_adj), float(bic)


def _design_SinAxis_no_bias_dyn(p, f, axis="a"):
    def _X(a, b):
        t = 2 * np.pi * f * (a if axis == "a" else b) / p
        S = np.sin(t).ravel()
        C = np.cos(t).ravel()
        cols = []
        if np.std(S) > 1e-12:
            cols.append(S)
        if np.std(C) > 1e-12:
            cols.append(C)
        return np.column_stack(cols) if cols else np.zeros((a.size, 0))

    return _X


def _design_SinApmB_no_bias_dyn(p, f, sign=+1):
    def _X(a, b):
        t = 2 * np.pi * f * (a + sign * b) / p
        S = np.sin(t).ravel()
        C = np.cos(t).ravel()
        cols = []
        if np.std(S) > 1e-12:
            cols.append(S)
        if np.std(C) > 1e-12:
            cols.append(C)
        return np.column_stack(cols) if cols else np.zeros((a.size, 0))

    return _X


def _design_freq_allbases_no_bias_dyn(p, f, *, use_axes_only=True, use_pair_terms=True):
    def _X(a, b):
        blocks = []

        if use_axes_only:
            X_axes = _design_SinAplusSinB_no_bias_dyn(p, f)(a, b)
            if X_axes.shape[1] > 0:
                blocks.append(X_axes)

        if use_pair_terms:
            X_plus = _design_SinApmB_no_bias_dyn(p, f, +1)(a, b)
            if X_plus.shape[1] > 0:
                blocks.append(X_plus)
            X_minus = _design_SinApmB_no_bias_dyn(p, f, -1)(a, b)
            if X_minus.shape[1] > 0:
                blocks.append(X_minus)

        if not blocks:
            return np.zeros((a.size, 0), dtype=float)
        return np.column_stack(blocks)

    return _X


def _design_indicator_freq_allbases_quads(p, f, use_axes_only=True, use_pair_terms=True):
    a_idx, b_idx = np.indices((p, p))
    A = a_idx.ravel()
    B = b_idx.ravel()
    half = p // 2

    Ia = np.where(A < half, 1.0, -1.0)
    Ib = np.where(B < half, 1.0, -1.0)
    Iab = Ia * Ib

    omega = 2.0 * np.pi * f / p

    bases = []
    if use_axes_only:
        ta = omega * A
        tb = omega * B
        bases.extend([np.sin(ta), np.cos(ta), np.sin(tb), np.cos(tb)])
    if use_pair_terms:
        t_plus = omega * (A + B)
        t_minus = omega * (A - B)
        bases.extend([np.sin(t_plus), np.cos(t_plus), np.sin(t_minus), np.cos(t_minus)])

    cols = []
    for base in bases:
        for mask in (np.ones_like(base), Ia, Ib, Iab):
            col = base * mask
            if np.std(col) > 1e-12:
                cols.append(col)

    if not cols:
        return np.zeros((A.size, 0), dtype=float)
    return np.column_stack(cols)


def _fit_quadrant_sine_models(
    q,
    fa_base,
    fb_base,
    f_pool,
    *,
    max_iters=6,
    min_delta_r2=0.0,
    criterion="bic",
    use_axes_only=True,
    use_pair_terms=True,
    one_family_per_freq=True,
    include_ab_resid=False,
):
    p = q.shape[0]
    a_idx, b_idx = np.indices((p, p))
    y = q.reshape(-1)
    n = y.size

    y_centered = y - y.mean()
    sst = float(np.sum(y_centered * y_centered))
    if sst <= 1e-30:
        return {
            "R2_stage1": 0.0,
            "adjR2_stage1": 0.0,
            "BIC_stage1": 0.0,
            "chosen_stage2": [],
            "R2_final": 0.0,
            "adjR2_final": 0.0,
            "BIC_final": 0.0,
            "best": {
                "name": "none",
                "f": int(fa_base),
                "R2": 0.0,
                "adjR2": 0.0,
                "BIC": 0.0,
                "delta_R2": 0.0,
            },
            "fit_yhat": np.zeros_like(y),
            "fit_grid": np.zeros_like(q),
        }

    X_bias = np.ones((n, 1), dtype=float)
    beta0, yhat0 = _fit_linear_design(y, X_bias)
    R2_0, sse_0, _ = _r2_from_pred(y, yhat0)
    adj0, bic0 = _adjr2_bic_from_sse(sse_0, sst, n, X_bias.shape[1])

    X_sel = X_bias.copy()
    yhat_current = yhat0.copy()
    k_current = X_sel.shape[1]
    R2_cur = R2_0

    chosen = []
    used_f = set()

    def _score_with(X_all):
        beta, yhat_all = _fit_linear_design(y, X_all)
        sse = float(np.sum((y - yhat_all) ** 2))
        R2 = max(0.0, 1.0 - sse / sst)
        adj, bic = _adjr2_bic_from_sse(sse, sst, n, X_all.shape[1])
        return R2, adj, bic, yhat_all

    base_freqs = []
    for f0 in (fa_base, fb_base):
        if f0 is None:
            continue
        try:
            f0_int = int(abs(f0))
        except Exception:
            continue
        base_freqs.append(f0_int)
    base_freqs = sorted(set(base_freqs))

    for f0 in base_freqs:
        if one_family_per_freq and (f0 in used_f):
            continue

        Xf_fn = _design_freq_allbases_no_bias_dyn(
            p, f0, use_axes_only=use_axes_only, use_pair_terms=use_pair_terms
        )
        Xg0 = Xf_fn(a_idx.ravel(), b_idx.ravel())
        if Xg0.shape[1] == 0:
            continue

        X_all0 = np.column_stack([X_sel, Xg0])
        R2_new0, adj_new0, bic_new0, yhat_new0 = _score_with(X_all0)
        delta0 = R2_new0 - R2_cur

        X_sel = X_all0
        yhat_current = yhat_new0
        k_current = X_sel.shape[1]
        R2_cur = R2_new0

        chosen.append(
            {
                "name": "freq-all-base",
                "f": f0,
                "R2": R2_new0,
                "adjR2": adj_new0,
                "BIC": bic_new0,
                "delta_R2": delta0,
            }
        )
        used_f.add(f0)

    def _iter_candidates():
        for f in f_pool:
            if one_family_per_freq and (f in used_f):
                continue
            yield (
                "freq-all",
                int(f),
                _design_freq_allbases_no_bias_dyn(
                    p,
                    int(f),
                    use_axes_only=use_axes_only,
                    use_pair_terms=use_pair_terms,
                ),
            )

        if include_ab_resid:
            fa = int(fa_base)
            yield ("ab_base_axis", fa, _design_SinAplusSinB_no_bias_dyn(p, fa))

    for _ in range(int(max_iters)):
        best = None

        for name, f, Xf_fn in _iter_candidates():
            Xg = Xf_fn(a_idx.ravel(), b_idx.ravel())
            k_add = Xg.shape[1]
            if k_add == 0:
                continue

            X_all = np.column_stack([X_sel, Xg])
            R2_new, adj_new, bic_new, yhat_new = _score_with(X_all)
            delta_R2 = R2_new - R2_cur

            item = (name, int(f), Xg, k_add, R2_new, adj_new, bic_new, yhat_new, delta_R2)
            if best is None:
                best = item
            else:
                if criterion.lower() == "bic":
                    if (bic_new < best[6]) or (bic_new == best[6] and R2_new > best[4]):
                        best = item
                else:
                    if (delta_R2 > best[8]) or (delta_R2 == best[8] and bic_new < best[6]):
                        best = item

        if best is None:
            break

        name, f_sel, Xg, k_add, R2_new, adj_new, bic_new, yhat_new, delta_R2 = best
        if delta_R2 < float(min_delta_r2):
            break

        X_sel = np.column_stack([X_sel, Xg])
        yhat_current = yhat_new
        k_current = X_sel.shape[1]
        R2_cur = R2_new
        chosen.append(
            {
                "name": name,
                "f": f_sel,
                "R2": R2_new,
                "adjR2": adj_new,
                "BIC": bic_new,
                "delta_R2": delta_R2,
            }
        )

        if one_family_per_freq and (name not in ("ab_base_axis",)):
            used_f.add(f_sel)

    sse_final = float(np.sum((y - yhat_current) ** 2))
    R2_final = max(0.0, 1.0 - sse_final / sst)
    adj_final, bic_final = _adjr2_bic_from_sse(sse_final, sst, n, k_current)

    fit_yhat = yhat_current
    fit_grid = yhat_current.reshape(p, p)

    return {
        "R2_stage1": R2_0,
        "adjR2_stage1": adj0,
        "BIC_stage1": bic0,
        "chosen_stage2": chosen,
        "R2_final": R2_final,
        "adjR2_final": adj_final,
        "BIC_final": bic_final,
        "best": (
            chosen[-1]
            if chosen
            else {
                "name": "bias-only",
                "f": int(fa_base),
                "R2": R2_0,
                "adjR2": adj0,
                "BIC": bic0,
                "delta_R2": 0.0,
            }
        ),
        "fit_yhat": fit_yhat,
        "fit_grid": fit_grid,
    }


def fit_quadrant_sines(
    q,
    dom,
    freq_map,
    names,
    *,
    f_pool,
    max_iters=6,
    tau_inc=0.0,
    use_axes_only=True,
    use_pair_terms=True,
    include_ab_resid=False,
):
    p = q.shape[0]

    if dom["kind"] == "diag":
        f = abs(int(freq_map[dom["b_star"]]))
        fa_base = fb_base = f
    else:
        fa_base = abs(int(freq_map[dom["a_star"]]))
        fb_base = abs(int(freq_map[dom["b_star"]]))

    if fa_base not in f_pool:
        f_pool = [fa_base] + list(f_pool)
    if fb_base not in f_pool:
        f_pool = [fb_base] + list(f_pool)

    return _fit_quadrant_sine_models(
        q,
        fa_base,
        fb_base,
        f_pool,
        max_iters=max_iters,
        min_delta_r2=tau_inc,
        use_axes_only=use_axes_only,
        use_pair_terms=use_pair_terms,
        include_ab_resid=include_ab_resid,
    )


def neuron_raw_fit_html_plotly(p, pre_full, fitted_full, nid, title_prefix="neuron"):
    raw = np.asarray(pre_full[:, nid]).reshape(p, p)
    fit = np.asarray(fitted_full[:, nid]).reshape(p, p)
    diff = raw - fit

    fig_raw = px.imshow(raw, title=f"{title_prefix} {nid} - RAW preacts", origin="lower")
    fig_fit = px.imshow(fit, title=f"{title_prefix} {nid} - FITTED preacts", origin="lower")
    fig_diff = px.imshow(diff, title=f"{title_prefix} {nid} - RAW - FIT difference", origin="lower")

    blocks = []
    for i, fig in enumerate([fig_raw, fig_fit, fig_diff]):
        html = pio.to_html(fig, full_html=False, include_plotlyjs=("cdn" if i == 0 else False))
        blocks.append(html)

    return "<hr/>\n".join(blocks)


def build_indicator_phase_design(G, fa, fb, use_cross=False):
    a_idx, b_idx = np.indices((G, G))
    A = a_idx.ravel()
    B = b_idx.ravel()

    half = G // 2
    Ia = np.where(A < half, 1.0, -1.0)
    Ib = np.where(B < half, 1.0, -1.0)

    omega_a = 2.0 * np.pi * fa / G
    omega_b = 2.0 * np.pi * fb / G

    ta = omega_a * A
    tb = omega_b * B
    cos_a, sin_a = np.cos(ta), np.sin(ta)
    cos_b, sin_b = np.cos(tb), np.sin(tb)

    cols = []
    col_info = []

    cols.append(cos_a)
    col_info.append(("a", "cos", "base"))
    cols.append(sin_a)
    col_info.append(("a", "sin", "base"))
    cols.append(Ia * cos_a)
    col_info.append(("a", "cos", "Ia"))
    cols.append(Ia * sin_a)
    col_info.append(("a", "sin", "Ia"))

    cols.append(cos_b)
    col_info.append(("b", "cos", "base"))
    cols.append(sin_b)
    col_info.append(("b", "sin", "base"))
    cols.append(Ib * cos_b)
    col_info.append(("b", "cos", "Ib"))
    cols.append(Ib * sin_b)
    col_info.append(("b", "sin", "Ib"))

    if use_cross:
        Iab = Ia * Ib
        cols.append(Iab * cos_a)
        col_info.append(("a", "cos", "Iab"))
        cols.append(Iab * sin_a)
        col_info.append(("a", "sin", "Iab"))
        cols.append(Iab * cos_b)
        col_info.append(("b", "cos", "Iab"))
        cols.append(Iab * sin_b)
        col_info.append(("b", "sin", "Iab"))

    cols.append(np.ones_like(A))
    col_info.append(("bias", "", ""))

    X = np.column_stack(cols).astype(float)
    return X, col_info


def fit_indicator_phase_model_fullgrid(grid, fa, fb, use_cross=False):
    grid = np.asarray(grid, float)
    G = grid.shape[0]
    assert grid.shape[1] == G

    y = grid.ravel()
    X_main, col_info_main = build_indicator_phase_design(G, fa, fb, use_cross=use_cross)

    beta_main, *_ = np.linalg.lstsq(X_main, y, rcond=None)
    yhat_main = X_main @ beta_main

    y0 = y - y.mean()
    sst = float(np.sum(y0 * y0))
    if sst <= 1e-12:
        R2_main = 0.0
    else:
        sse_main = float(np.sum((y - yhat_main) ** 2))
        R2_main = max(0.0, 1.0 - sse_main / sst)

    grid_main = yhat_main.reshape(G, G)
    resid = grid - grid_main

    return {
        "beta_main": beta_main,
        "col_info_main": col_info_main,
        "fit_main_grid": grid_main,
        "resid_grid": resid,
        "R2_main": R2_main,
        "sst": sst,
        "y": y,
        "X_main": X_main,
    }


def fit_fullgrid_indicator_plus_residual_freqs(
    grid,
    fa_base,
    fb_base,
    f_pool,
    *,
    use_cross_main=False,
    use_axes_only=True,
    use_pair_terms=True,
    max_extra_freqs=4,
    min_delta_r2=0.0,
    criterion="bic",
):
    grid = np.asarray(grid, float)
    G = grid.shape[0]
    assert grid.shape[1] == G
    y = grid.ravel()

    main_fit = fit_indicator_phase_model_fullgrid(grid, fa_base, fb_base, use_cross=use_cross_main)
    X_main = main_fit["X_main"]
    beta_main = main_fit["beta_main"]
    col_info_main = main_fit["col_info_main"]
    yhat_main = X_main @ beta_main
    sst = main_fit["sst"]
    R2_main = main_fit["R2_main"]

    X_sel = X_main.copy()
    yhat_current = yhat_main.copy()
    R2_cur = R2_main
    n = y.size

    def _score_with(X_all):
        beta_all, *_ = np.linalg.lstsq(X_all, y, rcond=None)
        yhat_all = X_all @ beta_all
        sse = float(np.sum((y - yhat_all) ** 2))
        if sst <= 1e-12:
            R2 = 0.0
        else:
            R2 = max(0.0, 1.0 - sse / sst)
        adjR2, bic = _adjr2_bic_from_sse(sse, sst, n, X_all.shape[1])
        return R2, adjR2, bic, yhat_all, beta_all

    a_idx, b_idx = np.indices((G, G))
    used_f = set([int(abs(fa_base)), int(abs(fb_base))])
    chosen = []
    beta_final = None

    def _iter_candidates():
        for f in f_pool:
            f_int = int(abs(f))
            if f_int in used_f:
                continue
            Xf = _design_indicator_freq_allbases_quads(
                G, f_int, use_axes_only=use_axes_only, use_pair_terms=use_pair_terms
            )
            if Xf.shape[1] == 0:
                continue
            yield f_int, Xf

    for _ in range(max_extra_freqs):
        best = None

        for f_int, Xf in _iter_candidates():
            X_all = np.column_stack([X_sel, Xf])
            R2_new, adj_new, bic_new, yhat_new, beta_new = _score_with(X_all)
            delta_R2 = R2_new - R2_cur
            item = (f_int, Xf, R2_new, adj_new, bic_new, yhat_new, beta_new, delta_R2)

            if best is None:
                best = item
            else:
                if criterion.lower() == "bic":
                    if (bic_new < best[4]) or (bic_new == best[4] and R2_new > best[2]):
                        best = item
                else:
                    if (delta_R2 > best[7]) or (delta_R2 == best[7] and bic_new < best[4]):
                        best = item

        if best is None:
            break

        f_sel, Xf_sel, R2_new, adj_new, bic_new, yhat_new, beta_new, delta_R2 = best
        if delta_R2 < float(min_delta_r2):
            break

        X_sel = np.column_stack([X_sel, Xf_sel])
        yhat_current = yhat_new
        beta_final = beta_new
        R2_cur = R2_new
        chosen.append(
            {
                "f": f_sel,
                "R2": R2_new,
                "adjR2": adj_new,
                "BIC": bic_new,
                "delta_R2": delta_R2,
                "n_cols": Xf_sel.shape[1],
            }
        )
        used_f.add(f_sel)

    if beta_final is None:
        beta_final = np.linalg.lstsq(X_sel, y, rcond=None)[0]
        yhat_current = X_sel @ beta_final

    sse_final = float(np.sum((y - yhat_current) ** 2))
    if sst <= 1e-12:
        R2_final = 0.0
    else:
        R2_final = max(0.0, 1.0 - sse_final / sst)

    grid_hat = yhat_current.reshape(G, G)
    resid_grid = grid - grid_hat

    return {
        "R2_main": R2_main,
        "R2_final": R2_final,
        "grid_main": main_fit["fit_main_grid"],
        "grid_hat": grid_hat,
        "resid_grid": resid_grid,
        "chosen_freqs": chosen,
        "beta_final": beta_final,
        "X_sel": X_sel,
        "beta_main_init": beta_main,
        "col_info_main": col_info_main,
    }


def _quad_masks(G):
    n = G // 2
    a_idx, b_idx = np.indices((G, G))
    return {
        "Q00": (a_idx < n) & (b_idx < n),
        "Q01": (a_idx < n) & (b_idx >= n),
        "Q10": (a_idx >= n) & (b_idx < n),
        "Q11": (a_idx >= n) & (b_idx >= n),
    }


def _quad_sign_value(pair_sign_mode, quad_name):
    m = {
        "Q00": (+1, +1, +1),
        "Q01": (+1, -1, -1),
        "Q10": (-1, +1, -1),
        "Q11": (-1, -1, +1),
    }
    Ia, Ib, Iab = m[quad_name]
    if pair_sign_mode == "Ia":
        return Ia
    if pair_sign_mode == "Ib":
        return Ib
    if pair_sign_mode == "Iab":
        return Iab
    raise ValueError(f"Unknown pair_sign_mode={pair_sign_mode}")


def _r2_on_mask(grid, yhat_grid, mask):
    yy = np.asarray(grid, float)[mask].ravel()
    yh = np.asarray(yhat_grid, float)[mask].ravel()
    yy0 = yy - yy.mean()
    sst = float(np.sum(yy0 * yy0))
    if sst <= 1e-30:
        return {
            "R2": 0.0,
            "sst": sst,
            "sse": float(np.sum((yy - yh) ** 2)),
            "n": int(yy.size),
        }
    sse = float(np.sum((yy - yh) ** 2))
    r2 = max(0.0, 1.0 - sse / sst)
    return {"R2": float(r2), "sst": sst, "sse": sse, "n": int(yy.size)}


def _amp_phase_from_cos_sin(C, S, *, phase_convention="atan2(-S,C)"):
    C = float(C)
    S = float(S)
    amp = float(np.hypot(C, S))
    if phase_convention == "atan2(-S,C)":
        phi = float(np.arctan2(-S, C))
    else:
        raise ValueError("phase_convention must be 'atan2(-S,C)' or 'atan2(S,C)'")
    return {"cos": C, "sin": S, "amp": amp, "phase": float(_angle_wrap(phi))}


def extract_quadfree_report(
    grid,
    f_list,
    *,
    use_axes=True,
    use_pair=True,
    quad_bias=True,
    mask_mode="quadfree",
    pair_sign_mode="Ia",
    pair_s=+1,
    pair_keep="both",
    pair_coupling="none",
    pair_swap_keep="both",
    phase_convention="atan2(-S,C)",
    decompose_level="family",
    compute_contrib=True,
    compute_quad_r2=True,
    compute_pair_spectrum=True,
    compute_geom_pair=True,
    return_design=False,
):
    grid = np.asarray(grid, float)
    G = grid.shape[0]
    assert grid.shape[1] == G

    y, X, info, couplings = design_fullgrid_coupled(
        grid,
        f_list,
        use_axes=use_axes,
        use_pair=use_pair,
        quad_bias=quad_bias,
        mask_mode=mask_mode,
        return_info=True,
        pair_sign_mode=pair_sign_mode,
        pair_s=pair_s,
        pair_keep=pair_keep,
        pair_coupling=pair_coupling,
        pair_swap_keep=pair_swap_keep,
    )
    out = ols_fit(y, X)
    beta = out["beta"]
    yhat_grid = out["yhat"].reshape(G, G)

    result = {
        "fit": {
            **out,
            "yhat_grid": yhat_grid,
            "G": int(G),
            "mask_mode": mask_mode,
            "pair_sign_mode": pair_sign_mode,
            "pair_s": float(pair_s),
            "pair_keep": pair_keep,
            "pair_coupling": pair_coupling,
            "pair_swap_keep": pair_swap_keep,
        }
    }

    groups = defaultdict(list)

    if decompose_level == "family":
        for j, (fam, base, f, nm) in enumerate(info):
            groups[(nm, fam)].append(j)
    elif decompose_level == "base":
        for j, (fam, base, f, nm) in enumerate(info):
            groups[(nm, fam, base, int(f) if str(f) != "" else f)].append(j)
    elif decompose_level == "none":
        pass
    else:
        raise ValueError("decompose_level must be one of {'family','base','none'}")

    result["groups"] = {k: v for k, v in groups.items()}

    if compute_contrib and decompose_level != "none":
        contrib = {}
        for k, js in groups.items():
            if not js:
                continue
            part = (X[:, js] @ beta[js]).reshape(G, G)
            contrib[k] = part
        result["contrib"] = contrib

    if compute_quad_r2:
        qm = _quad_masks(G)
        quad_stats = {}
        for q, m in qm.items():
            quad_stats[q] = _r2_on_mask(grid, yhat_grid, m)
        result["quad"] = quad_stats

    if compute_pair_spectrum and use_pair:
        pair = defaultdict(lambda: defaultdict(lambda: {"Ap": {"cos": 0.0, "sin": 0.0}, "Am": {"cos": 0.0, "sin": 0.0}}))

        for j, (fam, base, f, nm) in enumerate(info):
            if fam != "pair":
                continue

            if "dihedral_Ap" in base:
                which = "Ap"
            elif "dihedral_Am" in base:
                which = "Am"
            else:
                continue

            trig = "cos" if base.startswith("cos") else ("sin" if base.startswith("sin") else None)
            if trig is None:
                continue

            quad = nm
            f = int(f)
            pair[quad][f][which][trig] = float(beta[j])

        pair_out = {}
        for quad, byf in pair.items():
            pair_out[quad] = {}
            for f, dd in byf.items():
                Ap = _amp_phase_from_cos_sin(dd["Ap"]["cos"], dd["Ap"]["sin"], phase_convention=phase_convention)
                Am = _amp_phase_from_cos_sin(dd["Am"]["cos"], dd["Am"]["sin"], phase_convention=phase_convention)
                pair_out[quad][int(f)] = {"Ap": Ap, "Am": Am}

        result["pair"] = pair_out

    if compute_geom_pair and ("pair" in result):
        geom = {}
        for quad, byf in result["pair"].items():
            sgn = _quad_sign_value(pair_sign_mode, quad)
            geom[quad] = {"sign": int(sgn), "by_f": {}}
            for f, item in byf.items():
                Ap, Am = item["Ap"], item["Am"]
                if sgn > 0:
                    plus = Ap
                    minus = Am
                else:
                    plus = Am
                    minus = Ap
                geom[quad]["by_f"][int(f)] = {"a_plus_b": plus, "a_minus_b": minus}
        result["geom_pair"] = geom

    if return_design:
        result["design"] = {"y": y, "X": X, "info": info}

    return result


def lstsq_with_couplings(y, X, couplings, lam):
    y = np.asarray(y, float).ravel()
    X = np.asarray(X, float)

    if (couplings is None) or (len(couplings) == 0) or (lam <= 0):
        beta, *_ = np.linalg.lstsq(X, y, rcond=None)
        return beta

    m = len(couplings)
    k = X.shape[1]
    Xa = np.zeros((m, k), float)
    ya = np.zeros((m,), float)

    s = np.sqrt(float(lam))
    for i, (j1, j2, sign) in enumerate(couplings):
        Xa[i, j1] = s
        Xa[i, j2] = -s * float(sign)

    X_aug = np.vstack([X, Xa])
    y_aug = np.concatenate([y, ya])
    beta, *_ = np.linalg.lstsq(X_aug, y_aug, rcond=None)
    return beta


def ols_fit(y, X):
    y = np.asarray(y, float).ravel()
    X = np.asarray(X, float)
    beta, *_ = np.linalg.lstsq(X, y, rcond=None)
    yhat = X @ beta
    y0 = y - y.mean()
    sst = float(np.sum(y0 * y0))
    sse = float(np.sum((y - yhat) ** 2))
    r2 = 0.0 if sst <= 1e-30 else max(0.0, 1.0 - sse / sst)
    n, k = y.size, X.shape[1]
    bic = n * np.log(max(sse, 1e-30) / n) + k * np.log(max(n, 2))
    return {"beta": beta, "yhat": yhat, "R2": r2, "BIC": float(bic), "sse": sse, "sst": sst, "n": n, "k": k}


def _build_info_index(info, families=("pair",)):
    d = {}
    for j, (fam, base, ff, mask) in enumerate(info):
        if fam not in families:
            continue
        try:
            f_int = int(ff)
        except Exception:
            continue
        d[(fam, base, f_int, mask)] = j
    return d


def estimate_phi_off_from_quadfree(beta, info, f_list, eps=1e-12, mod_pi=True):
    beta = np.asarray(beta, float).ravel()
    imap = _build_info_index(info)
    phi = {}

    def get_z(quad, base, f):
        jc = imap.get(("pair", f"cos({base})", f, quad))
        js = imap.get(("pair", f"sin({base})", f, quad))
        if (jc is None) or (js is None):
            return None
        C = beta[jc]
        S = beta[js]
        return C - 1j * S

    for f in map(int, f_list):
        z01p = get_z("Q01", "a+b", f)
        z01m = get_z("Q01", "a-b", f)
        z10p = get_z("Q10", "a+b", f)
        z10m = get_z("Q10", "a-b", f)

        phis = []
        ws = []

        if (z01p is not None) and (z10m is not None):
            w1 = float((abs(z01p) ** 2) * (abs(z10m) ** 2))
            if w1 > eps:
                phis.append(float(np.angle(z10m * z01p)))
                ws.append(w1)

        if (z01m is not None) and (z10p is not None):
            w2 = float((abs(z01m) ** 2) * (abs(z10p) ** 2))
            if w2 > eps:
                phis.append(float(np.angle(z10p * z01m)))
                ws.append(w2)

        if not phis:
            phi[f] = 0.0
            continue

        phis = np.asarray(phis, float)
        ws = np.asarray(ws, float)

        if mod_pi:
            S = float(np.sum(ws * np.sin(2 * phis)))
            C = float(np.sum(ws * np.cos(2 * phis)))
            phi_f = 0.5 * float(np.arctan2(S, C))
        else:
            S = float(np.sum(ws * np.sin(phis)))
            C = float(np.sum(ws * np.cos(phis)))
            phi_f = float(np.arctan2(S, C))

        phi[f] = float((phi_f + np.pi) % (2 * np.pi) - np.pi)

    return phi


def design_fullgrid_coupled(
    grid,
    f_list,
    *,
    use_axes=True,
    use_pair=True,
    quad_bias=True,
    mask_mode="quadfree",
    pair_coupling="none",
    phi_off_map=None,
    std_eps=1e-12,
    return_info=False,
    pair_sign_mode="Ia",
    pair_s=+1,
    pair_keep="both",
    pair_swap_keep="both",
):
    grid = np.asarray(grid, float)
    G = grid.shape[0]
    assert grid.shape[1] == G
    n = G // 2

    a_idx, b_idx = np.indices((G, G))
    A = a_idx.ravel()
    B = b_idx.ravel()
    y = grid.ravel()

    A0 = (A % n).astype(float)
    B0 = (B % n).astype(float)

    Ia = np.where(A < n, 1.0, -1.0)
    Ib = np.where(B < n, 1.0, -1.0)
    Iab = Ia * Ib

    Q00 = ((A < n) & (B < n)).astype(float)
    Q01 = ((A < n) & (B >= n)).astype(float)
    Q10 = ((A >= n) & (B < n)).astype(float)
    Q11 = ((A >= n) & (B >= n)).astype(float)

    if mask_mode == "quadfree":
        masks_bias = [Q00, Q01, Q10, Q11]
        mask_names_bias = ["Q00", "Q01", "Q10", "Q11"]
        masks_axis_a = masks_axis_b = masks_bias
        mask_names_axis_a = mask_names_axis_b = mask_names_bias
        masks_pair = masks_bias
        mask_names_pair = mask_names_bias

    elif mask_mode == "hadamard4":
        ones = np.ones_like(Ia)
        masks_bias = [ones, Ia, Ib, Iab]
        mask_names_bias = ["1", "Ia", "Ib", "Iab"]
        masks_axis_a = masks_axis_b = masks_bias
        mask_names_axis_a = mask_names_axis_b = mask_names_bias
        masks_pair = masks_bias
        mask_names_pair = mask_names_bias

    elif mask_mode == "family":
        ones = np.ones_like(Ia)
        masks_bias = [ones] if not quad_bias else [ones, Ia, Ib, Iab]
        mask_names_bias = ["1"] if not quad_bias else ["1", "Ia", "Ib", "Iab"]
        masks_axis_a = [ones, Ia]
        masks_axis_b = [ones, Ib]
        mask_names_axis_a = ["1", "Ia"]
        mask_names_axis_b = ["1", "Ib"]
        masks_pair = [ones, Iab]
        mask_names_pair = ["1", "Iab"]

    else:
        raise ValueError(f"Unknown mask_mode={mask_mode}")

    X_parts = []
    info = []

    if quad_bias:
        for m, nm in zip(masks_bias, mask_names_bias):
            col = np.asarray(m, float).ravel()
            if nm == "1":
                X_parts.append(col[:, None])
                if return_info:
                    info.append(("bias", "", "", nm))
                continue
            if np.std(col) > std_eps:
                X_parts.append(col[:, None])
                if return_info:
                    info.append(("bias", "", "", nm))
    else:
        col = np.ones((G * G,), float)
        X_parts.append(col[:, None])
        if return_info:
            info.append(("bias", "", "", "1"))

    def add_base(col, family, base_name, f, masks, mask_names):
        col = np.asarray(col, float).ravel()
        if np.std(col) <= std_eps:
            return
        for m, nm in zip(masks, mask_names):
            c = col * m
            X_parts.append(c[:, None])
            if return_info:
                info.append((family, base_name, int(f), nm))

    def add_col(col, base_name, f, nm):
        col = np.asarray(col, float).ravel()
        if np.std(col) <= std_eps:
            return None
        X_parts.append(col[:, None])
        j = len(X_parts) - 1
        if return_info:
            info.append(("pair", base_name, int(f), nm))
        return j

    couplings = []

    for f in f_list:
        f = int(abs(f))
        omega = 2.0 * np.pi * f / float(n)

        if use_axes:
            Sa = np.sin(omega * A0)
            Ca = np.cos(omega * A0)
            Sb = np.sin(omega * B0)
            Cb = np.cos(omega * B0)

            add_base(Sa, "axis-a", "sin", f, masks_axis_a, mask_names_axis_a)
            add_base(Ca, "axis-a", "cos", f, masks_axis_a, mask_names_axis_a)
            add_base(Sb, "axis-b", "sin", f, masks_axis_b, mask_names_axis_b)
            add_base(Cb, "axis-b", "cos", f, masks_axis_b, mask_names_axis_b)

        if use_pair:
            Dp = (A0 + B0) % n
            Dm = (A0 - B0) % n

            Sp = np.sin(omega * Dp)
            Cp = np.cos(omega * Dp)
            Sm = np.sin(omega * Dm)
            Cm = np.cos(omega * Dm)

            if mask_mode != "family":
                add_base(Sp, "pair", "sin(a+b)", f, masks_pair, mask_names_pair)
                add_base(Cp, "pair", "cos(a+b)", f, masks_pair, mask_names_pair)
                add_base(Sm, "pair", "sin(a-b)", f, masks_pair, mask_names_pair)
                add_base(Cm, "pair", "cos(a-b)", f, masks_pair, mask_names_pair)

            else:
                if pair_coupling == "none":
                    add_base(Sp, "pair", "sin(a+b)", f, masks_pair, mask_names_pair)
                    add_base(Cp, "pair", "cos(a+b)", f, masks_pair, mask_names_pair)
                    add_base(Sm, "pair", "sin(a-b)", f, masks_pair, mask_names_pair)
                    add_base(Cm, "pair", "cos(a-b)", f, masks_pair, mask_names_pair)

                elif pair_coupling == "swap_by_quads_soft":
                    j00_Cp = add_col(Q00 * Cp, "cos(Q00: a+b)", f, "diag_Q00")
                    j00_Sp = add_col(Q00 * Sp, "sin(Q00: a+b)", f, "diag_Q00")
                    j00_Cm = add_col(Q00 * Cm, "cos(Q00: a-b)", f, "diag_Q00")
                    j00_Sm = add_col(Q00 * Sm, "sin(Q00: a-b)", f, "diag_Q00")

                    j11_Cm = add_col(Q11 * Cm, "cos(Q11: a-b)", f, "diag_Q11")
                    j11_Sm = add_col(Q11 * Sm, "sin(Q11: a-b)", f, "diag_Q11")
                    j11_Cp = add_col(Q11 * Cp, "cos(Q11: a+b)", f, "diag_Q11")
                    j11_Sp = add_col(Q11 * Sp, "sin(Q11: a+b)", f, "diag_Q11")

                    if None not in (j00_Cp, j11_Cm):
                        couplings.append((j00_Cp, j11_Cm, +1.0))
                    if None not in (j00_Sp, j11_Sm):
                        couplings.append((j00_Sp, j11_Sm, -1.0))
                    if None not in (j00_Cm, j11_Cp):
                        couplings.append((j00_Cm, j11_Cp, +1.0))
                    if None not in (j00_Sm, j11_Sp):
                        couplings.append((j00_Sm, j11_Sp, -1.0))

                    j01_Cp = add_col(Q01 * Cp, "cos(Q01: a+b)", f, "off_Q01")
                    j01_Sp = add_col(Q01 * Sp, "sin(Q01: a+b)", f, "off_Q01")
                    j01_Cm = add_col(Q01 * Cm, "cos(Q01: a-b)", f, "off_Q01")
                    j01_Sm = add_col(Q01 * Sm, "sin(Q01: a-b)", f, "off_Q01")

                    j10_Cm = add_col(Q10 * Cm, "cos(Q10: a-b)", f, "off_Q10")
                    j10_Sm = add_col(Q10 * Sm, "sin(Q10: a-b)", f, "off_Q10")
                    j10_Cp = add_col(Q10 * Cp, "cos(Q10: a+b)", f, "off_Q10")
                    j10_Sp = add_col(Q10 * Sp, "sin(Q10: a+b)", f, "off_Q10")

                    if None not in (j01_Cp, j10_Cm):
                        couplings.append((j01_Cp, j10_Cm, +1.0))
                    if None not in (j01_Sp, j10_Sm):
                        couplings.append((j01_Sp, j10_Sm, -1.0))
                    if None not in (j01_Cm, j10_Cp):
                        couplings.append((j01_Cm, j10_Cp, +1.0))
                    if None not in (j01_Sm, j10_Sp):
                        couplings.append((j01_Sm, j10_Sp, -1.0))

                elif pair_coupling == "swap_by_quads_hard_shift":
                    phi = 0.0
                    if phi_off_map is not None:
                        phi = float(phi_off_map.get(f, 0.0))
                    c = np.cos(phi)
                    s = np.sin(phi)

                    Cm_shift = Cm * c - Sm * s
                    Sm_shift = Sm * c + Cm * s

                    add_col(Q00 * Cp + Q11 * Cm, "cos(diag_tied_plus)", f, "diag(Q00<->Q11)")
                    add_col(Q00 * Sp - Q11 * Sm, "sin(diag_tied_plus)", f, "diag(Q00<->Q11)")

                    add_col(Q00 * Cm + Q11 * Cp, "cos(diag_tied_minus)", f, "diag(Q00<->Q11)")
                    add_col(Q00 * Sm - Q11 * Sp, "sin(diag_tied_minus)", f, "diag(Q00<->Q11)")

                    add_col(
                        Q01 * Cp + Q10 * Cm_shift,
                        "cos(off_tied_plus_shift)",
                        f,
                        "off(Q01<->Q10,phi)",
                    )
                    add_col(
                        Q01 * Sp - Q10 * Sm_shift,
                        "sin(off_tied_plus_shift)",
                        f,
                        "off(Q01<->Q10,phi)",
                    )

                    add_col(
                        Q01 * Cm_shift + Q10 * Cp,
                        "cos(off_tied_minus_shift)",
                        f,
                        "off(Q01<->Q10,phi)",
                    )
                    add_col(
                        Q01 * Sm_shift - Q10 * Sp,
                        "sin(off_tied_minus_shift)",
                        f,
                        "off(Q01<->Q10,phi)",
                    )

                else:
                    raise ValueError(f"Unknown pair_coupling={pair_coupling} for family mode")

    X = np.concatenate(X_parts, axis=1) if X_parts else np.zeros((G * G, 0), float)
    if X.shape[1] > 0:
        col_norm = np.linalg.norm(X, axis=0)
        keep = col_norm > (std_eps * np.sqrt(X.shape[0]))
        X = X[:, keep]

        if return_info:
            info = [t for t, k in zip(info, keep) if k]

        if couplings:
            old_to_new = -np.ones(len(keep), dtype=int)
            old_to_new[np.where(keep)[0]] = np.arange(int(np.sum(keep)))

            new_couplings = []
            for j1, j2, sign in couplings:
                nj1 = old_to_new[int(j1)]
                nj2 = old_to_new[int(j2)]
                if (nj1 >= 0) and (nj2 >= 0):
                    new_couplings.append((int(nj1), int(nj2), float(sign)))
            couplings = new_couplings

    if return_info:
        assert X.shape[1] == len(info), (X.shape, len(info))
        return y, X, info, couplings
    return y, X


def fit_with_beta(y, X, beta):
    y = np.asarray(y, float).ravel()
    X = np.asarray(X, float)
    beta = np.asarray(beta, float).ravel()

    yhat = X @ beta
    resid = y - yhat
    sse = float(np.sum(resid * resid))
    y0 = y - y.mean()
    sst = float(np.sum(y0 * y0))
    r2 = 0.0 if sst <= 1e-30 else max(0.0, 1.0 - sse / sst)

    n = y.size
    k = X.shape[1]
    bic = n * np.log(max(sse, 1e-30) / n) + k * np.log(max(n, 2))

    return {"beta": beta, "yhat": yhat, "R2": r2, "BIC": float(bic), "sse": sse, "sst": sst, "n": n, "k": k}


def fit_fullgrid_once_autoshift(
    grid,
    f_list,
    *,
    use_axes=True,
    use_pair=True,
    quad_bias=True,
):
    y_qf, X_qf, info_qf, _ = design_fullgrid_coupled(
        grid,
        f_list,
        use_axes=use_axes,
        use_pair=use_pair,
        quad_bias=quad_bias,
        mask_mode="quadfree",
        pair_coupling="none",
        return_info=True,
    )
    out_qf = ols_fit(y_qf, X_qf)

    phi_off_map = estimate_phi_off_from_quadfree(out_qf["beta"], info_qf, f_list)

    y, X, info, couplings = design_fullgrid_coupled(
        grid,
        f_list,
        use_axes=use_axes,
        use_pair=use_pair,
        quad_bias=quad_bias,
        mask_mode="family",
        pair_coupling="swap_by_quads_hard_shift",
        phi_off_map=phi_off_map,
        return_info=True,
    )
    out = ols_fit(y, X)

    out["phi_off_map"] = phi_off_map
    out["R2_quadfree_probe"] = out_qf["R2"]
    out["BIC_quadfree_probe"] = out_qf["BIC"]
    return out, X, info


def coupling_rmse(beta, couplings):
    if (couplings is None) or (len(couplings) == 0):
        return 0.0
    beta = np.asarray(beta, float).ravel()
    ds = []
    for j1, j2, sign in couplings:
        ds.append(beta[j1] - float(sign) * beta[j2])
    ds = np.asarray(ds, float)
    return float(np.sqrt(np.mean(ds * ds)))


def build_family_fitted_full(
    p: int,
    pre_full: np.ndarray,
    pre_grid_alive: np.ndarray,
    alive_ids: list,
    f_list,
    *,
    use_axes=True,
    use_pair=True,
    quad_bias=True,
    pair_coupling="swap_by_quads_soft",
    lam=0.0,
):
    fitted_full = np.asarray(pre_full, float).copy()
    meta = {}

    for j, nid in enumerate(alive_ids):
        grid = np.asarray(pre_grid_alive[:, :, j], float)

        out, X, info = fit_fullgrid_once_autoshift(
            grid,
            f_list,
            use_axes=use_axes,
            use_pair=use_pair,
            quad_bias=quad_bias,
        )

        beta = out["beta"]
        fit_grid = (X @ beta).reshape(p, p)
        fitted_full[:, nid] = fit_grid.reshape(-1)

        meta[nid] = {
            "R2": out["R2"],
            "BIC": out["BIC"],
            "pair_coupling": pair_coupling,
            "lam": float(lam),
            "coupling_rmse": out.get("coupling_rmse", None),
            "n_couplings": out.get("n_couplings", None),
        }

    return fitted_full, meta


def _sinefit_preact_layer_full(
    *,
    p: int,
    pre_full: np.ndarray,
    pre_grid_alive: np.ndarray,
    alive_ids: List[int],
    artifacts: Dict[str, Any],
    base_layer_artifacts: Dict[str, Any],
    freq_map: Dict[str, int],
    base_f_pool=None,
    use_axes_only: bool = True,
    use_pair_terms: bool = True,
    mode: str = "indicator_full",
) -> np.ndarray:
    names = artifacts["names"]
    F_full = artifacts.get("F_full", {}) or {}

    if base_f_pool is not None:
        f_pool_L1 = list(base_f_pool)
    elif base_layer_artifacts is not None:
        f_pool_L1 = report.build_f_pool_from_layer1_artifacts(base_layer_artifacts, freq_map, p)
    else:
        raise ValueError("Need base_f_pool or base_layer_artifacts for layer>=1 residual fitting.")

    p_quad = p // 2
    fitted_full = np.asarray(pre_full).copy()

    N_alive = pre_grid_alive.shape[-1]
    if N_alive != len(alive_ids):
        raise ValueError(f"_sinefit_preact_layer_full: N_alive={N_alive} != len(alive_ids)={len(alive_ids)}")

    if mode == "indicator_full":
        fitted_full, chosen = build_family_fitted_full(
            p=p,
            pre_full=pre_full,
            pre_grid_alive=pre_grid_alive,
            alive_ids=alive_ids,
            f_list=f_pool_L1,
            use_axes=use_axes_only,
            use_pair=use_pair_terms,
            quad_bias=True,
        )
    else:
        for j, nid in enumerate(alive_ids):
            grid = pre_grid_alive[:, :, j]

            try:
                Fhat_n = {k: v for k, v in F_full.items() if k[2] == j}
                if not Fhat_n:
                    continue
                dom = report._classify_by_gft(Fhat_n, names, freq_map, strict=True)
            except Exception as e:
                print(f"[sinefit] neuron local={j}, global={nid}: classify dom failed: {e}; skip.")
                continue

            quads = [
                grid[:p_quad, :p_quad],
                grid[:p_quad, p_quad:],
                grid[p_quad:, :p_quad],
                grid[p_quad:, p_quad:],
            ]

            fit_quads = []
            for qi, q in enumerate(quads):
                try:
                    fit = fit_quadrant_sines(
                        q,
                        dom,
                        freq_map,
                        names,
                        f_pool=f_pool_L1,
                        max_iters=6,
                        tau_inc=0.0,
                        use_axes_only=use_axes_only,
                        use_pair_terms=use_pair_terms,
                        include_ab_resid=False,
                    )
                    fit_q = fit.get("fit_grid", None)
                    if fit_q is None:
                        yhat = fit.get("fit_yhat", None)
                        if yhat is not None:
                            fit_q = np.asarray(yhat).reshape(q.shape)
                    if fit_q is None:
                        fit_q = q
                except Exception as e:
                    print(f"[sinefit] neuron local={j}, global={nid}, quad={qi}: fit failed: {e}; use original.")
                    fit_q = q
                fit_quads.append(np.asarray(fit_q, dtype=float))

            fit_grid = np.zeros_like(grid, dtype=float)
            fit_grid[:p_quad, :p_quad] = fit_quads[0]
            fit_grid[:p_quad, p_quad:] = fit_quads[1]
            fit_grid[p_quad:, :p_quad] = fit_quads[2]
            fit_grid[p_quad:, p_quad:] = fit_quads[3]

            fitted_full[:, nid] = fit_grid.reshape(-1)

    snips = []
    for nid in alive_ids[:25]:
        snips.append(neuron_raw_fit_html_plotly(p, pre_full, fitted_full, nid, title_prefix="layerX"))
    for nid in alive_ids[-10:]:
        snips.append(neuron_raw_fit_html_plotly(p, pre_full, fitted_full, nid, title_prefix="layerX"))

    full_html = """<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<title>LayerX neuron fits</title>
<script src="https://cdn.plot.ly/plotly-2.35.2.min.js"></script>
</head>
<body>
"""
    full_html += "\n<hr/>\n".join(snips)
    full_html += "\n</body>\n</html>\n"

    with open("layerX_neuron_fits.html", "w", encoding="utf-8") as f:
        f.write(full_html)

    return fitted_full


def _build_info_index_any(info, families=None):
    d = {}
    for j, (fam, base, ff, mask) in enumerate(info):
        if (families is not None) and (fam not in families):
            continue
        try:
            f_int = int(ff)
        except Exception:
            continue
        d[(fam, base, f_int, mask)] = j
    return d


def angle_wrap(phi):
    return (phi + np.pi) % (2 * np.pi) - np.pi


def circ_mean_R(phases, weights=None, eps=1e-12):
    phases = np.asarray(phases, float)
    if phases.size == 0:
        return {"mean": 0.0, "R": 0.0, "w_sum": 0.0, "n": 0}
    if weights is None:
        w = np.ones_like(phases)
    else:
        w = np.asarray(weights, float)
    w = np.clip(w, 0.0, None)
    ws = float(w.sum())
    if ws <= eps:
        return {"mean": float(phases[0]), "R": 0.0, "w_sum": ws, "n": int(phases.size)}
    z = np.sum(w * (np.cos(phases) + 1j * np.sin(phases)))
    mean = float(np.angle(z))
    return {"mean": float(angle_wrap(mean)), "R": float(np.abs(z) / ws), "w_sum": ws, "n": int(phases.size)}


def extract_axis_phase_family(beta, info, f_list, *, phase_convention="atan2(-S,C)", eps=1e-12):
    beta = np.asarray(beta, float).ravel()
    imap = _build_info_index_any(info, families=("axis-a", "axis-b"))

    def coef(fam, base, f, nm):
        j = imap.get((fam, base, int(f), nm))
        return 0.0 if j is None else float(beta[j])

    out = {
        "axis_a": {"basis": {"1": {}, "Ia": {}}, "half": {"A+": {}, "A-": {}}},
        "axis_b": {"basis": {"1": {}, "Ib": {}}, "half": {"B+": {}, "B-": {}}},
        "phase_convention": phase_convention,
    }

    for f in map(int, f_list):
        C1 = coef("axis-a", "cos", f, "1")
        S1 = coef("axis-a", "sin", f, "1")
        CI = coef("axis-a", "cos", f, "Ia")
        SI = coef("axis-a", "sin", f, "Ia")

        out["axis_a"]["basis"]["1"][f] = _amp_phase_from_cos_sin(C1, S1, phase_convention=phase_convention)
        out["axis_a"]["basis"]["Ia"][f] = _amp_phase_from_cos_sin(CI, SI, phase_convention=phase_convention)

        C_plus, S_plus = (C1 + CI, S1 + SI)
        C_minus, S_minus = (C1 - CI, S1 - SI)

        out["axis_a"]["half"]["A+"][f] = _amp_phase_from_cos_sin(C_plus, S_plus, phase_convention=phase_convention)
        out["axis_a"]["half"]["A-"][f] = _amp_phase_from_cos_sin(C_minus, S_minus, phase_convention=phase_convention)

        C1 = coef("axis-b", "cos", f, "1")
        S1 = coef("axis-b", "sin", f, "1")
        CI = coef("axis-b", "cos", f, "Ib")
        SI = coef("axis-b", "sin", f, "Ib")

        out["axis_b"]["basis"]["1"][f] = _amp_phase_from_cos_sin(C1, S1, phase_convention=phase_convention)
        out["axis_b"]["basis"]["Ib"][f] = _amp_phase_from_cos_sin(CI, SI, phase_convention=phase_convention)

        C_plus, S_plus = (C1 + CI, S1 + SI)
        C_minus, S_minus = (C1 - CI, S1 - SI)

        out["axis_b"]["half"]["B+"][f] = _amp_phase_from_cos_sin(C_plus, S_plus, phase_convention=phase_convention)
        out["axis_b"]["half"]["B-"][f] = _amp_phase_from_cos_sin(C_minus, S_minus, phase_convention=phase_convention)

    return out


def round_half_away_from_zero(x: float) -> int:
    eps = 1e-12
    if x >= 0:
        return int(math.floor(x + 0.5 + eps))
    else:
        return int(math.ceil(x - 0.5 - eps))


def canon_mod_p(mg: float, p: int) -> float:
    return float(np.mod(float(mg), float(p)))


def phase_to_grid_units(phi, G):
    return float(phi) * (G / (2 * np.pi))


def R_to_sigma_grid(R, G, eps=1e-12):
    R = float(np.clip(R, eps, 1.0))
    sigma_phi = float(np.sqrt(-2.0 * np.log(R)))
    sigma_grid = sigma_phi * (G / (2 * np.pi))
    return sigma_phi, sigma_grid


def stat_to_grid_units(stat, p, f, *, round_to_1dp=True, boundary_risk=0.49):
    if stat is None:
        return None
    out = dict(stat)

    p = int(p)
    f = int(f)

    if "mean" in out:
        out["mean_grid"] = phase_to_grid_units(out["mean"], p)

        h = math.gcd(p, f)
        if h == 0:
            out["h_gcd"] = 0
            out["g"] = None
            return out

        g = p // h
        out["h_gcd"] = int(h)
        out["g"] = int(g)

        mg_raw = float(out["mean_grid"])
        mg_canon = canon_mod_p(mg_raw, p)
        x = mg_canon / float(h)

        if round_to_1dp:
            x = float(np.round(x, 1))

        k0 = round_half_away_from_zero(x)
        k_hat = int(k0 % g)

        out["mean_grid_canon_mod_p"] = float(mg_canon)
        out["generator"] = int(k_hat)

        out["generator_raw"] = int(k0)
        out["generator_x"] = float(x)
        out["generator_margin_to_int"] = float(abs(x - k0))
        out["generator_boundary_risk"] = bool(abs(x - k0) >= float(boundary_risk))

        if out.get("R", None) is not None:
            sigma_phi, sigma_grid = R_to_sigma_grid(out["R"], p)
            out["sigma_phi"] = sigma_phi
            out["sigma_grid"] = sigma_grid

    return out


def phase_to_shift_steps(phi, f, G):
    n = G // 2
    return float(-phi) * (n / (2 * np.pi * float(f)))


def _get_phase_amp(axis_phase, axis, half, f):
    byf = axis_phase.get(axis, {}).get("half", {}).get(half, {})
    item = byf.get(int(f), None)
    if item is None:
        return None
    if ("phase" not in item) or ("amp" not in item):
        return None
    return float(item["phase"]), float(item["amp"])


def _neuron_across_f_stat(axis_phase, axis, half, f_pool, amp_eps=1e-6, topk=None):
    items = []
    for f in f_pool:
        got = _get_phase_amp(axis_phase, axis, half, f)
        if got is None:
            continue
        phi, amp = got
        if amp <= amp_eps:
            continue
        items.append((int(f), phi, amp))

    if topk is not None and len(items) > int(topk):
        items = sorted(items, key=lambda t: t[2], reverse=True)[: int(topk)]

    phases = [t[1] for t in items]
    weights = [t[2] for t in items]
    stat = circ_mean_R(phases, weights)

    return {
        "n_freq_used": int(len(items)),
        "freqs_used": [t[0] for t in items],
        "mean_phase": stat["mean"],
        "R": stat["R"],
        "w_sum": stat["w_sum"],
    }


def _diff_wrap(phi1, phi2):
    return float(angle_wrap(phi2 - phi1))


def _sum_wrap(phi1, phi2):
    return float(angle_wrap(phi1 + phi2))


def _neuron_relations_at_f0(axis_phase, f0, amp_eps=1e-6):
    out = {}

    Aplus = _get_phase_amp(axis_phase, "axis_a", "A+", f0)
    Aminus = _get_phase_amp(axis_phase, "axis_a", "A-", f0)
    Bplus = _get_phase_amp(axis_phase, "axis_b", "B+", f0)
    Bminus = _get_phase_amp(axis_phase, "axis_b", "B-", f0)

    def ok(x):
        return (x is not None) and (x[1] > amp_eps)

    if ok(Aplus) and ok(Aminus):
        phi_p, ap = Aplus
        phi_m, am = Aminus
        out["dA"] = _diff_wrap(phi_p, phi_m)
        out["sA"] = _sum_wrap(phi_p, phi_m)
        out["wA"] = float(min(ap, am))

    if ok(Bplus) and ok(Bminus):
        phi_p, bp = Bplus
        phi_m, bm = Bminus
        out["dB"] = _diff_wrap(phi_p, phi_m)
        out["sB"] = _sum_wrap(phi_p, phi_m)
        out["wB"] = float(min(bp, bm))

    def add_cross_diff(name, X, Y):
        if ok(X) and ok(Y):
            phx, ax = X
            phy, ay = Y
            out[name] = _diff_wrap(phx, phy)
            out["w_" + name] = float(min(ax, ay))

    def add_cross_sum(name, X, Y):
        if ok(X) and ok(Y):
            phx, ax = X
            phy, ay = Y
            out[name] = _sum_wrap(phx, phy)
            out["w_" + name] = float(min(ax, ay))

    add_cross_diff("dApBp", Aplus, Bplus)
    add_cross_diff("dAmBm", Aminus, Bminus)
    add_cross_diff("dApBm", Aplus, Bminus)
    add_cross_diff("dAmBp", Aminus, Bplus)

    add_cross_sum("sApBp", Aplus, Bplus)
    add_cross_sum("sAmBm", Aminus, Bminus)
    add_cross_sum("sApBm", Aplus, Bminus)
    add_cross_sum("sAmBp", Aminus, Bplus)

    return out


def _cluster_circ_stat_from_neurons(values, weights):
    stat = circ_mean_R(values, weights)
    return {"mean": stat["mean"], "R": stat["R"], "n": stat["n"], "w_sum": stat["w_sum"]}


def _collect_phases_for_pair(per_neuron_axis_phase, neuron_ids, *, f0, axis_a_half, axis_b_half, amp_eps=1e-6):
    phia, phib, ww, used = [], [], [], []
    for nid in neuron_ids:
        ap = per_neuron_axis_phase.get(int(nid), None)
        if ap is None:
            continue
        A = _get_phase_amp(ap, "axis_a", axis_a_half, f0)
        B = _get_phase_amp(ap, "axis_b", axis_b_half, f0)
        if (A is None) or (B is None):
            continue
        phi_a, amp_a = A
        phi_b, amp_b = B
        if (amp_a <= amp_eps) or (amp_b <= amp_eps):
            continue
        phia.append(phi_a)
        phib.append(phi_b)
        ww.append(min(amp_a, amp_b))
        used.append(int(nid))
    return np.asarray(phia, float), np.asarray(phib, float), np.asarray(ww, float), used


def _assign_to_nearest_center(r, centers):
    r = np.asarray(r, float).ravel()
    K = len(centers)
    if K == 0:
        return -np.ones_like(r, dtype=int)

    centers = np.asarray(centers, float).ravel()
    D = np.abs(angle_wrap(r[:, None] - centers[None, :]))
    idx = np.argmin(D, axis=1).astype(int)
    return idx


def _decorate_mod_pi(out_pi, slope_eps=1e-3):
    o = dict(out_pi)
    if int(o.get("n_pairs_used", 0)) <= 0:
        o["dir_u"] = [1.0, 0.0]
        o["slope_tan"] = 0.0
        o["slope_valid"] = False
        return o

    th = float(o["mean_theta"])
    ux = float(np.cos(th))
    uy = float(np.sin(th))
    o["dir_u"] = [ux, uy]

    if abs(ux) < float(slope_eps):
        o["slope_tan"] = float(np.sign(uy) * np.inf)
        o["slope_valid"] = False
    else:
        o["slope_tan"] = float(uy / ux)
        o["slope_valid"] = True
    return o


def collinearity_from_pairwise_diffs(
    phia,
    phib,
    w=None,
    *,
    use_mod_pi=True,
    min_norm=1e-3,
    max_pairs=20000,
    rng=None,
):
    phia = np.asarray(phia, float).ravel()
    phib = np.asarray(phib, float).ravel()
    assert phia.shape == phib.shape
    n = int(phia.size)

    if w is None:
        w = np.ones_like(phia, dtype=float)
    else:
        w = np.asarray(w, float).ravel()
        assert w.shape == phia.shape
    w = np.clip(w, 0.0, None)

    if rng is None:
        rng = np.random.default_rng(0)

    if n <= 1:
        return {"R": 0.0, "mean_theta": 0.0, "n_pairs_used": 0, "use_mod_pi": bool(use_mod_pi)}

    total = n * (n - 1) // 2
    pairs = []
    if total <= int(max_pairs):
        for i in range(n):
            for j in range(i + 1, n):
                pairs.append((i, j))
    else:
        for _ in range(int(max_pairs)):
            i = int(rng.integers(0, n))
            j = int(rng.integers(0, n - 1))
            if j >= i:
                j += 1
            if i < j:
                pairs.append((i, j))
            else:
                pairs.append((j, i))

    thetas = []
    weights = []

    for i, j in pairs:
        da = angle_wrap(phia[i] - phia[j])
        db = angle_wrap(phib[i] - phib[j])
        norm = float(np.hypot(da, db))
        if norm < float(min_norm):
            continue

        theta = float(np.arctan2(db, da))
        wij = float(min(w[i], w[j]) * norm)
        thetas.append(theta)
        weights.append(wij)

    if len(thetas) == 0:
        return {"R": 0.0, "mean_theta": 0.0, "n_pairs_used": 0, "use_mod_pi": bool(use_mod_pi)}

    thetas = np.asarray(thetas, float)
    weights = np.asarray(weights, float)

    if use_mod_pi:
        stat = circ_mean_R(2.0 * thetas, weights)
        mean_theta = float(angle_wrap(stat["mean"] / 2.0))
        R = float(stat["R"])
    else:
        stat = circ_mean_R(thetas, weights)
        mean_theta = float(stat["mean"])
        R = float(stat["R"])

    return {"R": float(R), "mean_theta": float(mean_theta), "n_pairs_used": int(len(thetas)), "use_mod_pi": bool(use_mod_pi)}


def _collinearity_suite(
    phia,
    phib,
    ww,
    *,
    max_pairs=20000,
    min_norm=1e-3,
    compute_offsets=True,
    nbins=72,
    smooth=2,
    peak_prom=0.02,
    compute_within_peak=True,
    w_point_eps=0.0,
    slope_eps=1e-3,
):
    phia_ = np.asarray(phia, float).ravel()
    phib_ = np.asarray(phib, float).ravel()
    assert phia_.shape == phib_.shape
    n = int(phia_.size)

    if ww is None:
        w_ = np.ones_like(phia_, dtype=float)
    else:
        w_ = np.asarray(ww, float).ravel()
        assert w_.shape == phia_.shape
        w_ = np.clip(w_, 0.0, None)

    out_pi = collinearity_from_pairwise_diffs(
        phia_, phib_, w_, use_mod_pi=True, max_pairs=max_pairs, min_norm=min_norm
    )
    out_pi = _decorate_mod_pi(out_pi, slope_eps=slope_eps)

    result = {"mod_pi": out_pi}

    if compute_offsets and (int(out_pi.get("n_pairs_used", 0)) > 0) and (n >= 2):
        theta = float(out_pi["mean_theta"])
        r = angle_wrap(np.sin(theta) * phia_ - np.cos(theta) * phib_)

        use_mask = w_ > float(w_point_eps)
        r_use = r[use_mask]
        w_use = w_[use_mask]
        n_points_used = int(np.sum(use_mask))

        edges = np.linspace(-np.pi, np.pi, int(nbins) + 1)
        hist, _ = np.histogram(r_use, bins=edges, weights=w_use)
        h = hist.astype(float)

        for _ in range(int(smooth)):
            h = (np.roll(h, 1) + h + np.roll(h, -1)) / 3.0

        total = float(h.sum()) + 1e-12

        peaks = []
        B = int(nbins)
        for k in range(B):
            if (h[k] > h[k - 1]) and (h[k] > h[(k + 1) % B]):
                if (h[k] / total) >= float(peak_prom):
                    peaks.append(int(k))

        centers = []
        masses = []
        for k in peaks:
            center = 0.5 * (edges[k] + edges[k + 1])
            centers.append(float(angle_wrap(center)))
            masses.append(float(h[k]))

        if len(centers) > 0:
            order = np.argsort(-np.asarray(masses))
            centers = [centers[i] for i in order]
            masses = [masses[i] for i in order]

        offsets = {
            "theta_used": float(theta),
            "centers": centers,
            "masses": masses,
            "n_peaks": int(len(centers)),
            "n_points_used": int(n_points_used),
            "nbins": int(nbins),
            "smooth": int(smooth),
            "peak_prom": float(peak_prom),
            "w_point_eps": float(w_point_eps),
        }

        if compute_within_peak and (len(centers) > 0) and (n_points_used >= 2):
            idx_all = _assign_to_nearest_center(r_use, centers)

            per_peak = []
            agg_num_cos = 0.0
            agg_num_sin = 0.0
            agg_den = 0.0

            for k in range(len(centers)):
                m = idx_all == int(k)
                nk = int(np.sum(m))
                if nk < 2:
                    per_peak.append(
                        {
                            "peak_id": int(k),
                            "center_r": float(centers[k]),
                            "n_points": nk,
                            "w_sum": float(np.sum(w_use[m])) if nk > 0 else 0.0,
                            "mod_pi": {
                                "R": 0.0,
                                "mean_theta": 0.0,
                                "n_pairs_used": 0,
                                "use_mod_pi": True,
                                "dir_u": [1.0, 0.0],
                                "slope_tan": 0.0,
                                "slope_valid": False,
                            },
                        }
                    )
                    continue

                phia_k = phia_[use_mask][m]
                phib_k = phib_[use_mask][m]
                w_k = w_use[m]

                out_k = collinearity_from_pairwise_diffs(
                    phia_k, phib_k, w_k, use_mod_pi=True, max_pairs=max_pairs, min_norm=min_norm
                )
                out_k = _decorate_mod_pi(out_k, slope_eps=slope_eps)

                peak_w = float(np.sum(w_k)) * float(np.sqrt(max(1, int(out_k.get("n_pairs_used", 0)))))
                th = float(out_k.get("mean_theta", 0.0))
                agg_num_cos += peak_w * np.cos(2.0 * th)
                agg_num_sin += peak_w * np.sin(2.0 * th)
                agg_den += peak_w

                per_peak.append(
                    {
                        "peak_id": int(k),
                        "center_r": float(centers[k]),
                        "n_points": nk,
                        "w_sum": float(np.sum(w_k)),
                        "mod_pi": out_k,
                    }
                )

            offsets["per_peak"] = per_peak

            if agg_den > 1e-12:
                mean2 = float(np.arctan2(agg_num_sin, agg_num_cos))
                mean_theta = float(angle_wrap(mean2 / 2.0))
                R = float(np.hypot(agg_num_cos, agg_num_sin) / agg_den)
                offsets["within_peaks_summary"] = _decorate_mod_pi(
                    {"R": R, "mean_theta": mean_theta, "n_pairs_used": int(out_pi.get("n_pairs_used", 0)), "use_mod_pi": True},
                    slope_eps=slope_eps,
                )
            else:
                offsets["within_peaks_summary"] = _decorate_mod_pi(
                    {"R": 0.0, "mean_theta": 0.0, "n_pairs_used": 0, "use_mod_pi": True},
                    slope_eps=slope_eps,
                )

        result["offsets"] = offsets

    return result


def analyze_irrep_cluster_axis_phase(
    per_neuron_axis_phase,
    neuron_ids,
    *,
    f0,
    f_pool,
    G,
    amp_eps=1e-6,
    topk_within=None,
    compute_spectrum=True,
):
    f0 = int(f0)
    f_pool = [int(f) for f in f_pool]
    p = int(G // 2)

    within = {}
    for nid in neuron_ids:
        ap = per_neuron_axis_phase.get(int(nid), None)
        if ap is None:
            continue
        within[int(nid)] = {
            "Aplus_across_f": _neuron_across_f_stat(ap, "axis_a", "A+", f_pool, amp_eps=amp_eps, topk=topk_within),
            "Aminus_across_f": _neuron_across_f_stat(ap, "axis_a", "A-", f_pool, amp_eps=amp_eps, topk=topk_within),
            "Bplus_across_f": _neuron_across_f_stat(ap, "axis_b", "B+", f_pool, amp_eps=amp_eps, topk=topk_within),
            "Bminus_across_f": _neuron_across_f_stat(ap, "axis_b", "B-", f_pool, amp_eps=amp_eps, topk=topk_within),
        }

        dA_ph, dA_w = [], []
        dB_ph, dB_w = [], []
        for f in f_pool:
            Aplus = _get_phase_amp(ap, "axis_a", "A+", f)
            Aminus = _get_phase_amp(ap, "axis_a", "A-", f)
            if (Aplus is not None) and (Aminus is not None) and (Aplus[1] > amp_eps) and (Aminus[1] > amp_eps):
                dA_ph.append(_diff_wrap(Aplus[0], Aminus[0]))
                dA_w.append(min(Aplus[1], Aminus[1]))
            Bplus = _get_phase_amp(ap, "axis_b", "B+", f)
            Bminus = _get_phase_amp(ap, "axis_b", "B-", f)
            if (Bplus is not None) and (Bminus is not None) and (Bplus[1] > amp_eps) and (Bminus[1] > amp_eps):
                dB_ph.append(_diff_wrap(Bplus[0], Bminus[0]))
                dB_w.append(min(Bplus[1], Bminus[1]))

        within[int(nid)]["dA_across_f"] = (
            _cluster_circ_stat_from_neurons(dA_ph, dA_w) if len(dA_ph) else {"mean": 0.0, "R": 0.0, "n": 0, "w_sum": 0.0}
        )
        within[int(nid)]["dB_across_f"] = (
            _cluster_circ_stat_from_neurons(dB_ph, dB_w) if len(dB_ph) else {"mean": 0.0, "R": 0.0, "n": 0, "w_sum": 0.0}
        )

    def across_half(axis, half):
        phases, weights, used = [], [], []
        for nid in neuron_ids:
            ap = per_neuron_axis_phase.get(int(nid), None)
            if ap is None:
                continue
            got = _get_phase_amp(ap, axis, half, f0)
            if got is None:
                continue
            phi, amp = got
            if amp <= amp_eps:
                continue
            phases.append(phi)
            weights.append(amp)
            used.append(int(nid))
        stat = _cluster_circ_stat_from_neurons(phases, weights)
        stat["neuron_ids_used"] = used
        return stat

    across_f0 = {
        "Aplus": across_half("axis_a", "A+"),
        "Aminus": across_half("axis_a", "A-"),
        "Bplus": across_half("axis_b", "B+"),
        "Bminus": across_half("axis_b", "B-"),
    }

    rel_per_neuron = {}
    for nid in neuron_ids:
        ap = per_neuron_axis_phase.get(int(nid), None)
        if ap is None:
            continue
        rel_per_neuron[int(nid)] = _neuron_relations_at_f0(ap, f0, amp_eps=amp_eps)

    def across_rel(key, wkey):
        phases, weights = [], []
        for nid in neuron_ids:
            d = rel_per_neuron.get(int(nid), {})
            if key in d:
                phases.append(float(d[key]))
                weights.append(float(d.get(wkey, 1.0)))
        return _cluster_circ_stat_from_neurons(phases, weights)

    rel_across = {
        "dA": across_rel("dA", "wA"),
        "sA": across_rel("sA", "wA"),
        "dB": across_rel("dB", "wB"),
        "sB": across_rel("sB", "wB"),
        "dApBp": across_rel("dApBp", "w_dApBp"),
        "dAmBm": across_rel("dAmBm", "w_dAmBm"),
        "dApBm": across_rel("dApBm", "w_dApBm"),
        "dAmBp": across_rel("dAmBp", "w_dAmBp"),
        "sApBp": across_rel("sApBp", "w_sApBp"),
        "sAmBm": across_rel("sAmBm", "w_sAmBm"),
        "sApBm": across_rel("sApBm", "w_sApBm"),
        "sAmBp": across_rel("sAmBp", "w_sAmBp"),
    }

    collinearity = {}
    for name, (ha, hb) in {
        "ApBp": ("A+", "B+"),
        "AmBm": ("A-", "B-"),
        "ApBm": ("A+", "B-"),
        "AmBp": ("A-", "B+"),
    }.items():
        phia, phib, ww, used = _collect_phases_for_pair(
            per_neuron_axis_phase, neuron_ids, f0=f0, axis_a_half=ha, axis_b_half=hb, amp_eps=amp_eps
        )
        if phia.size < 3:
            collinearity[name] = {"n_neurons_used": int(phia.size), "neuron_ids_used": used, "mod_pi": None, "mod_2pi": None}
            continue

        suite = _collinearity_suite(phia, phib, ww, max_pairs=20000, min_norm=1e-3)
        suite["n_neurons_used"] = int(phia.size)
        suite["neuron_ids_used"] = used
        collinearity[name] = suite

    spectrum = None
    if compute_spectrum:
        spectrum = {"Aplus": {}, "Aminus": {}, "Bplus": {}, "Bminus": {}}
        for f in f_pool:
            f = int(f)
            for label, (axis, half) in {
                "Aplus": ("axis_a", "A+"),
                "Aminus": ("axis_a", "A-"),
                "Bplus": ("axis_b", "B+"),
                "Bminus": ("axis_b", "B-"),
            }.items():
                phases, weights = [], []
                for nid in neuron_ids:
                    ap = per_neuron_axis_phase.get(int(nid), None)
                    if ap is None:
                        continue
                    got = _get_phase_amp(ap, axis, half, f)
                    if got is None:
                        continue
                    phi, amp = got
                    if amp <= amp_eps:
                        continue
                    phases.append(phi)
                    weights.append(amp)
                spectrum[label][f] = _cluster_circ_stat_from_neurons(phases, weights)

    rel_across_grid = {}
    for key in ["dA", "sA", "dB", "sB", "dApBp", "dAmBm", "dApBm", "dAmBp", "sApBp", "sAmBm", "sApBm", "sAmBp"]:
        rel_across_grid[key] = stat_to_grid_units(rel_across.get(key, None), p, f0)

    across_f0_grid = {k: stat_to_grid_units(v, p, f0) for k, v in across_f0.items()}

    return {
        "meta": {"f0": int(f0), "f_pool": list(map(int, f_pool)), "amp_eps": float(amp_eps), "topk_within": (None if topk_within is None else int(topk_within))},
        "within_neuron": within,
        "relations_per_neuron_at_f0": rel_per_neuron,
        "grid_units": {"across_neuron_at_f0": across_f0_grid, "relations_across_neuron_at_f0": rel_across_grid},
        "collinearity_pairwise_diffs_at_f0": collinearity,
        "spectrum_across_neuron": spectrum,
    }
