import numpy as np, jax.numpy as jnp, jax, plotly.express as px
from collections import defaultdict, OrderedDict
from itertools import islice
import io, base64
from PyPDF2 import PdfWriter, PdfReader
import tempfile, os
from plotly.subplots import make_subplots
import plotly.graph_objects as go
from collections import Counter
from sklearn.cluster import KMeans
from functools import reduce
from typing import Tuple, List, Iterable, Dict, Any
import plotly.io as pio
pio.kaleido.scope.default_timeout = 60 * 5
from pca_diffusion_plots_w_helpers import generate_pdf_plots_for_matrix, compute_pca_coords
import pca_diffusion_plots_w_helpers
from analysis.plane_fit import plane_angle_per_cluster
import uuid, time
import math, json
import re
from math import gcd, pi, cos, sin
from pathlib import Path
import numpy as _np
import dihedral

def inverse_on_cosets(f: int, p: int) -> dict[int, int]:
    f, p = int(f), int(p)
    g = gcd(f, p)
    pg = p // g
    f_prime = f // g
    if gcd(f_prime, pg) != 1:
        raise ValueError(f"{f_prime} not invertible mod {pg}")
    inv = pow(f_prime, -1, pg)
    return {k: inv for k in range(g)}

def step_size(f: int, p: int) -> int:
    g = math.gcd(f, p)
    n = p // g
    return pow(f // g, -1, n)

import numpy as onp

def _remap_block(block: onp.ndarray):
    p = block.shape[0]
    fb, fa = dominant_freqs_ab(block)
    A = (fa * jnp.arange(p)) % p
    B = (fb * jnp.arange(p)) % p
    out = jnp.zeros_like(block)
    out = out.at[A[:, None], B[None, :]].set(block)
    return onp.asarray(out), fb, fa

def _remap_block_by_freq(block: np.ndarray, fb: int, fa: int) -> np.ndarray:
    p = block.shape[0]
    A = (fa * np.arange(p)) % p
    B = (fb * np.arange(p)) % p
    out = np.zeros_like(block)
    out[A[:, None], B[None, :]] = block
    return out

def _remapped_quadrants_by_freq(tile: np.ndarray, fb: int, fa: int) -> tuple[np.ndarray, int, int]:
    p = tile.shape[0] // 2
    BL, BR, TL, TR = tile[:p, :p], tile[:p, p:], tile[p:, :p], tile[p:, p:]
    stitched = np.zeros_like(tile)
    fb = abs(fb)
    fa = abs(fa)
    stitched[:p, :p] = _remap_block_by_freq(BL, fb, fa)
    stitched[:p, p:] = _remap_block_by_freq(BR, fb, fa)
    stitched[p:, :p] = _remap_block_by_freq(TL, fb, fa)
    stitched[p:, p:] = _remap_block_by_freq(TR, fb, fa)
    return stitched, fb, fa

import numpy as _np

def _r2_from_pred(y, yhat):
    y = _np.asarray(y, float); yhat = _np.asarray(yhat, float)
    m = _np.isfinite(y)
    if m.sum() < 3:
        return 0.0, 0.0, 0.0
    y = y[m]; yhat = yhat[m]
    y0 = y - y.mean()
    sst = float(_np.sum(y0 * y0))
    if sst <= 1e-12:
        return 0.0, 0.0, 0.0
    resid = y - yhat
    sse = float(_np.sum(resid * resid))
    r2 = max(0.0, 1.0 - sse / sst)
    n = y.size
    return r2, sse, n

def _fit_linear_design(y, X):
    beta, *_ = _np.linalg.lstsq(X, y, rcond=None)
    yhat = X @ beta
    return beta, yhat

def _design_SinA_plus_SinB_k(p, fa, fb):
    def _X(a, b):
        ta = 2 * _np.pi * fa * a / p
        tb = 2 * _np.pi * fb * b / p
        return _np.c_[_np.sin(ta), _np.cos(ta), _np.sin(tb), _np.cos(tb), _np.ones_like(a)]
    k_params = 5
    return _X, k_params

def _design_SinApmB_k(p, f, sign=+1):
    def _X(a, b):
        t = 2 * _np.pi * f * (a + sign * b) / p
        return _np.c_[_np.sin(t), _np.cos(t), _np.ones_like(a)]
    k_params = 3
    return _X, k_params

def _adjr2_bic_from_sse(sse, sst, n, k):
    if n <= k + 1 or sst <= 1e-30:
        return 0.0, 0.0
    r2_adj = 1.0 - ((sse / max(1e-30, (n - k - 1))) / (sst / max(1e-30, (n - 1))))
    bic = n * _np.log(max(sse, 1e-30) / n) + k * _np.log(max(n, 2))
    return float(r2_adj), float(bic)

def _design_SinAxis_no_bias(p, f, axis="a"):
    def Xa(a, b):
        t = 2 * np.pi * f * a / p
        return np.c_[np.sin(t), np.cos(t)]
    def Xb(a, b):
        t = 2 * np.pi * f * b / p
        return np.c_[np.sin(t), np.cos(t)]
    if axis == "a":
        return Xa, 2
    else:
        return Xb, 2

def _fit_quadrant_sine_models(q, fa_base, fb_base, f_pool, *, max_iters=6, tau_inc=0.0, use_axes_only=True, use_pair_terms=True, include_ab_resid=False):
    p = q.shape[0]
    a_idx, b_idx = _np.indices((p, p))
    y = q.reshape(-1)
    n = y.size
    sst = float(_np.sum((y - y.mean()) ** 2))
    Xab_fn, k_ab = _design_SinA_plus_SinB_k(p, int(fa_base), int(fb_base))
    Xab = Xab_fn(a_idx.ravel(), b_idx.ravel())
    beta_ab, yhat_ab = _fit_linear_design(y, Xab)
    R2_ab, sse_ab, _ = _r2_from_pred(y, yhat_ab)
    adj_ab, bic_ab = _adjr2_bic_from_sse(sse_ab, sst, n, k_ab)
    yhat_current = yhat_ab.copy()
    k_current = k_ab
    chosen = []
    def _R2_from_yhat(yhat):
        sse = float(_np.sum((y - yhat) ** 2))
        R2 = max(0.0, 1.0 - sse / max(sst, 1e-30))
        return R2, sse
    def _try_X(Xf_fn, kf):
        Xf = Xf_fn(a_idx.ravel(), b_idx.ravel())
        resid = y - yhat_current
        beta_f, yhat_f_resid = _fit_linear_design(resid, Xf)
        yhat_new = yhat_current + yhat_f_resid
        R2_new, sse_new = _R2_from_yhat(yhat_new)
        adj_new, bic_new = _adjr2_bic_from_sse(sse_new, sst, n, k_current + kf)
        R2_cur, _ = _R2_from_yhat(yhat_current)
        return {"beta": beta_f, "yhat_added": yhat_f_resid, "R2": R2_new, "adjR2": adj_new, "BIC": bic_new, "k_add": kf, "delta_R2": (R2_new - R2_cur)}, yhat_new, kf
    Xab_resid_fn, k_ab_resid = _design_SinA_plus_SinB_k(p, int(fa_base), int(fb_base))
    for _ in range(int(max_iters)):
        cands = []
        for f in f_pool:
            if use_pair_terms:
                Xp_fn, kp = _design_SinApmB_k(p, int(f), sign=+1)
                res_p = _try_X(Xp_fn, kp); cands.append(("a+b", f, *res_p))
                Xm_fn, km = _design_SinApmB_k(p, int(f), sign=-1)
                res_m = _try_X(Xm_fn, km); cands.append(("a-b", f, *res_m))
            if use_axes_only:
                Xa_fn, ka = _design_SinAxis_no_bias(p, int(f), axis="a")
                res_a = _try_X(Xa_fn, ka); cands.append(("a-only", f, *res_a))
                Xb_fn, kb = _design_SinAxis_no_bias(p, int(f), axis="b")
                res_b = _try_X(Xb_fn, kb); cands.append(("b-only", f, *res_b))
        if include_ab_resid:
            res_ab = _try_X(Xab_resid_fn, k_ab_resid)
            cands.append(("a|b_resid", fa_base, *res_ab))
        if not cands:
            break
        name, f_sel, res_best, yhat_new, k_add = min(cands, key=lambda t: (t[2]["BIC"], -t[2]["R2"]))
        if res_best["delta_R2"] < tau_inc:
            break
        yhat_current = yhat_new
        k_current += k_add
        log_item = {"name": name, "f": int(f_sel)}
        log_item.update({k: res_best[k] for k in ("R2", "adjR2", "BIC", "delta_R2")})
        chosen.append(log_item)
    R2_final, sse_final = _R2_from_yhat(yhat_current)
    adj_final, bic_final = _adjr2_bic_from_sse(sse_final, sst, n, k_current)
    return {"R2_layer1": R2_ab, "adjR2_layer1": adj_ab, "BIC_layer1": bic_ab, "chosen_layer2": chosen, "R2_final": R2_final, "adjR2_final": adj_final, "BIC_final": bic_final, "best": (chosen[-1] if chosen else {"name": "none", "f": int(fa_base), "R2": R2_ab, "adjR2": adj_ab, "BIC": bic_ab, "delta_R2": 0.0})}

def fit_quadrant_sines(q, dom, freq_map, names, *, f_pool, max_iters=6, tau_inc=0.0, use_axes_only=True, use_pair_terms=True, include_ab_resid=False):
    p = q.shape[0]
    if dom["kind"] == "diag":
        f = abs(int(freq_map[dom["r_star"]]))
        fa_base = fb_base = f
    else:
        fa_base = abs(int(freq_map[dom["s_star"]]))
        fb_base = abs(int(freq_map[dom["r_star"]]))
    if fa_base not in f_pool:
        f_pool = [fa_base] + list(f_pool)
    if fb_base not in f_pool:
        f_pool = [fb_base] + list(f_pool)
    return _fit_quadrant_sine_models(q, fa_base, fb_base, f_pool, max_iters=max_iters, tau_inc=tau_inc, use_axes_only=use_axes_only, use_pair_terms=use_pair_terms, include_ab_resid=include_ab_resid)

def dominant_freqs_ab(grid: np.ndarray) -> Tuple[int, int]:
    p = grid.shape[0]
    F = np.fft.fft2(grid)
    F_mag = np.abs(F) ** 2
    F_mag[0, 0] = -np.inf
    half = p // 2
    row0 = F_mag[0, :half + 1].copy() if half >= 0 else np.array([])
    col0 = F_mag[:half + 1, 0].copy() if half >= 0 else np.array([])
    if row0.size:
        row0[0] = -np.inf
    if col0.size:
        col0[0] = -np.inf
    fb_idx = int(np.argmax(row0)) if row0.size else 0
    fa_idx = int(np.argmax(col0)) if col0.size else 0
    axis_best = max(float(row0[fb_idx]) if row0.size else -np.inf, float(col0[fa_idx]) if col0.size else -np.inf)
    ks = np.arange(1, half + 1, dtype=int)
    diag_vals = F_mag[ks, ks] if ks.size else np.array([])
    if diag_vals.size:
        best_k_diag = int(ks[int(np.argmax(diag_vals))])
        diag_best = float(np.max(diag_vals))
    else:
        best_k_diag, diag_best = 0, -np.inf
    if diag_best > axis_best:
        fdiag = best_k_diag if best_k_diag > 0 else 1
        return fdiag, fdiag
    fb = fb_idx if fb_idx > 0 else 1
    fa = fa_idx if fa_idx > 0 else 1
    return fb, fa

def _remapped_quadrants(tile: np.ndarray) -> tuple[np.ndarray, int, int]:
    p = tile.shape[0] // 2
    quads = [tile[:p, :p], tile[:p, p:], tile[p:, :p], tile[p:, p:]]
    remapped = []
    fb_set: set[int] = set()
    fa_set: set[int] = set()
    for q in quads:
        out, fb_q, fa_q = _remap_block(q)
        remapped.append(out)
        fb_set.add(fb_q)
        fa_set.add(fa_q)
    stitched = np.zeros_like(tile)
    stitched[:p, :p] = remapped[0]
    stitched[:p, p:] = remapped[1]
    stitched[p:, :p] = remapped[2]
    stitched[p:, p:] = remapped[3]
    fa = 0
    fb = 0
    if len(fb_set) == 1 and len(fa_set) == 1:
        for freqb in fb_set:
            fb = freqb
        for freqa in fa_set:
            fa = freqa
    else:
        print("Quadrants are not mapped using the same fb/fa.")
    return stitched, fb, fa

def single_freq_phase_shifts_ab(mat: np.ndarray, p: int, fb: int, fa: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    if not (0 < fb < p) or not (0 < fa < p):
        raise ValueError("fb, fa must be in 1..p-1 (DC and Nyquist excluded)")
    if mat.shape[0] == p * p:
        grids = mat.T.reshape(-1, p, p)
    elif mat.shape[1] == p * p:
        grids = mat.reshape(-1, p, p)
    else:
        raise ValueError("mat must contain p*p pixels per sample")
    F = np.fft.fft2(grids, axes=(-2, -1)) / (p * p)
    c_b = F[:, 0, fb]
    c_a = F[:, fa, 0]
    amps = 2 * (np.abs(c_b) + np.abs(c_a))
    phi_b = np.mod(np.angle(c_b), 2 * np.pi)
    phi_a = np.mod(np.angle(c_a), 2 * np.pi)
    phi_b[np.isclose(phi_b, 2 * np.pi, atol=1e-8)] = 0.0
    phi_a[np.isclose(phi_a, 2 * np.pi, atol=1e-8)] = 0.0
    return phi_b, phi_a, amps

def _quadrant_ab_phases(grid: np.ndarray) -> List[Tuple[float, float]]:
    G = grid.shape[0]
    if grid.shape[1] != G:
        raise ValueError("grid must be square")
    p = G // 2
    quads = [grid[:p, :p], grid[:p, p:], grid[p:, :p], grid[p:, p:]]
    results: List[Tuple[float, float]] = []
    for q in quads:
        F = np.fft.fft2(q) / (p * p)
        mags_b = np.abs(F[0, 1:p // 2 + 1])
        mags_a = np.abs(F[1:p // 2 + 1, 0])
        fb = int(np.argmax(mags_b)) + 1
        fa = int(np.argmax(mags_a)) + 1
        mat = q.reshape(1, -1)
        phi_b, phi_a, _ = single_freq_phase_shifts_ab(mat, p, fb, fa)
        results.append((phi_b[0], phi_a[0]))
    return results

def plotly_to_pdf_bytes(fig, scale=2.0):
    return fig.to_image(format="pdf", scale=scale, width=600, height=450)

def dominant_irrep(Fhat_n, names):
    name2idx = {lab: i for i, lab in enumerate(names)}
    D = len(names)
    P = np.zeros((D, D))
    for (r, s, _), M in Fhat_n.items():
        M_np = np.array(M)
        P[name2idx[r], name2idx[s]] = np.linalg.norm(M)
    P[0, 0] = -np.inf
    a_vals = P[:, 0].copy()
    b_vals = P[0, :].copy()
    diag_vals = np.diagonal(P).copy()
    mx_a, ia = a_vals.max(), a_vals.argmax()
    mx_b, ib = b_vals.max(), b_vals.argmax()
    mx_d, idg = diag_vals.max(), diag_vals.argmax()
    if mx_d > mx_a and mx_d > mx_b:
        fa, fb = idg, idg
    else:
        fa, fb = ia, ib
    return (names[fb], names[fa])

import numpy as np
from collections import defaultdict

def _power_matrix_from_Fhat(Fhat_n, names):
    name2idx = {lab: i for i, lab in enumerate(names)}
    D = len(names)
    P = np.zeros((D, D))
    for (r, s, _), M in Fhat_n.items():
        P[name2idx[r], name2idx[s]] = np.linalg.norm(np.array(M))
    return P

def _classify_by_gft(Fhat_n, names, freq_map, *, strict=True):
    P = _power_matrix_from_Fhat(Fhat_n, names)
    P[0, 0] = -np.inf
    a_vals = P[:, 0].copy()
    b_vals = P[0, :].copy()
    diag_vals = np.diagonal(P).copy()
    ia = int(np.argmax(a_vals))
    ib = int(np.argmax(b_vals))
    idg = int(np.argmax(diag_vals))
    mx_a, mx_b, mx_d = a_vals[ia], b_vals[ib], diag_vals[idg]
    if mx_d > mx_a and mx_d > mx_b:
        kind = "diag"
        r_star = s_star = names[idg]
        if r_star not in freq_map:
            if strict:
                raise KeyError(f"freq_map no mapping for irrep '{r_star}'")
            fa = fb = None
        else:
            fa = fb = int(freq_map[r_star])
    else:
        kind = "axis"
        r_star, s_star = names[ib], names[ia]
        fa = int(freq_map.get(s_star, -1)) if s_star in freq_map else None
        fb = int(freq_map.get(r_star, -1)) if r_star in freq_map else None
        if strict and (fa is None or fb is None):
            raise KeyError(f"freq_map missing for ({r_star},{s_star})")
    return {"kind": kind, "r_star": r_star, "s_star": s_star, "fa": fa, "fb": fb, "mx_a": float(mx_a), "mx_b": float(mx_b), "mx_d": float(mx_d), "ia": ia, "ib": ib, "idg": idg}

def subgroup_scores(vec, coset_masks, skip_trivial=True, top_k=3):
    H2masks = defaultdict(list)
    for (H, cid), m in coset_masks.items():
        if skip_trivial and H == "C_1":
            continue
        H2masks[H].append(m)
    tot = vec.var()
    if tot < 1e-12:
        items = [(H, {"CH": 0.0, "Cbar": 0.0, "K": len(H2masks[H])}) for H in H2masks.keys()]
        items.sort(key=lambda t: t[1]["Cbar"])
        if isinstance(top_k, int):
            items = items[:top_k]
        return OrderedDict(items)
    scores = []
    for H, mask_list in H2masks.items():
        K = len(mask_list)
        sum_var = sum(vec[m].var() for m in mask_list)
        C = sum_var / tot
        Cbar = (sum_var / max(1, K)) / tot
        scores.append((H, {"CH": C, "Cbar": Cbar, "K": K}))
    scores.sort(key=lambda t: t[1]["Cbar"])
    if isinstance(top_k, int):
        scores = scores[:top_k]
    return OrderedDict(scores)

def merge_topk_sources(tag_scores_list, top_k: int = 3, tag_key: str = "origin"):
    merged = []
    for tag, scores in tag_scores_list:
        for H, s in scores.items():
            merged.append((s["Cbar"], H, tag, s["CH"], s["K"]))
    merged.sort(key=lambda t: t[0])
    top = merged[:top_k]
    return [{"H": H, tag_key: tag, "Cbar": cbar, "CH": CH, "K": K} for (cbar, H, tag, CH, K) in top]

def _quad_set_extrema_header(grid_2p: np.ndarray) -> tuple[str, bool]:
    G = grid_2p.shape[0]
    assert grid_2p.shape[1] == G, "grid must be square"
    p = G // 2
    quads = {"BL": grid_2p[:p, :p], "BR": grid_2p[:p, p:], "TL": grid_2p[p:, :p], "TR": grid_2p[p:, p:]}
    bases = {"BL": (0, 0), "BR": (p, 0), "TL": (0, p), "TR": (p, p)}
    parts = []
    order = ["BL", "BR", "TL", "TR"]
    max_labels = {}
    for tag in order:
        q = quads[tag]
        fb, fa = dominant_freqs_ab(q)
        if fb != fa:
            return f"<b>skip: {tag} fb({fb}) != fa({fa})</b>", False
        f = int(fb)
        if (p % f) != 0:
            return f"<b>skip: {tag} p({p}) not divisible by f({f})</b>", False
        g = p // math.gcd(p, f)
        amax_flat = int(np.argmax(q)); amin_flat = int(np.argmin(q))
        amax = np.unravel_index(amax_flat, q.shape)
        amin = np.unravel_index(amin_flat, q.shape)
        b_base, a_base = bases[tag]
        max_label = (b_base + int(amax[1] % g), a_base + int(amax[0] % g))
        min_label = (b_base + int(amin[1] % g), a_base + int(amin[0] % g))
        max_labels[tag] = max_label
        parts.append(f"{tag} max:{max_label} min:{min_label}")
    cond1 = (max_labels["BL"][0] == max_labels["TL"][0])
    cond2 = (max_labels["BR"][0] == max_labels["TR"][0])
    cond3 = (max_labels["BL"][1] == max_labels["BR"][1])
    cond4 = (max_labels["TL"][1] == max_labels["TR"][1])
    all_true = cond1 and cond2 and cond3 and cond4
    if all_true:
        parts.append("ALL_TRUE")
    header = " | ".join(parts)
    return f"<b>{header}</b>", True

def quad_set_extrema_records_strict_f(grid_2p: np.ndarray, n: int, cayley: np.ndarray | None = None) -> dict:
    grid_2p = np.asarray(grid_2p, dtype=float)
    G = grid_2p.shape[0]
    assert grid_2p.shape[1] == G, "grid must be square"
    p = G // 2
    quads = {"BL": grid_2p[:p, :p], "BR": grid_2p[:p, p:], "TL": grid_2p[p:, :p], "TR": grid_2p[p:, p:]}
    order = ["BL", "BR", "TL", "TR"]
    f_list = []
    for tag in order:
        q = quads[tag]
        fb, fa = dominant_freqs_ab(q)
        if fb != fa:
            return dict(ok=False, reason=f"{tag}: fb({fb}) != fa({fa})", f=None, g=None, records=[])
        f_list.append(int(fb))
    uniq = sorted(set(f_list))
    if len(uniq) != 1:
        return dict(ok=False, reason=f"f varies across quads: {f_list} (uniq={uniq})", f=None, g=None, records=[])
    f = uniq[0]
    if f == 0:
        return dict(ok=False, reason=f"dominant frequency is DC (f=0) for n={n}", f=0, g=None, records=[])
    if (p % f) != 0:
        return dict(ok=False, reason=f"p({p}) not divisible by f({f})", f=None, g=None, records=[])
    g = p // math.gcd(p, f)
    bases = {"BL": (0, 0), "BR": (p, 0), "TL": (0, p), "TR": (p, p)}
    if g <= 1:
        return dict(ok=False, reason=f"g={g} (p={p}, f={f}) leaves no second-best residue", f=f, g=g, records=[])
    def second_best_mod_row(row: np.ndarray, g: int, primary_mod: int):
        mods = np.arange(row.size) % g
        best_val = -np.inf; best_mod = None
        for r in range(g):
            if r == primary_mod:
                continue
            mask = (mods == r)
            if not np.any(mask):
                continue
            v = row[mask].max()
            if v > best_val:
                best_val = v; best_mod = r
        return best_mod, float(best_val)
    def second_best_mod_col(col: np.ndarray, g: int, primary_mod: int):
        mods = np.arange(col.size) % g
        best_val = -np.inf; best_mod = None
        for r in range(g):
            if r == primary_mod:
                continue
            mask = (mods == r)
            if not np.any(mask):
                continue
            v = col[mask].max()
            if v > best_val:
                best_val = v; best_mod = r
        return best_mod, float(best_val)
    records = []
    for tag in order:
        q = quads[tag]
        amax_flat = int(np.argmax(q))
        a_idx, b_idx = np.unravel_index(amax_flat, q.shape)
        max_val = float(q[a_idx, b_idx])
        a_mod0 = int(a_idx % g)
        b_mod0 = int(b_idx % g)
        rec = {"n": int(n), "quad": tag, "a_mod": a_mod0, "b_mod": b_mod0, "max_act": max_val}
        if cayley is not None:
            b_base, a_base = bases[tag]
            a_abs = a_idx + a_base
            b_abs = b_idx + b_base
            c_idx = int(cayley[a_abs, b_abs])
            rec["c_idx"] = c_idx
            rec["c_mod"] = int(c_idx % g)
            rec["c_upper"] = bool(c_idx >= p)
        if q.shape[1] > 1:
            b2_mod, val2 = second_best_mod_row(q[a_idx, :], g, b_mod0)
            if b2_mod is not None:
                rec["sec_max_a_fx"] = {"a_mod": a_mod0, "b_mod": int(b2_mod), "val": float(val2)}
        if q.shape[0] > 1:
            a2_mod, val3 = second_best_mod_col(q[:, b_idx], g, a_mod0)
            if a2_mod is not None:
                rec["sec_max_b_fx"] = {"a_mod": int(a2_mod), "b_mod": b_mod0, "val": float(val3)}
        records.append(rec)
    return dict(ok=True, reason=None, f=f, g=g, records=records)

def pick_c_records(records):
    out = []
    for r in records:
        if "c_mod" in r:
            out.append({"n": r["n"], "quad": r["quad"], "c_mod": int(r["c_mod"]), "c_upper": bool(r.get("c_upper", False)), "max_act": float(r.get("max_act", 0.0))})
    return out

import numpy as np
import plotly.graph_objects as go
from math import gcd, pi, cos, sin
from collections import defaultdict

def _ngon_vertices(g: int, R: float, cx: float, cy: float, phi0: float):
    angs = [phi0 + 2 * pi * k / g for k in range(g)]
    xs = [cx + R * cos(a) for a in angs]
    ys = [cy + R * sin(a) for a in angs]
    return np.array(xs), np.array(ys)

def pick_records_for_mode(records, mode: str):
    out = []
    for r in records:
        if mode == "max":
            out.append({"n": r["n"], "quad": r["quad"], "a_mod": r["a_mod"], "b_mod": r["b_mod"], "max_act": r["max_act"]})
        elif mode == "sec_a":
            s = r.get("sec_max_a_fx")
            if s is not None:
                out.append({"n": r["n"], "quad": r["quad"], "a_mod": s["a_mod"], "b_mod": s["b_mod"], "max_act": s["val"]})
        elif mode == "sec_b":
            s = r.get("sec_max_b_fx")
            if s is not None:
                out.append({"n": r["n"], "quad": r["quad"], "a_mod": s["a_mod"], "b_mod": s["b_mod"], "max_act": s["val"]})
        else:
            raise ValueError("mode must be 'max' | 'sec_a' | 'sec_b'")
    return out

def plot_coset_ngon_ring(p: int, f: int, records: list, neuron_main: set, title: str | None = None, R_outer: float = 1.0, r_inner: float = 0.26, ring_cap: int = 6, ring_step: float = 0.045, text_size: int = 12, show_labels: bool = True, rotate_random: bool = False, seed: int | None = None) -> go.Figure:
    g = p // gcd(p, f)
    if g < 3:
        raise ValueError(f"g must be >=3; got g={g} (p={p}, f={f})")
    rng = np.random.default_rng(seed) if rotate_random else None
    if rotate_random:
        phi0_a_low = float(rng.uniform(0, 2 * pi))
        phi0_b_low = float(rng.uniform(0, 2 * pi))
    else:
        phi0_a_low = pi / 2
        phi0_b_low = 0.0
    phi0_a_up = phi0_a_low + pi / g
    phi0_b_up = phi0_b_low + pi / g
    XaL, YaL = _ngon_vertices(g, R_outer, 0.0, 0.0, phi0_a_low)
    XaU, YaU = _ngon_vertices(g, R_outer, 0.0, 0.0, phi0_a_up)
    fig = go.Figure()
    fig.add_trace(go.Scatter(x=np.r_[XaL, XaL[:1]], y=np.r_[YaL, YaL[:1]], mode="lines", line=dict(width=2, color="red"), hoverinfo="skip", showlegend=False))
    fig.add_trace(go.Scatter(x=np.r_[XaU, XaU[:1]], y=np.r_[YaU, YaU[:1]], mode="lines", line=dict(width=2, color="green"), hoverinfo="skip", showlegend=False))
    if show_labels:
        for r in range(g):
            fig.add_trace(go.Scatter(x=[XaL[r]], y=[YaL[r]], mode="text", text=[f"a equiv {r} (mod {g})"], textfont=dict(size=12, color="red"), hoverinfo="skip", showlegend=False))
        for r in range(g):
            fig.add_trace(go.Scatter(x=[XaU[r]], y=[YaU[r]], mode="text", text=[f"a>={p}+{r} (equiv {r})"], textfont=dict(size=12, color="green"), hoverinfo="skip", showlegend=False))
    def _draw_inner_b_at(cx, cy):
        XbL, YbL = _ngon_vertices(g, r_inner, cx, cy, phi0_b_low)
        XbU, YbU = _ngon_vertices(g, r_inner, cx, cy, phi0_b_up)
        fig.add_trace(go.Scatter(x=np.r_[XbL, XbL[:1]], y=np.r_[YbL, YbL[:1]], mode="lines", line=dict(width=1, color="rgba(0,0,0,0.25)"), hoverinfo="skip", showlegend=False))
        fig.add_trace(go.Scatter(x=np.r_[XbU, XbU[:1]], y=np.r_[YbU, YbU[:1]], mode="lines", line=dict(width=1, color="rgba(0,0,0,0.25)"), hoverinfo="skip", showlegend=False))
        if show_labels:
            for r in range(g):
                fig.add_trace(go.Scatter(x=[XbL[r]], y=[YbL[r]], mode="text", text=[f"b equiv {r}"], textfont=dict(size=10), hoverinfo="skip", showlegend=False))
            for r in range(g):
                fig.add_trace(go.Scatter(x=[XbU[r]], y=[YbU[r]], mode="text", text=[f"b>={p}+{r} (equiv {r})"], textfont=dict(size=10), hoverinfo="skip", showlegend=False))
        return (XbL, YbL, XbU, YbU)
    inner_cache = {}
    for r in range(g):
        inner_cache[("aL", r)] = _draw_inner_b_at(XaL[r], YaL[r])
    for r in range(g):
        inner_cache[("aU", r)] = _draw_inner_b_at(XaU[r], YaU[r])
    clusters = defaultdict(list)
    def halves_from_quad(quad):
        quad = quad.upper()
        a_upper = quad in ("TL", "TR")
        b_upper = quad in ("BR", "TR")
        return a_upper, b_upper
    for rec in records:
        if rec["n"] not in neuron_main:
            continue
        n = int(rec["n"])
        a_mod = int(rec["a_mod"]) % g
        b_mod = int(rec["b_mod"]) % g
        max_act = float(rec.get("max_act", 0.0))
        a_upper, b_upper = halves_from_quad(rec["quad"])
        a_key = ("aU", a_mod) if a_upper else ("aL", a_mod)
        XbL, YbL, XbU, YbU = inner_cache[a_key]
        if b_upper:
            cx, cy = XbU[b_mod], YbU[b_mod]
        else:
            cx, cy = XbL[b_mod], YbL[b_mod]
        clusters[(a_key, ("bU" if b_upper else "bL"), b_mod)].append({"n": n, "max": max_act, "center": (cx, cy)})
    Xs, Ys, texts = [], [], []
    for key, items in clusters.items():
        items_sorted = sorted(items, key=lambda d: (-d["max"], d["n"]))
        for i, d in enumerate(items_sorted):
            ring = i // ring_cap
            pos = i % ring_cap
            ang = 2 * pi * pos / max(1, ring_cap)
            rad = ring_step * (ring + 1)
            cx, cy = d["center"]
            Xs.append(cx + rad * cos(ang))
            Ys.append(cy + rad * sin(ang))
            texts.append(str(d["n"]))
    if Xs:
        fig.add_trace(go.Scatter(x=Xs, y=Ys, mode="text", text=texts, textposition="top center", textfont=dict(size=text_size, color="#222"), hoverinfo="skip", showlegend=False))
    fig.update_xaxes(visible=False)
    fig.update_yaxes(visible=False, scaleanchor="x", scaleratio=1)
    fig.update_layout(title=title or f"Coset n-gon (g={g}) - p={p}, f={f}, g={g}", width=900, height=900, margin=dict(l=40, r=40, t=60, b=40), showlegend=False)
    return fig

def plot_coset_ngon_c_single(p: int, f: int, c_records: list, neuron_main: set, title: str | None = None, R_outer: float = 1.0, ring_cap: int = 6, ring_step: float = 0.05, text_size: int = 12, show_labels: bool = True, rotate_random: bool = False, seed: int | None = None) -> go.Figure:
    g = p // gcd(p, f)
    if g < 3:
        raise ValueError(f"g must be >=3; got g={g} (p={p}, f={f})")
    rng = np.random.default_rng(seed) if rotate_random else None
    phi0_low = (rng.uniform(0, 2 * np.pi) if rotate_random else np.pi / 2)
    phi0_high = phi0_low + np.pi / g
    Xlow, Ylow = _ngon_vertices(g, R_outer, 0.0, 0.0, phi0_low)
    Xhigh, Yhigh = _ngon_vertices(g, R_outer, 0.0, 0.0, phi0_high)
    fig = go.Figure()
    fig.add_trace(go.Scatter(x=np.r_[Xlow, Xlow[:1]], y=np.r_[Ylow, Ylow[:1]], mode="lines", line=dict(width=2, color="red"), hoverinfo="skip", showlegend=False))
    fig.add_trace(go.Scatter(x=np.r_[Xhigh, Xhigh[:1]], y=np.r_[Yhigh, Yhigh[:1]], mode="lines", line=dict(width=2, color="green"), hoverinfo="skip", showlegend=False))
    if show_labels:
        for r in range(g):
            fig.add_trace(go.Scatter(x=[Xlow[r]], y=[Ylow[r]], mode="text", text=[f"c equiv {r} (mod {g})"], textfont=dict(size=12, color="red"), hoverinfo="skip", showlegend=False))
            fig.add_trace(go.Scatter(x=[Xhigh[r]], y=[Yhigh[r]], mode="text", text=[f"c>={p}+{r} (equiv {r})"], textfont=dict(size=12, color="green"), hoverinfo="skip", showlegend=False))
    clusters = defaultdict(list)
    for rec in c_records:
        if rec["n"] not in neuron_main:
            continue
        n = int(rec["n"])
        r = int(rec["c_mod"]) % g
        center = (Xhigh[r], Yhigh[r]) if rec["c_upper"] else (Xlow[r], Ylow[r])
        clusters[(r, rec["c_upper"])].append({"n": n, "max": float(rec["max_act"]), "center": center})
    Xs, Ys, texts = [], [], []
    for key, items in clusters.items():
        items_sorted = sorted(items, key=lambda d: (-d["max"], d["n"]))
        for i, d in enumerate(items_sorted):
            ring = i // ring_cap
            pos = i % ring_cap
            ang = 2 * np.pi * pos / max(1, ring_cap)
            rad = ring_step * (ring + 1)
            cx, cy = d["center"]
            Xs.append(cx + rad * np.cos(ang))
            Ys.append(cy + rad * np.sin(ang))
            texts.append(str(d["n"]))
    if Xs:
        fig.add_trace(go.Scatter(x=Xs, y=Ys, mode="text", text=texts, textposition="top center", textfont=dict(size=text_size, color="#222"), hoverinfo="skip", showlegend=False))
    fig.update_xaxes(visible=False)
    fig.update_yaxes(visible=False, scaleanchor="x", scaleratio=1)
    fig.update_layout(title=title or f"c n-gon (single layer) - p={p}, f={f}, g={g}", width=900, height=900, margin=dict(l=40, r=40, t=60, b=40), showlegend=False)
    return fig

def single_neuron_figure(n, pre_grid, left_vec, right_vec, F_full, F_L, F_R, names, subgroup_info, freq_map=None, strict=True, artifacts=None):
    G, D = pre_grid.shape[0], len(names)
    fig = make_subplots(rows=4, cols=4, row_heights=[0.25, 0.25, 0.25, 0.25], specs=[[{"type": "heatmap"}, {"type": "xy"}, {"type": "xy"}, {"type": "xy"}], [{"type": "heatmap"}, {"type": "heatmap"}, {"type": "heatmap"}, None], [{"type": "xy"}, {"type": "xy"}, {"type": "xy"}, {"type": "xy"}], [{"type": "table", "colspan": 4}, None, None, None]], subplot_titles=[f"Whole-RAW (n{n})", f"a-RAW-n{n}, max={jnp.max(left_vec[:, n]):.2f}", f"b-RAW-n{n},max={jnp.max(right_vec[:, n]):.2f}", f"Quad-REMAP (n{n})", f"Whole-DFT (n{n})", f"a-DFT (n{n})", f"b-DFT (n{n})", "Quad-BL", "Quad-BR", "Quad-TL", "Quad-TR"], horizontal_spacing=0.02, vertical_spacing=0.08)
    header, ok = _quad_set_extrema_header(pre_grid[:, :, n])
    if ok:
        fig.add_annotation(text=header, xref="paper", yref="paper", x=0.5, y=1.15, showarrow=False, font=dict(size=12), align="center")
    fig.add_trace(go.Heatmap(z=pre_grid[:, :, n], coloraxis="coloraxis1"), row=1, col=1)
    fig.update_xaxes(title_text="Right input (b)", row=1, col=1)
    fig.update_yaxes(title_text="Left input (a)", row=1, col=1)
    x_ticks = list(range(G))
    fig.add_trace(go.Scatter(x=x_ticks, y=left_vec[:, n], mode="lines+markers", line=dict(width=1), marker=dict(size=4), showlegend=False), row=1, col=2)
    fig.add_trace(go.Scatter(x=x_ticks, y=right_vec[:, n], mode="lines+markers", line=dict(width=1), marker=dict(size=4), showlegend=False), row=1, col=3)
    for i in [2, 3]:
        fig.update_xaxes(showgrid=True, row=1, col=i)
        fig.update_yaxes(showgrid=True, row=1, col=i)
    if freq_map is not None:
        Fhat_n = {k: v for k, v in F_full.items() if k[2] == n}
        dom = _classify_by_gft(Fhat_n, names, freq_map, strict=strict)
        sec_info = None
        try:
            sec_info = artifacts["secondary_per_neuron"].get(int(n), None)
        except Exception:
            pass
        remap_img, fb, fa = remap_with_special_sign_rule(pre_grid[:, :, n], dom, sec_info, freq_map)
    else:
        remap_img, fb, fa = _remapped_quadrants(pre_grid[:, :, n])
    fig.add_trace(go.Heatmap(z=remap_img, colorscale="Viridis", showscale=False), row=1, col=4)
    fig.update_xaxes(title_text=f"b (remapped) by {fb}", row=1, col=4)
    fig.update_yaxes(title_text=f"a (remapped) by {fa}", row=1, col=4)
    for col, Fhat in enumerate([F_full, F_L, F_R], start=1):
        P = np.zeros((D, D))
        for (r, s, idx), M in Fhat.items():
            if idx == n:
                P[names.index(r), names.index(s)] = np.linalg.norm(np.array(M))
        fig.add_trace(go.Heatmap(z=P, x=names, y=names, coloraxis="coloraxis2"), row=2, col=col)
    quad_phase = _quadrant_ab_phases(pre_grid[:, :, n])
    axis_range = [0, 2 * np.pi]
    tick_vals = [0, np.pi, 2 * np.pi]
    tick_text = ["0", "pi", "2pi"]
    for idx, (phib, phia) in enumerate(quad_phase):
        fig.add_trace(go.Scatter(x=[phib], y=[phia], mode="markers+text", marker=dict(size=8, color="red"), text=[f"{phib:.2f},{phia:.2f}"], textposition="top center", textfont=dict(size=10), cliponaxis=False, showlegend=False), row=3, col=idx + 1)
        fig.update_xaxes(range=axis_range, tickvals=tick_vals, ticktext=tick_text, title_text="phi_b(rad)", row=3, col=idx + 1)
        fig.update_yaxes(range=axis_range, tickvals=tick_vals, ticktext=tick_text, title_text="phi_a (rad)", row=3, col=idx + 1)
    lineL_LC = ", ".join(f"{h}:{s['Cbar']:.2e}" for h, s in subgroup_info['L']['Lcoset'].items())
    lineL_RC = ", ".join(f"{h}:{s['Cbar']:.2e}" for h, s in subgroup_info['L']['Rcoset'].items())
    lineL_mix = ", ".join(f"{d['origin']}*{d['H']}:{d['Cbar']:.2e}" for d in subgroup_info['L'].get('mix3', []))
    lineR_LC = ", ".join(f"{h}:{s['Cbar']:.2e}" for h, s in subgroup_info['R']['Lcoset'].items())
    lineR_RC = ", ".join(f"{h}:{s['Cbar']:.2e}" for h, s in subgroup_info['R']['Rcoset'].items())
    lineR_mix = ", ".join(f"{d['origin']}*{d['H']}:{d['Cbar']:.2e}" for d in subgroup_info['R'].get('mix3', []))
    headers = ["L vec - L-coset", "L vec - R-coset", "L vec - mix-Top3", "R vec - L-coset", "R vec - R-coset", "R vec - mix-Top3"]
    cells = [[lineL_LC], [lineL_RC], [lineL_mix], [lineR_LC], [lineR_RC], [lineR_mix]]
    fig.add_trace(go.Table(header=dict(values=headers, align="center"), cells=dict(values=cells, align="left"), columnwidth=[0.16, 0.16, 0.18, 0.16, 0.16, 0.18]), row=4, col=1)
    lineL = ", ".join(f"{h}:{s['Cbar']:.2e}" for h, s in subgroup_info['L'].items())
    lineR = ", ".join(f"{h}:{s['Cbar']:.2e}" for h, s in subgroup_info['R'].items())
    caption = (f"<b>Neuron {n} - top-3 subgroup scores</b><br>" f"<span style='color:#1f77b4'>L</span> -> {lineL}<br>" f"<span style='color:#ff7f0e'>R</span> -> {lineR}")
    fig.add_annotation(text=caption, showarrow=False, xref="paper", yref="paper", x=0.5, y=-0.14, font=dict(size=10), align="center")
    for r in range(1, 3):
        for c in range(1, 5):
            fig.update_xaxes(showgrid=False, zeroline=False, row=r, col=c)
            fig.update_yaxes(showgrid=False, zeroline=False, row=r, col=c)
    for (r, c) in [(1, 1), (1, 4), (2, 1), (2, 2), (2, 3), (3, 1), (3, 2), (3, 3), (3, 4)]:
        axis_name = f"x{(r - 1) * 4 + c}"
        fig.update_yaxes(scaleanchor=axis_name, scaleratio=1, row=r, col=c)
    fig.update_layout(margin=dict(t=30, b=30, l=80, r=50), width=900, height=900, autosize=False, coloraxis1=dict(colorscale="Viridis", colorbar=dict(title="RAW", len=0.25, y=0.88, yanchor="middle", x=1.02, xanchor="left")), coloraxis2=dict(colorscale="Viridis", colorbar=dict(title="GFT", len=0.25, y=0.53, yanchor="middle", x=1.02, xanchor="left")))
    fig.update_xaxes(automargin=True)
    fig.update_yaxes(automargin=True)
    return fig

def _two_stage_kmeans_prune(log_mp: np.ndarray, thresh1: float = 2.0, thresh2: float = 2.0, seed: int = 0) -> tuple[np.ndarray, np.ndarray]:
    assert log_mp.ndim == 1
    x = log_mp.reshape(-1, 1)
    n = x.shape[0]
    if n == 0:
        return np.array([], dtype=int), np.array([], dtype=int)
    if n == 1:
        return np.array([0], dtype=int), np.array([], dtype=int)
    n_unique = np.unique(x, axis=0).shape[0]
    if n_unique < 2:
        keep = np.arange(n, dtype=int)
        drop = np.array([], dtype=int)
        return keep, drop
    km1 = KMeans(n_clusters=2, n_init='auto', random_state=seed)
    lab1 = km1.fit_predict(x)
    centers1 = km1.cluster_centers_.ravel()
    hi1 = int(np.argmax(centers1))
    gap1 = float(abs(centers1[0] - centers1[1]))
    if gap1 >= float(thresh1):
        keep1 = np.flatnonzero(lab1 == hi1)
    else:
        keep1 = np.arange(x.size, dtype=int)
    keep2 = keep1
    if keep1.size >= 2:
        x2 = x[keep1]
        if np.unique(x2, axis=0).shape[0] >= 2:
            km2 = KMeans(n_clusters=2, n_init='auto', random_state=seed + 1)
            lab2 = km2.fit_predict(x2)
            centers2 = km2.cluster_centers_.ravel()
            hi2 = int(np.argmax(centers2))
            gap2 = float(abs(centers2[0] - centers2[1]))
            if gap2 >= float(thresh2):
                rel_keep2 = np.flatnonzero(lab2 == hi2)
                keep2 = keep1[rel_keep2]
    if keep2.size == 0:
        keep2 = np.array([int(np.argmax(x))], dtype=int)
    all_idx = np.arange(x.size, dtype=int)
    drop = np.setdiff1d(all_idx, keep2, assume_unique=False)
    return keep2, drop

import re
import numpy as np

def _coerce_freq(v):
    if isinstance(v, (int, np.integer)):
        return int(v)
    if isinstance(v, str):
        m = re.search(r'(\d+)$', v)
        if m:
            return int(m.group(1))
    raise TypeError(f"freq value {v!r} is not an int nor a string ending with digits.")

def _validate_freq_map(names, freq_map, p):
    missing = []
    bad = []
    for lab in names:
        if lab not in freq_map:
            continue
        try:
            f = _coerce_freq(freq_map[lab])
            assert 0 < f < p, f"f={f} out of range (1..{p-1}) for label {lab!r}"
        except Exception as e:
            bad.append((lab, freq_map[lab], str(e)))
    if missing:
        raise AssertionError(f"freq_map missing labels: {missing[:8]}{'...' if len(missing)>8 else ''}")
    if bad:
        lines = ", ".join([f"{lab}->{val} ({err})" for lab, val, err in bad[:8]])
        raise AssertionError(f"freq_map invalid entries: {lines}{' ...' if len(bad)>8 else ''}")

def remap_with_special_sign_rule(tile: np.ndarray, dom: dict, sec_info: dict, freq_map: dict[str, int]) -> tuple[np.ndarray, int, int]:
    if dom["r_star"] == "sign" and dom["s_star"] == "sign":
        sec = (sec_info or {}).get("secondary", None)
        if not sec:
            return tile.copy(), 0, 0
        if sec not in freq_map:
            raise KeyError(f"freq_map missing '{sec}'")
        f2 = int(freq_map[sec])
        return _remapped_quadrants_by_freq(tile, f2, f2)
    fb = dom.get("fb", None)
    fa = dom.get("fa", None)
    if fb is None or fa is None or dom["kind"] not in ("diag", "axis"):
        return _remapped_quadrants(tile)
    else:
        return _remapped_quadrants_by_freq(tile, int(fb), int(fa))

def _find_secondary_strong(Fhat_n, names: list[str], dom: dict, sec_ratio: float = 1/3) -> dict:
    Pn = _power_matrix_from_Fhat(Fhat_n, names)
    r_star, s_star = dom["r_star"], dom["s_star"]
    out = {"mode": dom["kind"], "primary_pair": (r_star, s_star), "secondary": None, "secondary_source": None, "metrics": {}}
    if dom["kind"] == "diag" and r_star == s_star:
        diag_vals = np.diagonal(Pn).astype(float)
        prim_idx = names.index(r_star)
        prim_pow = float(diag_vals[prim_idx])
        diag_vals_wo = diag_vals.copy(); diag_vals_wo[prim_idx] = -np.inf
        sec_idx_d = int(np.argmax(diag_vals_wo))
        sec_pow_d = float(diag_vals_wo[sec_idx_d])
        ratio_d = (sec_pow_d / prim_pow) if prim_pow > 0 else 0.0
        cand_diag = names[sec_idx_d] if (np.isfinite(sec_pow_d) and ratio_d >= sec_ratio) else None
        a_vals = Pn[:, 0].astype(float); b_vals = Pn[0, :].astype(float)
        ia = int(np.argmax(a_vals)); ib = int(np.argmax(b_vals))
        prim_a = float(a_vals[ia]); prim_b = float(b_vals[ib])
        best_k, best_score = None, (-np.inf, -np.inf)
        for k in range(len(names)):
            if k == prim_idx:
                continue
            ra = (a_vals[k] / prim_a) if prim_a > 0 else 0.0
            rb = (b_vals[k] / prim_b) if prim_b > 0 else 0.0
            if ra >= sec_ratio and rb >= sec_ratio and (a_vals[k] > 0 or b_vals[k] > 0):
                score = (min(ra, rb), ra + rb)
                if score > best_score:
                    best_score, best_k = score, k
        cand_axis = names[best_k] if best_k is not None else None
        score_axis = best_score[0] if best_k is not None else 0.0
        score_diag = ratio_d if cand_diag is not None else 0.0
        if cand_axis is not None and score_axis > score_diag:
            out["secondary"] = cand_axis
            out["secondary_source"] = "axis"
            out["metrics"].update(ratio_axis_min=score_axis, ratio_diag=score_diag)
        elif cand_diag is not None:
            out["secondary"] = cand_diag
            out["secondary_source"] = "diag"
            out["metrics"].update(ratio_axis_min=score_axis, ratio_diag=score_diag)
        else:
            out["metrics"].update(ratio_axis_min=score_axis, ratio_diag=score_diag)
        return out
    if dom["kind"] == "axis":
        a_vals = Pn[:, 0].astype(float)
        b_vals = Pn[0, :].astype(float)
        ia = int(np.argmax(a_vals))
        ib = int(np.argmax(b_vals))
        prim_a = float(a_vals[ia])
        prim_b = float(b_vals[ib])
        best_k, best_score = None, (-np.inf, -np.inf)
        for k in range(len(names)):
            ra = (a_vals[k] / prim_a) if prim_a > 0 else 0.0
            rb = (b_vals[k] / prim_b) if prim_b > 0 else 0.0
            if (ra >= sec_ratio) and (rb >= sec_ratio) and (a_vals[k] > 0 or b_vals[k] > 0):
                score = (min(ra, rb), ra + rb)
                if score > best_score:
                    best_score = score
                    best_k = k
        if best_k is not None:
            out["secondary"] = names[best_k]
        out["metrics"].update(primary_power_a=prim_a, primary_power_b=prim_b, secondary_power_a=(a_vals[best_k] if best_k is not None else 0.0), secondary_power_b=(b_vals[best_k] if best_k is not None else 0.0), ratio_a=((a_vals[best_k] / prim_a) if (best_k is not None and prim_a > 0) else 0.0), ratio_b=((b_vals[best_k] / prim_b) if (best_k is not None and prim_b > 0) else 0.0))
        return out
    return out

def prepare_layer_artifacts(pre_grid, left, right, dft_2d, irreps, freq_map, strict=True, prune_cfg: dict | None = None, store_full_neuron_grids: bool = False, sec_ratio: float = 1/3):
    G, N = pre_grid.shape[0], pre_grid.shape[-1]
    flat_all = pre_grid.reshape(G * G, N)
    F_full = dft_2d(flat_all)
    F_L = dft_2d(left)
    F_R = dft_2d(right)
    names = [lab for lab, _, _, _ in irreps]
    irrep2neurons = defaultdict(list)
    freq_cluster = defaultdict(list)
    neuron_data = {}
    secondary_per_neuron = {}
    secondary_by_cluster = defaultdict(list)
    for n in range(N):
        Fhat_n = {k: v for k, v in F_full.items() if k[2] == n}
        dom = _classify_by_gft(Fhat_n, names, freq_map, strict=strict)
        r_star, s_star = dom["r_star"], dom["s_star"]
        irrep2neurons[(r_star, s_star)].append(n)
        if r_star == s_star:
            sec_info = _find_secondary_strong(Fhat_n, names, dom, sec_ratio=sec_ratio)
            secondary_per_neuron[int(n)] = sec_info
            secondary_by_cluster[(r_star, s_star)].append((int(n), sec_info["secondary"]))
        else:
            sec_info = {"mode": dom["kind"], "primary_pair": (r_star, s_star), "secondary": None, "metrics": {}}
            secondary_per_neuron[int(n)] = sec_info
            secondary_by_cluster[(r_star, s_star)].append((int(n), None))
        entry = {"a_values": np.arange(G, dtype=int), "b_values": np.arange(G, dtype=int), "dominant": dom}
        if store_full_neuron_grids:
            grid = pre_grid[:, :, n]
            post = np.maximum(grid, 0.0)
            entry["real_preactivations"] = grid
            entry["postactivations"] = post
        neuron_data[int(n)] = entry
    for k, v in irrep2neurons.items():
        irrep2neurons[k] = list(dict.fromkeys(v))
    diag_labels = set()
    cluster_prune = {}
    if prune_cfg is not None:
        rel_tau = float(prune_cfg.get("rel_tau", 0.05))
        rel_tau = max(0.0, min(1.0, rel_tau))
        abs_floor = float(prune_cfg.get("abs_floor", 0.0))
        for (r, s), neuron_list in irrep2neurons.items():
            if r != s or len(neuron_list) == 0:
                continue
            cluster_grid = pre_grid[:, :, neuron_list]
            max_preacts = np.max(np.abs(cluster_grid), axis=(0, 1))
            max_preacts = np.where(max_preacts > abs_floor, max_preacts, 0.0)
            cluster_max = float(max_preacts.max()) if max_preacts.size else 0.0
            thr = rel_tau * cluster_max
            keep_rel = np.flatnonzero(max_preacts > max(thr, abs_floor))
            drop_rel = np.setdiff1d(np.arange(max_preacts.size), keep_rel)
            main = [int(neuron_list[i]) for i in keep_rel]
            drop = [int(neuron_list[i]) for i in drop_rel]
            per_log = {int(neuron_list[i]): (float(np.log10(max_preacts[i])) if max_preacts[i] > 0.0 else None) for i in range(len(neuron_list))}
            cluster_prune[(r, s)] = {"main": main, "drop": drop, "per_neuron_log10": per_log}
            if not main:
                continue
            if r not in freq_map:
                if strict:
                    raise KeyError(f"freq_map has no mapping for irrep '{r}'")
            else:
                f = int(freq_map[r])
                freq_cluster[f].extend(main)
                for n in main:
                    kind_n = neuron_data[n]["dominant"]["kind"]
                    diag_labels.add((r, kind_n))
    freq_cluster = {f: list(dict.fromkeys(ids)) for f, ids in freq_cluster.items() if ids}
    artifacts = {"F_full": F_full, "F_L": F_L, "F_R": F_R, "names": names, "irrep2neurons": irrep2neurons, "freq_cluster": freq_cluster, "diag_labels": diag_labels, "neuron_data": neuron_data, "secondary_per_neuron": secondary_per_neuron, "secondary_by_cluster": dict(secondary_by_cluster)}
    if prune_cfg is not None:
        artifacts["cluster_prune"] = cluster_prune
    return artifacts

def build_f_pool_from_layer1_artifacts(layer1_artifacts, freq_map, p):
    ir2n = layer1_artifacts["irrep2neurons"]
    pruned = layer1_artifacts.get("cluster_prune", {}) or {}
    fset = set()
    for (r, s), ids in ir2n.items():
        if r != s:
            continue
        used = pruned.get((r, s), {}).get("main", None) or ids
        if not used:
            continue
        if r in freq_map:
            f = int(freq_map[r])
            if 0 < f < p:
                fset.add(f)
    return sorted(fset)

def summarize_diag_labels(diag_labels: Iterable[Tuple[str, str]], p: int, names: List[str]) -> Dict[str, Any]:
    diag_labels = set((str(label).strip(), str(kind).strip()) for label, kind in diag_labels)
    total_diag = len(diag_labels)
    label_kind_map = defaultdict(set)
    for label, kind in diag_labels:
        label_kind_map[label].add(kind)
    names_1d = {"sign", "rp", "srp"}
    pat_2d = re.compile(r"^2D[_\-]?(\d+)$", flags=re.IGNORECASE)
    approx_coset = []
    coset_2d = []
    coset_1d = []
    others = []
    for label, kinds in label_kind_map.items():
        m = pat_2d.match(label)
        kinds_list = sorted(kinds)
        if m:
            f = int(m.group(1))
            gcd_pf = math.gcd(p, f)
            item = {"label": label, "kinds": kinds_list, "f": f, "gcd_pf": gcd_pf}
            if gcd_pf == 1:
                approx_coset.append(item)
            else:
                coset_2d.append(item)
        elif label.lower() in names_1d:
            coset_1d.append({"label": label, "kinds": kinds_list})
        else:
            others.append({"label": label, "kinds": kinds_list})
    total_diag = len(label_kind_map)
    counts = {"approx_coset": len(approx_coset), "coset_2d": len(coset_2d), "coset_1d": len(coset_1d), "others": len(others), "total_diag": total_diag}
    sum_all = counts["approx_coset"] + counts["coset_2d"] + counts["coset_1d"] + counts["others"]
    consistency_ok = (sum_all == total_diag)
    processed_labels = set()
    for group in [approx_coset, coset_2d, coset_1d, others]:
        for item in group:
            processed_labels.add(item["label"])
    missing_labels = []
    if not consistency_ok:
        missing_labels = list(sorted(set(label_kind_map.keys()) - processed_labels))
    approx_ratio = (counts["approx_coset"] / total_diag) if total_diag > 0 else 0.0
    summary = {"p": p, "names": names, "items": {"approx_coset": approx_coset, "coset_2d": coset_2d, "coset_1d": coset_1d, "others": others}, "counts": counts, "consistency": {"ok": consistency_ok, "sum_all": sum_all, "expected_total": total_diag, "missing_when_not_ok": missing_labels}, "approx_ratio_in_diag": approx_ratio}
    return summary

def make_layer_report(pre_grid, left, right, p, dft_2d, irreps, coset_masks_left, coset_masks_right, save_dir: str, cluster_tau: float = 1e-3, color_rule=None, artifacts=None, layer_idx: int = 0, freq_map=None):
    G = pre_grid.shape[0]
    N = pre_grid.shape[-1]
    p_rot = G // 2
    G_list, idx_map, cayley_table = dihedral.build_cayley_table(p_rot)
    if artifacts is None:
        artifacts = prepare_layer_artifacts(pre_grid, left, right, dft_2d, irreps, freq_map)
    F_full = artifacts["F_full"]
    F_L = artifacts["F_L"]
    F_R = artifacts["F_R"]
    names = artifacts["names"]
    irrep2neurons = artifacts["irrep2neurons"]
    cont_l = left.reshape(G, G, N)
    left_vec = cont_l.mean(axis=1)
    cont_r = right.reshape(G, G, N)
    right_vec = cont_r.mean(axis=0)
    coset_info = {}
    for n in range(N):
        v_left = left_vec[:, n]
        v_right = right_vec[:, n]
        L_LC_top = subgroup_scores(v_left, coset_masks_left, top_k=2)
        L_RC_top = subgroup_scores(v_left, coset_masks_right, top_k=2)
        L_LC_all = subgroup_scores(v_left, coset_masks_left, top_k=None)
        L_RC_all = subgroup_scores(v_left, coset_masks_right, top_k=None)
        L_mix3 = merge_topk_sources([("Lcoset", L_LC_all), ("Rcoset", L_RC_all)], top_k=3)
        R_LC_top = subgroup_scores(v_right, coset_masks_left, top_k=2)
        R_RC_top = subgroup_scores(v_right, coset_masks_right, top_k=2)
        R_LC_all = subgroup_scores(v_right, coset_masks_left, top_k=None)
        R_RC_all = subgroup_scores(v_right, coset_masks_right, top_k=None)
        R_mix3 = merge_topk_sources([("Lcoset", R_LC_all), ("Rcoset", R_RC_all)], top_k=3)
        coset_info[n] = {"L": {"Lcoset": L_LC_top, "Rcoset": L_RC_top, "mix3": L_mix3}, "R": {"Lcoset": R_LC_top, "Rcoset": R_RC_top, "mix3": R_mix3}}
        if n < 10:
            print(f"neuron {n:3d} | L[Lcoset]{L_LC_top} | L[Rcoset]{L_RC_top} | L-mix3 {L_mix3} | R[Lcoset]{R_LC_top} | R[Rcoset]{R_RC_top} | R-mix3 {R_mix3}")
    for (r, s), neuron_list in irrep2neurons.items():
        if r == s:
            writer = PdfWriter()
            cluster_acts = pre_grid[:, :, neuron_list]
            max_act = cluster_acts.max()
            if artifacts is not None and "cluster_prune" in artifacts and (r, s) in artifacts["cluster_prune"]:
                _pack = artifacts["cluster_prune"][(r, s)]
                neuron_main = list(_pack["main"])
                neuron_drop = list(_pack["drop"])
                per_log = _pack["per_neuron_log10"]
                if len(neuron_main) == 0:
                    print(f"[skip] Cluster ({r},{s}) is empty after pruning.")
                    continue
            else:
                neuron_main = neuron_list
                neuron_drop = []
                cluster_grid = pre_grid[:, :, neuron_list]
                max_preacts = np.max(np.abs(cluster_grid), axis=(0, 1))
                per_log = {int(neuron_list[i]): float(np.log10(max_preacts[i] + 1e-20)) for i in range(len(neuron_list))}
            cover = make_subplots(rows=1, cols=1)
            cover.add_annotation(text=(f"<b>Cluster ({r},{s})</b><br>" f"size = {len(neuron_list)}<br>" f"max activation = {max_act:.2e}"), xref="paper", yref="paper", x=0.5, y=0.6, showarrow=False, font=dict(size=24), align="center")
            cover.update_xaxes(visible=False)
            cover.update_yaxes(visible=False)
            cover._uuid = uuid.uuid4().hex
            pdf_cover = cover.to_image(format="pdf", engine="kaleido")
            reader = PdfReader(io.BytesIO(pdf_cover))
            writer.add_page(reader.pages[0])
            quad_labels = ["Quad-BL", "Quad-BR", "Quad-TL", "Quad-TR"]
            fig_quads = make_subplots(rows=2, cols=2, subplot_titles=quad_labels)
            palette = (px.colors.qualitative.Light24 + px.colors.qualitative.Dark24)
            num_colors = len(palette)
            axis_range = [0, 2 * np.pi]
            tick_vals = [0, np.pi, 2 * np.pi]
            tick_text = ["0", "pi", "2pi"]
            for n_idx, nn in enumerate(neuron_list):
                quad_phase = _quadrant_ab_phases(pre_grid[:, :, nn])
                color = palette[n_idx % num_colors]
                for q_idx, (phix, phiy) in enumerate(quad_phase):
                    r_0, c_0 = divmod(q_idx, 2)
                    show_legend = (q_idx == 0)
                    fig_quads.add_trace(go.Scatter(x=[phix], y=[phiy], mode="markers", marker=dict(size=5, color=color), name=f"neuron {nn}", showlegend=show_legend), row=r_0 + 1, col=c_0 + 1)
                    if n_idx == 0:
                        fig_quads.update_xaxes(title_text="phi_b (rad)", range=axis_range, tickvals=tick_vals, ticktext=tick_text, row=r_0 + 1, col=c_0 + 1)
                        fig_quads.update_yaxes(title_text="phi_a (rad)", range=axis_range, tickvals=tick_vals, ticktext=tick_text, row=r_0 + 1, col=c_0 + 1)
            fig_quads._uuid = uuid.uuid4().hex
            pdf_bytes = fig_quads.to_image(format="pdf", engine="kaleido")
            reader = PdfReader(io.BytesIO(pdf_bytes))
            writer.add_page(reader.pages[0])
            quad_labels = ["Quad-BL", "Quad-BR", "Quad-TL", "Quad-TR"]
            fig_quads_mer = make_subplots(rows=2, cols=2, subplot_titles=quad_labels)
            axis_range = [0, 2 * np.pi]
            tick_vals = [0, np.pi, 2 * np.pi]
            tick_text = ["0", "pi", "2pi"]
            color = "red"
            merged_quads = [defaultdict(list) for _ in range(4)]
            for nn in neuron_main:
                quad_phase = _quadrant_ab_phases(pre_grid[:, :, nn])
                max_amp = np.abs(pre_grid[:, :, nn]).max()
                for q_idx, (phix, phiy) in enumerate(quad_phase):
                    key = (round(phix, 4), round(phiy, 4))
                    merged_quads[q_idx][key].append(max_amp)
            for q_idx, phase_dict in enumerate(merged_quads):
                r_0, c_0 = divmod(q_idx, 2)
                for (mx, my), amps in phase_dict.items():
                    sum_amp = sum(amps)
                    count = len(amps)
                    size = 6 + 3 * np.log1p(count)
                    fig_quads_mer.add_trace(go.Scatter(x=[mx], y=[my], mode="markers+text", marker=dict(size=size, color=color), text=[f"{sum_amp:.2f}"], textposition="top center", showlegend=False), row=r_0 + 1, col=c_0 + 1)
                    fig_quads_mer.update_xaxes(title_text="phi_b (rad)", range=axis_range, tickvals=tick_vals, ticktext=tick_text, row=r_0 + 1, col=c_0 + 1)
                    fig_quads_mer.update_yaxes(title_text="phi_a (rad)", range=axis_range, tickvals=tick_vals, ticktext=tick_text, row=r_0 + 1, col=c_0 + 1)
            fig_quads_mer._uuid = uuid.uuid4().hex
            pdf_bytes = fig_quads_mer.to_image(format="pdf", engine="kaleido")
            reader = PdfReader(io.BytesIO(pdf_bytes))
            writer.add_page(reader.pages[0])
            cluster_grid = pre_grid[:, :, neuron_list]
            log_mp_all = np.log10(np.max(np.abs(cluster_grid), axis=(0, 1)) + 1e-20)
            color_tag = np.array(["drop"] * len(neuron_list))
            idx_map = {nn: i for i, nn in enumerate(neuron_list)}
            for nn in neuron_main:
                if nn in idx_map:
                    color_tag[idx_map[nn]] = "main"
            fig_mp = px.scatter(x=np.arange(len(neuron_list)), y=log_mp_all, color=color_tag, labels=dict(x="Neuron index in cluster", y="log10(max pre-act)"), title=f"Cluster ({r},{s}) - log10(max pre-act): main vs drop")
            fig_mp._uuid = uuid.uuid4().hex
            pdf_bytes = fig_mp.to_image(format="pdf", engine="kaleido")
            writer.add_page(PdfReader(io.BytesIO(pdf_bytes)).pages[0])
            by_f = {}
            skipped = []
            for nn in neuron_main:
                res = quad_set_extrema_records_strict_f(pre_grid[:, :, nn], nn, cayley=cayley_table)
                if not res["ok"]:
                    skipped.append((nn, res["reason"]))
                    continue
                f = res["f"]
                by_f.setdefault(f, {"records": [], "neurons": set()})
                by_f[f]["records"].extend(res["records"])
                by_f[f]["neurons"].add(nn)
            pages_made = 0
            for f, pack in by_f.items():
                if not pack["records"]:
                    continue
                rec_max = pick_records_for_mode(pack["records"], "max")
                rec_sec_a = pick_records_for_mode(pack["records"], "sec_a")
                rec_sec_b = pick_records_for_mode(pack["records"], "sec_b")
                rec_c = pick_c_records(pack["records"])
                for mode, recs in [("max", rec_max), ("sec_a", rec_sec_a), ("sec_b", rec_sec_b)]:
                    if not recs:
                        continue
                    try:
                        fig = plot_coset_ngon_ring(p=p, f=f, records=recs, neuron_main=pack["neurons"], title=f"coset n-gon - p={p}, f={f}, g={p//math.gcd(p,f)}  [{mode}]", r_inner=0.32, show_labels=True)
                        if skipped:
                            fig.add_annotation(text=f"Skipped {len(skipped)} neuron(s) (inconsistent f)", xref="paper", yref="paper", x=1.0, y=1.08, xanchor="right", yanchor="bottom", showarrow=False, font=dict(size=10))
                        fig._uuid = uuid.uuid4().hex
                        pdf_hex = fig.to_image(format="pdf", engine="kaleido")
                        writer.add_page(PdfReader(io.BytesIO(pdf_hex)).pages[0])
                        pages_made += 1
                    except ValueError as e:
                        print(f"[skip] f={f}, mode={mode}: {e}")
                        continue
                if rec_c:
                    try:
                        fig_c = plot_coset_ngon_c_single(p=p, f=f, c_records=rec_c, neuron_main=pack["neurons"], title=f"coset n-gon (c) - p={p}, f={f}, g={p//math.gcd(p,f)}")
                        fig_c._uuid = uuid.uuid4().hex
                        pdf_hex_c = fig_c.to_image(format="pdf", engine="kaleido")
                        writer.add_page(PdfReader(io.BytesIO(pdf_hex_c)).pages[0])
                        pages_made += 1
                    except ValueError as e:
                        print(f"[skip c] f={f}: {e}")
            if pages_made == 0:
                print(f"[hexagram] no pages generated for Cluster ({r},{s}); skipped={len(skipped)}; by_f empty.")
            mat = pre_grid[:, :, neuron_main].reshape(G * G, -1).astype(float)
            embed_dir = os.path.join(save_dir, f"cluster_{r}_{s}_embeds")
            os.makedirs(embed_dir, exist_ok=True)
            p2 = G // 2
            quads = {"BL": pre_grid[:p2, :p2, neuron_main], "BR": pre_grid[:p2, p2:, neuron_main], "TL": pre_grid[p2:, :p2, neuron_main], "TR": pre_grid[p2:, p2:, neuron_main]}
            freqs = []
            if r in freq_map:
                f_main = int(freq_map[r])
                freqs.append(f_main)
            else:
                print(f"[warn] freq_map missing {r}; skip main frequency")
            sec_pairs = artifacts.get("secondary_by_cluster", {}).get((r, s), [])
            from collections import Counter
            cnt = Counter()
            for _, lab in sec_pairs:
                if lab and (lab in freq_map):
                    cnt[int(freq_map[lab])] += 1
            min_count = max(1, int(0.30 * len(neuron_main)))
            f2_list = sorted([ff for ff, c in cnt.items() if c >= min_count])
            freqs.extend(f2_list)
            freqs = sorted(dict.fromkeys(freqs))
            if not freqs:
                print(f"[warn] cluster ({r},{s}) no freq; empty list")
            quad_freq_lists = {k: freqs for k in ["BL", "BR", "TL", "TR", "full"]}
            color_rules = color_rule
            generate_pdf_plots_for_matrix(mat, p=p, save_dir=embed_dir, seed=f"{r}_{s}", freq_list=quad_freq_lists["full"], tag="full", tag_q="full", class_string=f"cluster_{r}_{s}", colour_rule=color_rules, num_principal_components=4)
            for tag_q, quad in quads.items():
                qmat = quad.reshape(p2 * p2, -1).astype(float)
                generate_pdf_plots_for_matrix(qmat, p=p2, save_dir=embed_dir, seed=f"{r}_{s}", freq_list=quad_freq_lists[tag_q], tag=tag_q, tag_q=tag_q, class_string=f"{tag_q}_{r}_{s}", colour_rule=color_rules, num_principal_components=4)
            side = G
            indices = np.arange(side * side)
            a_vals = indices // side
            b_vals = indices % side
            pcs, _ = compute_pca_coords(mat, num_components=min(8, mat.shape[1], mat.shape[0] - 1))
            coords_for_angle = pcs[:, :4]
            angles_dir = os.path.join(embed_dir, "angles", "pca")
            os.makedirs(angles_dir, exist_ok=True)
            if layer_idx == 0:
                plane_angle_per_cluster(coords=coords_for_angle, a_vals=a_vals, b_vals=b_vals, p=p2, cluster_ids=np.zeros(side * side, dtype=int), mode="a", tag_q="full", save_dir=os.path.join(angles_dir, "layer0_a"), title=f"Layer-0: a-rot vs a-ref (cluster {r},{s})")
                plane_angle_per_cluster(coords=coords_for_angle, a_vals=a_vals, b_vals=b_vals, p=p2, cluster_ids=np.zeros(side * side, dtype=int), mode="b", tag_q="full", save_dir=os.path.join(angles_dir, "layer0_b"), title=f"Layer-0: b-rot vs b-ref (cluster {r},{s})")
            else:
                plane_angle_per_cluster(coords=coords_for_angle, a_vals=a_vals, b_vals=b_vals, p=p2, cluster_ids=np.zeros(side * side, dtype=int), mode="c", tag_q="full", save_dir=os.path.join(angles_dir, "layerK_c"), title=f"Layer-{layer_idx}: c-rot vs c-ref (cluster {r},{s})")
            for nn in neuron_main:
                print(f"Rendering neuron {nn} in cluster ({r},{s})")
                fig_n = single_neuron_figure(nn, pre_grid, left_vec, right_vec, F_full, F_L, F_R, names, coset_info[nn], freq_map=freq_map, strict=True, artifacts=artifacts)
                fig_n._uuid = uuid.uuid4().hex
                pdf_bytes = fig_n.to_image(format="pdf", engine="kaleido")
                reader = PdfReader(io.BytesIO(pdf_bytes))
                page0 = reader.pages[0]
                writer.add_page(page0)
            if len(neuron_drop) > 0:
                rows = []
                for nn in sorted(neuron_drop):
                    logv = per_log.get(int(nn), None)
                    maxv = float(np.max(np.abs(pre_grid[:, :, nn])))
                    rows.append((int(nn), maxv, (None if logv is None else float(logv))))
                rows.sort(key=lambda t: (t[2] if t[2] is not None else -1e9), reverse=True)
                tbl = go.Figure(data=[go.Table(header=dict(values=["neuron index", "max |preact|", "log10(max)"]), cells=dict(values=[[r[0] for r in rows], [f"{r[1]:.3g}" for r in rows], [("—" if r[2] is None else f"{r[2]:.3f}") for r in rows]]))])
                tbl.update_layout(title=f"Cluster ({r},{s}) - Dropped neurons")
                tbl._uuid = uuid.uuid4().hex
                pdf_bytes = tbl.to_image(format="pdf", engine="kaleido")
                writer.add_page(PdfReader(io.BytesIO(pdf_bytes)).pages[0])
            path = f"cluster_{r}_{s}.pdf"
            fin_path = os.path.join(save_dir, path)
            with open(fin_path, "wb") as f:
                writer.write(f)
            print("saved", fin_path)
        else:
            cluster_acts = pre_grid[:, :, neuron_list]
            max_act = cluster_acts.max()
            print(f"Skip {r}_{s} irreps plot, max activation: {max_act}")

import traceback
def make_R2_c_angle_report_no_plot(pre_grid, p, dft_2d, irreps, save_dir: str, artifacts=None, base_layer_artifacts=None, base_f_pool=None, layer_idx: int = 0, freq_map=None):
    G = pre_grid.shape[0]
    N = pre_grid.shape[-1]
    p_rot = G // 2
    if layer_idx >= 1:
        if base_f_pool is not None:
            f_pool_L1 = list(base_f_pool)
        elif base_layer_artifacts is not None:
            f_pool_L1 = build_f_pool_from_layer1_artifacts(base_layer_artifacts, freq_map, p)
        else:
            raise ValueError("Need base_f_pool or base_layer_artifacts for layer>=1 residual fitting.")
    names = artifacts["names"]
    irrep2neurons = artifacts["irrep2neurons"]
    for (r, s), neuron_list in irrep2neurons.items():
        if r == s:
            writer = PdfWriter()
            cluster_acts = pre_grid[:, :, neuron_list]
            max_act = cluster_acts.max()
            if artifacts is not None and "cluster_prune" in artifacts and (r, s) in artifacts["cluster_prune"]:
                _pack = artifacts["cluster_prune"][(r, s)]
                neuron_main = list(_pack["main"])
                neuron_drop = list(_pack["drop"])
                per_log = _pack["per_neuron_log10"]
                if len(neuron_main) == 0:
                    print(f"[skip] Cluster ({r},{s}) is empty after pruning.")
                    continue
            else:
                neuron_main = neuron_list
                neuron_drop = []
                cluster_grid = pre_grid[:, :, neuron_list]
                max_preacts = np.max(np.abs(cluster_grid), axis=(0, 1))
                per_log = {int(neuron_list[i]): float(np.log10(max_preacts[i] + 1e-20)) for i in range(len(neuron_list))}
            mat = pre_grid[:, :, neuron_main].reshape(G * G, -1).astype(float)
            embed_dir = os.path.join(save_dir, f"cluster_{r}_{s}_embeds")
            os.makedirs(embed_dir, exist_ok=True)
            side = G
            indices = np.arange(side * side)
            a_vals = indices // side
            b_vals = indices % side
            pcs, _ = compute_pca_coords(mat, num_components=min(8, mat.shape[1], mat.shape[0] - 1))
            coords_for_angle = pcs[:, :4]
            angles_dir = os.path.join(embed_dir, "angles", "pca")
            os.makedirs(angles_dir, exist_ok=True)
            if layer_idx == 0:
                plane_angle_per_cluster(coords=coords_for_angle, a_vals=a_vals, b_vals=b_vals, p=p, cluster_ids=np.zeros(side * side, dtype=int), mode="a", tag_q="full", save_dir=os.path.join(angles_dir, "layer0_a"), title=f"Layer-0: a-rot vs a-ref (cluster {r},{s})")
                plane_angle_per_cluster(coords=coords_for_angle, a_vals=a_vals, b_vals=b_vals, p=p, cluster_ids=np.zeros(side * side, dtype=int), mode="b", tag_q="full", save_dir=os.path.join(angles_dir, "layer0_b"), title=f"Layer-0: b-rot vs b-ref (cluster {r},{s})")
            else:
                plane_angle_per_cluster(coords=coords_for_angle, a_vals=a_vals, b_vals=b_vals, p=p, cluster_ids=np.zeros(side * side, dtype=int), mode="c", tag_q="full", save_dir=os.path.join(angles_dir, "layerK_c"), title=f"Layer-{layer_idx}: c-rot vs c-ref (cluster {r},{s})")
            if freq_map is not None:
                try:
                    F_full = artifacts.get("F_full", {})
                    r2_dir = os.path.join(embed_dir, "r2")
                    os.makedirs(r2_dir, exist_ok=True)
                    f_pool_L1 = build_f_pool_from_layer1_artifacts(base_layer_artifacts, freq_map, p)
                    p_quad = G // 2
                    cluster_r2 = []
                    per_neuron_avg = {}
                    for nn in neuron_main:
                        Fhat_n = {k: v for k, v in artifacts.get("F_full", {}).items() if k[2] == nn}
                        dom = _classify_by_gft(Fhat_n, names, freq_map, strict=True)
                        grid = pre_grid[:, :, nn]
                        quads = [grid[:p_quad, :p_quad], grid[:p_quad, p_quad:], grid[p_quad:, :p_quad], grid[p_quad:, p_quad:]]
                        r2_list = []
                        for q in quads:
                            for q in quads:
                                fit = fit_quadrant_sines(q, dom, freq_map, names, f_pool=f_pool_L1, max_iters=6, tau_inc=0.0, use_axes_only=True, use_pair_terms=True, include_ab_resid=False)
                                r2_list.append(float(fit["R2_final"]))
                        if r2_list:
                            per_neuron_avg[int(nn)] = float(np.mean(r2_list))
                            cluster_r2.extend(r2_list)
                    summary = {"cluster": [(r, s)], "layer_idx": int(layer_idx), "p": int(p), "n_neurons": int(len(neuron_main)), "f_pool": list(map(int, f_pool_L1)), "cluster_avg_R2_final": (float(np.mean(cluster_r2)) if cluster_r2 else None), "per_neuron_avg_R2_final": per_neuron_avg}
                    with open(os.path.join(r2_dir, "r2_summary.json"), "w") as f:
                        json.dump(summary, f, indent=2)
                except Exception as e:
                    print(f"[R2] cluster ({r},{s}) failed: {type(e).__name__}: {e}")
                    traceback.print_exc()
                    raise

def make_stripe_report_only_phase_pdf(pre_grid, p, dft_2d, irreps, save_dir: str, artifacts=None):
    G = pre_grid.shape[0]
    N = pre_grid.shape[-1]
    p_rot = G // 2
    names = artifacts["names"]
    irrep2neurons = artifacts["irrep2neurons"]
    for (r, s), neuron_list in irrep2neurons.items():
        if r == s:
            writer = PdfWriter()
            cluster_acts = pre_grid[:, :, neuron_list]
            max_act = cluster_acts.max()
            if artifacts is not None and "cluster_prune" in artifacts and (r, s) in artifacts["cluster_prune"]:
                _pack = artifacts["cluster_prune"][(r, s)]
                neuron_main = list(_pack["main"])
                neuron_drop = list(_pack["drop"])
                per_log = _pack["per_neuron_log10"]
            else:
                neuron_main = neuron_list
                neuron_drop = []
                cluster_grid = pre_grid[:, :, neuron_list]
                max_preacts = np.max(np.abs(cluster_grid), axis=(0, 1))
                per_log = {int(neuron_list[i]): float(np.log10(max_preacts[i] + 1e-20)) for i in range(len(neuron_list))}
            cover = make_subplots(rows=1, cols=1)
            cover.add_annotation(text=(f"<b>Cluster ({r},{s})</b><br>" f"size = {len(neuron_list)}<br>" f"max activation = {max_act:.2e}"), xref="paper", yref="paper", x=0.5, y=0.6, showarrow=False, font=dict(size=24), align="center")
            cover.update_xaxes(visible=False)
            cover.update_yaxes(visible=False)
            cover._uuid = uuid.uuid4().hex
            pdf_cover = cover.to_image(format="pdf", engine="kaleido")
            reader = PdfReader(io.BytesIO(pdf_cover))
            writer.add_page(reader.pages[0])
            quad_labels = ["Quad-BL", "Quad-BR", "Quad-TL", "Quad-TR"]
            fig_quads = make_subplots(rows=2, cols=2, subplot_titles=quad_labels)
            palette = (px.colors.qualitative.Light24 + px.colors.qualitative.Dark24)
            num_colors = len(palette)
            axis_range = [0, 2 * np.pi]
            tick_vals = [0, np.pi, 2 * np.pi]
            tick_text = ["0", "pi", "2pi"]
            for n_idx, nn in enumerate(neuron_list):
                quad_phase = _quadrant_ab_phases(pre_grid[:, :, nn])
                color = palette[n_idx % num_colors]
                for q_idx, (phix, phiy) in enumerate(quad_phase):
                    r_0, c_0 = divmod(q_idx, 2)
                    show_legend = (q_idx == 0)
                    fig_quads.add_trace(go.Scatter(x=[phix], y=[phiy], mode="markers", marker=dict(size=5, color=color), name=f"neuron {nn}", showlegend=show_legend), row=r_0 + 1, col=c_0 + 1)
                    if n_idx == 0:
                        fig_quads.update_xaxes(title_text="phi_b (rad)", range=axis_range, tickvals=tick_vals, ticktext=tick_text, row=r_0 + 1, col=c_0 + 1)
                        fig_quads.update_yaxes(title_text="phi_a (rad)", range=axis_range, tickvals=tick_vals, ticktext=tick_text, row=r_0 + 1, col=c_0 + 1)
            fig_quads._uuid = uuid.uuid4().hex
            pdf_bytes = fig_quads.to_image(format="pdf", engine="kaleido")
            reader = PdfReader(io.BytesIO(pdf_bytes))
            writer.add_page(reader.pages[0])
            quad_labels = ["Quad-BL", "Quad-BR", "Quad-TL", "Quad-TR"]
            fig_quads_mer = make_subplots(rows=2, cols=2, subplot_titles=quad_labels)
            pad = 0.30
            axis_range = [-pad, 2 * np.pi + pad]
            tick_vals = [0, np.pi, 2 * np.pi]
            tick_text = ["0", "pi", "2pi"]
            color = "red"
            merge_eps = 0.25
            def _torus_dist(x1, y1, x2, y2):
                dx = np.abs(x1 - x2); dx = np.minimum(dx, 2 * np.pi - dx)
                dy = np.abs(y1 - y2); dy = np.minimum(dy, 2 * np.pi - dy)
                return np.hypot(dx, dy)
            def _circ_mean(angles, weights):
                s = np.sum(weights * np.sin(angles))
                c = np.sum(weights * np.cos(angles))
                return (np.arctan2(s, c)) % (2 * np.pi)
            raw_quads = [[] for _ in range(4)]
            for nn in neuron_main:
                quad_phase = _quadrant_ab_phases(pre_grid[:, :, nn])
                max_amp = float(np.abs(pre_grid[:, :, nn]).max())
                for q_idx, (phix, phiy) in enumerate(quad_phase):
                    raw_quads[q_idx].append((float(phix), float(phiy), max_amp))
            for q_idx, pts in enumerate(raw_quads):
                r_0, c_0 = divmod(q_idx, 2)
                if not pts:
                    fig_quads_mer.update_xaxes(title_text="phi_b (rad)", range=axis_range, tickvals=tick_vals, ticktext=tick_text, row=r_0 + 1, col=c_0 + 1)
                    fig_quads_mer.update_yaxes(title_text="phi_a (rad)", range=axis_range, tickvals=tick_vals, ticktext=tick_text, row=r_0 + 1, col=c_0 + 1)
                    continue
                used = [False] * len(pts)
                merged = []
                for i, (x0, y0, a0) in enumerate(pts):
                    if used[i]:
                        continue
                    group_idx = [i]
                    used[i] = True
                    for j, (xj, yj, aj) in enumerate(pts):
                        if not used[j] and _torus_dist(x0, y0, xj, yj) <= merge_eps:
                            used[j] = True
                            group_idx.append(j)
                    A = np.array([pts[k][2] for k in group_idx], dtype=float)
                    X = np.array([pts[k][0] for k in group_idx], dtype=float)
                    Y = np.array([pts[k][1] for k in group_idx], dtype=float)
                    mx = _circ_mean(X, A)
                    my = _circ_mean(Y, A)
                    merged.append((mx, my, float(A.sum()), len(group_idx)))
                for (mx, my, sum_amp, count) in merged:
                    size = 8 + 4 * np.log1p(count)
                    fig_quads_mer.add_trace(go.Scatter(x=[mx], y=[my], mode="markers+text", marker=dict(size=size, color=color), text=[f"{sum_amp:.2f}"], textposition="top center", textfont=dict(size=10), cliponaxis=False, showlegend=False), row=r_0 + 1, col=c_0 + 1)
                fig_quads_mer.update_xaxes(title_text="phi_b (rad)", range=axis_range, tickvals=tick_vals, ticktext=tick_text, row=r_0 + 1, col=c_0 + 1)
                fig_quads_mer.update_yaxes(title_text="phi_a (rad)", range=axis_range, tickvals=tick_vals, ticktext=tick_text, row=r_0 + 1, col=c_0 + 1)
            fig_quads_mer.update_layout(margin=dict(l=80, r=60, t=80, b=80))
            fig_quads_mer._uuid = uuid.uuid4().hex
            pdf_bytes = fig_quads_mer.to_image(format="pdf", engine="kaleido")
            reader = PdfReader(io.BytesIO(pdf_bytes))
            writer.add_page(reader.pages[0])
            cluster_grid = pre_grid[:, :, neuron_list]
            log_mp_all = np.log10(np.max(np.abs(cluster_grid), axis=(0, 1)) + 1e-20)
            color_tag = np.array(["drop"] * len(neuron_list))
            idx_map = {nn: i for i, nn in enumerate(neuron_list)}
            for nn in neuron_main:
                if nn in idx_map:
                    color_tag[idx_map[nn]] = "main"
            fig_mp = px.scatter(x=np.arange(len(neuron_list)), y=log_mp_all, color=color_tag, labels=dict(x="Neuron index in cluster", y="log10(max pre-act)"), title=f"Cluster ({r},{s}) - log10(max pre-act): main vs drop")
            fig_mp._uuid = uuid.uuid4().hex
            pdf_bytes = fig_mp.to_image(format="pdf", engine="kaleido")
            writer.add_page(PdfReader(io.BytesIO(pdf_bytes)).pages[0])
            mat = pre_grid[:, :, neuron_main].reshape(G * G, -1).astype(float)
            embed_dir = os.path.join(save_dir, f"cluster_{r}_{s}_embeds")
            os.makedirs(embed_dir, exist_ok=True)
            if isinstance(r, str) and r.startswith("2D_"):
                f = int(r.split("_", 1)[1])
                pca_diffusion_plots_w_helpers.run_pca_and_stripes_no_plots(mat, p=p, save_dir=embed_dir, freq_list=[f], tag_q="full", num_principal_components=4, s_mode="anchor", model="auto", cluster_meta={"n_neurons_main": len(neuron_main), "n_neurons_total": len(neuron_list)})
            path = f"cluster_{r}_{s}.pdf"
            fin_path = os.path.join(save_dir, path)
            with open(fin_path, "wb") as f:
                writer.write(f)
            print("Saved", fin_path)
        else:
            cluster_acts = pre_grid[:, :, neuron_list]
            max_act = cluster_acts.max()
            print(f"Skip {r}_{s} irreps plot, max activation: {max_act}")
