import os
import json
import math
import uuid
import tempfile
import itertools
from pathlib import Path
from typing import List
from types import SimpleNamespace

import numpy as np
from scipy.spatial.distance import pdist, squareform
from scipy.linalg import eigh
from sklearn.decomposition import PCA

import plotly.graph_objects as go
from plotly.subplots import make_subplots

from persistent_homology_gpu import run_ph_for_point_cloud
from analysis.stripe_analysis import run_and_save_stripe_analysis

from color_rules import (
    colour_quad_mul_f,
    colour_quad_mod_g,
    colour_quad_mod_g_no_fb,
    colour_quad_a_only,
    colour_quad_b_only,
    colour_c_mod_p,
    step_size,
    lines_a_mod_g_step,
    lines_b_mod_g_step,
    lines_c_mod_g_step,
    build_ro_scale,
    build_vi_scale,
)

import analysis.plane_fit
import analysis.comp_geo

try:
    from PyPDF2 import PdfMerger
except ImportError as e:
    PdfMerger = None
    _pdf2_err = e

FONT_SIZE = 18
CBAR_TICK_SIZE = 18
CBAR_TITLE_SIZE = 18
TICK_SIZE = 16
LEGEND_POS = dict(
    x=1.12,
    y=1.02,
    xanchor="left",
    yanchor="top",
    orientation="v",
    bgcolor="rgba(255,255,255,0.65)",
    bordercolor="rgba(0,0,0,0.2)",
    borderwidth=1,
    font=dict(size=FONT_SIZE),
)


def _jitter_if_constant(arr: np.ndarray, eps: float = 1e-9) -> np.ndarray:
    if arr.size and np.allclose(arr, arr[0]):
        out = arr.astype(float).copy()
        out[0] = out[0] + eps
        return out
    return arr


def _sanitize_matrix(X: np.ndarray) -> np.ndarray:
    X = np.asarray(X, dtype=float)
    X = np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0)
    col_var = X.var(axis=0)
    keep = col_var > 0
    if keep.any():
        X = X[:, keep]
    else:
        X = np.zeros((X.shape[0], 1), dtype=float)
    X = X - X.mean(axis=0, keepdims=True)
    return X


def _safe_pca_coords(X: np.ndarray, want_components: int):
    X = _sanitize_matrix(X)
    n, d = X.shape
    if n < 2 or d == 0:
        k = 1 if d == 0 else min(1, want_components)
        return np.zeros((n, k), float), SimpleNamespace(explained_variance_ratio_=np.zeros(k, float))

    try:
        S = np.linalg.svd(X, full_matrices=False, compute_uv=False)
    except np.linalg.LinAlgError:
        k = 1
        return np.zeros((n, k), float), SimpleNamespace(explained_variance_ratio_=np.zeros(k, float))

    tol = np.finfo(float).eps * max(n, d) * (S[0] if S.size else 0.0)
    rank = int((S > tol).sum())

    k = max(0, min(want_components, rank, n - 1))
    if k == 0:
        k = 1
        return np.zeros((n, k), float), SimpleNamespace(explained_variance_ratio_=np.zeros(k, float))

    pca = PCA(n_components=k, svd_solver="full")
    coords = pca.fit_transform(X)
    if not np.isfinite(getattr(pca, "explained_variance_ratio_", np.array([1.0]))).all():
        pca.explained_variance_ratio_ = np.zeros(k, float)
    return coords, pca


def compute_pca_coords(embedding_weights, num_components=17):
    return _safe_pca_coords(np.asarray(embedding_weights), num_components)


def compute_diffusion_coords(
    embedding_weights: np.ndarray,
    num_coords: int = 5,
) -> tuple[np.ndarray, np.ndarray]:
    N = int(embedding_weights.shape[0])
    max_nontrivial = max(N - 1, 1)
    want = int(num_coords)
    k = min(want, max_nontrivial)

    d2 = squareform(pdist(embedding_weights, metric="euclidean")) ** 2
    eps = float(np.median(d2))
    if not np.isfinite(eps) or eps <= 0:
        pos = d2[d2 > 0]
        eps = float(pos.mean()) if pos.size else 1e-12

    A = np.exp(-d2 / eps)
    M = A / A.sum(axis=1, keepdims=True)

    eigenvalues, eigenvectors = eigh(M)
    eigenvalues = eigenvalues[::-1]
    eigenvectors = eigenvectors[:, ::-1]

    if eigenvalues.shape[0] < num_coords + 1:
        raise ValueError("Not enough eigenvalues to compute the requested diffusion coordinates.")

    coords = eigenvectors[:, 1: num_coords + 1] * eigenvalues[1: num_coords + 1]
    return coords, eigenvalues


def make_json(
    freq_list: list[int] | None,
    var_ratio: list[float],
    cum_ratio: list[float],
    save_dir: str,
    extra: dict | None = None,
) -> None:
    os.makedirs(save_dir, exist_ok=True)
    data = {
        "freq_list": freq_list,
        "variance_ratio": var_ratio,
        "cumulative_variance_ratio": cum_ratio,
        "extra": extra,
    }
    out_path = os.path.join(save_dir, "variance_explained.json")
    with open(out_path, "w") as fh:
        json.dump(data, fh, indent=4)


def _make_hover(a_vals: np.ndarray, b_vals: np.ndarray) -> dict:
    custom = np.stack([a_vals, b_vals], axis=1)
    return dict(
        customdata=custom,
        hovertemplate="a=%{customdata[0]}<br>b=%{customdata[1]}<extra></extra>",
    )


def generate_new_diffusion_plot(embedding_weights, output_file, p):
    diff_coords, _ = compute_diffusion_coords(embedding_weights, num_coords=17)
    num_plots = 16
    fig = make_subplots(
        rows=4,
        cols=4,
        subplot_titles=[f"Coord {i+1} vs {i+2}" for i in range(num_plots)],
    )

    labels = np.arange(diff_coords.shape[0]) % p
    marker_args = dict(
        color=labels,
        colorscale=[(0.0, "blue"), (0.5, "red"), (1.0, "blue")],
        cmin=0,
        cmax=p - 1,
        size=6,
    )

    plot_idx = 0
    for i in range(4):
        for j in range(4):
            x_coord = diff_coords[:, plot_idx]
            y_coord = diff_coords[:, plot_idx + 1]
            fig.add_trace(
                go.Scatter(x=x_coord, y=y_coord, mode="markers", marker=marker_args),
                row=i + 1,
                col=j + 1,
            )
            plot_idx += 1

    fig.update_layout(
        height=1000,
        width=1000,
        title_text="New Diffusion Plot (16 coordinate pair plots)",
        showlegend=False,
    )
    fig.write_html(output_file, include_plotlyjs="cdn")
    print(f"New diffusion plot saved to {output_file}")


def create_2d_diffusion_figure(embedding_weights, color_values, title_text, p):
    diff_coords, _ = compute_diffusion_coords(embedding_weights, num_coords=17)
    num_plots = 16
    fig = make_subplots(
        rows=4,
        cols=4,
        subplot_titles=[f"Coord {i+1} vs {i+2}" for i in range(num_plots)],
    )

    n_points = diff_coords.shape[0]
    indices = np.arange(n_points)
    a_vals = indices // p
    b_vals = indices % p
    y_vals = (a_vals + b_vals) % p
    hover_texts = [f"a={a}, b={b}, y={y}" for a, b, y in zip(a_vals, b_vals, y_vals)]

    marker_args = dict(
        color=color_values,
        colorscale=[(0.0, "blue"), (1.0, "red")],
        cmin=0,
        cmax=p - 1,
        size=6,
    )

    plot_idx = 0
    for i in range(4):
        for j in range(4):
            x_coord = diff_coords[:, plot_idx]
            y_coord = diff_coords[:, plot_idx + 1]
            trace = go.Scatter(
                x=x_coord,
                y=y_coord,
                mode="markers",
                marker=marker_args,
                hovertext=hover_texts,
                hovertemplate="%{hovertext}<extra></extra>",
            )
            fig.add_trace(trace, row=i + 1, col=j + 1)
            plot_idx += 1

    fig.update_layout(height=1000, width=1000, title_text=title_text, showlegend=False)
    return fig


def create_3d_diffusion_figure(embedding_weights, color_values, title_text, p):
    diff_coords, _ = compute_diffusion_coords(embedding_weights, num_coords=17)
    num_plots = 15
    rows, cols = 3, 5
    specs = [[{"type": "scene"} for _ in range(cols)] for _ in range(rows)]
    fig = make_subplots(
        rows=rows,
        cols=cols,
        subplot_titles=[f"Coords {i+1}-{i+3}" for i in range(num_plots)],
        specs=specs,
        horizontal_spacing=0.05,
        vertical_spacing=0.1,
    )

    n_points = diff_coords.shape[0]
    indices = np.arange(n_points)
    a_vals = indices // p
    b_vals = indices % p
    y_vals = (a_vals + b_vals) % p
    hover_texts = [f"a={a}, b={b}, y={y}" for a, b, y in zip(a_vals, b_vals, y_vals)]

    marker_args = dict(
        size=4,
        color=color_values,
        colorscale=[(0.0, "blue"), (1.0, "red")],
        cmin=0,
        cmax=p - 1,
    )

    plot_idx = 0
    for i in range(rows):
        for j in range(cols):
            if plot_idx < num_plots:
                x_data = diff_coords[:, plot_idx]
                y_data = diff_coords[:, plot_idx + 1]
                z_data = diff_coords[:, plot_idx + 2]
                trace = go.Scatter3d(
                    x=x_data,
                    y=y_data,
                    z=z_data,
                    mode="markers",
                    marker=marker_args,
                    hovertext=hover_texts,
                    hovertemplate="%{hovertext}<extra></extra>",
                )
                fig.add_trace(trace, row=i + 1, col=j + 1)
                scene_id = f"scene{(i * cols + j + 1) if (i * cols + j + 1) > 1 else ''}"
                fig.layout[scene_id].xaxis.title = f"diff coord {plot_idx + 1}"
                fig.layout[scene_id].yaxis.title = f"diff coord {plot_idx + 2}"
                fig.layout[scene_id].zaxis.title = f"diff coord {plot_idx + 3}"
                plot_idx += 1

    fig.update_layout(height=1200, width=1800, title_text=title_text, showlegend=False)
    return fig


def create_3d_pca_figure(embedding_weights, color_values, title_text, p):
    pca_coords, _ = compute_pca_coords(embedding_weights, num_components=17)
    available_components = pca_coords.shape[1]
    if available_components < 3:
        raise ValueError("Not enough PCA components to create a 3D plot.")

    num_plots = available_components - 2
    cols = min(5, num_plots)
    rows = (num_plots + cols - 1) // cols

    specs = [[{"type": "scene"} for _ in range(cols)] for _ in range(rows)]
    fig = make_subplots(
        rows=rows,
        cols=cols,
        subplot_titles=[f"PCs {i+1}-{i+3}" for i in range(num_plots)],
        specs=specs,
        horizontal_spacing=0.05,
        vertical_spacing=0.1,
    )

    n_points = pca_coords.shape[0]
    indices = np.arange(n_points)
    a_vals = indices // p
    b_vals = indices % p
    y_vals = (a_vals + b_vals) % p
    hover_texts = [f"a={a}, b={b}, y={y}" for a, b, y in zip(a_vals, b_vals, y_vals)]

    marker_args = dict(
        size=4,
        color=color_values,
        colorscale=[(0.0, "blue"), (1.0, "red")],
        cmin=0,
        cmax=p - 1,
    )

    plot_idx = 0
    for i in range(rows):
        for j in range(cols):
            if plot_idx < num_plots:
                x_data = pca_coords[:, plot_idx]
                y_data = pca_coords[:, plot_idx + 1]
                z_data = pca_coords[:, plot_idx + 2]
                trace = go.Scatter3d(
                    x=x_data,
                    y=y_data,
                    z=z_data,
                    mode="markers",
                    marker=marker_args,
                    hovertext=hover_texts,
                    hovertemplate="%{hovertext}<extra></extra>",
                )
                fig.add_trace(trace, row=i + 1, col=j + 1)
                scene_id = f"scene{(i * cols + j + 1) if (i * cols + j + 1) > 1 else ''}"
                fig.layout[scene_id].xaxis.title = f"PCA coord {plot_idx + 1}"
                fig.layout[scene_id].yaxis.title = f"PCA coord {plot_idx + 2}"
                fig.layout[scene_id].zaxis.title = f"PCA coord {plot_idx + 3}"
                plot_idx += 1

    fig.update_layout(height=1200, width=1800, title_text=title_text, showlegend=False)
    return fig


def generate_diffusion_map_figure(embedding_weights, epoch, p, f_multiplier=1, diffusion_coords=None):
    if diffusion_coords is None:
        diffusion_coords, _ = compute_diffusion_coords(embedding_weights)

    num_points = diffusion_coords.shape[0]
    if num_points == p:
        indices = np.arange(num_points)
        labels = (f_multiplier * indices) % p
    elif num_points == p * p:
        indices = np.arange(num_points)
        a = indices // p
        b = indices % p
        labels = (a + b) % p
    else:
        labels = np.zeros(num_points)

    custom_colorscale = [(0.0, "blue"), (0.5, "red"), (1.0, "blue")]

    fig = make_subplots(
        rows=2,
        cols=2,
        subplot_titles=(
            "Coordinate 1 vs 2",
            "Coordinate 2 vs 3",
            "Coordinate 3 vs 4",
            "Coordinate 4 vs 5",
        ),
    )

    marker_args = dict(
        color=labels,
        colorscale=custom_colorscale,
        cmin=0,
        cmax=p - 1,
        size=8,
        colorbar=dict(title="(f * index) mod p"),
    )

    fig.add_trace(
        go.Scatter(x=diffusion_coords[:, 0], y=diffusion_coords[:, 1], mode="markers", marker=marker_args),
        row=1,
        col=1,
    )
    fig.add_trace(
        go.Scatter(x=diffusion_coords[:, 1], y=diffusion_coords[:, 2], mode="markers", marker=marker_args),
        row=1,
        col=2,
    )
    fig.add_trace(
        go.Scatter(x=diffusion_coords[:, 2], y=diffusion_coords[:, 3], mode="markers", marker=marker_args),
        row=2,
        col=1,
    )
    fig.add_trace(
        go.Scatter(x=diffusion_coords[:, 3], y=diffusion_coords[:, 4], mode="markers", marker=marker_args),
        row=2,
        col=2,
    )

    fig.update_xaxes(title_text="Diffusion Coordinate 1", row=1, col=1)
    fig.update_yaxes(title_text="Diffusion Coordinate 2", row=1, col=1)
    fig.update_xaxes(title_text="Diffusion Coordinate 2", row=1, col=2)
    fig.update_yaxes(title_text="Diffusion Coordinate 3", row=1, col=2)
    fig.update_xaxes(title_text="Diffusion Coordinate 3", row=2, col=1)
    fig.update_yaxes(title_text="Diffusion Coordinate 4", row=2, col=1)
    fig.update_xaxes(title_text="Diffusion Coordinate 4", row=2, col=2)
    fig.update_yaxes(title_text="Diffusion Coordinate 5", row=2, col=2)

    fig.update_layout(
        height=800,
        width=800,
        title_text=f"Diffusion Map (Epoch {epoch}, f_multiplier={f_multiplier})",
        showlegend=False,
    )
    return fig


def generate_interactive_diffusion_map_html(epoch_embedding_log, output_file, p, f_multiplier=1):
    sorted_epochs = sorted(epoch_embedding_log.keys())
    frames = []
    for epoch in sorted_epochs:
        emb_weights = np.array(epoch_embedding_log[epoch])
        diff_coords, _ = compute_diffusion_coords(emb_weights)
        fig_epoch = generate_diffusion_map_figure(
            emb_weights, epoch, p, f_multiplier=f_multiplier, diffusion_coords=diff_coords
        )
        frame = go.Frame(data=fig_epoch.data, name=str(epoch))
        frames.append(frame)
        print(f"Made diffusion plot for epoch {epoch} (f_multiplier={f_multiplier}).")

    base_epoch = sorted_epochs[0]
    base_emb_weights = np.array(epoch_embedding_log[base_epoch])
    base_diff_coords, _ = compute_diffusion_coords(base_emb_weights)
    base_fig = generate_diffusion_map_figure(
        base_emb_weights, base_epoch, p, f_multiplier=f_multiplier, diffusion_coords=base_diff_coords
    )

    slider_steps = []
    for epoch in sorted_epochs:
        slider_steps.append(
            dict(
                label=str(epoch),
                method="animate",
                args=[
                    [str(epoch)],
                    {"mode": "immediate", "frame": {"duration": 300, "redraw": True}, "transition": {"duration": 200}},
                ],
            )
        )

    base_fig.update_layout(
        updatemenus=[
            dict(
                type="buttons",
                showactive=False,
                y=1,
                x=1.1,
                xanchor="right",
                yanchor="top",
                pad={"t": 0, "r": 10},
                buttons=[
                    dict(
                        label="Play",
                        method="animate",
                        args=[
                            None,
                            {
                                "frame": {"duration": 300, "redraw": True},
                                "fromcurrent": True,
                                "transition": {"duration": 200},
                            },
                        ],
                    )
                ],
            )
        ],
        sliders=[dict(active=0, currentvalue={"prefix": "Epoch: "}, pad={"t": 50}, steps=slider_steps)],
    )

    base_fig.frames = frames
    base_fig.write_html(output_file, include_plotlyjs="cdn")
    print(f"Interactive diffusion map saved to {output_file}")


def _split_dualscale_views(coords: np.ndarray, colour: np.ndarray, p_cbar: int):
    h = int(p_cbar // 2)

    mask_orange = (colour >= h)
    mask_viridis = ~mask_orange

    coords_orange = coords.copy().astype(float)
    coords_orange[~mask_orange, :] = np.nan
    colour_orange = np.where(mask_orange, colour - h, np.nan)

    coords_viridis = coords.copy().astype(float)
    coords_viridis[~mask_viridis, :] = np.nan
    colour_viridis = np.where(mask_viridis, colour, np.nan)

    scale_orange = build_ro_scale(h)
    scale_viridis = build_vi_scale(h)

    return (coords_orange, colour_orange, h, scale_orange), (coords_viridis, colour_viridis, h, scale_viridis)


def _write_multiplot_2d(
    coords: np.ndarray,
    colour: np.ndarray,
    ctitle: str,
    out_path: str,
    p: int,
    p_cbar: int,
    colorscale: str,
    seed,
    label: str,
    tag: str,
) -> None:
    if coords.ndim == 1:
        coords = coords.reshape(-1, 1)
    if coords.ndim == 2 and coords.shape[0] == 1 and coords.shape[1] > 1:
        coords = coords.T
    if coords.ndim == 2 and coords.shape[1] == 1:
        coords = np.hstack([coords, np.zeros((coords.shape[0], 1), dtype=coords.dtype)])
    if coords.ndim != 2 or coords.shape[1] < 2:
        raise ValueError("coords must be 2-D with at least two columns after padding.")
    if PdfMerger is None:
        raise ImportError("PyPDF2 is required for PDF concatenation but could not be imported.") from _pdf2_err

    pairs = list(itertools.combinations(range(coords.shape[1]), 2))
    per_page = 32
    n_pages = math.ceil(len(pairs) / per_page)

    tmp_files: List[str] = []
    n_pts = coords.shape[0]
    side = int(math.isqrt(n_pts))
    indices = np.arange(n_pts)
    a_vals = indices // side
    b_vals = indices % side

    for page in range(n_pages):
        page_pairs = pairs[page * per_page: (page + 1) * per_page]
        n_cols, n_rows = 4, max(1, math.ceil(len(page_pairs) / 4))

        fig = make_subplots(
            rows=n_rows,
            cols=n_cols,
            subplot_titles=[f"{label}{i} vs {label}{j}" for i, j in page_pairs],
            horizontal_spacing=0.04,
            vertical_spacing=0.06,
        )

        hover_kw = _make_hover(a_vals, b_vals)
        for k, (i, j) in enumerate(page_pairs, 1):
            r, c = 1 + (k - 1) // n_cols, 1 + (k - 1) % n_cols

            step = max(1, p_cbar // 10)
            tickvals = list(range(0, p_cbar, step))
            if tickvals[-1] != p_cbar - 1:
                tickvals.append(p_cbar - 1)
            ticktext = [str(v) for v in tickvals]

            x = _jitter_if_constant(coords[:, i])
            y = _jitter_if_constant(coords[:, j])
            fig.add_trace(
                go.Scatter(
                    x=x,
                    y=y,
                    mode="markers",
                    name="",
                    showlegend=False,
                    marker=dict(
                        size=4,
                        color=colour,
                        colorscale=colorscale,
                        cmin=0,
                        cmax=p_cbar - 1,
                        line=dict(width=0),
                        showscale=(k == 1),
                        colorbar=dict(
                            title=dict(text=ctitle, side="right", font=dict(size=CBAR_TITLE_SIZE)),
                            tickvals=tickvals,
                            ticktext=ticktext,
                            tickfont=dict(size=CBAR_TICK_SIZE),
                            len=0.90,
                        ),
                    ),
                    **hover_kw,
                ),
                row=r,
                col=c,
            )

            fig.update_xaxes(title_text=f"{label}{i}", row=r, col=c)
            fig.update_yaxes(title_text=f"{label}{j}", row=r, col=c)

        fig.update_layout(
            width=1400,
            height=250 * n_rows + 100,
            title=f"{label} 2-D - seed {seed} - page {page + 1}/{n_pages} - {tag}",
            margin=dict(l=40, r=40, t=80, b=40),
        )
        fig.update_layout(font=dict(size=FONT_SIZE))

        if hasattr(fig.layout, "annotations") and fig.layout.annotations:
            for ann in fig.layout.annotations:
                if ann is not None and hasattr(ann, "font") and ann.font is not None:
                    ann.font.size = FONT_SIZE

        fig.update_xaxes(title_font=dict(size=FONT_SIZE), tickfont=dict(size=TICK_SIZE))
        fig.update_yaxes(title_font=dict(size=FONT_SIZE), tickfont=dict(size=TICK_SIZE))

        tmp_pdf = os.path.join(tempfile.gettempdir(), f"{uuid.uuid4().hex}.pdf")
        fig.write_image(tmp_pdf, format="pdf")
        html_name = os.path.basename(out_path).replace(".pdf", f"_page{page+1}.html")
        html_path = os.path.join(os.path.dirname(out_path), html_name)
        pth = Path(html_path)
        pth.parent.mkdir(parents=True, exist_ok=True)
        fig.write_html(str(pth), include_plotlyjs="cdn")

        tmp_files.append(tmp_pdf)

    merger = PdfMerger()
    for pdf in tmp_files:
        merger.append(pdf)
    merger.write(out_path)
    merger.close()

    for pdf in tmp_files:
        try:
            os.remove(pdf)
        except OSError:
            pass

    print(f"[{label} 2-D] -> {out_path}")


def _build_value_filter_js() -> str:
    return r"""
(function(){
  function makeGlobalFilter(gd){
    const data = gd.data;

    function isPointTrace(tr){
      return tr && (tr.type==='scatter' || tr.type==='scatter3d')
             && tr.marker && Array.isArray(tr.marker.color);
    }

    const pointIdx = [];
    for(let i=0;i<data.length;i++) if(isPointTrace(data[i])) pointIdx.push(i);
    if(!pointIdx.length) return;

    const N = data[pointIdx[0]].marker.color.length;

    const views = {};
    function pushView(key, idx){
      if(!views[key]){
        const vals = data[idx].marker.color.slice();
        const uniq = Array.from(new Set(vals)).sort((a,b)=>a-b);
        views[key] = {idxs:[idx], vals, uniq};
      }else{
        views[key].idxs.push(idx);
      }
    }
    pointIdx.forEach(i=>{
      const tr = data[i];
      const lg = (tr.legendgroup || '').toLowerCase();
      const nm = (tr.name || '').toLowerCase();
      if(['colour by a','colour by b','colour by c','colour by d'].includes(lg)){
        pushView(tr.legendgroup, i);
      }else if(nm.endsWith('mod g') || nm.startsWith('colour by ')){
        if(nm.startsWith('colour by ')) pushView(tr.name, i);
        else if(nm.startsWith('a '))     pushView('colour by a', i);
        else if(nm.startsWith('b '))     pushView('colour by b', i);
        else if(nm.startsWith('c '))     pushView('colour by c', i);
        else if(nm.startsWith('d '))     pushView('colour by d', i);
      }
    });
    if(!Object.keys(views).length){
      const idxs = pointIdx.slice();
      const vals = data[idxs[0]].marker.color.slice();
      const uniq = Array.from(new Set(vals)).sort((a,b)=>a-b);
      views['_ALL_'] = {idxs, vals, uniq};
    }

    const keep = new Array(N).fill(true);

    function ensureOrig(tr){
        if(!tr._orig){
            tr._orig = {
            x:(tr.x||[]).slice(),
            y:(tr.y||[]).slice(),
            z: tr.z ? tr.z.slice() : null,
            c: tr.marker.color.slice(),
            cd: tr.customdata ? tr.customdata.slice() : null
            };
        }
    }

    pointIdx.forEach(i=>ensureOrig(data[i]));

    function applyMask(){
        const keepIdx = [];
        for(let i=0;i<N;i++) if(keep[i]) keepIdx.push(i);
        pointIdx.forEach(idx=>{
            const tr = data[idx], o = tr._orig;
            const upd = {
            x:[keepIdx.map(i=>o.x[i])],
            y:[keepIdx.map(i=>o.y[i])],
            'marker.color':[keepIdx.map(i=>o.c[i])]
            };
            if(o.z)  upd.z  = [keepIdx.map(i=>o.z[i])];
            if(o.cd) upd.customdata = [keepIdx.map(i=>o.cd[i])];

            Plotly.restyle(gd, upd, [idx]);
        });
    }

    function currentView(){
      for(const [name, view] of Object.entries(views)){
        for(const idx of view.idxs){
          const v = gd.data[idx].visible;
          if(v===true || v===undefined) return name;
        }
      }
      return Object.keys(views)[0];
    }

    function ensurePanel(){
      if(gd._vf_panel) return gd._vf_panel;
      const host = gd.parentNode; host.style.position='relative';
      const panel = document.createElement('div');
      panel.style.cssText = 'position:absolute;bottom:8px;right:20px;'
        +'background:rgba(255,255,255,0.94);border:1px solid #ccc;border-radius:8px;'
        +'padding:8px 10px;max-width:320px;max-height:220px;overflow:auto;'
        +'font:12px system-ui,sans-serif;box-shadow:0 2px 6px rgba(0,0,0,0.12);z-index:10;';
      const title = document.createElement('div'); title.style.cssText='font-weight:600;margin-bottom:6px';
      const toolbar = document.createElement('div'); toolbar.style.cssText='display:flex;gap:6px;margin-bottom:6px';
      const btnReset = document.createElement('span');
      btnReset.textContent='Reset';
      btnReset.style.cssText='padding:2px 8px;border:1px solid #999;border-radius:10px;cursor:pointer;background:#f5f5f5';
      btnReset.onclick = ()=>{ for(let i=0;i<N;i++) keep[i]=true; applyMask(); refresh(); };
      const btnHideAll = document.createElement('span');
      btnHideAll.textContent='Hide all';
      btnHideAll.style.cssText='padding:2px 8px;border:1px solid #999;border-radius:10px;cursor:pointer;background:#f5f5f5';
      btnHideAll.onclick = ()=>{ for(let i=0;i<N;i++) keep[i]=false; applyMask(); refresh(); };
      toolbar.appendChild(btnReset); toolbar.appendChild(btnHideAll);
      const row = document.createElement('div'); row.style.cssText='display:flex;flex-wrap:wrap;gap:6px';
      panel.appendChild(title); panel.appendChild(toolbar); panel.appendChild(row);
      host.appendChild(panel);
      gd._vf_panel = {panel,row,title};
      return gd._vf_panel;
    }

    function countsForValue(view, v){
      let tot=0, vis=0;
      for(let i=0;i<N;i++){
        if(view.vals[i]===v){ tot++; if(keep[i]) vis++; }
      }
      return {vis, tot};
    }

    function chipsFor(name){
      const view = views[name];
      const {row, title} = ensurePanel();
      row.innerHTML=''; title.textContent = `Filter by value - ${name}`;
      view.uniq.forEach(val=>{
        const {vis, tot} = countsForValue(view, val);
        const chip = document.createElement('span');
        chip.textContent = `${val} (${vis}/${tot})`;
        chip.style.cssText='padding:2px 8px;border:1px solid #999;border-radius:12px;cursor:pointer;user-select:none;';
        if(vis===0) chip.style.background='#ddd';
        else if(vis<tot) chip.style.background='#fff6cc';
        chip.onclick = ()=>{
          const hide = (vis>0);
          for(let i=0;i<N;i++) if(view.vals[i]===val) keep[i] = !hide;
          applyMask();
          const c2 = countsForValue(view, val);
          chip.textContent = `${val} (${c2.vis}/${c2.tot})`;
          chip.style.background = (c2.vis===0) ? '#ddd' : (c2.vis<c2.tot ? '#fff6cc' : '');
        };
        row.appendChild(chip);
      });
    }
    function refresh(){ chipsFor(currentView()); }

    applyMask(); refresh();
    gd.on('plotly_restyle', d=>{
      if(d && d[0] && ('visible' in d[0])){ applyMask(); refresh(); }
    });
    gd.on('plotly_relayout', ()=>{ refresh(); });
  }
  var gd = document.getElementsByClassName('plotly-graph-div')[0];
  if(gd) makeGlobalFilter(gd);
})();
"""


def _write_multiplot_3d(
    coords: np.ndarray,
    colour: np.ndarray,
    ctitle: str,
    out_path: str,
    p: int,
    p_cbar: int,
    colorscale: str,
    seed,
    label: str,
    tag: str,
    f: int,
    mult: bool,
    write_pdf: bool = True,
    show_fit_plane: bool = False,
    plane_split_mode: str = "c",
    plane_opacity: float = 0.35,
):
    n_pts = coords.shape[0]
    n_dim = coords.shape[1]
    if n_dim < 3:
        coords = np.pad(coords, ((0, 0), (0, 3 - n_dim)), mode="constant")
        n_dim = 3

    g = p // math.gcd(p, f) or p
    side = int(math.isqrt(n_pts))
    multi_view = (side == 2 * p) and (mult)
    use_dims = min(n_dim, 4)
    if use_dims == 3:
        triplets = [(0, 1, 2)]
        nrows, ncols = 1, 1
    else:
        triplets = list(itertools.combinations(range(use_dims), 3))[:4]
        nrows, ncols = (2, 2)

    fig = make_subplots(
        rows=nrows,
        cols=ncols,
        specs=[[{"type": "scene"} for _ in range(ncols)] for _ in range(nrows)],
        subplot_titles=[f"{label}{i} vs {label}{j} vs {label}{k}" for i, j, k in triplets],
        horizontal_spacing=0.03,
        vertical_spacing=0.03,
    )

    idxs = np.arange(n_pts)
    a_vals = idxs // side
    b_vals = idxs % side
    hover_kw = _make_hover(a_vals, b_vals)

    can_planes = False
    rot_mask = ref_mask = None
    if show_fit_plane and (side == 2 * p):
        try:
            rot_mask, ref_mask = analysis.plane_fit._split_masks(a_vals, b_vals, p, plane_split_mode, tag_q="full")
            can_planes = True
        except Exception:
            can_planes = False

    if not multi_view:
        for s_idx, (i, j, k) in enumerate(triplets, 1):
            row, col = (1, s_idx) if s_idx <= 2 else (2, s_idx - 2)
            x = _jitter_if_constant(coords[:, i])
            y = _jitter_if_constant(coords[:, j])

            fig.add_trace(
                go.Scatter3d(
                    x=x,
                    y=y,
                    z=coords[:, k],
                    mode="markers",
                    marker=dict(
                        size=3,
                        color=colour,
                        colorscale=colorscale,
                        cmin=0,
                        cmax=p_cbar - 1,
                        showscale=(s_idx == 1),
                        colorbar=dict(
                            title=dict(text=ctitle, side="right", font=dict(size=CBAR_TITLE_SIZE)),
                            tickfont=dict(size=CBAR_TICK_SIZE),
                            len=0.90,
                        ),
                    ),
                    **hover_kw,
                ),
                row=row,
                col=col,
            )

            sid = f"scene{s_idx if s_idx > 1 else ''}"
            fig.layout[sid].xaxis.title.text = f"{label}{i}"
            fig.layout[sid].yaxis.title.text = f"{label}{j}"
            fig.layout[sid].zaxis.title.text = f"{label}{k}"

            if can_planes:
                P = coords
                if np.count_nonzero(rot_mask) >= 3:
                    out = analysis.plane_fit._plane_mesh_from_points(P[rot_mask], (i, j, k))
                    if out is not None:
                        Xr, Yr, Zr = out
                        fig.add_trace(
                            go.Surface(
                                x=Xr,
                                y=Yr,
                                z=Zr,
                                surfacecolor=np.zeros_like(Xr),
                                cmin=0,
                                cmax=1,
                                colorscale=[(0.0, "rgb(0,120,255)"), (1.0, "rgb(0,120,255)")],
                                showscale=False,
                                opacity=plane_opacity,
                                name="plane (rot)",
                                legendgroup="planes",
                                showlegend=True,
                                visible=False,
                            ),
                            row=row,
                            col=col,
                        )
                if np.count_nonzero(ref_mask) >= 3:
                    out = analysis.plane_fit._plane_mesh_from_points(P[ref_mask], (i, j, k))
                    if out is not None:
                        Xf, Yf, Zf = out
                        fig.add_trace(
                            go.Surface(
                                x=Xf,
                                y=Yf,
                                z=Zf,
                                surfacecolor=np.ones_like(Xf),
                                cmin=0,
                                cmax=1,
                                colorscale=[(0.0, "rgb(255,120,0)"), (1.0, "rgb(255,120,0)")],
                                showscale=False,
                                opacity=plane_opacity,
                                name="plane (ref)",
                                legendgroup="planes",
                                showlegend=True,
                                visible=False,
                            ),
                            row=row,
                            col=col,
                        )
    else:
        col_a, _, _, cs_a = colour_quad_a_only(a_vals, b_vals, p, f, "full")
        col_b, _, _, cs_b = colour_quad_b_only(a_vals, b_vals, p, f, "full")
        col_c, _, pcbar_c, cs_c = colour_quad_mod_g(a_vals, b_vals, p, f, "full")
        col_a = np.asarray(col_a)
        col_b = np.asarray(col_b)
        col_c = np.asarray(col_c)

        d = step_size(f, p)
        h_pairs = []
        v_pairs = []
        c_pairs = []
        n_h = n_v = n_c = 0
        legend_shown_a, legend_shown_b, legend_shown_c = set(), set(), set()
        if g != p:
            h_pairs = lines_a_mod_g_step(a_vals, b_vals, p, g, d)
            v_pairs = lines_b_mod_g_step(a_vals, b_vals, p, g, d)
            c_pairs = lines_c_mod_g_step(a_vals, b_vals, p, g, d)
            n_h, n_v, n_c = len(h_pairs), len(v_pairs), len(c_pairs)

        for s_idx, (i, j, k) in enumerate(triplets, 1):
            row, col = (1, s_idx) if s_idx <= 2 else (2, s_idx - 2)
            x = _jitter_if_constant(coords[:, i])
            y = _jitter_if_constant(coords[:, j])

            fig.add_trace(
                go.Scatter3d(
                    x=x,
                    y=y,
                    z=coords[:, k],
                    mode="markers",
                    marker=dict(
                        size=3,
                        color=col_a,
                        colorscale=cs_a,
                        cmin=0,
                        cmax=2 * g - 1,
                        showscale=(s_idx == 1),
                        colorbar=dict(
                            title=dict(text=ctitle, side="right", font=dict(size=CBAR_TITLE_SIZE)),
                            tickfont=dict(size=CBAR_TICK_SIZE),
                            len=0.90,
                        ),
                    ),
                    name="a mod g",
                    legendgroup="a",
                    visible=True,
                    **hover_kw,
                ),
                row=row,
                col=col,
            )

            fig.add_trace(
                go.Scatter3d(
                    x=x,
                    y=y,
                    z=coords[:, k],
                    mode="markers",
                    marker=dict(
                        size=3,
                        color=col_b,
                        colorscale=cs_b,
                        cmin=0,
                        cmax=2 * g - 1,
                        showscale=(s_idx == 1),
                        colorbar=dict(
                            title=dict(text=ctitle, side="right", font=dict(size=CBAR_TITLE_SIZE)),
                            tickfont=dict(size=CBAR_TICK_SIZE),
                            len=0.90,
                        ),
                    ),
                    name="b mod g",
                    legendgroup="b",
                    visible=False,
                    **hover_kw,
                ),
                row=row,
                col=col,
            )

            fig.add_trace(
                go.Scatter3d(
                    x=x,
                    y=y,
                    z=coords[:, k],
                    mode="markers",
                    marker=dict(
                        size=3,
                        color=col_c,
                        colorscale=cs_c,
                        cmin=0,
                        cmax=pcbar_c - 1,
                        showscale=(s_idx == 1),
                        colorbar=dict(
                            title=dict(text=ctitle, side="right", font=dict(size=CBAR_TITLE_SIZE)),
                            tickfont=dict(size=CBAR_TICK_SIZE),
                            len=0.90,
                        ),
                    ),
                    name="c mod g",
                    legendgroup="c",
                    showlegend=(s_idx == 1),
                    visible=False,
                    **hover_kw,
                ),
                row=row,
                col=col,
            )

            for idx_ordered, dash, color, gid in h_pairs:
                idx_plot = idx_ordered
                if idx_plot.size > 2:
                    idx_plot = np.concatenate([idx_plot, idx_plot[:1]])
                show_legend = gid not in legend_shown_a
                legend_shown_a.add(gid)
                fig.add_trace(
                    go.Scatter3d(
                        x=coords[idx_plot, i],
                        y=coords[idx_plot, j],
                        z=coords[idx_plot, k],
                        mode="lines",
                        name=gid,
                        legendgroup=gid,
                        showlegend=show_legend,
                        line=dict(color=color, dash=dash, width=1.2),
                        hoverinfo="skip",
                        visible=True,
                    ),
                    row=row,
                    col=col,
                )

            for idx_ordered, dash, color, gid in v_pairs:
                idx_plot = idx_ordered
                if idx_plot.size > 2:
                    idx_plot = np.concatenate([idx_plot, idx_plot[:1]])
                show_legend = gid not in legend_shown_b
                legend_shown_b.add(gid)
                fig.add_trace(
                    go.Scatter3d(
                        x=coords[idx_plot, i],
                        y=coords[idx_plot, j],
                        z=coords[idx_plot, k],
                        mode="lines",
                        name=gid,
                        legendgroup=gid,
                        showlegend=show_legend,
                        line=dict(color=color, dash=dash, width=1.2),
                        hoverinfo="skip",
                        visible=False,
                    ),
                    row=row,
                    col=col,
                )

            for idx_ordered, dash, color, gid in c_pairs:
                idx_plot = idx_ordered
                if idx_plot.size > 2:
                    idx_plot = np.concatenate([idx_plot, idx_plot[:1]])
                show_legend = gid not in legend_shown_c
                legend_shown_c.add(gid)
                fig.add_trace(
                    go.Scatter3d(
                        x=coords[idx_plot, i],
                        y=coords[idx_plot, j],
                        z=coords[idx_plot, k],
                        mode="lines",
                        name=gid,
                        legendgroup=gid,
                        showlegend=show_legend,
                        line=dict(color=color, dash=dash, width=1.2),
                        hoverinfo="skip",
                        visible=False,
                    ),
                    row=row,
                    col=col,
                )

            sid = f"scene{s_idx if s_idx > 1 else ''}"
            fig.layout[sid].xaxis.title.text = f"{label}{i}"
            fig.layout[sid].yaxis.title.text = f"{label}{j}"
            fig.layout[sid].zaxis.title.text = f"{label}{k}"

            if can_planes:
                P = coords
                if np.count_nonzero(rot_mask) >= 3:
                    out = analysis.plane_fit._plane_mesh_from_points(P[rot_mask], (i, j, k))
                    if out is not None:
                        Xr, Yr, Zr = out
                        fig.add_trace(
                            go.Surface(
                                x=Xr,
                                y=Yr,
                                z=Zr,
                                surfacecolor=np.zeros_like(Xr),
                                cmin=0,
                                cmax=1,
                                colorscale=[(0.0, "rgb(0,120,255)"), (1.0, "rgb(0,120,255)")],
                                showscale=False,
                                opacity=plane_opacity,
                                name="plane (rot)",
                                legendgroup="planes",
                                showlegend=True,
                                visible=False,
                            ),
                            row=row,
                            col=col,
                        )
                if np.count_nonzero(ref_mask) >= 3:
                    out = analysis.plane_fit._plane_mesh_from_points(P[ref_mask], (i, j, k))
                    if out is not None:
                        Xf, Yf, Zf = out
                        fig.add_trace(
                            go.Surface(
                                x=Xf,
                                y=Yf,
                                z=Zf,
                                surfacecolor=np.ones_like(Xf),
                                cmin=0,
                                cmax=1,
                                colorscale=[(0.0, "rgb(255,120,0)"), (1.0, "rgb(255,120,0)")],
                                showscale=False,
                                opacity=plane_opacity,
                                name="plane (ref)",
                                legendgroup="planes",
                                showlegend=True,
                                visible=False,
                            ),
                            row=row,
                            col=col,
                        )

        vis_a, vis_b, vis_c = [], [], []
        for _ in range(len(triplets)):
            vis_a += [True, False, False]
            vis_b += [False, True, False]
            vis_c += [False, False, True]
            if n_h or n_v or n_c:
                vis_a += [True] * n_h + [False] * n_v + [False] * n_c
                vis_b += [False] * n_h + [True] * n_v + [False] * n_c
                vis_c += [False] * n_h + [False] * n_v + [True] * n_c

        if can_planes:
            is_plane = [
                ("planes" == getattr(t, "legendgroup", None) or (isinstance(getattr(t, "name", ""), str) and "plane (" in t.name))
                for t in fig.data
            ]

            vis_off = [False if is_plane[i] else None for i in range(len(fig.data))]
            vis_rot = [("plane (rot)" == getattr(t, "name", "")) if is_plane[k] else None for k, t in enumerate(fig.data)]
            vis_ref = [("plane (ref)" == getattr(t, "name", "")) if is_plane[k] else None for k, t in enumerate(fig.data)]
            vis_both = [True if is_plane[i] else None for i in range(len(fig.data))]

            def _merge_visible(template):
                out = []
                for k in range(len(fig.data)):
                    if template[k] is None:
                        out.append(fig.data[k].visible)
                    else:
                        out.append(template[k])
                return out

            btn_planes = dict(
                type="buttons",
                direction="right",
                x=0.01,
                y=1.08,
                xanchor="left",
                yanchor="top",
                buttons=[
                    dict(label="Plane: Off", method="update", args=[{"visible": _merge_visible(vis_off)}]),
                    dict(label="Plane: rot", method="update", args=[{"visible": _merge_visible(vis_rot)}]),
                    dict(label="Plane: ref", method="update", args=[{"visible": _merge_visible(vis_ref)}]),
                    dict(label="Plane: both", method="update", args=[{"visible": _merge_visible(vis_both)}]),
                ],
                pad={"t": 0, "r": 6},
            )

            if "updatemenus" in fig.layout and fig.layout.updatemenus:
                fig.update_layout(updatemenus=list(fig.layout.updatemenus) + [btn_planes])
            else:
                fig.update_layout(updatemenus=[btn_planes])

        menus = list(fig.layout.updatemenus) if getattr(fig.layout, "updatemenus", None) else []
        menus.append(
            dict(
                buttons=[
                    dict(label="a mod g", method="update", args=[{"visible": vis_a}, {"title": "colour = a mod g"}]),
                    dict(label="b mod g", method="update", args=[{"visible": vis_b}, {"title": "colour = b mod g"}]),
                    dict(label="c mod g", method="update", args=[{"visible": vis_c}, {"title": "colour = c mod g"}]),
                ],
                direction="down",
                x=0.99,
                y=1.05,
                xanchor="left",
                pad={"t": 0, "r": 6},
            )
        )
        if can_planes:
            menus.append(btn_planes)
        fig.update_layout(updatemenus=menus)
        fig.update_layout(legend=LEGEND_POS)

    for layout_key in fig.layout:
        if str(layout_key).startswith("scene"):
            scene = fig.layout[layout_key]
            scene.xaxis.title.font = dict(size=FONT_SIZE)
            scene.yaxis.title.font = dict(size=FONT_SIZE)
            scene.zaxis.title.font = dict(size=FONT_SIZE)
            scene.xaxis.tickfont = dict(size=TICK_SIZE)
            scene.yaxis.tickfont = dict(size=TICK_SIZE)
            scene.zaxis.tickfont = dict(size=TICK_SIZE)

    fig.update_layout(
        width=1600,
        height=1200,
        title=f"{label} 3-D (first 4) - seed {seed} - {tag}",
        margin=dict(l=40, r=40, t=80, b=40),
        font=dict(size=FONT_SIZE),
    )

    p_out = Path(out_path)
    p_out.parent.mkdir(parents=True, exist_ok=True)
    if write_pdf:
        fig.write_image(out_path, format="pdf")

    fig.write_html(
        os.path.splitext(out_path)[0] + ".html",
        include_plotlyjs="cdn",
        full_html=True,
        post_script=_build_value_filter_js(),
    )
    print(f"[{label} 3-D] -> {out_path}")


def save_homology_artifacts(
    coords: np.ndarray,
    root_dir: str,
    tag: str,
    seed,
    label: str,
    num_dims: int | None = 2,
) -> None:
    subdir = os.path.join(root_dir, "homology", tag)
    stem = f"{label.lower()}_seed_{seed}"

    if coords.ndim == 1:
        coords = coords.reshape(-1, 1)
    d_avail = coords.shape[1]

    if num_dims is None:
        keep_dims = d_avail
    else:
        if num_dims < 1:
            raise ValueError("num_dims must be >= 1")
        keep_dims = min(num_dims, d_avail)
        if num_dims > d_avail:
            print(f"[PH] requested num_dims={num_dims} > available={d_avail}; falling back to {keep_dims}")

    coords_to_use = coords[:, :keep_dims]
    n_nbrs = 300 if keep_dims == 2 else 150

    run_ph_for_point_cloud(
        coords_to_use,
        maxdim=2,
        ph_sparse=True,
        n_nbrs=n_nbrs,
        save_dir=subdir,
        filename_stem=stem,
        title=f"{label} (seed={seed})",
    )


def _make_single_freq_phase_plots(
    mat: np.ndarray,
    p: int,
    f: int,
    save_dir: str,
    *,
    seed: int | str = "",
    tag: str = "",
    colour_scale: str = "Viridis",
    eps: float = 0.16,
) -> None:
    f = int(f) % p
    if f == 0:
        print("[phase-plots] f == 0 (mod p) - skipped.")
        return

    n_neurons = mat.shape[1]
    amps = np.empty(n_neurons)
    phi_a = np.empty(n_neurons)
    phi_b = np.empty(n_neurons)

    for n in range(n_neurons):
        grid = mat[:, n].reshape(p, p).T
        F = np.fft.fft2(grid) / (p * p)
        ca, cb = F[f, 0], F[0, f]
        amps[n] = np.hypot(2 * np.abs(ca), 2 * np.abs(cb))
        phi_a[n] = (-np.angle(ca)) % (2 * np.pi)
        phi_b[n] = (-np.angle(cb)) % (2 * np.pi)

    unpicked = set(range(n_neurons))
    m_phi_a, m_phi_b, m_amp = [], [], []

    def torus_dist(x1, y1, x2, y2):
        dx = np.abs(x1 - x2)
        dx = np.minimum(dx, 2 * np.pi - dx)
        dy = np.abs(y1 - y2)
        dy = np.minimum(dy, 2 * np.pi - dy)
        return np.sqrt(dx * dx + dy * dy)

    while unpicked:
        i = unpicked.pop()
        group = [i]
        for j in list(unpicked):
            if torus_dist(phi_a[i], phi_b[i], phi_a[j], phi_b[j]) <= eps:
                unpicked.remove(j)
                group.append(j)

        A = amps[group]
        ang_ax = np.arctan2((A * np.sin(phi_a[group])).sum(), (A * np.cos(phi_a[group])).sum()) % (2 * np.pi)
        ang_bx = np.arctan2((A * np.sin(phi_b[group])).sum(), (A * np.cos(phi_b[group])).sum()) % (2 * np.pi)

        m_phi_a.append(ang_ax)
        m_phi_b.append(ang_bx)
        m_amp.append(A.sum())

    m_phi_a = np.asarray(m_phi_a)
    m_phi_b = np.asarray(m_phi_b)
    m_amp = np.asarray(m_amp)

    fig = make_subplots(
        rows=2,
        cols=2,
        subplot_titles=["raw scatter", "raw vectors", "merged scatter", "merged vectors"],
        horizontal_spacing=0.12,
        vertical_spacing=0.15,
    )

    fig.add_trace(
        go.Scatter(
            x=phi_a,
            y=phi_b,
            mode="markers",
            marker=dict(size=6, color=amps, colorscale=colour_scale, colorbar=dict(title="amplitude")),
            hovertemplate="phi_a=%{x:.2f}<br>phi_b=%{y:.2f}<br>amp=%{marker.color:.3f}<extra></extra>",
        ),
        row=1,
        col=1,
    )

    for pa, pb in zip(phi_a, phi_b):
        fig.add_trace(
            go.Scatter(x=[0, pa], y=[0, pb], mode="lines", line=dict(width=1.5, color="rgba(0,0,0,0.5)"), hoverinfo="skip"),
            row=1,
            col=2,
        )
    fig.add_trace(
        go.Scatter(
            x=phi_a,
            y=phi_b,
            mode="markers",
            marker=dict(size=6, color=amps, colorscale=colour_scale, showscale=False),
            hovertemplate="phi_a=%{x:.2f}<br>phi_b=%{y:.2f}<br>amp=%{marker.color:.3f}<extra></extra>",
        ),
        row=1,
        col=2,
    )

    fig.add_trace(
        go.Scatter(
            x=m_phi_a,
            y=m_phi_b,
            mode="markers+text",
            marker=dict(size=12, color=m_amp, colorscale=colour_scale, showscale=False, line=dict(width=1, color="black")),
            text=[f"{a:.1f}" for a in m_amp],
            textposition="top center",
            hovertemplate="[merged]<br>phi_a=%{x:.2f}<br>phi_b=%{y:.2f}<br>amp=%{marker.color:.3f}<extra></extra>",
        ),
        row=2,
        col=1,
    )

    for pa, pb in zip(m_phi_a, m_phi_b):
        fig.add_trace(
            go.Scatter(x=[0, pa], y=[0, pb], mode="lines", line=dict(width=2, color="rgba(0,0,0,0.6)"), hoverinfo="skip"),
            row=2,
            col=2,
        )
    fig.add_trace(
        go.Scatter(
            x=m_phi_a,
            y=m_phi_b,
            mode="markers+text",
            marker=dict(size=12, color=m_amp, colorscale=colour_scale, showscale=False, line=dict(width=1, color="black")),
            text=[f"{a:.1f}" for a in m_amp],
            textposition="top center",
            hovertemplate="[merged]<br>phi_a=%{x:.2f}<br>phi_b=%{y:.2f}<br>amp=%{marker.color:.3f}<extra></extra>",
        ),
        row=2,
        col=2,
    )

    for r in (1, 2):
        for c in (1, 2):
            fig.update_xaxes(title_text="phi_a (rad)", row=r, col=c)
            fig.update_yaxes(title_text="phi_b (rad)", row=r, col=c)

    fig.update_layout(
        width=1100,
        height=900,
        title=f"Seed {seed} - f = {f} - {tag}",
        margin=dict(l=60, r=60, t=80, b=60),
        showlegend=False,
    )

    phase_dir = os.path.join(save_dir, "phase_plots")
    os.makedirs(phase_dir, exist_ok=True)
    fname_pdf = f"seed_{seed}_f{f}{'_'+tag if tag else ''}.pdf"
    out_path = os.path.join(phase_dir, fname_pdf)
    fig.write_image(out_path, format="pdf")
    print(f"[phase-plots] wrote {out_path}")


def first_component_exceeding_thresholds(pca_cum_ratio, thresholds=(0.5, 0.6, 0.7, 0.8, 0.9)):
    cum = np.asarray(pca_cum_ratio, dtype=float)
    out = {}
    for t in thresholds:
        idx = int(np.searchsorted(cum, t, side="left"))
        out[str(t)] = (idx + 1) if idx < cum.size else None
    return out


def _generate_pdf_plots_for_matrix(
    mat: np.ndarray,
    p: int,
    save_dir: str,
    *,
    seed: int | str = "",
    freq_list: list[int] | None = None,
    tag=None,
    tag_q: str = "",
    class_string: str = "",
    colour_rule=None,
    num_principal_components=2,
) -> None:
    n_samples, n_features = mat.shape
    if n_samples < 2:
        raise ValueError("Need at least 2 samples to compute diffusion coordinates.")
    num_components = min(n_features, 8, n_samples - 1)

    if colour_rule in (
        colour_quad_a_only,
        colour_quad_b_only,
        colour_quad_mul_f,
        colour_quad_mod_g,
        colour_quad_mod_g_no_fb,
        colour_c_mod_p,
    ):
        mult = True
    else:
        raise ValueError(f"Unsupported colour_rule: {colour_rule!r}")

    if num_components >= 4:
        append_to_title = f"{tag} & {class_string}"
        freq_list = sorted(freq_list or [])
        os.makedirs(save_dir, exist_ok=True)

        is_grid_pp = (n_samples == p * p)
        is_grid_2pp = (n_samples == (2 * p) * (2 * p))

        if tag_q == "full" and is_grid_2pp:
            side = 2 * p
            indices = np.arange(side * side)
            a_vals = indices // side
            b_vals = indices % side
        elif is_grid_pp:
            indices = np.arange(p * p)
            a_vals = indices // p
            b_vals = indices % p
        else:
            a_vals = None
            b_vals = None

        pca_root = os.path.join(save_dir, "pca_pdf_plots")
        dif_root = os.path.join(save_dir, "diffusion_pdf_plots")
        for root in (pca_root, dif_root):
            for sub in ("2d", "3d"):
                os.makedirs(os.path.join(root, sub, tag), exist_ok=True)

        print("computing PCA")
        pcs, pca = compute_pca_coords(mat, num_components=num_components)

        base_2d_dir = os.path.join(pca_root, "2d", tag)
        pca_var_ratio = pca.explained_variance_ratio_.tolist()
        pca_cum_ratio = np.cumsum(pca.explained_variance_ratio_).tolist()
        cum_ratio_first_ge = first_component_exceeding_thresholds(
            pca_cum_ratio,
            thresholds=(0.5, 0.6, 0.7, 0.8, 0.9),
        )
        make_json(
            freq_list,
            pca_var_ratio,
            pca_cum_ratio,
            base_2d_dir,
            extra=dict(cum_ratio_first_ge=cum_ratio_first_ge, num_components=int(num_components)),
        )

        base_3d_dir = os.path.join(pca_root, "3d", tag)
        make_json(freq_list, pca_var_ratio, pca_cum_ratio, base_3d_dir)

        if (a_vals is not None) and (b_vals is not None):
            for f in freq_list:
                f_abs = abs(int(f))
                if f_abs % p == 0:
                    continue
                if colour_rule is None:
                    raise ValueError("Color rule empty.")
                colour, caption, p_cbar, colorscale = colour_rule(a_vals, b_vals, p, f_abs, tag_q)

                name_stub = f"pca_seed_{seed}_freq_{f}.pdf"
                _write_multiplot_2d(
                    pcs,
                    colour,
                    caption,
                    os.path.join(pca_root, "2d", tag, tag_q, name_stub.replace(".pdf", "_2d.pdf")),
                    p,
                    p_cbar,
                    colorscale,
                    seed,
                    "PC",
                    append_to_title,
                )
                _write_multiplot_3d(
                    pcs,
                    colour,
                    caption,
                    os.path.join(pca_root, "3d", tag, tag_q, name_stub.replace(".pdf", "_3d.pdf")),
                    p,
                    p_cbar,
                    colorscale,
                    seed,
                    "PC",
                    append_to_title,
                    f=f,
                    mult=mult,
                    show_fit_plane=False,
                )

                if colour_rule in (colour_quad_mod_g, colour_quad_mul_f, colour_quad_a_only, colour_quad_b_only) and tag_q == "full":
                    (pcs_orange, col_orange, h, scale_orange), (pcs_viridis, col_viridis, h2, scale_viridis) = _split_dualscale_views(
                        pcs, colour, p_cbar
                    )

                    _write_multiplot_2d(
                        pcs_orange,
                        col_orange,
                        f"{caption} - TL/BR (Orange only)",
                        os.path.join(pca_root, "2d", tag, tag_q, name_stub.replace(".pdf", "_2d_orange.pdf")),
                        p,
                        h,
                        scale_orange,
                        seed,
                        "PC",
                        append_to_title,
                    )
                    _write_multiplot_3d(
                        pcs_orange,
                        col_orange,
                        f"{caption} - TL/BR (Orange only)",
                        os.path.join(pca_root, "3d", tag, tag_q, name_stub.replace(".pdf", "_3d_orange.pdf")),
                        p,
                        h,
                        scale_orange,
                        seed,
                        "PC",
                        append_to_title,
                        f=f,
                        mult=False,
                    )

                    _write_multiplot_2d(
                        pcs_viridis,
                        col_viridis,
                        f"{caption} - BL/TR (Viridis only)",
                        os.path.join(pca_root, "2d", tag, tag_q, name_stub.replace(".pdf", "_2d_viridis.pdf")),
                        p,
                        h2,
                        scale_viridis,
                        seed,
                        "PC",
                        append_to_title,
                    )
                    _write_multiplot_3d(
                        pcs_viridis,
                        col_viridis,
                        f"{caption} - BL/TR (Viridis only)",
                        os.path.join(pca_root, "3d", tag, tag_q, name_stub.replace(".pdf", "_3d_viridis.pdf")),
                        p,
                        h2,
                        scale_viridis,
                        seed,
                        "PC",
                        append_to_title,
                        f=f,
                        mult=False,
                    )

        save_homology_artifacts(
            pcs,
            root_dir=pca_root,
            tag=tag_q,
            seed=seed,
            label=f"PCA--{class_string}",
            num_dims=num_principal_components,
        )

        dmap, eigenvalues = compute_diffusion_coords(mat, num_coords=num_components)

        nontriv = np.abs(eigenvalues[1:17])
        total = nontriv.sum()
        if total > 0:
            diff_var_ratio = (nontriv / total).tolist()
            diff_cum_ratio = np.cumsum(nontriv / total).tolist()
        else:
            diff_var_ratio = [0.0] * 16
            diff_cum_ratio = [0.0] * 16

        base_2d_d_dir = os.path.join(dif_root, "2d", tag)
        make_json(freq_list, diff_var_ratio, diff_cum_ratio, base_2d_d_dir)
        base_3d_d_dir = os.path.join(dif_root, "3d", tag)
        make_json(freq_list, diff_var_ratio, diff_cum_ratio, base_3d_d_dir)

        if (a_vals is not None) and (b_vals is not None):
            for f in freq_list:
                f_abs = abs(int(f))
                if f_abs % p == 0:
                    continue
                if colour_rule is None:
                    raise ValueError("Color rule empty.")
                colour, caption, p_cbar, colorscale = colour_rule(a_vals, b_vals, p, f_abs, tag_q)

                name_stub = f"diff_seed_{seed}_freq_{f}.pdf"
                _write_multiplot_2d(
                    dmap,
                    colour,
                    caption,
                    os.path.join(dif_root, "2d", tag, tag_q, name_stub.replace(".pdf", "_2d.pdf")),
                    p,
                    p_cbar,
                    colorscale,
                    seed,
                    "DM",
                    append_to_title,
                )
                _write_multiplot_3d(
                    dmap,
                    colour,
                    caption,
                    os.path.join(dif_root, "3d", tag, tag_q, name_stub.replace(".pdf", "_3d.pdf")),
                    p,
                    p_cbar,
                    colorscale,
                    seed,
                    "DM",
                    append_to_title,
                    f=f,
                    mult=mult,
                )

                if colour_rule in (colour_quad_mod_g, colour_quad_mul_f, colour_quad_a_only, colour_quad_b_only) and tag_q == "full":
                    (dmap_orange, col_orange, h, scale_orange), (dmap_viridis, col_viridis, h2, scale_viridis) = _split_dualscale_views(
                        dmap, colour, p_cbar
                    )

                    _write_multiplot_2d(
                        dmap_orange,
                        col_orange,
                        f"{caption} - TL/BR (Orange only)",
                        os.path.join(dif_root, "2d", tag, tag_q, name_stub.replace(".pdf", "_2d_orange.pdf")),
                        p,
                        h,
                        scale_orange,
                        seed,
                        "DM",
                        append_to_title,
                    )
                    _write_multiplot_3d(
                        dmap_orange,
                        col_orange,
                        f"{caption} - TL/BR (Orange only)",
                        os.path.join(dif_root, "3d", tag, tag_q, name_stub.replace(".pdf", "_3d_orange.pdf")),
                        p,
                        h,
                        scale_orange,
                        seed,
                        "DM",
                        append_to_title,
                        f=f,
                        mult=False,
                    )

                    _write_multiplot_2d(
                        dmap_viridis,
                        col_viridis,
                        f"{caption} - BL/TR (Viridis only)",
                        os.path.join(dif_root, "2d", tag, tag_q, name_stub.replace(".pdf", "_2d_viridis.pdf")),
                        p,
                        h2,
                        scale_viridis,
                        seed,
                        "DM",
                        append_to_title,
                    )
                    _write_multiplot_3d(
                        dmap_viridis,
                        col_viridis,
                        f"{caption} - BL/TR (Viridis only)",
                        os.path.join(dif_root, "3d", tag, tag_q, name_stub.replace(".pdf", "_3d_viridis.pdf")),
                        p,
                        h2,
                        scale_viridis,
                        seed,
                        "DM",
                        append_to_title,
                        f=f,
                        mult=False,
                    )

        save_homology_artifacts(
            dmap,
            root_dir=dif_root,
            tag=tag_q,
            seed=seed,
            label=f"Dif--{class_string}",
            num_dims=num_principal_components,
        )

        print("All PCA / diffusion PDF plots written.")

        bundle_dir = os.path.join(save_dir, "bundles", tag)
        dump_embedding_bundle_json(
            bundle_dir,
            seed=seed,
            p=p,
            tag=tag,
            tag_q=tag_q,
            class_string=class_string,
            freq_list=freq_list,
            colour_rule_name=_rule_obj_to_name(colour_rule),
            pcs=pcs,
            pca_var_ratio=pca_var_ratio,
            dmap=dmap,
            diff_eigvals=eigenvalues,
            a_vals=(a_vals if a_vals is not None else []),
            b_vals=(b_vals if b_vals is not None else []),
            store_colour_vectors=False,
        )

        print("PCA/Diffusion json written.")
        if len(freq_list) == 1 and (mat.shape[0] == p ** 2):
            _make_single_freq_phase_plots(mat, p, freq_list[0], save_dir, seed=seed, tag=tag)


def generate_pdf_plots_for_matrix(
    mat: np.ndarray,
    p: int,
    save_dir: str,
    *,
    seed: int | str = "",
    freq_list: list[int] | None = None,
    tag: str = "",
    tag_q: str = "",
    class_string: str = "",
    colour_rule=None,
    num_principal_components: int = 2,
    do_transposed: bool = False,
) -> None:
    _generate_pdf_plots_for_matrix(
        mat,
        p,
        save_dir,
        seed=seed,
        freq_list=freq_list,
        tag=tag,
        tag_q=tag_q,
        class_string=class_string,
        colour_rule=colour_rule,
        num_principal_components=num_principal_components,
    )

    if do_transposed:
        new_tag = f"{tag}_transposed" if tag else "transposed"
        _generate_pdf_plots_for_matrix(
            mat.T,
            p,
            save_dir,
            seed=seed,
            freq_list=freq_list,
            tag=new_tag,
            class_string=class_string,
            colour_rule=colour_rule,
            num_principal_components=num_principal_components,
        )


def _has_degenerate_or_tie(rows):
    for r in rows:
        if (
            (r.get("reason") == "degenerate")
            or r.get("deg_rot")
            or r.get("deg_ref")
            or r.get("tie_rot")
            or r.get("tie_ref")
        ):
            return True
    return False


def _any_family_degenerate(
    out_dir: str,
    f_abs: int,
    *,
    label: str = "pca",
    methods: tuple[str, ...] | None = ("longest", "shortest", "random"),
    seed: int | None = None,
) -> bool:
    def _paths_for(fam: str) -> list[str]:
        base = os.path.join(out_dir, f"{label.lower()}_stripe_summary_f{int(f_abs)}")
        cand = []
        if methods:
            for m in methods:
                if m == "random" and (seed is not None):
                    cand.append(f"{base}_{m}_seed{seed}_by_{fam}.json")
                cand.append(f"{base}_{m}_by_{fam}.json")
        cand.append(f"{base}_by_{fam}.json")
        seen, out = set(), []
        for pth in cand:
            if pth not in seen:
                seen.add(pth)
                out.append(pth)
        return out

    for fam in ("a", "b"):
        for path in _paths_for(fam):
            if not os.path.isfile(path):
                continue
            try:
                with open(path, "r") as fh:
                    data = json.load(fh)
                rows = data.get("rows", [])
                if "_has_degenerate_or_tie" in globals():
                    if _has_degenerate_or_tie(rows):
                        return True
                else:
                    for r in rows:
                        if r.get("deg_rot") or r.get("deg_ref") or r.get("tie_rot") or r.get("tie_ref"):
                            return True
            except Exception:
                continue
    return False


def run_pca_core(
    mat: np.ndarray,
    p: int,
    save_dir: str,
    *,
    seed: int | str = "",
    tag: str = "",
    tag_q: str = "full",
    max_components: int = 8,
) -> dict:
    n_samples = mat.shape[0]
    if tag_q == "full" and n_samples == (2 * p) * (2 * p):
        side = 2 * p
        idx = np.arange(side * side)
        a_vals = idx // side
        b_vals = idx % side
    elif n_samples == p * p:
        side = p
        idx = np.arange(p * p)
        a_vals = idx // p
        b_vals = idx % p
    else:
        raise ValueError(
            f"run_pca_core assumes p*p or (2p)*(2p) grid; got n={n_samples}, p={p}, tag_q={tag_q}"
        )

    num_components = min(mat.shape[1], max_components, n_samples - 1)
    pca = PCA(n_components=num_components, svd_solver="full")
    pcs = pca.fit_transform(mat)
    var_ratio = pca.explained_variance_ratio_.tolist()
    cum_ratio = np.cumsum(pca.explained_variance_ratio_).tolist()
    cumvar3_raw = float(np.sum(pca.explained_variance_ratio_[:3]))

    pca_root = os.path.join(save_dir, "pca_pdf_plots")
    base_2d_dir = os.path.join(pca_root, "2d", tag)
    os.makedirs(base_2d_dir, exist_ok=True)
    with open(os.path.join(base_2d_dir, "pca_variance.json"), "w") as fh:
        json.dump(
            {"seed": seed, "p": int(p), "tag": tag, "tag_q": tag_q, "var_ratio": var_ratio, "cum_ratio": cum_ratio},
            fh,
            indent=2,
        )

    return dict(
        pcs=pcs,
        pcs2=pcs[:, :2],
        var_ratio=var_ratio,
        cum_ratio=cum_ratio,
        cumvar3_raw=cumvar3_raw,
        a_vals=a_vals,
        b_vals=b_vals,
        pca_root=pca_root,
        base_2d_dir=base_2d_dir,
    )


def run_pca_and_stripes_no_plots(
    mat: np.ndarray,
    p: int,
    save_dir: str,
    *,
    seed: int | str = "",
    freq_list: list[int] | None = None,
    cluster_meta: dict = None,
    tag: str = "",
    tag_q: str = "full",
    num_principal_components: int = 2,
    s_mode: str = "anchor",
    model: str = "auto",
) -> None:
    core = run_pca_core(
        mat=mat,
        p=p,
        save_dir=save_dir,
        seed=seed,
        tag=tag,
        tag_q=tag_q,
        max_components=max(4, num_principal_components),
    )
    pcs = core["pcs"]
    pcs2 = core["pcs2"]
    a_vals = core["a_vals"]
    b_vals = core["b_vals"]
    cumvar3_raw = core["cumvar3_raw"]

    out_dir = os.path.join(save_dir, "analysis", tag, tag_q)
    os.makedirs(out_dir, exist_ok=True)
    run_and_save_stripe_analysis(
        XY=pcs2,
        a_vals=a_vals,
        b_vals=b_vals,
        p=p,
        f_list=sorted(freq_list or []),
        out_dir=out_dir,
        label="PCA",
        tag_q=tag_q,
        s_mode=s_mode,
        model=model,
        seed=0,
        cluster_meta=cluster_meta,
        anchor_methods=("longest", "shortest", "random"),
        cumvar4_tau=0.90,
        cumvar3_for_dimcheck=cumvar3_raw,
    )
    print("PCA+Stripe (no plots) done.")

    needs_preview = False
    valid_f = None
    for f in sorted(freq_list or []):
        f_abs = abs(int(f))
        if f_abs % p == 0:
            continue
        if _any_family_degenerate(out_dir, f_abs):
            needs_preview = True
            valid_f = f_abs
            break

    if needs_preview and pcs.shape[1] >= 3:
        preview_dir = os.path.join(out_dir, "quick_3d_preview")
        os.makedirs(preview_dir, exist_ok=True)

        col_a, caption_a, p_cbar_a, cs_a = colour_quad_a_only(a_vals, b_vals, p, valid_f, tag_q)
        _write_multiplot_3d(
            coords=pcs[:, :4],
            colour=np.asarray(col_a),
            ctitle=caption_a,
            out_path=os.path.join(preview_dir, "pca3_preview_by_a.pdf"),
            p=p,
            p_cbar=p_cbar_a,
            colorscale=cs_a,
            seed=seed,
            label="PC",
            tag=tag_q,
            f=valid_f,
            mult=True,
            write_pdf=False,
        )


def run_comp_geo_coset_pipeline(
    mat: np.ndarray,
    p: int,
    save_dir: str,
    *,
    seed: int | str = "",
    freq_list: list[int] | None = None,
    tag: str = "",
    tag_q: str = "full",
    label: str = "PCA",
    num_pca_dims: int = 4,
) -> None:
    n_samples = mat.shape[0]
    if not (n_samples == (2 * p) * (2 * p)):
        raise ValueError(f"run_comp_geo_coset_pipeline assumes (2p)*(2p) grid; got n={n_samples}, p={p}, tag_q={tag_q}")

    if tag_q != "full":
        print(f"[comp-geo] tag_q={tag_q}: skip coset collapse")
        return

    pca_root = os.path.join(save_dir, "pca_pdf_plots")
    coset_dir = os.path.join(pca_root, "coset_collapse", tag, tag_q)
    os.makedirs(coset_dir, exist_ok=True)

    for f in sorted(freq_list or []):
        f_abs = abs(int(f))
        if f_abs % p == 0:
            continue
        analysis.comp_geo.run_and_save_coset_collapse(
            embedding_weights=mat,
            p=p,
            f=f_abs,
            out_dir=coset_dir,
            label=label,
            tag_q=tag_q,
            num_pca_dims=num_pca_dims,
            seed=int(seed) if seed != "" else None,
        )
    print("Comp-geo coset collapse pipeline done.")


def generate_pca_information_scaling_experiment(
    mat: np.ndarray,
    p: int,
    save_dir: str,
    *,
    seed: int | str = "",
    freq_list: list[int] | None = None,
    tag: str = "",
) -> None:
    os.makedirs(save_dir, exist_ok=True)

    X = _sanitize_matrix(mat)
    n, d = X.shape
    n_comp = min(4, d, max(1, n - 1))
    coords, pca = _safe_pca_coords(X, n_comp)
    var_ratio = getattr(pca, "explained_variance_ratio_", np.zeros(n_comp))
    cum_var_ratio = np.cumsum(var_ratio).tolist()

    d2 = squareform(pdist(X, metric="euclidean")) ** 2
    eps = float(np.median(d2))
    if not np.isfinite(eps) or eps <= 0:
        pos = d2[d2 > 0]
        eps = float(pos.mean()) if pos.size else 1e-12
    A = np.exp(-d2 / eps)
    M = A / A.sum(axis=1, keepdims=True)

    eigvals, _ = eigh(M)
    eigvals = eigvals[::-1]
    nontrivial = eigvals[1: 1 + n_comp]
    total = float(np.sum(nontrivial))
    diff_ratios = (nontrivial / total) if total > 0 else np.zeros_like(nontrivial)
    cum_diff_ratio = np.cumsum(diff_ratios).tolist()

    info = {
        "seed": seed,
        "p": p,
        "num_pca_components": int(len(var_ratio)),
        "cumulative_pca_variance_ratio": cum_var_ratio,
        "num_diffusion_components": int(len(nontrivial)),
        "cumulative_diffusion_eigenvalue_ratio": cum_diff_ratio,
    }
    fname = f"pca_info_seed_{seed}" + (f"_{tag}" if tag else "") + ".json"
    out_path = os.path.join(save_dir, fname)
    with open(out_path, "w") as f:
        json.dump(info, f, indent=4)
    print(f"PCA & diffusion scaling info saved to {out_path}")


def _rule_obj_to_name(rule_fn) -> str:
    if rule_fn is None:
        return "none"
    mapping = {
        "colour_quad_mul_f": "mul_f",
        "colour_quad_mod_g": "mod_g",
        "colour_quad_a_only": "a_only",
        "colour_quad_b_only": "b_only",
        "colour_quad_mod_g_no_fb": "mod_g_no_fb",
    }
    name = getattr(rule_fn, "__name__", "")
    return mapping.get(name, name or "custom")


def _rule_name_to_obj(name: str):
    if name in (None, "", "none"):
        return None
    mapping = {
        "mul_f": colour_quad_mul_f,
        "mod_g": colour_quad_mod_g,
        "mod_g_no_fb": colour_quad_mod_g_no_fb,
        "a_only": colour_quad_a_only,
        "b_only": colour_quad_b_only,
        "colour_quad_mul_f": colour_quad_mul_f,
        "colour_quad_mod_g": colour_quad_mod_g,
        "colour_quad_a_only": colour_quad_a_only,
        "colour_quad_b_only": colour_quad_b_only,
        "colour_quad_mod_g_no_fb": colour_quad_mod_g_no_fb,
    }
    return mapping.get(name, None)

