import numpy as np
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
import networkx as nx

# ──────────────────────────────────────────────────────────────────────────────
# helpers & palette
# ──────────────────────────────────────────────────────────────────────────────
def circular_pos(n: int):
    return {k: (np.cos(2*np.pi*k/n), np.sin(2*np.pi*k/n)) for k in range(n)}

def circular_dist(a: int, b: int, n: int) -> int:
    d = abs(a-b) % n
    return min(d, n-d)

def cosets(n: int, s: int):
    g   = np.gcd(n, s)                       # number of cosets
    H   = { (k*s) % n for k in range(n//g) }
    rep = [r for r in range(g)]              # 0,1,…,g-1
    return [ { (h+r) % n for h in H } for r in rep ]

# Viridis_r → pick three brightness levels + darkest
VIRIDIS_R   = px.colors.sequential.Viridis_r
T           = {d: (d/3)*0.8 for d in range(3)}              # d = 0,1,2
COL         = {d: px.colors.sample_colorscale('Viridis_r', [T[d]])[0] for d in range(3)}
UNREACHABLE = VIRIDIS_R[-1]                                 # darkest

# ──────────────────────────────────────────────────────────────────────────────
# panel 1 (frequency 1)  – unchanged except for the darkest-d=3 vertex
# ──────────────────────────────────────────────────────────────────────────────
def panel_frequency_1(fig, row, col, *, n=6, s=1):
    pos   = circular_pos(n)
    theta = np.linspace(0, 2*np.pi, 400)

    # unit circle
    fig.add_trace(go.Scatter(x=np.cos(theta), y=np.sin(theta), mode='lines',
                             line=dict(color='black', width=2),
                             showlegend=False), row=row, col=col)

    # Cayley edges
    for k in range(n):
        a, b = k, (k + s) % n
        xa, ya = pos[a];  xb, yb = pos[b]
        fig.add_trace(go.Scatter(x=[xa, xb], y=[ya, yb], mode='lines',
                                 line=dict(color='grey', width=2),
                                 showlegend=False), row=row, col=col)

    # distances in that graph
    G    = nx.Graph([(k, (k+s) % n) for k in range(n)])
    dist = nx.single_source_shortest_path_length(G, 0)
    x0, y0 = pos[0]

    for k, (xk, yk) in pos.items():
        dk     = dist[k]
        colour = UNREACHABLE if dk == 3 else COL[dk]

        # vertex
        fig.add_trace(go.Scatter(x=[xk], y=[yk], mode='markers',
                                 marker=dict(size=11, color=colour),
                                 showlegend=False), row=row, col=col)
        
        # black labels (anticlockwise)
        text_pos = "middle right" if k == 0 else \
                "top center" if k in [1, 2] else \
                "middle left" if k == 3 else \
                "bottom center"   # for 4, 5

        label_offset = 0.15  # Adjust as needed

        angle = 2 * np.pi * k / n
        x_label = xk + label_offset * np.cos(angle)
        y_label = yk + label_offset * np.sin(angle)

        fig.add_trace(go.Scatter(
            x=[x_label], y=[y_label],
            mode="text",
            text=[str(k)],
            textfont=dict(size=24, color="black"),
            showlegend=False
        ), row=row, col=col)

        # dashed spoke & label
        if k != 0:
            fig.add_trace(go.Scatter(x=[x0, xk], y=[y0, yk], mode='lines',
                                     line=dict(color=colour, width=2, dash='dot'),
                                     showlegend=False), row=row, col=col)
            mx, my = x0 + 0.5*(xk-x0), y0 + 0.5*(yk-y0)
            fig.add_trace(go.Scatter(x=[mx], y=[my], mode='text',
                                     text=[str(dk)],
                                     textfont=dict(size=20, color=colour),
                                     showlegend=False), row=row, col=col)

    fig.update_xaxes(showgrid=False, zeroline=False,
                     range=[-1.35, 1.35], scaleanchor=f'y{col}', scaleratio=1,
                     row=row, col=col)
    fig.update_yaxes(showgrid=False, zeroline=False,
                     range=[-1.35, 1.35], row=row, col=col)

# ──────────────────────────────────────────────────────────────────────────────
# first-row panels for frequencies 2 & 3 (“plain” style)
# ──────────────────────────────────────────────────────────────────────────────
def panel_plain(fig, row, col, *, n=6, s=2):
    pos        = circular_pos(n)
    coset_ls   = cosets(n, s)
    coset0     = coset_ls[0]

    theta = np.linspace(0, 2*np.pi, 400)
    fig.add_trace(go.Scatter(x=np.cos(theta), y=np.sin(theta), mode='lines',
                             line=dict(color='black', width=2),
                             showlegend=False), row=row, col=col)

    # ── solid polygons for each coset + vertices ────────────────────────────
    for idx, C in enumerate(coset_ls):
        colour   = COL[0] if idx == 0 else UNREACHABLE
        vertices = sorted(C)
        m        = len(vertices)

        # intra-coset edges
        for i in range(m):
            a, b = vertices[i], vertices[(i+1) % m]
            xa, ya = pos[a];  xb, yb = pos[b]
            fig.add_trace(go.Scatter(x=[xa, xb], y=[ya, yb], mode='lines',
                                     line=dict(color=colour, width=2),
                                     showlegend=False), row=row, col=col)

        # vertices
        for v in vertices:
            xv, yv = pos[v]
            fig.add_trace(go.Scatter(x=[xv], y=[yv], mode='markers',
                                     marker=dict(size=11, color=colour),
                                     showlegend=False), row=row, col=col)
    label_offset = 0.15
    for k in range(n):
        xk, yk = pos[k]
        angle = 2 * np.pi * k / n
        x_label = xk + label_offset * np.cos(angle)
        y_label = yk + label_offset * np.sin(angle)

        fig.add_trace(go.Scatter(
            x=[x_label], y=[y_label],
            mode="text",
            text=[str(k)],
            textfont=dict(size=24, color="black"),
            showlegend=False
        ), row=row, col=col)

    # ── dashed spokes & “1” labels to *every* vertex outside coset₀ ──────────
    x0, y0 = pos[0]
    for k in range(1, n):
        if k in coset0:            # same coset ⇒ no dashed spoke
            continue
        xk, yk = pos[k]
        fig.add_trace(go.Scatter(x=[x0, xk], y=[y0, yk], mode='lines',
                                 line=dict(color=UNREACHABLE, width=2,
                                           dash='dot'),
                                 showlegend=False), row=row, col=col)
        text_pos = "middle right" if k == 0 else \
           "top center" if k in [1, 2] else \
           "middle left" if k == 3 else \
           "bottom center"   # for 4, 5

        label_offset = 0.15  # Adjust as needed

        angle = 2 * np.pi * k / n
        x_label = xk + label_offset * np.cos(angle)
        y_label = yk + label_offset * np.sin(angle)

        # fig.add_trace(go.Scatter(
        #     x=[x_label], y=[y_label],
        #     mode="text",
        #     text=[str(k)],
        #     textfont=dict(size=20, color="black"),
        #     showlegend=False
        # ), row=row, col=col)
        mx, my = x0 + 0.5*(xk-x0), y0 + 0.5*(yk-y0)
        fig.add_trace(go.Scatter(x=[mx], y=[my], mode='text',
                                 text=['1'],
                                 textfont=dict(size=20, color=UNREACHABLE),
                                 showlegend=False), row=row, col=col)

    fig.update_xaxes(showgrid=False, zeroline=False,
                     range=[-1.35, 1.35], scaleanchor=f'y{col}', scaleratio=1,
                     row=row, col=col)
    fig.update_yaxes(showgrid=False, zeroline=False,
                     range=[-1.35, 1.35], row=row, col=col)

# ──────────────────────────────────────────────────────────────────────────────
# second-row panels – “sub-circle” style
# ──────────────────────────────────────────────────────────────────────────────
def choose_rep(C, *, idx, pos):
    """pick a representative so that coset₀ is at (1,0) and the rest sit left"""
    if idx == 0:
        return 0                         # vertex 0   → (1,0)
    left_side = [v for v in C if pos[v][0] <= 0]
    return min(left_side, key=lambda v: pos[v][0]) if left_side else min(C)

def panel_subcircles(fig, row, col, *, n=6, s=2, y_offset=0.1):
    pos        = circular_pos(n)
    coset_ls   = cosets(n, s)
    reps       = [choose_rep(C, idx=i, pos=pos) for i, C in enumerate(coset_ls)]
    small_r    = 1/3

    # large circle
    t_big = np.linspace(0, 2*np.pi, 400)
    fig.add_trace(go.Scatter(x=np.cos(t_big), y=np.sin(t_big), mode='lines',
                             line=dict(color='black', width=2),
                             showlegend=False), row=row, col=col)

    # draw every coset as a small circle
    for idx, (C, rep) in enumerate(zip(coset_ls, reps)):
        cx, cy = pos[rep]
        colour = COL[0] if idx == 0 else UNREACHABLE

        # outline small circle
        th = np.linspace(0, 2*np.pi, 200)
        xs = cx + small_r * np.cos(th)
        ys = cy + small_r * np.sin(th)
        fig.add_trace(go.Scatter(x=xs, y=ys, mode='lines',
                                 line=dict(color=colour, width=1),
                                 showlegend=False), row=row, col=col)

        # coset points
        m   = len(C)
        ang = 2*np.pi*np.arange(m)/m
        xp  = cx + small_r*np.cos(ang)
        yp  = cy + small_r*np.sin(ang)
        fig.add_trace(go.Scatter(x=xp, y=yp, mode='markers',
                                 marker=dict(color=colour, size=9),
                                 showlegend=False), row=row, col=col)

        # centre dot
        fig.add_trace(go.Scatter(x=[cx], y=[cy], mode='markers',
                                 marker=dict(color=colour, size=12),
                                 showlegend=False), row=row, col=col)

        # dashed radii
        xl, yl = [], []
        for xq, yq in zip(xp, yp):
            xl += [cx, xq, None]
            yl += [cy, yq, None]
        fig.add_trace(go.Scatter(x=xl, y=yl, mode='lines',
                                 line=dict(color=colour, width=2, dash='dash'),
                                 showlegend=False), row=row, col=col)

        # -- NEW: label each small circle with its coset set above it
        label = "{" + ",".join(str(v) for v in sorted(C)) + "}"
        fig.add_trace(go.Scatter(
            x=[cx],
            y=[cy + small_r + y_offset],  # position above the subcircle
            mode="text",
            text=[label],
            textfont=dict(size=28, color=colour),
            showlegend=False
        ), row=row, col=col)

    # dashed centre-to-centre edges & “1” labels unchanged
    cx0, cy0 = pos[0]
    for rep in reps[1:]:
        cx, cy = pos[rep]
        fig.add_trace(go.Scatter(x=[cx0, cx], y=[cy0, cy], mode='lines',
                                 line=dict(color=UNREACHABLE, width=2, dash='dot'),
                                 showlegend=False), row=row, col=col)
        mx, my = cx0 + 0.5*(cx-cx0), cy0 + 0.5*(cy-cy0)
        fig.add_trace(go.Scatter(x=[mx], y=[my], mode='text',
                                 text=['1'],
                                 textfont=dict(size=20, color=UNREACHABLE),
                                 showlegend=False), row=row, col=col)

    # adjust axes ranges to accommodate label offset
    fig.update_xaxes(range=[-1.45, 1.45], row=row, col=col)
    fig.update_yaxes(range=[-1.45, 1.45], row=row, col=col)


# ──────────────────────────────────────────────────────────────────────────────
# build the 2 × 3 figure
# ──────────────────────────────────────────────────────────────────────────────
def make_full_plot():
    fig = make_subplots(
        rows=2, cols=3,
        subplot_titles=[
            "f=1: cos(1(2(pi(a)/6)): {0}, {1}, {2}, {3}, {4}, {5}", "f=3: cos(3(2(pi(a)/6)) {0,2,4}, {1,3,5}", "f=2: cos(2(2(pi(a)/6)) {0,3}, {1,4}, {2,5}",
            "f=1: cos(1(2(pi(a)/6)) (sub-circles)",
            "f=3: cos(3(2(pi(a)/6)) (sub-circles)",
            "f=2: cos(2(2(pi(a)/6)) (sub-circles)"
        ],
        horizontal_spacing=0.09, vertical_spacing=0.14
    )

    # row 1
    panel_frequency_1(fig, 1, 1, n=6, s=1)
    panel_plain     (fig, 1, 2, n=6, s=2)
    panel_plain     (fig, 1, 3, n=6, s=3)

    # row 2
    panel_frequency_1(fig, 2, 1, n=6, s=1)
    panel_subcircles(fig, 2, 2, n=6, s=2)
    panel_subcircles(fig, 2, 3, n=6, s=3)

    fig.update_layout(
        height=1200, width=1800,
        margin=dict(l=50, r=50, t=80, b=20),
        plot_bgcolor='white', paper_bgcolor='white',
        title="C₆ – generators f=1, f=3, f=2; step size d=1 in all three cases"
    )
    fig.update_layout(
        font=dict(size=20),        # default for everything
        title_font_size=20         # main title (just to be safe)
    )
    for ann in fig.layout.annotations:
        ann.font.size = 20

    for i in range(1, 4):  # cols 1 to 3
        for j in range(1, 3):  # rows 1 and 2
            fig.update_xaxes(showticklabels=False, ticks='', row=j, col=i)
            fig.update_yaxes(showticklabels=False, ticks='', row=j, col=i)

    return fig

# ──────────────────────────────────────────────────────────────────────────────
# create & save
# ──────────────────────────────────────────────────────────────────────────────
import plotly.io as pio
pio.kaleido.scope.mathjax = None
# pio.renderers.default = "kaleido"

fig = make_full_plot()
file_path = "appendix_coset_plots.pdf"
fig.write_image(file_path, format="pdf")    # will auto-detect PDF from extension

print(f"Saved static pdf to {file_path}")

def cosine_plot_with_custom_colors():
    fig = make_subplots(
        rows=1, cols=3,
        subplot_titles=[r"$f=1$", r"$f=2$", r"$f=3$"],
        horizontal_spacing=0.1
    )

    a_vals = np.linspace(0, 6, 300)
    a_ints = np.arange(0, 6)
    freqs = [1, 2, 3]

    # Panel-specific color mappings
    plasma_r_colors = px.colors.sequential.Plasma_r
    panel1_colors = px.colors.sample_colorscale("Plasma_r", [i/5 for i in range(6)])

    panel2_colors = {
        0: "purple", 3: "purple",
        1: "teal",   4: "teal",
        2: "orange", 5: "orange"
    }

    panel3_colors = {
        0: "teal", 2: "teal", 4: "teal",
        1: "orange", 3: "orange", 5: "orange"
    }

    for i, f in enumerate(freqs, start=1):
        y_vals = np.cos((f * 2 * np.pi * a_vals) / 6)
        y_ints = np.cos((f * 2 * np.pi * a_ints) / 6)

        # Assign colors based on panel
        if i == 1:
            point_colors = panel1_colors
        elif i == 2:
            point_colors = [panel2_colors[a] for a in a_ints]
        elif i == 3:
            point_colors = [panel3_colors[a] for a in a_ints]

        # Continuous cosine curve
        fig.add_trace(go.Scatter(x=a_vals, y=y_vals, mode="lines",
                                 line=dict(color="black", width=2),
                                 showlegend=False),
                      row=1, col=i)

        # Colored integer points
        fig.add_trace(go.Scatter(
            x=a_ints,
            y=y_ints,
            mode="markers+text",
            marker=dict(size=10, color=point_colors),
            text=["" for _ in a_ints],
            textposition="top center",
            showlegend=False
        ), row=1, col=i)

        fig.update_xaxes(title_text="a", dtick=1, range=[-0.2, 6.2], row=1, col=i)
        fig.update_yaxes(title_text="cos value", range=[-1.1, 1.1], row=1, col=i)

    fig.update_layout(
        title="Cosines for $f=1,2,3$ on $C_6$",
        height=600, width=1200,
        plot_bgcolor="white", paper_bgcolor="white",
        margin=dict(t=50, l=40, r=40, b=50),
        font=dict(size=20),
        title_font_size=20
    )

    for ann in fig.layout.annotations:
        ann.font.size = 20

    return fig

# Export to PDF
cos_fig = cosine_plot_with_custom_colors()
cos_fig.write_image("appendix_cosines.pdf", format="pdf")
print("Saved cosine plot with custom coloring to 'appendix_cosines.pdf'")
