import base64
import io
import json
import math
import os
import re
import tempfile
import time
import uuid
from collections import Counter, OrderedDict, defaultdict
from functools import reduce
from itertools import islice
from math import cos, gcd, pi, sin
from pathlib import Path
from typing import Any, Dict, Iterable, List, Tuple

import jax
import jax.numpy as jnp
import numpy as np
import numpy as onp
import plotly.express as px
import plotly.graph_objects as go
import plotly.io as pio
from plotly.subplots import make_subplots
from PyPDF2 import PdfReader, PdfWriter
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA

import DFT
import dihedral
import dn_recon
import paper_plots
import pca_diffusion_plots_w_helpers
import analysis.R2 as R2
import analysis.cayley as cayley
from analysis.plane_fit import plane_angle_per_cluster
from pca_diffusion_plots_w_helpers import compute_pca_coords, generate_pdf_plots_for_matrix

pio.kaleido.scope.default_timeout = 60 * 5


def inverse_on_cosets(f: int, p: int) -> dict[int, int]:
    f, p = int(f), int(p)
    g = gcd(f, p)
    pg = p // g
    f_prime = f // g
    if gcd(f_prime, pg) != 1:
        raise ValueError(f"{f_prime} not invertible mod {pg}")
    inv = pow(f_prime, -1, pg)
    return {k: inv for k in range(g)}


def step_size(f: int, p: int) -> int:
    g = math.gcd(f, p)
    n = p // g
    return pow(f // g, -1, n)


def _remap_block(block: onp.ndarray):
    p = block.shape[0]
    fb, fa = dominant_freqs_ab(block)
    A = (fa * jnp.arange(p)) % p
    B = (fb * jnp.arange(p)) % p
    out = jnp.zeros_like(block)
    out = out.at[A[:, None], B[None, :]].set(block)
    return onp.asarray(out), fb, fa


def _remap_block_by_freq(block: np.ndarray, fb: int, fa: int) -> np.ndarray:
    p = block.shape[0]
    A = (fa * np.arange(p)) % p
    B = (fb * np.arange(p)) % p
    out = np.zeros_like(block)
    out[A[:, None], B[None, :]] = block
    return out


def _remapped_quadrants_by_freq(tile: np.ndarray, fb: int, fa: int) -> tuple[np.ndarray, int, int]:
    p = tile.shape[0] // 2
    BL, BR, TL, TR = tile[:p, :p], tile[:p, p:], tile[p:, :p], tile[p:, p:]
    stitched = np.zeros_like(tile)
    fb = abs(fb)
    fa = abs(fa)
    stitched[:p, :p] = _remap_block_by_freq(BL, fb, fa)
    stitched[:p, p:] = _remap_block_by_freq(BR, fb, fa)
    stitched[p:, :p] = _remap_block_by_freq(TL, fb, fa)
    stitched[p:, p:] = _remap_block_by_freq(TR, fb, fa)
    return stitched, fb, fa


def dominant_freqs_ab(grid: np.ndarray) -> Tuple[int, int]:
    p = grid.shape[0]
    F = np.fft.fft2(grid)
    F_mag = np.abs(F) ** 2
    F_mag[0, 0] = -np.inf

    half = p // 2
    row0 = F_mag[0, : half + 1].copy() if half >= 0 else np.array([])
    col0 = F_mag[: half + 1, 0].copy() if half >= 0 else np.array([])
    if row0.size:
        row0[0] = -np.inf
    if col0.size:
        col0[0] = -np.inf

    fb_idx = int(np.argmax(row0)) if row0.size else 0
    fa_idx = int(np.argmax(col0)) if col0.size else 0
    axis_best = max(
        float(row0[fb_idx]) if row0.size else -np.inf,
        float(col0[fa_idx]) if col0.size else -np.inf,
    )

    ks = np.arange(1, half + 1, dtype=int)
    diag_vals = F_mag[ks, ks] if ks.size else np.array([])
    if diag_vals.size:
        best_k_diag = int(ks[int(np.argmax(diag_vals))])
        diag_best = float(np.max(diag_vals))
    else:
        best_k_diag, diag_best = 0, -np.inf

    if diag_best > axis_best:
        fdiag = best_k_diag if best_k_diag > 0 else 1
        return fdiag, fdiag

    fb = fb_idx if fb_idx > 0 else 1
    fa = fa_idx if fa_idx > 0 else 1
    return fb, fa


def _remapped_quadrants(tile: np.ndarray) -> tuple[np.ndarray, int, int]:
    p = tile.shape[0] // 2
    quads = [tile[:p, :p], tile[:p, p:], tile[p:, :p], tile[p:, p:]]
    remapped = []
    fb_set: set[int] = set()
    fa_set: set[int] = set()
    for q in quads:
        out, fb_q, fa_q = _remap_block(q)
        remapped.append(out)
        fb_set.add(fb_q)
        fa_set.add(fa_q)

    stitched = np.zeros_like(tile)
    stitched[:p, :p] = remapped[0]
    stitched[:p, p:] = remapped[1]
    stitched[p:, :p] = remapped[2]
    stitched[p:, p:] = remapped[3]

    fa = 0
    fb = 0
    if len(fb_set) == 1 and len(fa_set) == 1:
        for freqb in fb_set:
            fb = freqb
        for freqa in fa_set:
            fa = freqa
    else:
        print("Quadrants are not mapped using the same fb/fa.")
    return stitched, fb, fa


def single_freq_phase_shifts_ab(
    mat: np.ndarray,
    p: int,
    fb: int,
    fa: int,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    if not (0 < fb < p) or not (0 < fa < p):
        raise ValueError("fb, fa must be in 1..p-1 (DC & Nyquist excluded)")

    if mat.shape[0] == p * p:
        grids = mat.T.reshape(-1, p, p)
    elif mat.shape[1] == p * p:
        grids = mat.reshape(-1, p, p)
    else:
        raise ValueError("mat must contain p*p pixels per sample")

    F = np.fft.fft2(grids, axes=(-2, -1)) / (p * p)
    c_b = F[:, 0, fb]
    c_a = F[:, fa, 0]

    amps = 2 * (np.abs(c_b) + np.abs(c_a))
    phi_b = np.mod(np.angle(c_b), 2 * np.pi)
    phi_a = np.mod(np.angle(c_a), 2 * np.pi)

    phi_b[np.isclose(phi_b, 2 * np.pi, atol=1e-8)] = 0.0
    phi_a[np.isclose(phi_a, 2 * np.pi, atol=1e-8)] = 0.0

    return phi_b, phi_a, amps


def _quadrant_ab_phases(grid: np.ndarray) -> List[Tuple[float, float]]:
    G = grid.shape[0]
    if grid.shape[1] != G:
        raise ValueError("grid must be square")
    p = G // 2

    quads = [
        grid[:p, :p],
        grid[:p, p:],
        grid[p:, :p],
        grid[p:, p:],
    ]

    results: List[Tuple[float, float]] = []
    for q in quads:
        F = np.fft.fft2(q) / (p * p)
        mags_b = np.abs(F[0, 1 : p // 2 + 1])
        mags_a = np.abs(F[1 : p // 2 + 1, 0])
        fb = int(np.argmax(mags_b)) + 1
        fa = int(np.argmax(mags_a)) + 1

        mat = q.reshape(1, -1)
        phi_b, phi_a, _ = single_freq_phase_shifts_ab(mat, p, fb, fa)
        results.append((phi_b[0], phi_a[0]))

    return results


def plotly_to_pdf_bytes(fig, scale=2.0):
    return fig.to_image(format="pdf", scale=scale, width=600, height=450)


def dominant_irrep(Fhat_n, names):
    name2idx = {lab: i for i, lab in enumerate(names)}
    D = len(names)
    P = np.zeros((D, D))
    for (r, s, _), M in Fhat_n.items():
        P[name2idx[r], name2idx[s]] = np.linalg.norm(M)

    P[0, 0] = -np.inf
    a_vals = P[:, 0].copy()
    b_vals = P[0, :].copy()
    diag_vals = np.diagonal(P).copy()

    mx_a, ia = a_vals.max(), a_vals.argmax()
    mx_b, ib = b_vals.max(), b_vals.argmax()
    mx_d, idg = diag_vals.max(), diag_vals.argmax()

    if mx_d > mx_a and mx_d > mx_b:
        fa, fb = idg, idg
    else:
        fa, fb = ia, ib

    return (names[fb], names[fa])


def _power_matrix_from_Fhat(Fhat_n, names):
    name2idx = {lab: i for i, lab in enumerate(names)}
    D = len(names)
    P = np.zeros((D, D))
    for (r, s, _), M in Fhat_n.items():
        P[name2idx[r], name2idx[s]] = np.linalg.norm(np.array(M))
    return P


def _classify_by_gft(Fhat_n, names, freq_map, *, strict=True):
    P = _power_matrix_from_Fhat(Fhat_n, names)
    P[0, 0] = -np.inf
    a_vals = P[:, 0].copy()
    b_vals = P[0, :].copy()
    diag_vals = np.diagonal(P).copy()

    ia = int(np.argmax(a_vals))
    ib = int(np.argmax(b_vals))
    idg = int(np.argmax(diag_vals))
    mx_a, mx_b, mx_d = a_vals[ia], b_vals[ib], diag_vals[idg]

    if mx_d > mx_a and mx_d > mx_b:
        kind = "diag"
        b_star = a_star = names[idg]
        if b_star not in freq_map:
            if strict:
                raise KeyError(f"freq_map no mapping for irrep '{b_star}'")
            fa = fb = None
        else:
            fa = fb = int(freq_map[b_star])
    else:
        kind = "axis"
        b_star, a_star = names[ib], names[ia]
        fa = int(freq_map.get(a_star, -1)) if a_star in freq_map else None
        fb = int(freq_map.get(b_star, -1)) if b_star in freq_map else None
        if strict and (fa is None or fb is None):
            raise KeyError(f"freq_map missing for ({b_star},{a_star})")

    return {
        "kind": kind,
        "b_star": b_star,
        "a_star": a_star,
        "fa": fa,
        "fb": fb,
        "mx_a": float(mx_a),
        "mx_b": float(mx_b),
        "mx_d": float(mx_d),
        "ia": ia,
        "ib": ib,
        "idg": idg,
    }


def subgroup_scores(vec, coset_masks, skip_trivial=True, top_k=3):
    H2masks = defaultdict(list)
    for (H, cid), m in coset_masks.items():
        if skip_trivial and H == "C_1":
            continue
        H2masks[H].append(m)

    tot = vec.var()
    if tot < 1e-12:
        items = [
            (H, {"CH": 0.0, "Cbar": 0.0, "K": len(H2masks[H])})
            for H in H2masks.keys()
        ]
        items.sort(key=lambda t: t[1]["Cbar"])
        if isinstance(top_k, int):
            items = items[:top_k]
        return OrderedDict(items)

    scores = []
    for H, mask_list in H2masks.items():
        K = len(mask_list)
        sum_var = sum(vec[m].var() for m in mask_list)
        C = sum_var / tot
        Cbar = (sum_var / max(1, K)) / tot
        scores.append((H, {"CH": C, "Cbar": Cbar, "K": K}))

    scores.sort(key=lambda t: t[1]["Cbar"])
    if isinstance(top_k, int):
        scores = scores[:top_k]
    return OrderedDict(scores)


def merge_topk_sources(tag_scores_list, top_k: int = 3, tag_key: str = "origin"):
    merged = []
    for tag, scores in tag_scores_list:
        for H, s in scores.items():
            merged.append((s["Cbar"], H, tag, s["CH"], s["K"]))
    merged.sort(key=lambda t: t[0])
    top = merged[:top_k]
    return [
        {"H": H, tag_key: tag, "Cbar": cbar, "CH": CH, "K": K}
        for (cbar, H, tag, CH, K) in top
    ]


def _quad_set_extrema_header(grid_2p: np.ndarray) -> tuple[str, bool]:
    G = grid_2p.shape[0]
    assert grid_2p.shape[1] == G, "grid must be square"
    p = G // 2

    quads = {
        "BL": grid_2p[:p, :p],
        "BR": grid_2p[:p, p:],
        "TL": grid_2p[p:, :p],
        "TR": grid_2p[p:, p:],
    }
    bases = {"BL": (0, 0), "BR": (p, 0), "TL": (0, p), "TR": (p, p)}

    parts = []
    order = ["BL", "BR", "TL", "TR"]
    max_labels = {}

    for tag in order:
        q = quads[tag]
        fb, fa = dominant_freqs_ab(q)

        if fb != fa:
            return f"<b>skip: {tag} fb({fb}) != fa({fa})</b>", False

        f = int(fb)
        if (p % f) != 0:
            return f"<b>skip: {tag} p({p}) not divisible by f({f})</b>", False

        g = p // math.gcd(p, f)

        amax_flat = int(np.argmax(q))
        amin_flat = int(np.argmin(q))
        amax = np.unravel_index(amax_flat, q.shape)
        amin = np.unravel_index(amin_flat, q.shape)

        b_base, a_base = bases[tag]
        max_label = (b_base + int(amax[1] % g), a_base + int(amax[0] % g))
        min_label = (b_base + int(amin[1] % g), a_base + int(amin[0] % g))

        max_labels[tag] = max_label
        parts.append(f"{tag} max:{max_label} min:{min_label}")

    cond1 = max_labels["BL"][0] == max_labels["TL"][0]
    cond2 = max_labels["BR"][0] == max_labels["TR"][0]
    cond3 = max_labels["BL"][1] == max_labels["BR"][1]
    cond4 = max_labels["TL"][1] == max_labels["TR"][1]
    all_true = cond1 and cond2 and cond3 and cond4

    if all_true:
        parts.append("ALL_TRUE")
    header = " | ".join(parts)
    return f"<b>{header}</b>", True


def quad_set_extrema_records_strict_f(
    grid_2p: np.ndarray,
    n: int,
    cayley: np.ndarray | None = None,
) -> dict:
    grid_2p = np.asarray(grid_2p, dtype=float)
    G = grid_2p.shape[0]
    assert grid_2p.shape[1] == G, "grid must be square"
    p = G // 2

    quads = {
        "BL": grid_2p[:p, :p],
        "BR": grid_2p[:p, p:],
        "TL": grid_2p[p:, :p],
        "TR": grid_2p[p:, p:],
    }
    order = ["BL", "BR", "TL", "TR"]

    f_list = []
    for tag in order:
        q = quads[tag]
        fb, fa = dominant_freqs_ab(q)
        if fb != fa:
            return dict(ok=False, reason=f"{tag}: fb({fb}) != fa({fa})", f=None, g=None, records=[])
        f_list.append(int(fb))

    uniq = sorted(set(f_list))
    if len(uniq) != 1:
        return dict(ok=False, reason=f"f varies across quads: {f_list} (uniq={uniq})", f=None, g=None, records=[])

    f = uniq[0]
    if f == 0:
        return dict(ok=False, reason=f"dominant frequency is DC (f=0) for n={n}", f=0, g=None, records=[])
    if (p % f) != 0:
        return dict(ok=False, reason=f"p({p}) not divisible by f({f})", f=None, g=None, records=[])

    g = p // math.gcd(p, f)
    bases = {"BL": (0, 0), "BR": (p, 0), "TL": (0, p), "TR": (p, p)}

    if g <= 1:
        return dict(ok=False, reason=f"g={g} (p={p}, f={f}) leaves no second-best residue", f=f, g=g, records=[])

    def second_best_mod_row(row: np.ndarray, g: int, primary_mod: int):
        mods = np.arange(row.size) % g
        best_val = -np.inf
        best_mod = None
        for r in range(g):
            if r == primary_mod:
                continue
            mask = mods == r
            if not np.any(mask):
                continue
            v = row[mask].max()
            if v > best_val:
                best_val = v
                best_mod = r
        return best_mod, float(best_val)

    def second_best_mod_col(col: np.ndarray, g: int, primary_mod: int):
        mods = np.arange(col.size) % g
        best_val = -np.inf
        best_mod = None
        for r in range(g):
            if r == primary_mod:
                continue
            mask = mods == r
            if not np.any(mask):
                continue
            v = col[mask].max()
            if v > best_val:
                best_val = v
                best_mod = r
        return best_mod, float(best_val)

    records = []
    for tag in order:
        q = quads[tag]

        amax_flat = int(np.argmax(q))
        a_idx, b_idx = np.unravel_index(amax_flat, q.shape)
        max_val = float(q[a_idx, b_idx])

        a_mod0 = int(a_idx % g)
        b_mod0 = int(b_idx % g)

        rec = {
            "n": int(n),
            "quad": tag,
            "a_mod": a_mod0,
            "b_mod": b_mod0,
            "max_act": max_val,
        }

        if cayley is not None:
            b_base, a_base = bases[tag]
            a_abs = a_idx + a_base
            b_abs = b_idx + b_base
            c_idx = int(cayley[a_abs, b_abs])
            rec["c_idx"] = c_idx
            rec["c_mod"] = int(c_idx % g)
            rec["c_upper"] = bool(c_idx >= p)

        if q.shape[1] > 1:
            b2_mod, val2 = second_best_mod_row(q[a_idx, :], g, b_mod0)
            if b2_mod is not None:
                rec["sec_max_a_fx"] = {"a_mod": a_mod0, "b_mod": int(b2_mod), "val": float(val2)}

        if q.shape[0] > 1:
            a2_mod, val3 = second_best_mod_col(q[:, b_idx], g, a_mod0)
            if a2_mod is not None:
                rec["sec_max_b_fx"] = {"a_mod": int(a2_mod), "b_mod": b_mod0, "val": float(val3)}

        records.append(rec)

    return dict(ok=True, reason=None, f=f, g=g, records=records)


def pick_c_records(records):
    out = []
    for r in records:
        if "c_mod" in r:
            out.append(
                {
                    "n": r["n"],
                    "quad": r["quad"],
                    "c_mod": int(r["c_mod"]),
                    "c_upper": bool(r.get("c_upper", False)),
                    "max_act": float(r.get("max_act", 0.0)),
                }
            )
    return out


def _ngon_vertices(g: int, R: float, cx: float, cy: float, phi0: float):
    angs = [phi0 + 2 * pi * k / g for k in range(g)]
    xs = [cx + R * cos(a) for a in angs]
    ys = [cy + R * sin(a) for a in angs]
    return np.array(xs), np.array(ys)


def pick_records_for_mode(records, mode: str):
    out = []
    for r in records:
        if mode == "max":
            out.append(
                {"n": r["n"], "quad": r["quad"], "a_mod": r["a_mod"], "b_mod": r["b_mod"], "max_act": r["max_act"]}
            )
        elif mode == "sec_a":
            s = r.get("sec_max_a_fx")
            if s is not None:
                out.append({"n": r["n"], "quad": r["quad"], "a_mod": s["a_mod"], "b_mod": s["b_mod"], "max_act": s["val"]})
        elif mode == "sec_b":
            s = r.get("sec_max_b_fx")
            if s is not None:
                out.append({"n": r["n"], "quad": r["quad"], "a_mod": s["a_mod"], "b_mod": s["b_mod"], "max_act": s["val"]})
        else:
            raise ValueError("mode must be 'max' | 'sec_a' | 'sec_b'")
    return out


def plot_coset_ngon_ring(
    p: int,
    f: int,
    records: list,
    neuron_main: set,
    title: str | None = None,
    R_outer: float = 1.0,
    r_inner: float = 0.26,
    ring_cap: int = 6,
    ring_step: float = 0.045,
    text_size: int = 12,
    show_labels: bool = True,
    rotate_random: bool = False,
    seed: int | None = None,
) -> go.Figure:
    g = p // gcd(p, f)
    if g < 3:
        raise ValueError(f"g must be >=3; got g={g} (p={p}, f={f})")

    rng = np.random.default_rng(seed) if rotate_random else None

    if rotate_random:
        phi0_a_low = float(rng.uniform(0, 2 * pi))
        phi0_b_low = float(rng.uniform(0, 2 * pi))
    else:
        phi0_a_low = pi / 2
        phi0_b_low = 0.0

    phi0_a_up = phi0_a_low + pi / g
    phi0_b_up = phi0_b_low + pi / g

    XaL, YaL = _ngon_vertices(g, R_outer, 0.0, 0.0, phi0_a_low)
    XaU, YaU = _ngon_vertices(g, R_outer, 0.0, 0.0, phi0_a_up)

    fig = go.Figure()

    fig.add_trace(go.Scatter(x=np.r_[XaL, XaL[:1]], y=np.r_[YaL, YaL[:1]], mode="lines", line=dict(width=2, color="red"), hoverinfo="skip", showlegend=False))
    fig.add_trace(go.Scatter(x=np.r_[XaU, XaU[:1]], y=np.r_[YaU, YaU[:1]], mode="lines", line=dict(width=2, color="green"), hoverinfo="skip", showlegend=False))

    if show_labels:
        for r in range(g):
            fig.add_trace(go.Scatter(x=[XaL[r]], y=[YaL[r]], mode="text", text=[f"a mod {g} = {r}"], textfont=dict(size=12, color="red"), hoverinfo="skip", showlegend=False))
        for r in range(g):
            fig.add_trace(go.Scatter(x=[XaU[r]], y=[YaU[r]], mode="text", text=[f"a in upper half, mod {g} = {r}"], textfont=dict(size=12, color="green"), hoverinfo="skip", showlegend=False))

    def _draw_inner_b_at(cx, cy):
        XbL, YbL = _ngon_vertices(g, r_inner, cx, cy, phi0_b_low)
        XbU, YbU = _ngon_vertices(g, r_inner, cx, cy, phi0_b_up)
        fig.add_trace(go.Scatter(x=np.r_[XbL, XbL[:1]], y=np.r_[YbL, YbL[:1]], mode="lines", line=dict(width=1, color="rgba(0,0,0,0.25)"), hoverinfo="skip", showlegend=False))
        fig.add_trace(go.Scatter(x=np.r_[XbU, XbU[:1]], y=np.r_[YbU, YbU[:1]], mode="lines", line=dict(width=1, color="rgba(0,0,0,0.25)"), hoverinfo="skip", showlegend=False))
        if show_labels:
            for r in range(g):
                fig.add_trace(go.Scatter(x=[XbL[r]], y=[YbL[r]], mode="text", text=[f"b mod {g} = {r}"], textfont=dict(size=10), hoverinfo="skip", showlegend=False))
            for r in range(g):
                fig.add_trace(go.Scatter(x=[XbU[r]], y=[YbU[r]], mode="text", text=[f"b in upper half, mod {g} = {r}"], textfont=dict(size=10), hoverinfo="skip", showlegend=False))
        return (XbL, YbL, XbU, YbU)

    inner_cache = {}
    for r in range(g):
        inner_cache[("aL", r)] = _draw_inner_b_at(XaL[r], YaL[r])
    for r in range(g):
        inner_cache[("aU", r)] = _draw_inner_b_at(XaU[r], YaU[r])

    clusters = defaultdict(list)

    def halves_from_quad(quad):
        quad = quad.upper()
        a_upper = quad in ("TL", "TR")
        b_upper = quad in ("BR", "TR")
        return a_upper, b_upper

    for rec in records:
        if rec["n"] not in neuron_main:
            continue
        n = int(rec["n"])
        a_mod = int(rec["a_mod"]) % g
        b_mod = int(rec["b_mod"]) % g
        max_act = float(rec.get("max_act", 0.0))
        a_upper, b_upper = halves_from_quad(rec["quad"])
        a_key = ("aU", a_mod) if a_upper else ("aL", a_mod)
        XbL, YbL, XbU, YbU = inner_cache[a_key]

        if b_upper:
            cx, cy = XbU[b_mod], YbU[b_mod]
        else:
            cx, cy = XbL[b_mod], YbL[b_mod]
        clusters[(a_key, ("bU" if b_upper else "bL"), b_mod)].append({"n": n, "max": max_act, "center": (cx, cy)})

    Xs, Ys, texts = [], [], []
    for key, items in clusters.items():
        items_sorted = sorted(items, key=lambda d: (-d["max"], d["n"]))
        for i, d in enumerate(items_sorted):
            ring = i // ring_cap
            pos = i % ring_cap
            ang = 2 * pi * pos / max(1, ring_cap)
            rad = ring_step * (ring + 1)
            cx, cy = d["center"]
            Xs.append(cx + rad * cos(ang))
            Ys.append(cy + rad * sin(ang))
            texts.append(str(d["n"]))

    if Xs:
        fig.add_trace(go.Scatter(x=Xs, y=Ys, mode="text", text=texts, textposition="top center", textfont=dict(size=text_size, color="#222"), hoverinfo="skip", showlegend=False))

    fig.update_xaxes(visible=False)
    fig.update_yaxes(visible=False, scaleanchor="x", scaleratio=1)
    fig.update_layout(title=title or f"Coset n-gon (g={g}) - p={p}, f={f}", width=900, height=900, margin=dict(l=40, r=40, t=60, b=40), showlegend=False)
    return fig


def plot_coset_ngon_c_single(
    p: int,
    f: int,
    c_records: list,
    neuron_main: set,
    title: str | None = None,
    R_outer: float = 1.0,
    ring_cap: int = 6,
    ring_step: float = 0.05,
    text_size: int = 12,
    show_labels: bool = True,
    rotate_random: bool = False,
    seed: int | None = None,
) -> go.Figure:
    g = p // gcd(p, f)
    if g < 3:
        raise ValueError(f"g must be >=3; got g={g} (p={p}, f={f})")
    rng = np.random.default_rng(seed) if rotate_random else None
    phi0_low = float(rng.uniform(0, 2 * np.pi)) if rotate_random else np.pi / 2
    phi0_high = phi0_low + np.pi / g

    Xlow, Ylow = _ngon_vertices(g, R_outer, 0.0, 0.0, phi0_low)
    Xhigh, Yhigh = _ngon_vertices(g, R_outer, 0.0, 0.0, phi0_high)

    fig = go.Figure()
    fig.add_trace(go.Scatter(x=np.r_[Xlow, Xlow[:1]], y=np.r_[Ylow, Ylow[:1]], mode="lines", line=dict(width=2, color="red"), hoverinfo="skip", showlegend=False))
    fig.add_trace(go.Scatter(x=np.r_[Xhigh, Xhigh[:1]], y=np.r_[Yhigh, Yhigh[:1]], mode="lines", line=dict(width=2, color="green"), hoverinfo="skip", showlegend=False))

    if show_labels:
        for r in range(g):
            fig.add_trace(go.Scatter(x=[Xlow[r]], y=[Ylow[r]], mode="text", text=[f"c mod {g} = {r}"], textfont=dict(size=12, color="red"), hoverinfo="skip", showlegend=False))
            fig.add_trace(go.Scatter(x=[Xhigh[r]], y=[Yhigh[r]], mode="text", text=[f"c in upper half, mod {g} = {r}"], textfont=dict(size=12, color="green"), hoverinfo="skip", showlegend=False))

    clusters = defaultdict(list)
    for rec in c_records:
        if rec["n"] not in neuron_main:
            continue
        n = int(rec["n"])
        r = int(rec["c_mod"]) % g
        center = (Xhigh[r], Yhigh[r]) if rec["c_upper"] else (Xlow[r], Ylow[r])
        clusters[(r, rec["c_upper"])].append({"n": n, "max": float(rec["max_act"]), "center": center})

    Xs, Ys, texts = [], [], []
    for key, items in clusters.items():
        items_sorted = sorted(items, key=lambda d: (-d["max"], d["n"]))
        for i, d in enumerate(items_sorted):
            ring = i // ring_cap
            pos = i % ring_cap
            ang = 2 * np.pi * pos / max(1, ring_cap)
            rad = ring_step * (ring + 1)
            cx, cy = d["center"]
            Xs.append(cx + rad * np.cos(ang))
            Ys.append(cy + rad * np.sin(ang))
            texts.append(str(d["n"]))

    if Xs:
        fig.add_trace(go.Scatter(x=Xs, y=Ys, mode="text", text=texts, textposition="top center", textfont=dict(size=text_size, color="#222"), hoverinfo="skip", showlegend=False))

    fig.update_xaxes(visible=False)
    fig.update_yaxes(visible=False, scaleanchor="x", scaleratio=1)
    fig.update_layout(title=title or f"c n-gon (single layer) - p={p}, f={f}, g={g}", width=900, height=900, margin=dict(l=40, r=40, t=60, b=40), showlegend=False)
    return fig


def gft_energy_matrix_single_channel(X2: np.ndarray, *, dft_fn, irreps):
    names = [name for (name, dim, R, freq) in irreps]
    D = len(names)
    X = X2[:, :, None].astype(np.float32)
    F_full = dft_fn(X)

    P = np.zeros((D, D), dtype=float)
    for (r, s, idx), M in F_full.items():
        if int(idx) == 0:
            P[names.index(r), names.index(s)] = float(np.linalg.norm(np.asarray(M)))
    return P, names


def single_head_attn_report_page_all_qk(
    *,
    h,
    attn_full,
    group_size,
    dft_fn,
    irreps,
    qk_list=None,
    zmin=None,
    zmax=None,
    title="",
    demean_for_gft=True,
    row_scale_mode="auto",
    q_lo=0.02,
    q_hi=0.98,
):
    B, H, T, _ = attn_full.shape
    if qk_list is None:
        qk_list = [(q, k) for q in range(T) for k in range(T)]

    nrows = len(qk_list)
    names = [name for (name, dim, R, freq) in irreps]

    fig = make_subplots(
        rows=nrows,
        cols=2,
        subplot_titles=[f"attn[q={q}->k={k}] (head {h})" if c == 0 else "GFT energy" for (q, k) in qk_list for c in range(2)],
        specs=[[{"type": "heatmap"}, {"type": "heatmap"}] for _ in range(nrows)],
        horizontal_spacing=0.08,
        vertical_spacing=max(0.05, 0.18 / max(1, nrows)),
    )

    def _row_y(r):
        return 1.0 - (r - 0.5) / nrows

    bar_len = min(0.9 / nrows, 0.18)

    for r, (q, k) in enumerate(qk_list, start=1):
        X = attn_full[:, h, q, k].reshape(group_size, group_size).astype(np.float32)

        hm_kwargs = dict(z=X, colorscale="Blues", showscale=True)

        if row_scale_mode == "quantile":
            lo = float(np.quantile(X, q_lo))
            hi = float(np.quantile(X, q_hi))
            if hi <= lo:
                hi = lo + 1e-6
            hm_kwargs["zmin"] = lo
            hm_kwargs["zmax"] = hi
        elif row_scale_mode == "auto":
            pass
        else:
            raise ValueError("row_scale_mode must be 'auto' or 'quantile'")

        if zmin is not None:
            hm_kwargs["zmin"] = float(zmin)
        if zmax is not None:
            hm_kwargs["zmax"] = float(zmax)

        hm_kwargs["colorbar"] = dict(title="attn", x=0.46, y=_row_y(r), len=bar_len, yanchor="middle")
        fig.add_trace(go.Heatmap(**hm_kwargs), row=r, col=1)
        fig.update_xaxes(title_text="b", row=r, col=1)
        fig.update_yaxes(title_text="a", row=r, col=1)

        Xg = X - float(X.mean()) if demean_for_gft else X
        P, _ = gft_energy_matrix_single_channel(Xg, dft_fn=dft_fn, irreps=irreps)

        gft_kwargs = dict(z=P, x=names, y=names, colorscale="Viridis", showscale=True)
        if row_scale_mode == "quantile":
            lo = float(np.quantile(P, q_lo))
            hi = float(np.quantile(P, q_hi))
            if hi <= lo:
                hi = lo + 1e-12
            gft_kwargs["zmin"] = lo
            gft_kwargs["zmax"] = hi

        gft_kwargs["colorbar"] = dict(title="||F||", x=1.02, y=_row_y(r), len=bar_len, yanchor="middle")
        fig.add_trace(go.Heatmap(**gft_kwargs), row=r, col=2)

    fig.update_layout(title=title or f"Attention head {h} | all (q,k) | group_size={group_size}", width=1200, height=260 * nrows + 140, margin=dict(t=70, b=40, l=60, r=120))
    return fig


def single_neuron_figure(
    n,
    pre_grid,
    left_vec,
    right_vec,
    F_full,
    F_L,
    F_R,
    names,
    subgroup_info,
    freq_map=None,
    freqs=None,
    strict=True,
    artifacts=None,
):
    G, D = pre_grid.shape[0], len(names)
    fig = make_subplots(
        rows=4,
        cols=4,
        row_heights=[0.25, 0.25, 0.25, 0.25],
        specs=[
            [{"type": "heatmap"}, {"type": "xy"}, {"type": "xy"}, {"type": "xy"}],
            [{"type": "heatmap"}, {"type": "heatmap"}, {"type": "heatmap"}, None],
            [{"type": "xy"}, {"type": "xy"}, {"type": "xy"}, {"type": "xy"}],
            [{"type": "table", "colspan": 4}, None, None, None],
        ],
        subplot_titles=[
            f"Whole-RAW (n{n})",
            f"a-RAW-n{n}, max={jnp.max(left_vec[:, n]):.2f}",
            f"b-RAW-n{n}, max={jnp.max(right_vec[:, n]):.2f}",
            f"Quad-REMAP (n{n})",
            f"Whole-DFT (n{n})",
            f"a-DFT (n{n})",
            f"b-DFT (n{n})",
            "Quad-BL",
            "Quad-BR",
            "Quad-TL",
            "Quad-TR",
        ],
        horizontal_spacing=0.02,
        vertical_spacing=0.08,
    )

    header, ok = _quad_set_extrema_header(pre_grid[:, :, n])
    if ok:
        fig.add_annotation(text=header, xref="paper", yref="paper", x=0.5, y=1.15, showarrow=False, font=dict(size=12), align="center")

    fig.add_trace(go.Heatmap(z=pre_grid[:, :, n], coloraxis="coloraxis1"), row=1, col=1)
    fig.update_xaxes(title_text="Right input (b)", row=1, col=1)
    fig.update_yaxes(title_text="Left input (a)", row=1, col=1)

    x_ticks = list(range(G))
    fig.add_trace(go.Scatter(x=x_ticks, y=left_vec[:, n], mode="lines+markers", line=dict(width=1), marker=dict(size=4), showlegend=False), row=1, col=2)
    fig.add_trace(go.Scatter(x=x_ticks, y=right_vec[:, n], mode="lines+markers", line=dict(width=1), marker=dict(size=4), showlegend=False), row=1, col=3)

    for i in [2, 3]:
        fig.update_xaxes(showgrid=True, row=1, col=i)
        fig.update_yaxes(showgrid=True, row=1, col=i)

    if freq_map is not None:
        remap_img, fb, fa = _remapped_quadrants_by_freq(pre_grid[:, :, n], fa=freqs[0], fb=freqs[0])
    else:
        remap_img, fb, fa = _remapped_quadrants(pre_grid[:, :, n])

    fig.add_trace(go.Heatmap(z=remap_img, colorscale="Viridis", showscale=False), row=1, col=4)
    fig.update_xaxes(title_text=f"b (remapped) by {fb}", row=1, col=4)
    fig.update_yaxes(title_text=f"a (remapped) by {fa}", row=1, col=4)

    for col, Fhat in enumerate([F_full, F_L, F_R], start=1):
        P = np.zeros((D, D))
        for (r, s, idx), M in Fhat.items():
            if idx == n:
                P[names.index(r), names.index(s)] = np.linalg.norm(np.array(M))
        fig.add_trace(go.Heatmap(z=P, x=names, y=names, coloraxis="coloraxis2"), row=2, col=col)

    quad_phase = _quadrant_ab_phases(pre_grid[:, :, n])
    axis_range = [0, 2 * np.pi]
    tick_vals = [0, np.pi, 2 * np.pi]
    tick_text = ["0", "pi", "2pi"]
    for idx, (phib, phia) in enumerate(quad_phase):
        fig.add_trace(
            go.Scatter(
                x=[phib],
                y=[phia],
                mode="markers+text",
                marker=dict(size=8, color="red"),
                text=[f"{phib:.2f},{phia:.2f}"],
                textposition="top center",
                textfont=dict(size=10),
                cliponaxis=False,
                showlegend=False,
            ),
            row=3,
            col=idx + 1,
        )
        fig.update_xaxes(range=axis_range, tickvals=tick_vals, ticktext=tick_text, title_text="phi_b (rad)", row=3, col=idx + 1)
        fig.update_yaxes(range=axis_range, tickvals=tick_vals, ticktext=tick_text, title_text="phi_a (rad)", row=3, col=idx + 1)

    lineL_LC = ", ".join(f"{h}:{s['Cbar']:.2e}" for h, s in subgroup_info["L"]["Lcoset"].items())
    lineL_RC = ", ".join(f"{h}:{s['Cbar']:.2e}" for h, s in subgroup_info["L"]["Rcoset"].items())
    lineL_mix = ", ".join(f"{d['origin']}*{d['H']}:{d['Cbar']:.2e}" for d in subgroup_info["L"].get("mix3", []))

    lineR_LC = ", ".join(f"{h}:{s['Cbar']:.2e}" for h, s in subgroup_info["R"]["Lcoset"].items())
    lineR_RC = ", ".join(f"{h}:{s['Cbar']:.2e}" for h, s in subgroup_info["R"]["Rcoset"].items())
    lineR_mix = ", ".join(f"{d['origin']}*{d['H']}:{d['Cbar']:.2e}" for d in subgroup_info["R"].get("mix3", []))

    headers = ["L vec - L-coset", "L vec - R-coset", "L vec - mix-Top3", "R vec - L-coset", "R vec - R-coset", "R vec - mix-Top3"]
    cells = [[lineL_LC], [lineL_RC], [lineL_mix], [lineR_LC], [lineR_RC], [lineR_mix]]

    fig.add_trace(go.Table(header=dict(values=headers, align="center"), cells=dict(values=cells, align="left"), columnwidth=[0.16, 0.16, 0.18, 0.16, 0.16, 0.18]), row=4, col=1)

    for r in range(1, 3):
        for c in range(1, 5):
            fig.update_xaxes(showgrid=False, zeroline=False, row=r, col=c)
            fig.update_yaxes(showgrid=False, zeroline=False, row=r, col=c)

    for (r, c) in [(1, 1), (1, 4), (2, 1), (2, 2), (2, 3), (3, 1), (3, 2), (3, 3), (3, 4)]:
        axis_name = f"x{(r - 1) * 4 + c}"
        fig.update_yaxes(scaleanchor=axis_name, scaleratio=1, row=r, col=c)

    fig.update_layout(
        margin=dict(t=30, b=30, l=80, r=50),
        width=900,
        height=900,
        autosize=False,
        coloraxis1=dict(colorscale="Viridis", colorbar=dict(title="RAW", len=0.25, y=0.88, yanchor="middle", x=1.02, xanchor="left")),
        coloraxis2=dict(colorscale="Viridis", colorbar=dict(title="GFT", len=0.25, y=0.53, yanchor="middle", x=1.02, xanchor="left")),
    )
    fig.update_xaxes(automargin=True)
    fig.update_yaxes(automargin=True)

    return fig


def _two_stage_kmeans_prune(log_mp: np.ndarray, thresh1: float = 2.0, thresh2: float = 2.0, seed: int = 0) -> tuple[np.ndarray, np.ndarray]:
    assert log_mp.ndim == 1
    x = log_mp.reshape(-1, 1)
    n = x.shape[0]

    if n == 0:
        return np.array([], dtype=int), np.array([], dtype=int)
    if n == 1:
        return np.array([0], dtype=int), np.array([], dtype=int)

    n_unique = np.unique(x, axis=0).shape[0]
    if n_unique < 2:
        keep = np.arange(n, dtype=int)
        drop = np.array([], dtype=int)
        return keep, drop

    km1 = KMeans(n_clusters=2, n_init="auto", random_state=seed)
    lab1 = km1.fit_predict(x)
    centers1 = km1.cluster_centers_.ravel()
    hi1 = int(np.argmax(centers1))
    gap1 = float(abs(centers1[0] - centers1[1]))

    if gap1 >= float(thresh1):
        keep1 = np.flatnonzero(lab1 == hi1)
    else:
        keep1 = np.arange(x.size, dtype=int)

    keep2 = keep1
    if keep1.size >= 2:
        x2 = x[keep1]
        if np.unique(x2, axis=0).shape[0] >= 2:
            km2 = KMeans(n_clusters=2, n_init="auto", random_state=seed + 1)
            lab2 = km2.fit_predict(x2)
            centers2 = km2.cluster_centers_.ravel()
            hi2 = int(np.argmax(centers2))
            gap2 = float(abs(centers2[0] - centers2[1]))
            if gap2 >= float(thresh2):
                rel_keep2 = np.flatnonzero(lab2 == hi2)
                keep2 = keep1[rel_keep2]

    if keep2.size == 0:
        keep2 = np.array([int(np.argmax(x))], dtype=int)

    all_idx = np.arange(x.size, dtype=int)
    drop = np.setdiff1d(all_idx, keep2, assume_unique=False)
    return keep2, drop


def _coerce_freq(v):
    if isinstance(v, (int, np.integer)):
        return int(v)
    if isinstance(v, str):
        m = re.search(r"(\d+)$", v)
        if m:
            return int(m.group(1))
    raise TypeError(f"freq value {v!r} is not an int nor a string ending with digits.")


def _validate_freq_map(names, freq_map, p):
    missing = []
    bad = []
    for lab in names:
        if lab not in freq_map:
            continue
        try:
            f = _coerce_freq(freq_map[lab])
            assert 0 < f < p, f"f={f} out of range (1..{p-1}) for label {lab!r}"
        except Exception as e:
            bad.append((lab, freq_map[lab], str(e)))
    if missing:
        raise AssertionError(f"freq_map missing labels: {missing[:8]}{'...' if len(missing) > 8 else ''}")
    if bad:
        lines = ", ".join([f"{lab}->{val} ({err})" for lab, val, err in bad[:8]])
        raise AssertionError(f"freq_map invalid entries: {lines}{' ...' if len(bad) > 8 else ''}")


def remap_with_special_sign_rule(tile: np.ndarray, dom: dict, sec_info: dict, freq_map: dict[str, int]) -> tuple[np.ndarray, int, int]:
    if dom["b_star"] == "sign" and dom["a_star"] == "sign":
        sec = (sec_info or {}).get("secondary", None)
        if not sec:
            return tile.copy(), 0, 0
        if sec not in freq_map:
            raise KeyError(f"freq_map has no frequency for '{sec}'")
        f2 = int(freq_map[sec])
        return _remapped_quadrants_by_freq(tile, f2, f2)

    fb = dom.get("fb", None)
    fa = dom.get("fa", None)
    if fb is None or fa is None or dom["kind"] not in ("diag", "axis"):
        return _remapped_quadrants(tile)
    return _remapped_quadrants_by_freq(tile, int(fb), int(fa))


def _find_secondary_strong(Fhat_n, names: list[str], dom: dict, sec_ratio: float = 1 / 3) -> dict:
    Pn = _power_matrix_from_Fhat(Fhat_n, names)
    b_star, a_star = dom["b_star"], dom["a_star"]
    out = {"mode": dom["kind"], "primary_pair": (b_star, a_star), "secondary": None, "secondary_source": None, "metrics": {}}

    if dom["kind"] == "diag" and b_star == a_star:
        diag_vals = np.diagonal(Pn).astype(float)
        prim_idx = names.index(b_star)
        prim_pow = float(diag_vals[prim_idx])
        diag_vals_wo = diag_vals.copy()
        diag_vals_wo[prim_idx] = -np.inf
        sec_idx_d = int(np.argmax(diag_vals_wo))
        sec_pow_d = float(diag_vals_wo[sec_idx_d])
        ratio_d = (sec_pow_d / prim_pow) if prim_pow > 0 else 0.0
        cand_diag = names[sec_idx_d] if (np.isfinite(sec_pow_d) and ratio_d >= sec_ratio) else None

        a_vals = Pn[:, 0].astype(float)
        b_vals = Pn[0, :].astype(float)
        ia = int(np.argmax(a_vals))
        ib = int(np.argmax(b_vals))
        prim_a = float(a_vals[ia])
        prim_b = float(b_vals[ib])

        best_k, best_score = None, (-np.inf, -np.inf)
        for k in range(len(names)):
            if k == prim_idx:
                continue
            ra = (a_vals[k] / prim_a) if prim_a > 0 else 0.0
            rb = (b_vals[k] / prim_b) if prim_b > 0 else 0.0
            if ra >= sec_ratio and rb >= sec_ratio and (a_vals[k] > 0 or b_vals[k] > 0):
                score = (min(ra, rb), ra + rb)
                if score > best_score:
                    best_score, best_k = score, k

        cand_axis = names[best_k] if best_k is not None else None
        score_axis = best_score[0] if best_k is not None else 0.0
        score_diag = ratio_d if cand_diag is not None else 0.0

        if cand_axis is not None and score_axis > score_diag:
            out["secondary"] = cand_axis
            out["secondary_source"] = "axis"
            out["metrics"].update(ratio_axis_min=score_axis, ratio_diag=score_diag)
        elif cand_diag is not None:
            out["secondary"] = cand_diag
            out["secondary_source"] = "diag"
            out["metrics"].update(ratio_axis_min=score_axis, ratio_diag=score_diag)
        else:
            out["metrics"].update(ratio_axis_min=score_axis, ratio_diag=score_diag)
        return out

    if dom["kind"] == "axis":
        a_vals = Pn[:, 0].astype(float)
        b_vals = Pn[0, :].astype(float)
        ia = int(np.argmax(a_vals))
        ib = int(np.argmax(b_vals))
        prim_a = float(a_vals[ia])
        prim_b = float(b_vals[ib])

        best_k, best_score = None, (-np.inf, -np.inf)
        for k in range(len(names)):
            ra = (a_vals[k] / prim_a) if prim_a > 0 else 0.0
            rb = (b_vals[k] / prim_b) if prim_b > 0 else 0.0
            if (ra >= sec_ratio) and (rb >= sec_ratio) and (a_vals[k] > 0 or b_vals[k] > 0):
                score = (min(ra, rb), ra + rb)
                if score > best_score:
                    best_score = score
                    best_k = k

        if best_k is not None:
            out["secondary"] = names[best_k]

        out["metrics"].update(
            primary_power_a=prim_a,
            primary_power_b=prim_b,
            secondary_power_a=(a_vals[best_k] if best_k is not None else 0.0),
            secondary_power_b=(b_vals[best_k] if best_k is not None else 0.0),
            ratio_a=((a_vals[best_k] / prim_a) if (best_k is not None and prim_a > 0) else 0.0),
            ratio_b=((b_vals[best_k] / prim_b) if (best_k is not None and prim_b > 0) else 0.0),
        )
        return out

    return out


def prepare_layer_artifacts(
    pre_grid,
    left,
    right,
    dft_2d,
    irreps,
    freq_map,
    strict=True,
    prune_cfg: dict | None = None,
    store_full_neuron_grids: bool = False,
    sec_ratio: float = 1 / 3,
):
    G, N = pre_grid.shape[0], pre_grid.shape[-1]

    flat_all = pre_grid.reshape(G * G, N)
    F_full = dft_2d(flat_all)
    F_L = dft_2d(left)
    F_R = dft_2d(right)

    names = [lab for lab, _, _, _ in irreps]
    irrep2neurons = defaultdict(list)
    freq_cluster = defaultdict(list)

    neuron_data = {}
    secondary_per_neuron = {}
    secondary_by_cluster = defaultdict(list)
    for n in range(N):
        Fhat_n = {k: v for k, v in F_full.items() if k[2] == n}
        dom = _classify_by_gft(Fhat_n, names, freq_map, strict=strict)
        b_star, a_star = dom["b_star"], dom["a_star"]
        irrep2neurons[(b_star, a_star)].append(n)

        if b_star == a_star:
            sec_info = _find_secondary_strong(Fhat_n, names, dom, sec_ratio=sec_ratio)
            secondary_per_neuron[int(n)] = sec_info
            secondary_by_cluster[(b_star, a_star)].append((int(n), sec_info["secondary"]))
        else:
            sec_info = {"mode": dom["kind"], "primary_pair": (b_star, a_star), "secondary": None, "metrics": {}}
            secondary_per_neuron[int(n)] = sec_info
            secondary_by_cluster[(b_star, a_star)].append((int(n), None))

        entry = {"a_values": np.arange(G, dtype=int), "b_values": np.arange(G, dtype=int), "dominant": dom}
        if store_full_neuron_grids:
            grid = pre_grid[:, :, n]
            post = np.maximum(grid, 0.0)
            entry["real_preactivations"] = grid
            entry["postactivations"] = post
        neuron_data[int(n)] = entry

    for k, v in irrep2neurons.items():
        irrep2neurons[k] = list(dict.fromkeys(v))

    diag_labels = set()
    cluster_prune = {}

    if prune_cfg is not None:
        rel_tau = float(prune_cfg.get("rel_tau", 0.05))
        rel_tau = max(0.0, min(1.0, rel_tau))

        for (r, s), neuron_list in irrep2neurons.items():
            if r != s or len(neuron_list) == 0:
                continue

            cluster_grid = pre_grid[:, :, neuron_list]
            max_preacts = np.max(np.abs(cluster_grid), axis=(0, 1))
            cluster_max = float(max_preacts.max()) if max_preacts.size else 0.0
            thr = rel_tau * cluster_max

            keep_rel = np.where(max_preacts > thr)[0]
            if keep_rel.size == 0 and max_preacts.size:
                keep_rel = np.array([int(np.argmax(max_preacts))], dtype=int)
            drop_rel = np.setdiff1d(np.arange(max_preacts.size), keep_rel)

            main = [int(neuron_list[i]) for i in keep_rel]
            drop = [int(neuron_list[i]) for i in drop_rel]
            per_log = {int(neuron_list[i]): float(np.log10(max_preacts[i] + 1e-20)) for i in range(len(neuron_list))}
            cluster_prune[(r, s)] = {"main": main, "drop": drop, "per_neuron_log10": per_log}

            if r not in freq_map:
                if strict:
                    raise KeyError(f"freq_map has no mapping for irrep '{r}'")
            else:
                f = int(freq_map[r])
                freq_cluster[f].extend(main)
                for n in main:
                    kind_n = neuron_data[n]["dominant"]["kind"]
                    diag_labels.add((r, kind_n))
    else:
        for (r, s), neuron_list in irrep2neurons.items():
            if r != s or not neuron_list:
                continue
            if r not in freq_map:
                if strict:
                    raise KeyError(f"freq_map has no mapping for irrep '{r}'")
                continue
            f = int(freq_map[r])
            freq_cluster[f].extend(neuron_list)
            for n in neuron_list:
                kind_n = neuron_data[int(n)]["dominant"]["kind"]
                diag_labels.add((r, kind_n))

    freq_cluster = {f: list(dict.fromkeys(ids)) for f, ids in freq_cluster.items() if ids}

    artifacts = {
        "F_full": F_full,
        "F_L": F_L,
        "F_R": F_R,
        "names": names,
        "irrep2neurons": irrep2neurons,
        "freq_cluster": freq_cluster,
        "diag_labels": diag_labels,
        "neuron_data": neuron_data,
        "secondary_per_neuron": secondary_per_neuron,
        "secondary_by_cluster": dict(secondary_by_cluster),
    }
    if prune_cfg is not None:
        artifacts["cluster_prune"] = cluster_prune
    return artifacts


def build_f_pool_from_layer1_artifacts(layer1_artifacts, freq_map, p):
    ir2n = layer1_artifacts["irrep2neurons"]
    pruned = layer1_artifacts.get("cluster_prune", {}) or {}
    fset = set()
    for (r, s), ids in ir2n.items():
        if r != s:
            continue
        used = pruned.get((r, s), {}).get("main", None) or ids
        if not used:
            continue
        if r in freq_map:
            f = int(freq_map[r])
            fset.add(abs(f))
    return sorted(fset)


def summarize_diag_labels(diag_labels: Iterable[Tuple[str, str]], p: int, names: List[str]) -> Dict[str, Any]:
    diag_labels = set((str(label).strip(), str(kind).strip()) for label, kind in diag_labels)
    label_kind_map = defaultdict(set)
    for label, kind in diag_labels:
        label_kind_map[label].add(kind)

    names_1d = {"sign", "rp", "srp"}
    pat_2d = re.compile(r"^2D[_\-]?(\d+)$", flags=re.IGNORECASE)

    approx_coset = []
    coset_2d = []
    coset_1d = []
    others = []

    for label, kinds in label_kind_map.items():
        m = pat_2d.match(label)
        kinds_list = sorted(kinds)

        if m:
            f = int(m.group(1))
            gcd_pf = math.gcd(p, f)
            item = {"label": label, "kinds": kinds_list, "f": f, "gcd_pf": gcd_pf}
            if gcd_pf == 1:
                approx_coset.append(item)
            else:
                coset_2d.append(item)
        elif label.lower() in names_1d:
            coset_1d.append({"label": label, "kinds": kinds_list})
        else:
            others.append({"label": label, "kinds": kinds_list})

    total_diag = len(label_kind_map)
    counts = {
        "approx_coset": len(approx_coset),
        "coset_2d": len(coset_2d),
        "coset_1d": len(coset_1d),
        "others": len(others),
        "total_diag": total_diag,
    }

    sum_all = counts["approx_coset"] + counts["coset_2d"] + counts["coset_1d"] + counts["others"]
    consistency_ok = sum_all == total_diag

    processed_labels = set()
    for group in [approx_coset, coset_2d, coset_1d, others]:
        for item in group:
            processed_labels.add(item["label"])

    missing_labels = []
    if not consistency_ok:
        missing_labels = list(sorted(set(label_kind_map.keys()) - processed_labels))

    approx_ratio = (counts["approx_coset"] / total_diag) if total_diag > 0 else 0.0

    summary = {
        "p": p,
        "names": names,
        "items": {"approx_coset": approx_coset, "coset_2d": coset_2d, "coset_1d": coset_1d, "others": others},
        "counts": counts,
        "consistency": {"ok": consistency_ok, "sum_all": sum_all, "expected_total": total_diag, "missing_when_not_ok": missing_labels},
        "approx_ratio_in_diag": approx_ratio,
    }

    return summary


def make_attn_report_pdf(
    *,
    save_dir: str,
    attn_full: np.ndarray,
    group_size: int,
    dft_fn,
    irreps,
    tag: str,
    k_a: int = 0,
    k_b: int = 1,
):
    os.makedirs(save_dir, exist_ok=True)
    B, H, T, _ = attn_full.shape
    assert B == group_size * group_size

    writer = PdfWriter()
    for h in range(H):
        fig = single_head_attn_report_page_all_qk(
            h=h,
            attn_full=attn_full,
            group_size=group_size,
            dft_fn=dft_fn,
            irreps=irreps,
            row_scale_mode="auto",
            demean_for_gft=True,
        )

        fig._uuid = uuid.uuid4().hex
        pdf_bytes = fig.to_image(format="pdf", engine="kaleido")
        reader = PdfReader(io.BytesIO(pdf_bytes))
        writer.add_page(reader.pages[0])

    out_path = os.path.join(save_dir, f"{tag}.pdf")
    with open(out_path, "wb") as f:
        writer.write(f)
    print(f"[attn-report] saved {out_path}")


def make_layer_report(
    pre_grid,
    left,
    right,
    p,
    dft_2d,
    irreps,
    coset_masks_left,
    coset_masks_right,
    save_dir: str,
    cluster_tau: float = 1e-3,
    color_rule=None,
    artifacts=None,
    layer_idx: int = 0,
    freq_map=None,
):
    G = pre_grid.shape[0]
    N = pre_grid.shape[-1]
    p_rot = G // 2
    G_list, idx_map, cayley_table = dihedral.build_cayley_table(p_rot)
    rho_cache = DFT.build_rho_cache(G_list, irreps)

    if artifacts is None:
        artifacts = prepare_layer_artifacts(pre_grid, left, right, dft_2d, irreps)
    F_full = artifacts["F_full"]
    F_L = artifacts["F_L"]
    F_R = artifacts["F_R"]
    names = artifacts["names"]
    irrep2neurons = artifacts["irrep2neurons"]

    cont_l = left.reshape(G, G, N)
    left_vec = cont_l.mean(axis=1)
    cont_r = right.reshape(G, G, N)
    right_vec = cont_r.mean(axis=0)
    coset_info = {}

    for n in range(N):
        v_left = left_vec[:, n]
        v_right = right_vec[:, n]

        L_LC_top = subgroup_scores(v_left, coset_masks_left, top_k=2)
        L_RC_top = subgroup_scores(v_left, coset_masks_right, top_k=2)
        L_LC_all = subgroup_scores(v_left, coset_masks_left, top_k=None)
        L_RC_all = subgroup_scores(v_left, coset_masks_right, top_k=None)
        L_mix3 = merge_topk_sources([("Lcoset", L_LC_all), ("Rcoset", L_RC_all)], top_k=3)

        R_LC_top = subgroup_scores(v_right, coset_masks_left, top_k=2)
        R_RC_top = subgroup_scores(v_right, coset_masks_right, top_k=2)
        R_LC_all = subgroup_scores(v_right, coset_masks_left, top_k=None)
        R_RC_all = subgroup_scores(v_right, coset_masks_right, top_k=None)
        R_mix3 = merge_topk_sources([("Lcoset", R_LC_all), ("Rcoset", R_RC_all)], top_k=3)

        coset_info[n] = {"L": {"Lcoset": L_LC_top, "Rcoset": L_RC_top, "mix3": L_mix3}, "R": {"Lcoset": R_LC_top, "Rcoset": R_RC_top, "mix3": R_mix3}}

    for (r, s), neuron_list in irrep2neurons.items():
        if r == s:
            writer = PdfWriter()

            cluster_acts = pre_grid[:, :, neuron_list]
            max_act = cluster_acts.max()

            if artifacts is not None and "cluster_prune" in artifacts and (r, s) in artifacts["cluster_prune"]:
                _pack = artifacts["cluster_prune"][(r, s)]
                neuron_main = list(_pack["main"])
                neuron_drop = list(_pack["drop"])
                per_log = _pack["per_neuron_log10"]
                if len(neuron_main) == 0:
                    print(f"[skip] Cluster ({r},{s}) is empty after pruning.")
                    continue
            else:
                neuron_main = neuron_list
                neuron_drop = []

                cluster_grid = pre_grid[:, :, neuron_list]
                max_preacts = np.max(np.abs(cluster_grid), axis=(0, 1))
                per_log = {int(neuron_list[i]): float(np.log10(max_preacts[i] + 1e-20)) for i in range(len(neuron_list))}

            cover = make_subplots(rows=1, cols=1)
            cover.add_annotation(
                text=(f"<b>Cluster ({r},{s})</b><br>" f"size = {len(neuron_list)}<br>" f"max activation = {max_act:.2e}"),
                xref="paper",
                yref="paper",
                x=0.5,
                y=0.6,
                showarrow=False,
                font=dict(size=24),
                align="center",
            )
            cover.update_xaxes(visible=False)
            cover.update_yaxes(visible=False)

            cover._uuid = uuid.uuid4().hex
            pdf_cover = cover.to_image(format="pdf", engine="kaleido")
            reader = PdfReader(io.BytesIO(pdf_cover))
            writer.add_page(reader.pages[0])

            quad_labels = ["Quad-BL", "Quad-BR", "Quad-TL", "Quad-TR"]
            fig_quads = make_subplots(rows=2, cols=2, subplot_titles=quad_labels)

            palette = px.colors.qualitative.Light24 + px.colors.qualitative.Dark24
            num_colors = len(palette)

            axis_range = [0, 2 * np.pi]
            tick_vals = [0, np.pi, 2 * np.pi]
            tick_text = ["0", "pi", "2pi"]

            for n_idx, n in enumerate(neuron_main):
                quad_phase = _quadrant_ab_phases(pre_grid[:, :, n])
                color = palette[n_idx % num_colors]

                for q_idx, (phix, phiy) in enumerate(quad_phase):
                    r_0, c_0 = divmod(q_idx, 2)
                    show_legend = q_idx == 0
                    fig_quads.add_trace(
                        go.Scatter(x=[phix], y=[phiy], mode="markers", marker=dict(size=5, color=color), name=f"neuron {n}", showlegend=show_legend),
                        row=r_0 + 1,
                        col=c_0 + 1,
                    )

                    if n_idx == 0:
                        fig_quads.update_xaxes(title_text="phi_b (rad)", range=axis_range, tickvals=tick_vals, ticktext=tick_text, row=r_0 + 1, col=c_0 + 1)
                        fig_quads.update_yaxes(title_text="phi_a (rad)", range=axis_range, tickvals=tick_vals, ticktext=tick_text, row=r_0 + 1, col=c_0 + 1)

            fig_quads._uuid = uuid.uuid4().hex
            pdf_bytes = fig_quads.to_image(format="pdf", engine="kaleido")
            writer.add_page(PdfReader(io.BytesIO(pdf_bytes)).pages[0])

            fig_quads_mer = make_subplots(rows=2, cols=2, subplot_titles=quad_labels)

            merged_quads = [defaultdict(list) for _ in range(4)]
            for n in neuron_main:
                quad_phase = _quadrant_ab_phases(pre_grid[:, :, n])
                max_amp = np.abs(pre_grid[:, :, n]).max()
                for q_idx, (phix, phiy) in enumerate(quad_phase):
                    key = (round(phix, 4), round(phiy, 4))
                    merged_quads[q_idx][key].append(max_amp)

            for q_idx, phase_dict in enumerate(merged_quads):
                r_0, c_0 = divmod(q_idx, 2)
                for (phix, phiy), amps in phase_dict.items():
                    sum_amp = sum(amps)
                    count = len(amps)
                    size = 6 + 3 * np.log1p(count)
                    fig_quads_mer.add_trace(
                        go.Scatter(x=[phix], y=[phiy], mode="markers+text", marker=dict(size=size, color="red"), text=[f"{sum_amp:.2f}"], textposition="top center", showlegend=False),
                        row=r_0 + 1,
                        col=c_0 + 1,
                    )
                    fig_quads_mer.update_xaxes(title_text="phi_b (rad)", range=axis_range, tickvals=tick_vals, ticktext=tick_text, row=r_0 + 1, col=c_0 + 1)
                    fig_quads_mer.update_yaxes(title_text="phi_a (rad)", range=axis_range, tickvals=tick_vals, ticktext=tick_text, row=r_0 + 1, col=c_0 + 1)

            fig_quads_mer._uuid = uuid.uuid4().hex
            pdf_bytes = fig_quads_mer.to_image(format="pdf", engine="kaleido")
            writer.add_page(PdfReader(io.BytesIO(pdf_bytes)).pages[0])

            cluster_grid = pre_grid[:, :, neuron_list]
            log_mp_all = np.log10(np.max(np.abs(cluster_grid), axis=(0, 1)) + 1e-20)
            color_tag = np.array(["drop"] * len(neuron_list))
            idx_map_local = {n: i for i, n in enumerate(neuron_list)}
            for n in neuron_main:
                if n in idx_map_local:
                    color_tag[idx_map_local[n]] = "main"

            fig_mp = px.scatter(
                x=np.arange(len(neuron_list)),
                y=log_mp_all,
                color=color_tag,
                labels=dict(x="Neuron index in cluster", y="log10(max pre-act)"),
                title=f"Cluster ({r},{s}) - log10(max pre-act): main vs drop",
            )
            fig_mp._uuid = uuid.uuid4().hex
            pdf_bytes = fig_mp.to_image(format="pdf", engine="kaleido")
            writer.add_page(PdfReader(io.BytesIO(pdf_bytes)).pages[0])

            mat = pre_grid[:, :, neuron_main].reshape(G * G, -1).astype(float)
            embed_dir = os.path.join(save_dir, f"cluster_{r}_{s}_embeds")
            os.makedirs(embed_dir, exist_ok=True)

            p2 = G // 2
            quads = {"BL": pre_grid[:p2, :p2, neuron_main], "BR": pre_grid[:p2, p2:, neuron_main], "TL": pre_grid[p2:, :p2, neuron_main], "TR": pre_grid[p2:, p2:, neuron_main]}

            freqs = []
            if r in freq_map:
                f_main = int(freq_map[r])
                freqs.append(f_main)
            else:
                print(f"[warn] freq_map missing {r}; skipping main frequency")

            freqs = sorted(dict.fromkeys(freqs))
            if not freqs:
                print(f"[warn] cluster ({r},{s}) no freq; return empty list")

            quad_freq_lists = {k: freqs for k in ["BL", "BR", "TL", "TR", "full"]}
            color_rules = color_rule

            generate_pdf_plots_for_matrix(
                mat,
                p=p2,
                save_dir=embed_dir,
                seed=f"{r}_{s}",
                freq_list=quad_freq_lists["full"],
                tag="full",
                tag_q="full",
                class_string=f"cluster_{r}_{s}",
                colour_rule=color_rules,
                num_principal_components=4,
            )

            for tag_q, quad in quads.items():
                qmat = quad.reshape(p2 * p2, -1).astype(float)
                generate_pdf_plots_for_matrix(
                    qmat,
                    p=p2,
                    save_dir=embed_dir,
                    seed=f"{r}_{s}",
                    freq_list=quad_freq_lists[tag_q],
                    tag=tag_q,
                    tag_q=tag_q,
                    class_string=f"{tag_q}_{r}_{s}",
                    colour_rule=color_rules,
                    num_principal_components=4,
                )

            side = G
            indices = np.arange(side * side)
            a_vals = indices // side
            b_vals = indices % side

            pcs, _ = compute_pca_coords(mat, num_components=min(8, mat.shape[1], mat.shape[0] - 1))
            coords_for_angle = pcs[:, :4]
            angles_dir = os.path.join(embed_dir, "angles", "pca")
            os.makedirs(angles_dir, exist_ok=True)

            if layer_idx == 0:
                plane_angle_per_cluster(
                    coords=coords_for_angle,
                    a_vals=a_vals,
                    b_vals=b_vals,
                    p=p2,
                    cluster_ids=np.zeros(side * side, dtype=int),
                    mode="a",
                    tag_q="full",
                    save_dir=os.path.join(angles_dir, "layer0_a"),
                    title=f"Layer-0: a-rot vs a-ref (cluster {r},{s})",
                )
                plane_angle_per_cluster(
                    coords=coords_for_angle,
                    a_vals=a_vals,
                    b_vals=b_vals,
                    p=p2,
                    cluster_ids=np.zeros(side * side, dtype=int),
                    mode="b",
                    tag_q="full",
                    save_dir=os.path.join(angles_dir, "layer0_b"),
                    title=f"Layer-0: b-rot vs b-ref (cluster {r},{s})",
                )
            else:
                plane_angle_per_cluster(
                    coords=coords_for_angle,
                    a_vals=a_vals,
                    b_vals=b_vals,
                    p=p2,
                    cluster_ids=np.zeros(side * side, dtype=int),
                    mode="c",
                    tag_q="full",
                    save_dir=os.path.join(angles_dir, "layerK_c"),
                    title=f"Layer-{layer_idx}: c-rot vs c-ref (cluster {r},{s})",
                )

            for n in neuron_main:
                print(f"Rendering neuron {n} in cluster ({r},{s})")
                fig_n = single_neuron_figure(n, pre_grid, left_vec, right_vec, F_full, F_L, F_R, names, coset_info[n], freq_map=freq_map, freqs=freqs, strict=True, artifacts=artifacts)
                fig_n._uuid = uuid.uuid4().hex
                pdf_bytes = fig_n.to_image(format="pdf", engine="kaleido")
                reader = PdfReader(io.BytesIO(pdf_bytes))
                writer.add_page(reader.pages[0])

            if len(neuron_drop) > 0:
                rows = []
                for n in sorted(neuron_drop):
                    logv = per_log.get(int(n), None)
                    maxv = float(np.max(np.abs(pre_grid[:, :, n])))
                    rows.append((int(n), maxv, (None if logv is None else float(logv))))
                rows.sort(key=lambda t: (t[2] if t[2] is not None else -1e9), reverse=True)

                tbl = go.Figure(
                    data=[
                        go.Table(
                            header=dict(values=["neuron index", "max |preact|", "log10(max)"]),
                            cells=dict(
                                values=[
                                    [r_[0] for r_ in rows],
                                    [f"{r_[1]:.3g}" for r_ in rows],
                                    [("—" if r_[2] is None else f"{r_[2]:.3f}") for r_ in rows],
                                ]
                            ),
                        )
                    ]
                )
                tbl.update_layout(title=f"Cluster ({r},{s}) - Dropped neurons")
                tbl._uuid = uuid.uuid4().hex
                pdf_bytes = tbl.to_image(format="pdf", engine="kaleido")
                writer.add_page(PdfReader(io.BytesIO(pdf_bytes)).pages[0])

            path = f"cluster_{r}_{s}.pdf"
            fin_path = os.path.join(save_dir, path)
            with open(fin_path, "wb") as f:
                writer.write(f)
            print("saved", fin_path)
        else:
            cluster_acts = pre_grid[:, :, neuron_list]
            max_act = cluster_acts.max()
            print(f"Skip {r}_{s} irreps plot, max activation: {max_act}")


def _trimmed_mean(x, trim=0.1):
    x = np.asarray(x, float)
    if x.size == 0:
        return float("nan")
    xs = np.sort(x)
    k = int(np.floor(trim * xs.size))
    if xs.size - 2 * k <= 0:
        return float(np.median(xs))
    return float(xs[k : xs.size - k].mean())


def _cluster_robust_summary(per_neuron_vals: dict[int, float], method: str = "mad", k: float = 3.5, min_keep: int = 5, trim: float = 0.1):
    ids = np.array(sorted(per_neuron_vals.keys()), dtype=int)
    vals = np.array([per_neuron_vals[i] for i in ids], dtype=float)
    mfin = np.isfinite(vals)
    ids, vals = ids[mfin], vals[mfin]
    if vals.size == 0:
        return {
            "count_all": 0,
            "count_kept": 0,
            "dropped_ids": [],
            "mean": None,
            "median": None,
            "trimmed_mean": None,
            "std": None,
            "quantiles": {"p10": None, "p25": None, "p50": None, "p75": None, "p90": None},
        }

    if method == "std":
        mu = float(np.mean(vals))
        sd = float(np.std(vals, ddof=1)) if vals.size > 1 else 0.0
        keep = np.ones_like(vals, bool) if sd <= 1e-12 else (np.abs((vals - mu) / sd) <= k)
    else:
        med = float(np.median(vals))
        mad = float(np.median(np.abs(vals - med)))
        scale = 1.4826 * mad
        if scale <= 1e-12:
            sd = float(np.std(vals, ddof=1)) if vals.size > 1 else 0.0
            keep = np.ones_like(vals, bool) if sd <= 1e-12 else (np.abs(vals - med) <= k * sd)
        else:
            keep = np.abs((vals - med) / scale) <= k

    if keep.sum() < max(3, min_keep):
        keep[:] = True

    kept, dropped_ids = vals[keep], ids[~keep]
    q10, q25, q50, q75, q90 = np.percentile(kept, [10, 25, 50, 75, 90]) if kept.size else [None] * 5

    return {
        "count_all": int(vals.size),
        "count_kept": int(kept.size),
        "dropped_ids": dropped_ids.tolist(),
        "mean": float(np.mean(kept)) if kept.size else None,
        "median": float(np.median(kept)) if kept.size else None,
        "trimmed_mean": _trimmed_mean(kept, trim) if kept.size else None,
        "std": float(np.std(kept, ddof=1)) if kept.size > 1 else (0.0 if kept.size == 1 else None),
        "quantiles": {"p10": float(q10), "p25": float(q25), "p50": float(q50), "p75": float(q75), "p90": float(q90)} if kept.size else {"p10": None, "p25": None, "p50": None, "p75": None, "p90": None},
    }


def make_R2_c_angle_report_no_plot(
    pre_grid,
    p,
    save_dir: str,
    artifacts=None,
    base_layer_artifacts=None,
    base_f_pool=None,
    layer_idx: int = 0,
    freq_map=None,
    use_pair_terms=True,
):
    G = pre_grid.shape[0]
    N = pre_grid.shape[-1]
    p_rot = G // 2

    if layer_idx >= 1:
        if base_f_pool is not None:
            f_pool_L1 = list(base_f_pool)
        elif base_layer_artifacts is not None:
            f_pool_L1 = build_f_pool_from_layer1_artifacts(base_layer_artifacts, freq_map, p)
        else:
            raise ValueError("Need base_f_pool or base_layer_artifacts for layer>=1 residual fitting.")
    names = artifacts["names"]
    irrep2neurons = artifacts["irrep2neurons"]

    for (r, s), neuron_list in irrep2neurons.items():
        if r == s:
            writer = PdfWriter()

            cluster_acts = pre_grid[:, :, neuron_list]
            max_act = cluster_acts.max()

            if artifacts is not None and "cluster_prune" in artifacts and (r, s) in artifacts["cluster_prune"]:
                _pack = artifacts["cluster_prune"][(r, s)]
                neuron_main = list(_pack["main"])
                neuron_drop = list(_pack["drop"])
                per_log = _pack["per_neuron_log10"]
                if len(neuron_main) == 0:
                    print(f"[skip] Cluster ({r},{s}) is empty after pruning.")
                    continue
            else:
                neuron_main = neuron_list
                neuron_drop = []

                cluster_grid = pre_grid[:, :, neuron_list]
                max_preacts = np.max(np.abs(cluster_grid), axis=(0, 1))
                per_log = {int(neuron_list[i]): float(np.log10(max_preacts[i] + 1e-20)) for i in range(len(neuron_list))}

            mat = pre_grid[:, :, neuron_main].reshape(G * G, -1).astype(float)

            embed_dir = os.path.join(save_dir, f"cluster_{r}_{s}_embeds")
            os.makedirs(embed_dir, exist_ok=True)
            side = G
            indices = np.arange(side * side)
            a_vals = indices // side
            b_vals = indices % side

            core = pca_diffusion_plots_w_helpers.run_pca_core(mat=mat, p=p, save_dir=embed_dir, seed=f"layer{layer_idx}_{r}_{s}", tag="full", tag_q="full", max_components=8)
            pcs = core["pcs"]
            coords_for_angle = pcs[:, :4]
            angles_dir = os.path.join(embed_dir, "angles", "pca")
            os.makedirs(angles_dir, exist_ok=True)

            if layer_idx == 0:
                plane_angle_per_cluster(coords=coords_for_angle, a_vals=a_vals, b_vals=b_vals, p=p, cluster_ids=np.zeros(side * side, dtype=int), mode="a", tag_q="full", save_dir=os.path.join(angles_dir, "layer0_a"), title=f"Layer-0: a-rot vs a-ref (cluster {r},{s})")
                plane_angle_per_cluster(coords=coords_for_angle, a_vals=a_vals, b_vals=b_vals, p=p, cluster_ids=np.zeros(side * side, dtype=int), mode="b", tag_q="full", save_dir=os.path.join(angles_dir, "layer0_b"), title=f"Layer-0: b-rot vs b-ref (cluster {r},{s})")
            else:
                plane_angle_per_cluster(coords=coords_for_angle, a_vals=a_vals, b_vals=b_vals, p=p, cluster_ids=np.zeros(side * side, dtype=int), mode="c", tag_q="full", save_dir=os.path.join(angles_dir, "layerK_c"), title=f"Layer-{layer_idx}: c-rot vs c-ref (cluster {r},{s})")

            if freq_map is not None:
                try:
                    F_full = artifacts.get("F_full", {})
                    r2_dir = os.path.join(embed_dir, "r2")
                    os.makedirs(r2_dir, exist_ok=True)

                    f_pool_L1 = build_f_pool_from_layer1_artifacts(base_layer_artifacts, freq_map, p)

                    p_quad = G // 2
                    per_neuron_fullgrid = {}

                    per_neuron_R2_mean = {}
                    per_neuron_dR2_mean = {}
                    per_neuron_axis_phase = {}
                    per_neuron_fbase = {}

                    for n in neuron_main:
                        Fhat_n = {k: v for k, v in artifacts.get("F_full", {}).items() if k[2] == n}
                        dom = _classify_by_gft(Fhat_n, names, freq_map, strict=True)
                        if dom["kind"] == "diag":
                            f0 = abs(int(freq_map[dom["b_star"]]))
                            fa_base = fb_base = f0
                        else:
                            fa_base = abs(int(freq_map[dom["a_star"]]))
                            fb_base = abs(int(freq_map[dom["b_star"]]))
                            f0 = int(fa_base)
                        per_neuron_fbase[int(n)] = int(f0)
                        grid = pre_grid[:, :, n]
                        quads = [grid[:p_quad, :p_quad], grid[:p_quad, p_quad:], grid[p_quad:, :p_quad], grid[p_quad:, p_quad:]]

                        r2_final_list, dR2_list = [], []
                        selected_freqs = set([int(fa_base), int(fb_base)])
                        for q in quads:
                            fit = R2.fit_quadrant_sines(
                                q,
                                dom,
                                freq_map,
                                names,
                                f_pool=f_pool_L1,
                                max_iters=6,
                                tau_inc=0.0,
                                use_axes_only=True,
                                use_pair_terms=use_pair_terms,
                                include_ab_resid=False,
                            )
                            R2_final = float(fit.get("R2_final", 0.0))
                            R2_base = float(fit.get("R2_stage1", fit.get("R2_layer1", 0.0)))
                            dR2 = R2_final - R2_base
                            r2_final_list.append(R2_final)
                            dR2_list.append(dR2)
                            for item in fit.get("chosen_stage2", []):
                                selected_freqs.add(int(item["f"]))

                        if r2_final_list:
                            per_neuron_R2_mean[int(n)] = float(np.mean(r2_final_list))
                            per_neuron_dR2_mean[int(n)] = float(np.mean(dR2_list))

                        f_use = sorted(selected_freqs)

                        y, X = R2.design_fullgrid_coupled(grid, f_use, use_axes=True, use_pair=use_pair_terms, quad_bias=True, mask_mode="quadfree")
                        fit_qf = R2.ols_fit(y, X)
                        print("quadfree", fit_qf["R2"], X.shape)

                        y, X = R2.design_fullgrid_coupled(grid, f_use, use_axes=True, use_pair=use_pair_terms, quad_bias=True, mask_mode="hadamard4")
                        fit = R2.ols_fit(y, X)
                        print("hadamard4", fit["R2"], X.shape)

                        out_best, X_best, info_best = R2.fit_fullgrid_once_autoshift(grid, f_use, use_axes=True, use_pair=use_pair_terms, quad_bias=True)
                        print("family(best)", out_best["R2"], X_best.shape)
                        beta = out_best["beta"]
                        axis_phase = R2.extract_axis_phase_family(beta, info_best, f_use)
                        per_neuron_axis_phase[int(n)] = axis_phase

                        nid_int = int(n)
                        per_neuron_fullgrid[nid_int] = {
                            "f_use": list(map(int, f_use)),
                            "quadfree": {
                                "R2": float(fit_qf["R2"]),
                                "BIC": float(fit_qf["BIC"]),
                                "sse": float(fit_qf["sse"]),
                                "sst": float(fit_qf["sst"]),
                                "n": int(fit_qf["n"]),
                                "k": int(fit_qf["k"]),
                                "X_shape": [int(X.shape[0]), int(X.shape[1])],
                            },
                            "family_best": {
                                "R2": float(out_best["R2"]),
                                "BIC": float(out_best["BIC"]),
                                "sse": float(out_best["sse"]),
                                "sst": float(out_best["sst"]),
                                "n": int(out_best["n"]),
                                "k": int(out_best["k"]),
                                "X_shape": [int(X_best.shape[0]), int(X_best.shape[1])],
                                "pair_sign_mode": out_best.get("pair_sign_mode", None),
                                "pair_s": out_best.get("pair_s", None),
                                "pair_keep": out_best.get("pair_keep", None),
                                "phi_off": out_best.get("phi_off", None),
                                "phi_off_map": out_best.get("phi_off_map", None),
                            },
                        }

                    f0_list = [per_neuron_fbase.get(int(n), None) for n in neuron_main]
                    f0_list = [int(x) for x in f0_list if x is not None]
                    if len(f0_list) == 0:
                        cluster_f0 = None
                    else:
                        vals, cnts = np.unique(np.asarray(f0_list, int), return_counts=True)
                        cluster_f0 = int(vals[np.argmax(cnts)])

                    phase_analysis = None
                    if cluster_f0 is not None and len(per_neuron_axis_phase) > 0:
                        phase_analysis = R2.analyze_irrep_cluster_axis_phase(
                            per_neuron_axis_phase=per_neuron_axis_phase,
                            neuron_ids=[int(n) for n in neuron_main if int(n) in per_neuron_axis_phase],
                            f0=cluster_f0,
                            f_pool=f_pool_L1,
                            G=G,
                            amp_eps=1e-6,
                            topk_within=5,
                            compute_spectrum=True,
                        )
                        phase_analysis["meta"].update({"cluster": [r, s], "layer_idx": int(layer_idx)})
                        with open(os.path.join(r2_dir, "axis_phase_analysis.json"), "w") as f:
                            json.dump(phase_analysis, f, indent=2)

                    try:
                        if cluster_f0 is not None and phase_analysis is not None:
                            rel = phase_analysis["grid_units"]["relations_across_neuron_at_f0"]
                            k_a = rel.get("sA", {}).get("generator", None)
                            k_b = rel.get("dB", {}).get("generator", None)

                            if k_a is None or k_b is None:
                                print(f"[rot/ref report] skip cluster ({r},{s}): missing generator (k_a={k_a}, k_b={k_b})")
                            else:
                                d = step_size(int(cluster_f0), int(p))
                                grid_ids = np.arange(G * G, dtype=int)
                                out_report_dir = os.path.join(r2_dir, "rot_ref_report")
                                os.makedirs(out_report_dir, exist_ok=True)

                                cayley_like = cayley.run_rot_then_ref_report(
                                    embedding_weights=mat,
                                    neuron_ids=grid_ids,
                                    G=G,
                                    f0=int(cluster_f0),
                                    d=int(d),
                                    a_k=int(k_a),
                                    a_model="k_minus_x",
                                    b_k=int(k_b),
                                    b_model="shift",
                                    num_pca_dims=8,
                                    pca_dims=None,
                                    pca_cumvar_tau=0.90,
                                    ball_q=1.0,
                                    out_dir=out_report_dir,
                                    label=f"cluster_{r}_{s}_L{layer_idx}",
                                    do_bfs=True,
                                    start_a_state=0,
                                    start_b_state=0,
                                )

                    except Exception as e:
                        print(f"[rot/ref report] cluster ({r},{s}) failed: {type(e).__name__}: {e}")
                        raise

                    stats_R2 = _cluster_robust_summary(per_neuron_R2_mean, method="mad", k=3.5, min_keep=5, trim=0.1)
                    stats_dR2 = _cluster_robust_summary(per_neuron_dR2_mean, method="mad", k=3.5, min_keep=5, trim=0.1)

                    per_neuron_quadfree_R2 = {k: v["quadfree"]["R2"] for k, v in per_neuron_fullgrid.items()}
                    per_neuron_family_R2 = {k: v["family_best"]["R2"] for k, v in per_neuron_fullgrid.items()}

                    stats_qf = _cluster_robust_summary(per_neuron_quadfree_R2, method="mad", k=3.5, min_keep=5, trim=0.1)
                    stats_fm = _cluster_robust_summary(per_neuron_family_R2, method="mad", k=3.5, min_keep=5, trim=0.1)

                    summary = {
                        "cluster": [(r, s)],
                        "layer_idx": int(layer_idx),
                        "p": int(p),
                        "n_neurons_total": int(len(neuron_main)),
                        "f_pool": list(map(int, f_pool_L1)),
                        "per_neuron_R2_mean": per_neuron_R2_mean,
                        "per_neuron_dR2_mean": per_neuron_dR2_mean,
                        "per_neuron_fullgrid_fits": per_neuron_fullgrid,
                        "cluster_stats": {
                            "R2_final_mean_across_quads": stats_R2,
                            "dR2_mean_across_quads": stats_dR2,
                            "R2_quadfree_fullgrid": stats_qf,
                            "R2_family_fullgrid": stats_fm,
                            "filter_method": "mad",
                            "filter_k": 3.5,
                        },
                    }
                    with open(os.path.join(r2_dir, "r2_summary_cluster_robust.json"), "w") as f:
                        json.dump(summary, f, indent=2)

                except Exception as e:
                    import traceback

                    print(f"[R2] cluster ({r},{s}) failed: {type(e).__name__}: {e}")
                    traceback.print_exc()
                    raise



def make_stripe_report_only_phase_pdf(pre_grid, p, save_dir: str, artifacts=None):
    G = pre_grid.shape[0]
    N = pre_grid.shape[-1]
    p_rot = G // 2

    names = artifacts["names"]
    irrep2neurons = artifacts["irrep2neurons"]

    for (r, s), neuron_list in irrep2neurons.items():
        if r == s:
            writer = PdfWriter()

            embed_dir = os.path.join(save_dir, f"cluster_{r}_{s}_stripe")
            os.makedirs(embed_dir, exist_ok=True)

            cluster_acts = pre_grid[:, :, neuron_list]
            max_act = float(cluster_acts.max())

            if artifacts is not None and "cluster_prune" in artifacts and (r, s) in artifacts["cluster_prune"]:
                _pack = artifacts["cluster_prune"][(r, s)]
                neuron_main = list(_pack["main"])
                neuron_drop = list(_pack["drop"])
                per_log = _pack["per_neuron_log10"]
            else:
                neuron_main = neuron_list
                neuron_drop = []
                cluster_grid = pre_grid[:, :, neuron_list]
                max_preacts = np.max(np.abs(cluster_grid), axis=(0, 1))
                per_log = {int(neuron_list[i]): float(np.log10(float(max_preacts[i]) + 1e-20))
                           for i in range(len(neuron_list))}

            quad_labels = ["Quad-BL", "Quad-BR", "Quad-TL", "Quad-TR"]
            phase_json = {
                "meta": {
                    "cluster": [str(r), str(s)],
                    "G": int(G),
                    "p": int(p),
                    "p_rot": int(p_rot),
                    "n_neurons_total": int(len(neuron_list)),
                    "n_neurons_main": int(len(neuron_main)),
                    "max_activation": float(max_act),
                    "axes": {"x": "phi_b", "y": "phi_a"},
                    "quad_order": list(quad_labels),
                },
                "per_neuron": {},      
                    "params": {},
                    "raw_quads": {},   
                    "merged_quads": {}
                }
            }

            cover = make_subplots(rows=1, cols=1)
            cover.add_annotation(
                text=(
                    f"<b>Cluster ({r},{s})</b><br>"
                    f"size = {len(neuron_list)}<br>"
                    f"max activation = {max_act:.2e}"
                ),
                xref="paper", yref="paper",
                x=0.5, y=0.6, showarrow=False,
                font=dict(size=24), align="center"
            )
            cover.update_xaxes(visible=False)
            cover.update_yaxes(visible=False)
            cover._uuid = uuid.uuid4().hex
            pdf_cover = cover.to_image(format="pdf", engine="kaleido")
            writer.add_page(PdfReader(io.BytesIO(pdf_cover)).pages[0])

            fig_quads = make_subplots(rows=2, cols=2, subplot_titles=quad_labels)

            palette = (px.colors.qualitative.Light24 + px.colors.qualitative.Dark24)
            num_colors = len(palette)

            axis_range = [0, 2 * np.pi]
            tick_vals  = [0, np.pi, 2*np.pi]
            tick_text  = ["0", "π", "2π"]

            for n_idx, n in enumerate(neuron_list):
                n_int = int(n)
                quad_phase = _quadrant_ab_phases(pre_grid[:, :, n_int])  
                color = palette[n_idx % num_colors]

                max_amp = float(np.abs(pre_grid[:, :, n_int]).max())
                phase_json["per_neuron"][str(n_int)] = {
                    "max_amp": float(max_amp),
                    "is_main": bool(n_int in set(neuron_main)),
                    "is_drop": bool(n_int in set(neuron_drop)),
                    "quads": []
                }

                for q_idx, (phi_b, phi_a) in enumerate(quad_phase):
                    r_0, c_0 = divmod(q_idx, 2)
                    show_legend = (q_idx == 0)

                    fig_quads.add_trace(
                        go.Scatter(
                            x=[float(phi_b)], y=[float(phi_a)],
                            mode="markers",
                            marker=dict(size=5, color=color),
                            name=f"neuron {n_int}",
                            showlegend=show_legend,
                        ),
                        row=r_0 + 1, col=c_0 + 1
                    )

                    phase_json["per_neuron"][str(n_int)]["quads"].append({
                        "quad_id": int(q_idx),
                        "quad": quad_labels[q_idx],
                        "phi_b": float(phi_b),
                        "phi_a": float(phi_a),
                    })

                    if n_idx == 0:
                        fig_quads.update_xaxes(
                            title_text="phi_b (rad)",
                            range=axis_range, tickvals=tick_vals, ticktext=tick_text,
                            row=r_0 + 1, col=c_0 + 1
                        )
                        fig_quads.update_yaxes(
                            title_text="phi_a (rad)",
                            range=axis_range, tickvals=tick_vals, ticktext=tick_text,
                            row=r_0 + 1, col=c_0 + 1
                        )

            fig_quads._uuid = uuid.uuid4().hex
            pdf_bytes = fig_quads.to_image(format="pdf", engine="kaleido")
            writer.add_page(PdfReader(io.BytesIO(pdf_bytes)).pages[0])

            fig_quads_mer = make_subplots(rows=2, cols=2, subplot_titles=quad_labels)

            pad = 0.30
            axis_range = [-pad, 2 * np.pi + pad]
            tick_vals  = [0, np.pi, 2*np.pi]
            tick_text  = ["0", "pi", "2pi"]

            color = "red"
            merge_eps = 0.25

            def _torus_dist(x1, y1, x2, y2):
                dx = np.abs(x1 - x2); dx = np.minimum(dx, 2*np.pi - dx)
                dy = np.abs(y1 - y2); dy = np.minimum(dy, 2*np.pi - dy)
                return np.hypot(dx, dy)

            def _circ_mean(angles, weights):
                s = np.sum(weights * np.sin(angles))
                c = np.sum(weights * np.cos(angles))
                return float(np.arctan2(s, c) % (2*np.pi))

            phase_json["merged"]["params"] = {
                "merge_eps": float(merge_eps),
                "pad": float(pad),
                "x_range": [float(axis_range[0]), float(axis_range[1])],
                "y_range": [float(axis_range[0]), float(axis_range[1])],
            }

            raw_quads = [[] for _ in range(4)]
            for n in neuron_main:
                n_int = int(n)
                quad_phase = _quadrant_ab_phases(pre_grid[:, :, n_int])
                max_amp = float(np.abs(pre_grid[:, :, n_int]).max())
                for q_idx, (phi_b, phi_a) in enumerate(quad_phase):
                    raw_quads[q_idx].append((float(phi_b), float(phi_a), float(max_amp), n_int))

            for q_idx in range(4):
                phase_json["merged"]["raw_quads"][quad_labels[q_idx]] = [
                    {"phi_b": x, "phi_a": y, "amp": a, "nid": int(nid)}
                    for (x, y, a, nid) in raw_quads[q_idx]
                ]

            for q_idx, pts in enumerate(raw_quads):
                r_0, c_0 = divmod(q_idx, 2)

                merged_export = [] 
                if pts:
                    used = [False] * len(pts)
                    merged = []  

                    for i, (x0, y0, a0, nid0) in enumerate(pts):
                        if used[i]:
                            continue
                        group_idx = [i]
                        used[i] = True
                        for j, (xj, yj, aj, nidj) in enumerate(pts):
                            if not used[j] and _torus_dist(x0, y0, xj, yj) <= merge_eps:
                                used[j] = True
                                group_idx.append(j)

                        A = np.array([pts[k][2] for k in group_idx], dtype=float)
                        X = np.array([pts[k][0] for k in group_idx], dtype=float)
                        Y = np.array([pts[k][1] for k in group_idx], dtype=float)

                        mx = _circ_mean(X, A)
                        my = _circ_mean(Y, A)
                        sum_amp = float(A.sum())
                        count = int(len(group_idx))
                        merged.append((mx, my, sum_amp, count, group_idx))

                    for (mx, my, sum_amp, count, group_idx) in merged:
                        size = 8 + 4 * np.log1p(count)
                        fig_quads_mer.add_trace(
                            go.Scatter(
                                x=[mx], y=[my],
                                mode="markers+text",
                                marker=dict(size=float(size), color=color),
                                text=[f"{sum_amp:.2f}"],
                                textposition="top center",
                                textfont=dict(size=10),
                                cliponaxis=False,
                                showlegend=False,
                            ),
                            row=r_0 + 1, col=c_0 + 1
                        )

                        members = [int(pts[k][3]) for k in group_idx]
                        merged_export.append({
                            "phi_b": float(mx),
                            "phi_a": float(my),
                            "sum_amp": float(sum_amp),
                            "count": int(count),
                            "members": members,
                        })

                phase_json["merged"]["merged_quads"][quad_labels[q_idx]] = merged_export

                fig_quads_mer.update_xaxes(
                    title_text="phi_b (rad)",
                    range=axis_range, tickvals=tick_vals, ticktext=tick_text,
                    row=r_0 + 1, col=c_0 + 1
                )
                fig_quads_mer.update_yaxes(
                    title_text="phi_a (rad)",
                    range=axis_range, tickvals=tick_vals, ticktext=tick_text,
                    row=r_0 + 1, col=c_0 + 1
                )

            fig_quads_mer.update_layout(margin=dict(l=80, r=60, t=80, b=80))
            fig_quads_mer._uuid = uuid.uuid4().hex
            pdf_bytes = fig_quads_mer.to_image(format="pdf", engine="kaleido")
            writer.add_page(PdfReader(io.BytesIO(pdf_bytes)).pages[0])

            json_path = os.path.join(embed_dir, f"cluster_{r}_{s}_phase_points.json")
            with open(json_path, "w") as jf:
                json.dump(phase_json, jf, indent=2)
            print("saved phase points json:", json_path)

            cluster_grid = pre_grid[:, :, neuron_list]
            log_mp_all = np.log10(np.max(np.abs(cluster_grid), axis=(0, 1)) + 1e-20)
            color_tag = np.array(["drop"] * len(neuron_list))
            idx_map = {int(n): i for i, n in enumerate(neuron_list)}
            for n in neuron_main:
                n = int(n)
                if n in idx_map:
                    color_tag[idx_map[n]] = "main"

            fig_mp = px.scatter(
                x=np.arange(len(neuron_list)),
                y=log_mp_all,
                color=color_tag,
                labels=dict(x="Neuron index in cluster", y="log10(max pre-act)"),
                title=f"Cluster ({r},{s}) - log10(max pre-act): main vs drop"
            )
            fig_mp._uuid = uuid.uuid4().hex
            pdf_bytes = fig_mp.to_image(format="pdf", engine="kaleido")
            writer.add_page(PdfReader(io.BytesIO(pdf_bytes)).pages[0])

            mat = pre_grid[:, :, neuron_main].reshape(G*G, -1).astype(float)

            if isinstance(r, str) and r.startswith("2D_"):
                f = int(r.split("_", 1)[1])
                pca_diffusion_plots_w_helpers.run_pca_and_stripes_no_plots(
                    mat,
                    p=p,
                    save_dir=embed_dir,
                    freq_list=[f],
                    cluster_meta={"n_neurons_main": len(neuron_main), "n_neurons_total": len(neuron_list)},
                    tag_q="full",
                    num_principal_components=8,
                    s_mode="anchor",
                    model="auto",
                )

            path = f"cluster_{r}_{s}.pdf"
            fin_path = os.path.join(embed_dir, path)
            with open(fin_path, "wb") as fpdf:
                writer.write(fpdf)
            print("saved", fin_path)

        else:
            cluster_acts = pre_grid[:, :, neuron_list]
            max_act = float(cluster_acts.max())
            print(f"Skip {r}_{s} irreps plot, max activation: {max_act}")

def epsilon_analysis(pre_grid,
                     p,
                     save_dir: str,
                     artifacts=None,
                     seed: int | str = ""
                     ):
    G = pre_grid.shape[0]
    N = pre_grid.shape[-1]

    names  = artifacts["names"]
    irrep2neurons = artifacts["irrep2neurons"]

    for (r, s), neuron_list in irrep2neurons.items():
        if r != s:
            cluster_acts = pre_grid[:, :, neuron_list]
            max_act = cluster_acts.max()
            print(f"Skip {r}_{s} irreps plot, max activation: {max_act}")
            continue

        cluster_acts = pre_grid[:, :, neuron_list]
        max_act = cluster_acts.max()
        if "cluster_prune" in (artifacts or {}) and (r, s) in artifacts["cluster_prune"]:
            _pack = artifacts["cluster_prune"][(r, s)]
            neuron_main = list(_pack["main"])
        else:
            neuron_main = neuron_list

        mat = pre_grid[:, :, neuron_main].reshape(G*G, -1).astype(float)

        embed_dir = os.path.join(save_dir, f"cluster_{r}_{s}_compgeo")
        os.makedirs(embed_dir, exist_ok=True)

        if isinstance(r, str) and r.startswith("2D_"):
            f = int(r.split("_", 1)[1])
            pca_diffusion_plots_w_helpers.run_comp_geo_coset_pipeline(
                mat=mat,
                p=p,
                save_dir=embed_dir,
                freq_list=[f],
                tag_q="full",
                tag=f"cluster_{r}_{s}",
                seed=seed,
                label="PCA",
                num_pca_dims=4,
            )

        else:
            cluster_acts = pre_grid[:, :, neuron_list]
            max_act = cluster_acts.max()
            print(f"Skip {r}_{s} irreps plot, max activation: {max_act}")
