#!/usr/bin/env python3
"""
Six-panel summary + stand-alone ReLU’d circle graph

Outputs
    final_6plots_cosets.pdf      – composite 6-column figure
    relud_circle_graph.pdf       – single-panel ReLU circle graph
"""

# ──────────────────────────────────────────────
# Imports
# ──────────────────────────────────────────────
import numpy as np
import plotly.graph_objects as go
import plotly.subplots as sp
import plotly.express as px
from make_cosets_mod_66_mod_6 import get_coset_plot   # keeps your existing helper
                                                       # (must be on PYTHONPATH)

# ──────────────────────────────────────────────
# Constants & derived data
# ──────────────────────────────────────────────
modulus_66 = 66
modulus_67 = 67
modulus_89 = 89

a66 = np.arange(modulus_66)
a67 = np.arange(modulus_67)

# full neuron data (unused here but kept for completeness)
x_data_latest = np.arange(modulus_89)
y_data_latest = np.array([])

vertical_shift = 0.05          # label offset used across several panels

# The 6 coset colours (as in your original script) – Viridis_r plus black slots
viridis_r = px.colors.sequential.Viridis_r
black_colour = viridis_r[-1]

coset_distance = {0: 0, 1: 1, 5: 1, 2: 2, 4: 2, 3: 3}
t = {d: (d / 3) * 0.8 for d in range(4)}               # 80 % of the Viridis scale
coset_colors_66 = [
    px.colors.sample_colorscale('Viridis_r', [t[coset_distance[k]]])[0]
    for k in range(6)
]
for k in (2, 4, 3):                                    # force coset-distances 2 & 3
    coset_colors_66[k] = black_colour                  # to pure black

# ─── Distance map modulo 67 (generator 11) ───
coset0 = np.arange(0, modulus_67, 6)
coset1 = np.arange(1, modulus_67, 6)[::-1]
coset5 = np.arange(5, modulus_67, 6)
coset2 = np.arange(2, modulus_67, 6)[::-1]
coset4 = np.arange(4, modulus_67, 6)
coset3 = np.arange(3, modulus_67, 6)[::-1]

distance_map67 = {}
for i, a in enumerate(coset0): distance_map67[a] = i
for i, a in enumerate(coset1): distance_map67[a] = i + 1
for i, a in enumerate(coset5): distance_map67[a] = 12 + i
for i, a in enumerate(coset2): distance_map67[a] = 12 + i
for i, a in enumerate(coset4): distance_map67[a] = 23 + i
for i, a in enumerate(coset3): distance_map67[a] = 23 + i
max_dist67 = 33                                    # keep full 0–33 spread

# y-values for the “true” 67-periodic graph  cos(11·2π·a / 67)
y_true_67 = np.cos(2 * np.pi * 11 * a67 / modulus_67)

# colour helper for 67-vertex circle graphs
def colour_for_idx_67(idx: int) -> str:
    if y_true_67[idx] <= 0:
        return black_colour
    t = (distance_map67[idx] / max_dist67) * 0.8
    return px.colors.sample_colorscale('Viridis_r', [t])[0]

colour67 = [colour_for_idx_67(i) for i in a67]      # 67-entry list
colour66 = colour67[:modulus_66]                    # indices 0…65

# convenience mask for the ReLU subset (cosets 0, 1, 5)
relu_mask_67 = y_true_67 > 0

# ──────────────────────────────────────────────
# Geometry for each panel
# ──────────────────────────────────────────────
# Panel 0 (column 1): Cayley coset plot mod 66
coset_fig_mod66 = get_coset_plot()

# Panel 1 (column 2):  cos(11·2π·a / 66)
x1 = a66
y1 = np.cos(2 * np.pi * 11 * x1 / modulus_66)

# Panel 2 (column 3):  remap (11·a) mod 66
x2 = (11 * a66) % modulus_66
y2 = np.cos(2 * np.pi * 11 * x2 / modulus_66)

# Panel 3 (column 4):  full 67-gon circle graph
theta67 = 11 * 2 * np.pi * a67 / modulus_67
x67 = np.cos(theta67)
y67 = np.sin(theta67)

# Panel 4 (column 5):  cos(11·2π·a / 67)
x3 = a67
y3 = y_true_67

# Panel 5 (column 6):  remapped cos(11·2π·a / 67)
x4 = (11 * a67) % modulus_67
y4 = y3
order = np.argsort(x4)
x4_line, y4_line = x4[order], y4[order]

# ──────────────────────────────────────────────
# Helper to add the two Viridis-coloured line+marker panels (col 2 & 3)
# ──────────────────────────────────────────────
def add_coset_panel(fig, col, x_vals, y_vals, base_indices):
    """Add a pair of traces – grey guide line + coset-coloured markers."""
    fig.add_trace(
        go.Scatter(x=x_vals, y=y_vals, mode='lines',
                   line=dict(color='rgba(0,0,0,0.5)', width=1),
                   showlegend=False),
        row=1, col=col
    )
    fig.add_trace(
        go.Scatter(x=x_vals, y=y_vals, mode='markers',
                   marker=dict(color=[coset_colors_66[i % 6] for i in base_indices],
                               size=9),
                   showlegend=False),
        row=1, col=col
    )

# ──────────────────────────────────────────────
# ❶  Build the 1 × 6 composite figure
# ──────────────────────────────────────────────
fig6 = sp.make_subplots(
    rows=1, cols=6, horizontal_spacing=0.029
)

# Define titles and their relative x positions
titles = [
    "cosets of cos(11·2π·a / 66)",
    "cos(11·2π·a / 66)",
    "remapped cos(11·2π·a / 66)",
    "circle graph 11 a mod 67",
    "cos(11·2π·a / 67)",
    "remapped cos(11·2π·a / 67)"
]

# Estimate paper x-positions for each column's center (tweak if needed)
x_centers = [0.012, 0.20, 0.42, 0.59, 0.80, 0.995]

# Add transparent-title annotations
for title, x in zip(titles, x_centers):
    fig6.add_annotation(
        text=title,
        xref="paper", yref="paper",
        x=x, y=1.13,  # y controls vertical placement above plots
        showarrow=False,
        font=dict(size=22, color="black"),
        bgcolor='rgba(0,0,0,0)',  # Transparent background
        align='center'
    )

# Column 1 – Cayley coset circle (re-use ready-made figure)
for tr in coset_fig_mod66.data:
    fig6.add_trace(tr, row=1, col=1)
fig6.update_xaxes(range=[-1.5, 1.5], scaleanchor='y', scaleratio=1,
                  title_text="cos(11(2π)a/66)", title_standoff=3, row=1, col=1)


# Column 2 – cos(11·2π·a / 66)
add_coset_panel(fig6, 2, x1, y1, a66)
fig6.update_xaxes(range=[-1.8, 75], title_standoff=0, row=1, col=2)
fig6.update_yaxes(range=[-1.1, 1.28], title_standoff=0, row=1, col=2)

# Column 3 – remapped cos(11·2π·a / 66)
# split the line whenever base %6 == 5 (prevents cross-coset joins)
x2_line, y2_line = [], []
for base_i, xi, yi in zip(a66, x2, y2):
    x2_line.append(xi)
    y2_line.append(yi)
    if base_i % 6 == 5:
        x2_line.append(None)
        y2_line.append(None)
fig6.add_trace(
    go.Scatter(x=x2_line, y=y2_line, mode='lines',
               line=dict(color='rgba(0,0,0,0.5)', width=1),
               showlegend=False),
    row=1, col=3
)
fig6.add_trace(
    go.Scatter(x=x2, y=y2, mode='markers',
               marker=dict(color=[coset_colors_66[i % 6] for i in a66], size=9),
               showlegend=False),
    row=1, col=3
)

# add distance labels exactly as in the original code (columns 2 & 3)
for panel_col, (x_vals, y_vals) in zip((2, 3), ((x1, y1), (x2, y2))):
    for k, dist in coset_distance.items():
        coset_k = np.arange(k, modulus_66, 6)
        # Determine if label should be below
        if k % 6 in (5, 4):
            y_offset = -vertical_shift
            text_pos = 'bottom center'
        else:
            y_offset = vertical_shift
            text_pos = 'top center'

        fig6.add_trace(
            go.Scatter(
                x=x_vals[coset_k],
                y=y_vals[coset_k] + y_offset,
                mode='text',
                text=[dist] * len(coset_k),
                textposition=text_pos,
                textfont=dict(color=[coset_colors_66[k]] * len(coset_k)),
                showlegend=False
            ),
            row=1, col=panel_col
        )
fig6.update_xaxes(range=[-1.8, 75], title_standoff=0, row=1, col=3)
fig6.update_yaxes(range=[-1.1, 1.28], title_standoff=0, row=1, col=3)
# Column 4 – 67-gon circle graph (all residues 0…66)
# grey outline
fig6.add_trace(
    go.Scatter(x=np.append(x67, x67[0]), y=np.append(y67, y67[0]),
               mode='lines',
               line=dict(color='rgba(0,0,0,0.3)', width=1),
               showlegend=False),
    row=1, col=4
)
# coloured vertices
fig6.add_trace(
    go.Scatter(x=x67, y=y67, mode='markers',
               marker=dict(color=colour67, size=9),
               showlegend=False),
    row=1, col=4
)

# coloured dotted connectors + distance labels in two rings
r_outer, r_inner = 1.25, 1.15
remapped = (11 * a67) % modulus_67
r_lbl4 = np.where((remapped % 6) % 2 == 1, r_inner, r_outer)
x_lbl4 = r_lbl4 * np.cos(theta67)
y_lbl4 = r_lbl4 * np.sin(theta67)
for i, a in enumerate(a67):
    fig6.add_trace(
        go.Scatter(x=[x67[i], x_lbl4[i]], y=[y67[i], y_lbl4[i]],
                   mode='lines',
                   line=dict(dash='dot', width=0.5, color=colour67[a]),
                   showlegend=False),
        row=1, col=4
    )
fig6.add_trace(
    go.Scatter(x=x_lbl4, y=y_lbl4, mode='text',
               text=[distance_map67[a] for a in a67],
               textfont=dict(color=colour67, size=14),
               showlegend=False),
    row=1, col=4
)

# Column 5 – cos(11·2π·a / 67)    (guide line + coloured markers)
fig6.add_trace(
    go.Scatter(x=x3, y=y3, mode='lines',
               line=dict(color='rgba(0,0,0,0.5)', width=1),
               showlegend=False),
    row=1, col=5
)
fig6.add_trace(
    go.Scatter(x=x3, y=y3, mode='markers',
               marker=dict(color=colour67, size=9),
               showlegend=False),
    row=1, col=5
)
# original distance labels by coset (0,1,5,2 order)
for coset, lbl_start in (
        (np.arange(0, modulus_67, 6), 0),
        (np.arange(1, modulus_67, 6)[::-1], 1),
        (np.arange(5, modulus_67, 6), 12),
        (np.arange(2, modulus_67, 6)[::-1], 12)):
    
    labels = list(range(lbl_start, lbl_start + len(coset)))
    
    # Only include labels ≤ 16
    filtered = [(a, l) for a, l in zip(coset, labels) if l <= 16]
    
    if filtered:
        for a, lbl in filtered:
            y_offset = -vertical_shift if lbl in (6, 7, 8, 12, 13, 15) else vertical_shift
            text_pos = 'bottom center' if lbl in (6, 7, 8, 12, 13, 15) else 'top center'
            fig6.add_trace(
                go.Scatter(
                    x=[a],
                    y=[y3[a] + y_offset],
                    mode='text',
                    text=[lbl],
                    textposition=text_pos,
                    textfont=dict(color=[colour67[a]]),
                    showlegend=False
                ),
                row=1, col=5
            )

fig6.update_xaxes(range=[-2.8, 72], title_standoff=0, row=1, col=5)
fig6.update_yaxes(range=[-1.1, 1.25], title_standoff=0, row=1, col=5)
# Column 6 – remapped cos(11·2π·a / 67)
fig6.add_trace(
    go.Scatter(x=x4_line, y=y4_line, mode='lines',
               line=dict(color='rgba(0,0,0,0.5)', width=1),
               showlegend=False),
    row=1, col=6
)
fig6.add_trace(
    go.Scatter(x=x4, y=y4, mode='markers',
               marker=dict(color=colour67, size=9),
               showlegend=False),
    row=1, col=6
)
# ± label mapping (exact replica of the original)
label_to_idx = {}
for i, a in enumerate(coset0): label_to_idx[i] = a
for i, a in enumerate(coset1): label_to_idx[-(i + 1)] = a
for i, a in enumerate(coset5): label_to_idx[12 + i] = a
for i, a in enumerate(coset2): label_to_idx[-(12 + i)] = a
neg_lbls = list(range(0, -18, -2))
pos_lbls = list(range(1, 17, 2))

neg_x = [x4[label_to_idx[l]] for l in neg_lbls]
neg_y = [y4[label_to_idx[l]] for l in neg_lbls]
neg_clr = [colour67[label_to_idx[l]] for l in neg_lbls]
neg_pos = ['top right' if i % 2 == 0 else 'bottom left'
           for i in range(len(neg_lbls))]
neg_y_off = [y + vertical_shift if p.startswith('top') else y - vertical_shift
             for y, p in zip(neg_y, neg_pos)]

fig6.add_trace(
    go.Scatter(x=neg_x, y=neg_y_off, mode='text',
               text=[abs(l) for l in neg_lbls],
               textposition=neg_pos,
               textfont=dict(color=neg_clr),
               showlegend=False),
    row=1, col=6
)

pos_x = [x4[label_to_idx[l]] for l in pos_lbls]
pos_y = [y4[label_to_idx[l]] for l in pos_lbls]
pos_clr = [colour67[label_to_idx[l]] for l in pos_lbls]
pos_pos = ['top left' if i % 2 == 0 else 'bottom right'
           for i in range(len(pos_lbls))]
pos_y_off = [y + vertical_shift if p.startswith('top') else y - vertical_shift
             for y, p in zip(pos_y, pos_pos)]

fig6.add_trace(
    go.Scatter(x=pos_x, y=pos_y_off, mode='text',
               text=pos_lbls,
               textposition=pos_pos,
               textfont=dict(color=pos_clr),
               showlegend=False),
    row=1, col=6
)

# ──────────────────────────────────────────────
# Axes, grid & overall layout tweaks
# ──────────────────────────────────────────────
# grid on all but column 1 (circle)
for c in range(2, 7):
    fig6.update_xaxes(showgrid=True, gridcolor='rgba(128,128,128,0.5)',
                      griddash='dot', zeroline=False, row=1, col=c)
    fig6.update_yaxes(showgrid=True, gridcolor='rgba(128,128,128,0.5)',
                      griddash='dot', zeroline=False, row=1, col=c)

# shared axis labels for the XY-style panels
for c in (2, 3, 5, 6):
    fig6.update_xaxes(title_text="a", title_standoff=3, row=1, col=c)
    fig6.update_yaxes(title_text="activation", title_standoff=0, row=1, col=c)

# circle graph axes (columns 4)
fig6.update_xaxes(title_text="cos(11(2π)a/67", title_standoff=3, range=[-1.5, 1.5],
                  scaleanchor='y', scaleratio=1, row=1, col=4)

fig6.update_yaxes(title_text="sin(11(2π)a/67", title_standoff=0, range=[-1.5, 1.5],
                  scaleanchor='x', scaleratio=1, row=1, col=4)

fig6.update_layout(
    font=dict(size=26),
    plot_bgcolor='white', paper_bgcolor='white',
    width=2400, height=425,
    margin=dict(l=3, r=0, b=10, t=50),
    showlegend=False
)
fig6.update_yaxes(range=[-1.1, 1.25], title_standoff=0, row=1, col=6)

# ──────────────────────────────────────────────
# ❷  Stand-alone ReLU circle graph
# ──────────────────────────────────────────────
fig_relu = go.Figure()
# grey outline of the 67-gon
fig_relu.add_trace(
    go.Scatter(x=np.append(x67, x67[0]), y=np.append(y67, y67[0]),
               mode='lines',
               line=dict(color='rgba(0,0,0,0.3)', width=1),
               showlegend=False)
)
# subset of points with positive activation
fig_relu.add_trace(
    go.Scatter(x=x67[relu_mask_67], y=y67[relu_mask_67], mode='markers',
               marker=dict(color=[colour67[i] for i in a67[relu_mask_67]],
                           size=9),
               showlegend=False)
)
# dotted connectors + distance labels (same geometry as before)
idxs = np.where(relu_mask_67)[0]
r_lbl = np.where((idxs % 6) % 2 == 1, r_inner, r_outer)
x_lbl = r_lbl * np.cos(theta67[idxs])
y_lbl = r_lbl * np.sin(theta67[idxs])

for j, idx in enumerate(idxs):
    fig_relu.add_trace(
        go.Scatter(x=[x67[idx], x_lbl[j]], y=[y67[idx], y_lbl[j]],
                   mode='lines',
                   line=dict(dash='dot', width=0.5, color=colour67[idx]),
                   showlegend=False)
    )
fig_relu.add_trace(
    go.Scatter(x=x_lbl, y=y_lbl, mode='text',
               text=[distance_map67[i] for i in idxs],
               textfont=dict(color=[colour67[i] for i in idxs], size=12),
               showlegend=False)
)

fig_relu.update_xaxes(title_text="cos(11(2pi)a/67)", title_standoff=3, range=[-1.5, 1.5],
                      scaleanchor='y', scaleratio=1,
                      showgrid=True, gridcolor='rgba(128,128,128,0.5)',
                      griddash='dot', zeroline=False)
fig_relu.update_yaxes(title_text="sin(11(2pi)a/67)", title_standoff=0, range=[-1.5, 1.5],
                      showgrid=True, gridcolor='rgba(128,128,128,0.5)',
                      griddash='dot', zeroline=False)

fig_relu.update_layout(
    title_text="ReLU’d circle graph mod 67",
    plot_bgcolor='white', paper_bgcolor='white',
    width=500, height=500,
    margin=dict(l=10, r=10, b=10, t=10)
)

# ──────────────────────────────────────────────
# ❸  Write the two PDF files
# ──────────────────────────────────────────────
for col in range(1, 7):
    fig6.update_xaxes(
        showgrid=True,
        gridcolor='rgba(128,128,128,0.5)',
        griddash='dot',
        zeroline=False,
        title_standoff=3,
        row=1, col=col
    )
    fig6.update_yaxes(
        showgrid=True,
        gridcolor='rgba(128,128,128,0.5)',
        griddash='dot',
        zeroline=False,
        title_standoff=0,
        row=1, col=col
    )
fig6.update_yaxes(range=[-1.5, 1.5], title_text="sin(11(2π)a/66)", title_standoff=5, row=1, col=1)
fig6.write_image("final_6plots_cosets.pdf", format="pdf", engine="kaleido")
fig_relu.write_image("relud_circle_graph.pdf",   format="pdf", engine="kaleido")
