import numpy as np
import plotly.graph_objects as go
import plotly.express as px
import plotly.io as pio

pio.kaleido.scope.mathjax = None  # or False

def get_coset_plot():
    # ──────────────────────────────────────────────
    # Constants & color setup
    # ──────────────────────────────────────────────
    modulus = 66
    n_cosets = 6
    points_per_coset = modulus // n_cosets  # = 11
    coset_distance = {0: 0,           # a ≡ 0  (brightest)
                      1: 1, 5: 1,     # a ≡ 1,5 (mid-bright)
                      2: 2, 4: 2,     # a ≡ 2,4 (dark / “black”)
                      3: 3}           # a ≡ 3   (dark / “black”)
    
    t = {d: (d / 3) * 0.8 for d in range(4)}    # 0→0.0, 1→0.266…, 2→0.533…, 3→0.8
    viridis_r = px.colors.sequential.Viridis_r
    black_colour = viridis_r[-1] 

    # Viridis palette → pick 6 distinct colors
    coset_colors = [
        px.colors.sample_colorscale('Viridis_r', [t[coset_distance[k]]])[0]
        for k in range(n_cosets)
    ]
    for k in (2, 4, 3):            # d = 2 or 3
        coset_colors[k] = black_colour

    # ──────────────────────────────────────────────
    # Compute the 6 coset‐center points on the unit circle
    # ──────────────────────────────────────────────
    centers = [
        (np.cos(2 * np.pi * k / n_cosets), np.sin(2 * np.pi * k / n_cosets))
        for k in range(n_cosets)
    ]

    # ──────────────────────────────────────────────
    # Main unit circle (background)
    # ──────────────────────────────────────────────
    theta_main = np.linspace(0, 2 * np.pi, 400)
    x_main = np.cos(theta_main)
    y_main = np.sin(theta_main)

    # ──────────────────────────────────────────────
    # Build the figure
    # ──────────────────────────────────────────────
    fig = go.Figure()

    # Add the large unit circle
    fig.add_trace(
        go.Scatter(
            x=x_main, y=y_main,
            mode='lines',
            line=dict(color='black', width=2),
            showlegend=False
        )
    )

    # ──────────────────────────────────────────────
    # Small coset‐circles: radius = 1/3, centered on each coset direction,
    # plus mapping dots and dashed radii
    # ──────────────────────────────────────────────
    small_radius = 1 / 3

    for k, (cx, cy) in enumerate(centers):
        # draw the small circle outline
        theta_small = np.linspace(0, 2 * np.pi, 200)
        xs = cx + small_radius * np.cos(theta_small)
        ys = cy + small_radius * np.sin(theta_small)
        fig.add_trace(
            go.Scatter(
                x=xs, y=ys, mode='lines',
                line=dict(color=coset_colors[k], width=1),
                showlegend=False
            )
        )

        # place the 11 coset points
        j = np.arange(points_per_coset)
        angles_pts = 2 * np.pi * j / points_per_coset
        x_pts = cx + small_radius * np.cos(angles_pts)
        y_pts = cy + small_radius * np.sin(angles_pts)
        fig.add_trace(
            go.Scatter(
                x=x_pts, y=y_pts, mode='markers',
                marker=dict(color=coset_colors[k], size=9),
                showlegend=False
            )
        )

        # add the degenerative‐mapping center dot on the big circle
        fig.add_trace(
            go.Scatter(
                x=[cx], y=[cy], mode='markers',
                marker=dict(color=coset_colors[k], size=12, symbol='circle'),
                showlegend=False
            )
        )

        # draw dashed radii from center to each small‐circle point
        x_lines, y_lines = [], []
        for x_pt, y_pt in zip(x_pts, y_pts):
            x_lines += [cx, x_pt, None]
            y_lines += [cy, y_pt, None]
        fig.add_trace(
            go.Scatter(
                x=x_lines, y=y_lines, mode='lines',
                line=dict(color=coset_colors[k], width=2, dash='dash'),
                name=f"mod 6 = {k}",
                showlegend=True
            )
        )

     # ──────────────────────────────────────────────
    # Add labeled edges from coset 0's center to other coset centers
    # ──────────────────────────────────────────────
    edges = {'1': [1, 5], '2': [2, 4], '3': [3]}
    x0, y0 = centers[0]

    for label, targets in edges.items():
        for t in targets:
            xt, yt = centers[t]
            # draw the dashed‐dotted edge
            fig.add_trace(go.Scatter(
                x=[x0, xt], y=[y0, yt],
                mode='lines',
                line=dict(color=coset_colors[t], width=2, dash='dot'),
                opacity=0.8,
                showlegend=False
            ))
            # choose fraction along the segment
            if label == '1' and t == 1:
                p = 0.42
            elif label == '1' and t == 5:
                p = 0.58
            elif label == '2' and t == 2:
                p = 0.60
            else:
                p = 0.50
            # compute label position
            mid_x = x0 + p * (xt - x0)
            mid_y = y0 + p * (yt - y0)
            fig.add_trace(go.Scatter(
                x=[mid_x], y=[mid_y],
                mode='text',
                text=[label],
                textposition='top center',
                textfont=dict(color=coset_colors[t]),   # ← one-line fix
                showlegend=False
            )
            )

    fig.update_traces(selector=dict(mode='text'), textfont_size=24)

    # ──────────────────────────────────────────────
    # Style the axes and layout to match your previous figures
    # ──────────────────────────────────────────────

    rgba_grey = 'rgba(128,128,128,0.8)'
    fig.update_xaxes(
        showgrid=True,
        gridcolor=rgba_grey, 
        gridwidth=1, 
        griddash='dot',
        zeroline=False,
        range=[-1.45, 1.45],
        scaleanchor='y', 
        scaleratio=1,
        title_text="cos(11(2π)a/66)"
    )
    fig.update_yaxes(
        showgrid=True,
        gridcolor=rgba_grey, 
        gridwidth=1, 
        griddash='dot',
        zeroline=False,
        range=[-1.45, 1.45],
        title_text="sin(11(2π)a/66)"
    )

    fig.update_layout(
        plot_bgcolor='white',
        paper_bgcolor='white',
        width=600,
        height=600,
        margin=dict(l=10, r=10, t=30, b=10),

        # move the legend in 25px from the right
        legend=dict(
             xref='paper',
             x=0.98,
             xanchor='right',
             y=0.94,
             yanchor='top'
         ),

        legend_title_text="Coset class"
    )
    fig.update_layout(
    legend=dict(
        bgcolor='rgba(0,0,0,0)',   # transparent background
        bordercolor='rgba(0,0,0,0)' # transparent border (optional)
    )
)
    
    angle_deg  = 20
    angle_rad  = np.deg2rad(angle_deg)
    r_label    = 1.23                       # 1.2 × unit-circle radius
    x_label    = r_label * np.cos(angle_rad)
    y_label    = r_label * np.sin(angle_rad)

    fig.add_trace(
        go.Scatter(
            x=[x_label], y=[y_label],
            mode='text',
            text=['0'],
            textfont=dict(
                color=coset_colors[0],   # brightest reverse-Viridis
                size=24
            ),
            textposition='middle center',
            showlegend=False
        )
    )

    return fig


if __name__ == "__main__":
    fig = get_coset_plot()
    fig.write_image("cosets_on_circle.pdf", format="pdf", engine="kaleido")
