import os
import json
import re
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Base directories to process
BASE_DIRS = [
    "/logn/neurips_2025/transformer_r2_heatmap_k=50_2000/59_50_nn_4_fits_attn-co=0.0_top-k",
    "/logn/neurips_2025/transformer_r2_heatmap_k=50_2000/59_50_nn_4_fits_attn-co=1.0_top-k",
    "/logn/neurips_2025/transformer_r2_heatmap_k=50_2000/59_50_nn_8_fits_attn-co=0.0_top-k",
    "/logn/neurips_2025/transformer_r2_heatmap_k=50_2000/59_50_nn_8_fits_attn-co=1.0_top-k",
    "/logn/neurips_2025/transformer_r2_heatmap_k=50_2000/59_50_nn_16_fits_attn-co=0.0_top-k",
    "/logn/neurips_2025/transformer_r2_heatmap_k=50_2000/59_50_nn_16_fits_attn-co=1.0_top-k"
]

base_dirs_clock = [
    "/logn/transformer_r2_heatmap_k=50_2000/59_50_nn_8_fits_attn-co=1.0_top-k_layers=1",
    "/logn/transformer_r2_heatmap_k=50_2000/59_50_nn_8_fits_attn-co=1.0_top-k_layers=2",
    "/logn/transformer_r2_heatmap_k=50_2000/59_50_nn_8_fits_attn-co=1.0_top-k_layers=3",
    "/logn/transformer_r2_heatmap_k=50_2000/59_50_nn_8_fits_attn-co=1.0_top-k_layers=4"
]

base_dirs_pizza = [
    "/logn/transformer_r2_heatmap_k=50_2000/59_50_nn_8_fits_attn-co=0.0_top-k_layers=1",
    "/transformer_r2_heatmap_k=50_2000/59_50_nn_8_fits_attn-co=0.0_top-k_layers=2",
    "/transformer_r2_heatmap_k=50_2000/59_50_nn_8_fits_attn-co=0.0_top-k_layers=3",
    "/transformer_r2_heatmap_k=50_2000/59_50_nn_8_fits_attn-co=0.0_top-k_layers=4"
]


# Regex patterns to extract learning rate, weight decay, and layer number from filenames
LR_PATTERN    = re.compile(r"lr=([^_]+)")
WD_PATTERN    = re.compile(r"wd=([^_]+)")
LAYER_PATTERN = re.compile(r"all_preactivations_L(\d+)_seed")

def collapse_duplicate_labels(vals, to_label_fn):
    """
    vals: list of values (e.g. your wds)
    to_label_fn: function mapping each val → its label string
    Returns: ticktext list where runs of identical labels are blanked except
             for the middle one in each run.
    """
    raw = [to_label_fn(v) for v in vals]
    out = [""] * len(raw)
    i = 0
    while i < len(raw):
        j = i + 1
        while j < len(raw) and raw[j] == raw[i]:
            j += 1
        mid = i + (j - i - 1) // 2
        out[mid] = raw[mid]
        i = j
    return out


def extract_results(base_dir):
    """
    Walks through each subdirectory under base_dir, reads all summary JSONs,
    and aggregates data per layer for each (lr, wd) setting.
    Returns a dict: layer -> list of dicts {'lr', 'wd', 'avg_r2', 'perc_sin_cos', 'avg_test_acc', 'avg_inj_acc'}.
    """
    temp = {}  # key = (layer, lr, wd)

    for run_name in os.listdir(base_dir):
        run_path = os.path.join(base_dir, run_name)
        if not os.path.isdir(run_path):
            continue

        # Parse hyperparameters from directory name
        m_lr = LR_PATTERN.search(run_name)
        m_wd = WD_PATTERN.search(run_name)
        if not m_lr or not m_wd:
            continue
        lr = float(m_lr.group(1))
        wd = float(m_wd.group(1))

        for fname in os.listdir(run_path):
            if not fname.endswith('.json'):
                continue
            m_layer = LAYER_PATTERN.search(fname)
            if not m_layer:
                continue
            layer = int(m_layer.group(1))

            with open(os.path.join(run_path, fname), 'r') as f:
                summary = json.load(f)

            # Initialize accumulation entry
            key = (layer, lr, wd)
            entry = temp.setdefault(key, {
                'r2_list': [],
                'test_accs': [],
                'inj_accs': [],
                'sin_cos_count': 0,
                'total_count': 0
            })

            # --- Aggregate R² across neuron pairs ---
            valid = [v for v in summary.values() 
                     if isinstance(v, dict) and 'count' in v and 'avg_r2' in v]
            total_cnt = sum(v['count'] for v in valid)
            if total_cnt > 0:
                weighted_sum = sum(v['avg_r2'] * v['count'] for v in valid)
                avg_r2 = weighted_sum / total_cnt
            else:
                avg_r2 = np.nan
            entry['r2_list'].append(avg_r2)

            # --- Track test accuracies ---
            entry['test_accs'].append(summary.get('test_accuracy', np.nan))
            entry['inj_accs'].append(summary.get('injected_test_accuracy', np.nan))

            # --- Count sin_cos vs conjecture for % sin_cos heatmap ---
            sin_cnt  = sum(v['count'] for v in valid if v['fit_type']=='sin_cos')
            conj_cnt = sum(v['count'] for v in valid if v['fit_type']=='conjecture')
            entry['sin_cos_count'] += sin_cnt
            entry['total_count']   += (sin_cnt + conj_cnt)

    # Collapse into final results
    results = {}
    for (layer, lr, wd), d in temp.items():
        mean_r2 = float(np.nanmean(d['r2_list'])) if d['r2_list'] else np.nan
        perc_sin = (d['sin_cos_count']/d['total_count']
                    if d['total_count']>0 else np.nan)
        avg_ta = float(np.nanmean(d['test_accs'])) if d['test_accs'] else np.nan
        avg_ia = float(np.nanmean(d['inj_accs']))  if d['inj_accs'] else np.nan

        results.setdefault(layer, []).append({
            'lr': lr,
            'wd': wd,
            'avg_r2': mean_r2,
            'perc_sin_cos': perc_sin,
            'avg_test_acc': avg_ta,
            'avg_inj_acc':  avg_ia
        })
    return results

# Collect results per base directory
all_results = {os.path.basename(d): extract_results(d) for d in BASE_DIRS}

# Symbol mapping for overlay:
symbol_map = {
    'both':      ('✓', 'green'),
    'test_only': ('●', 'magenta'),
    'neither':   ('✗', 'red')
}

# Plot and save
for name, layer_dict in all_results.items():
    layers = sorted(layer_dict.keys())
    if not layers:
        print(f"Warning: no layers found for {name}, skipping...")
        continue

    # Unique hyperparameter values
    all_lrs = sorted({e['lr'] for entries in layer_dict.values() for e in entries})
    all_wds = sorted({e['wd'] for entries in layer_dict.values() for e in entries})

    # --- R² Heatmaps with overlayed symbols ---
    fig_r2 = make_subplots(
        rows=1, cols=len(layers),
        subplot_titles=[f"Layer {l} R²" for l in layers]
    )
    for idx, layer in enumerate(layers, start=1):
        # Build R² matrix
        mat = np.full((len(all_wds), len(all_lrs)), np.nan)
        for e in layer_dict[layer]:
            i = all_wds.index(e['wd'])
            j = all_lrs.index(e['lr'])
            mat[i, j] = e['avg_r2']

        # Add heatmap
        fig_r2.add_trace(
            go.Heatmap(
                x=all_lrs, y=all_wds, z=mat,
                coloraxis='coloraxis'
            ),
            row=1, col=idx
        )
        fig_r2.update_xaxes(title_text='Learning rate (step size)', row=1, col=idx)
        fig_r2.update_yaxes(title_text='L2 regularization', row=1, col=idx)

        # Overlay symbols
        for e in layer_dict[layer]:
            at, ai = e['avg_test_acc'], e['avg_inj_acc']
            # NEW THRESHOLD CHECK
            threshold = 0.999
            perfect_test = at > threshold
            perfect_inj  = ai > threshold

            if perfect_test and perfect_inj:
                key = 'both'
            elif perfect_test:
                key = 'test_only'
            else:
                key = 'neither'

            sym, col = symbol_map[key]

            fig_r2.add_trace(
                go.Scatter(
                    x=[e['lr']], y=[e['wd']],
                    mode='text',
                    text=[sym],
                    textfont=dict(color=col, size=14),
                    showlegend=False
                ),
                row=1, col=idx
            )
        # ─── Make both axes log‑scale ─────────────────────────
    import math

    def to_latex_exp(val):
        exponent = int(round(math.log10(val)))
        return f"10<sup>{exponent}</sup>"

    for col_idx in range(1, len(layers) + 1):
        tickvals_x = all_lrs
        ticktext_x = [to_latex_exp(lr) for lr in all_lrs]

        fig_r2.update_xaxes(
            title_text='Learning rate (step size)',
            type='log',
            tickangle=90,
            tickvals=tickvals_x,
            ticktext=ticktext_x,
            tickfont=dict(family="Arial", size=11),
            row=1, col=col_idx
        )

        tickvals_y = all_wds
        ticktext_y = [to_latex_exp(wd) for wd in all_wds]

        fig_r2.update_yaxes(
            title_text='L2 regularization',
            type='log',
            tickvals=tickvals_y,
            ticktext=ticktext_y,
            tickfont=dict(family="Arial", size=11),
            row=1, col=col_idx
        )


    out_r2 = f"{name}_r2_heatmaps.html"
    fig_r2.write_html(out_r2)
    print(f"Saved R² heatmaps → {out_r2}")

    # --- Percent SinCos Heatmaps (unchanged) ---
    for layer in layers:
        percs = [e['perc_sin_cos'] for e in layer_dict[layer]]
        if all(np.isnan(percs)):
            continue

        mat = np.full((len(all_wds), len(all_lrs)), np.nan)
        for e in layer_dict[layer]:
            i = all_wds.index(e['wd'])
            j = all_lrs.index(e['lr'])
            mat[i, j] = e['perc_sin_cos']

        # Create the basic heatmap
        fig_pct = go.Figure(
            go.Heatmap(
                x=all_lrs, 
                y=all_wds, 
                z=mat,
                colorbar=dict(title='% sin_cos')
            )
        )
        fig_pct.update_layout(
            title_text=f'{name} Layer {layer} % Neurons Fit by sin_cos',
            paper_bgcolor='white', plot_bgcolor='white'
        )

        # ——— Apply log scales and latex‐style ticks ———
        def to_latex_exp(val):
            exponent = int(round(math.log10(val)))
            return f"10<sup>{exponent}</sup>"

        tickvals_x = all_lrs
        ticktext_x = [to_latex_exp(lr) for lr in all_lrs]
        fig_pct.update_xaxes(
            title_text='Learning rate (step size)',
            type='log',
            tickvals=tickvals_x,
            ticktext=ticktext_x,
            tickangle=90,
            tickfont=dict(family="Arial", size=11)
        )

        tickvals_y = all_wds

        ticktext_y = collapse_duplicate_labels(all_wds, to_latex_exp)
        fig_pct.update_yaxes(
            title_text='L2 regularization',
            type='log',
            tickvals=tickvals_y,
            ticktext=ticktext_y,
            tickfont=dict(family="Arial", size=11)
        )

        # Finally, write out your HTML
        out_pct = f"{name}_layer{layer}_perc_sin_cos.html"
        fig_pct.write_html(out_pct)
        print(f"Saved % sin_cos heatmap → {out_pct}")



# Group base dirs by attention coefficient
dirs_co0 = [d for d in BASE_DIRS if 'attn-co=0.0' in d]
dirs_co1 = [d for d in BASE_DIRS if 'attn-co=1.0' in d]

# Sort each group by nn size (4, 8, 16)
nn_pattern = re.compile(r'_nn_(\d+)_fits')
dirs_co0.sort(key=lambda d: int(nn_pattern.search(d).group(1)))
dirs_co1.sort(key=lambda d: int(nn_pattern.search(d).group(1)))

# Determine the two layers (assumed common across dirs)
example_name = os.path.basename(dirs_co0[0])
layers = sorted(all_results[example_name].keys())[:2]  # [layer1, layer2]

# Build subplot titles in new order: layer1 across nn, then layer2
titles = []
for co, dirs in zip([0.0, 1.0], [dirs_co0, dirs_co1]):
    prefix = "Pizza: " if co < 0.1 else "Clock: "
    for layer in layers:
        for d in dirs:
            nn = nn_pattern.search(d).group(1)
            neurons = str(int(nn) * 128)
            titles.append(f"{prefix}#neurons={neurons}, Layer {layer}")

# Create 2x6 subplot figure
grid_kwargs = dict(rows=2, cols=6, subplot_titles=titles,
                   horizontal_spacing=0.02, vertical_spacing=0.1)
fig5 = make_subplots(**grid_kwargs)

# Apply a common colorscale and white background
fig5.update_layout(
    coloraxis=dict(colorscale='Viridis'),
    paper_bgcolor='white',
    plot_bgcolor='white',
    title_text='Figure 5: R² Heatmaps Across Models and Layers'
)

# Symbol map for baseline shapes
sym_map = {'both': '✓', 'test_only': '●', 'neither': '✗'}
threshold = 0.999

# Populate subplots in order: first 3 cols = layer1, next 3 = layer2
for row, dirs in enumerate([dirs_co0, dirs_co1], start=1):
    for layer_idx, layer in enumerate(layers):
        for group_idx, d in enumerate(dirs):
            col = layer_idx * len(dirs) + group_idx + 1
            name = os.path.basename(d)
            entries = all_results[name][layer]

            # Prepare R² matrix
            lrs = sorted({e['lr'] for e in entries})
            wds = sorted({e['wd'] for e in entries})
            mat = np.full((len(wds), len(lrs)), np.nan)
            for e in entries:
                i = wds.index(e['wd'])
                j = lrs.index(e['lr'])
                mat[i, j] = e['avg_r2']

            # Add heatmap with shared coloraxis
            fig5.add_trace(
                go.Heatmap(x=lrs, y=wds, z=mat, coloraxis='coloraxis'),
                row=row, col=col
            )

            # Overlay symbols with updated logic and correct coloring
            nn_int = int(nn_pattern.search(d).group(1))
            neurons = nn_int * 128
            for e in entries:
                at, ai = e['avg_test_acc'], e['avg_inj_acc']
                # Determine performance key and symbol color
                if at > threshold and ai > threshold:
                    key, color = 'both', 'green'
                elif at > threshold:
                    key, color = 'test_only', 'purple'
                else:
                    key, color = 'neither', 'red'

                # Baseline for 512 neurons; else compare to previous model
                if neurons == 512:
                    sym = sym_map[key]
                else:
                    prev_dir = dirs[group_idx - 1]
                    prev_name = os.path.basename(prev_dir)
                    prev_entries = all_results[prev_name][layer]
                    prev_e = next((pe for pe in prev_entries
                                   if pe['lr'] == e['lr'] and pe['wd'] == e['wd']), None)
                    if prev_e:
                        prev_r2 = prev_e['avg_r2']
                        curr_r2 = e['avg_r2']
                        if curr_r2 > prev_r2:
                            sym = '>'
                        elif curr_r2 < prev_r2:
                            sym = '<'
                        else:
                            sym = sym_map[key]
                    else:
                        sym = sym_map[key]

                fig5.add_trace(
                    go.Scatter(
                        x=[e['lr']], y=[e['wd']],
                        mode='text', text=[sym],
                        textfont=dict(color=color, size=14),
                        showlegend=False
                    ),
                    row=row, col=col
                )

            # Configure x-axis (log + LaTeX ticks)
            tickvals_x = lrs
            ticktext_x = [to_latex_exp(v) for v in tickvals_x]
            fig5.update_xaxes(
                title_text='Learning rate', type='log',
                tickvals=tickvals_x, ticktext=ticktext_x,
                tickangle=90, tickfont=dict(family='Arial', size=11),
                row=row, col=col
            )

            # Configure y-axis title only on first column
            tickvals_y = wds
            ticktext_y = [to_latex_exp(v) for v in tickvals_y]
            if col == 1:
                fig5.update_yaxes(
                    title_text='L2 regularization', type='log',
                    tickvals=tickvals_y, ticktext=ticktext_y,
                    tickfont=dict(family='Arial', size=11),
                    row=row, col=col
                )
            else:
                fig5.update_yaxes(
                    type='log',
                    tickvals=tickvals_y, ticktext=ticktext_y,
                    tickfont=dict(family='Arial', size=11),
                    row=row, col=col
                )

# Write out the combined html
fig5.write_html('figure_5.html')
print('Saved Figure 5 → figure_5.html')

# --- 6) COMBINED FIGURE FOR MLPs (Pizza & Clock) WITH CUSTOM POSITIONS ---
from plotly.subplots import make_subplots

def make_figure6(base_dirs, output_filename, main_title):
    # 1) Gather all_results and global hyperparams
    all_results = []
    global_lrs = set()
    global_wds = set()
    for bd in base_dirs:
        name = os.path.basename(bd.rstrip("/"))
        layer_dict = extract_results(bd)
        all_results.append((name, layer_dict))
        for entries in layer_dict.values():
            global_lrs |= {e['lr'] for e in entries}
            global_wds |= {e['wd'] for e in entries}
    global_lrs = sorted(global_lrs)
    global_wds = sorted(global_wds)

    # 2) Layout map & subplot titles (3 rows × 4 cols)
    layout_map = {
        1: (1, [1]),
        2: (1, [3,4]),
        3: (2, [1,2,3]),
        4: (3, [1,2,3,4])
    }
    titles = [""] * 12
    for idx, (_, layer_dict) in enumerate(all_results, start=1):
        row, cols = layout_map[idx]
        Ls = sorted(layer_dict)
        for j, L in enumerate(Ls):
            col = cols[j]
            pos = (row-1)*4 + (col-1)
            titles[pos] = f"Layer {L} out of {len(Ls)} R²"

    # 3) Create the figure
    fig = make_subplots(
        rows=3, cols=4,
        subplot_titles=titles,
        horizontal_spacing=0.05,
        vertical_spacing=0.1125
    )
    fig.for_each_annotation(lambda ann: ann.update(font=dict(size=18)))

    # 4) Fill in heatmaps + overlays
    for idx, (_, layer_dict) in enumerate(all_results, start=1):
        row, cols = layout_map[idx]
        for j, L in enumerate(sorted(layer_dict)):
            col = cols[j]

            # Build the R² matrix
            Z = np.full((len(global_wds), len(global_lrs)), np.nan)
            for e in layer_dict[L]:
                i = global_wds.index(e['wd'])
                k = global_lrs.index(e['lr'])
                Z[i, k] = e['avg_r2']

            # Add the heatmap (uses shared coloraxis)
            fig.add_trace(
                go.Heatmap(
                    z=Z,
                    x=global_lrs,
                    y=global_wds,
                    coloraxis="coloraxis"
                ),
                row=row, col=col
            )

            # Overlay ✓/●/✗
            for e in layer_dict[L]:
                at, ai = e['avg_test_acc'], e['avg_inj_acc']
                if at > 0.999 and ai > 0.999:
                    sym, col_sym = symbol_map['both']
                elif at > 0.999:
                    sym, col_sym = symbol_map['test_only']
                else:
                    sym, col_sym = symbol_map['neither']
                fig.add_trace(
                    go.Scatter(
                        x=[e['lr']], y=[e['wd']],
                        mode="text", text=[sym],
                        textfont=dict(color=col_sym, size=18),
                        showlegend=False, cliponaxis=True
                    ),
                    row=row, col=col
                )

            # Log‐scale axes + tick formatting
            # ─── LaTeX ticks + bigger fonts ───────────────────────────────
            def to_latex_exp(val):
                exponent = int(round(math.log10(val)))
                return f"10<sup>{exponent}</sup>"
            

            # X‐axis
            tickvals_x = global_lrs
            ticktext_x = collapse_duplicate_labels(global_lrs, to_latex_exp)
            fig.update_xaxes(
                type="log", tickvals=tickvals_x, ticktext=ticktext_x,
                tickangle=90, tickfont=dict(family="Arial", size=22),
                row=row, col=col
            )

            # Y‐axis
            tickvals_y = global_wds
            ticktext_y = collapse_duplicate_labels(global_wds, to_latex_exp)
            fig.update_yaxes(
                type="log", tickvals=tickvals_y, ticktext=ticktext_y,
                tickfont=dict(family="Arial", size=17),
                row=row, col=col
            )



    # 5) Manual legend in the empty slot (row=1,col=2)
    fig.update_xaxes(visible=False, row=1, col=2)
    fig.update_yaxes(visible=False, row=1, col=2)
    legend_items = [
        ("Fit and test accuracy both 100% (✓)", "green"),
        ("Test 100% and fit accuracy < 100% (●)",   "magenta"),
        ("Neither accuracy 100% (✗)",     "red")
    ]
    x0, y0, dy = 0.24, 0.9525, 0.05
    for i, (label, color) in enumerate(legend_items):
        fig.add_annotation(
            x=x0, y=y0 - i*dy,
            xref="paper", yref="paper",
            text=f"<span style='color:{color}; font-size:19px'>{label}</span>",
            showarrow=False, align="left"
        )

    # 6) Global axis titles + shared colorbar
    fig.update_xaxes(title_text="Learning rate", title_font=dict(size=17), title_standoff=0)
    fig.update_yaxes(title_text="L2 regularization", title_font=dict(size=17), title_standoff=2)
    fig.update_layout(
        width=1600, height=900,
        paper_bgcolor='white', plot_bgcolor='white',
        font=dict(size=18), title_font_size=24,
        title_text=main_title,
        coloraxis=dict(
            colorscale='plasma',
            colorbar=dict(
                title=dict(text="R²", font=dict(size=18)),
                tickfont=dict(size=18),
                x=0.87375, xanchor="left",
                y=0.6292,  yanchor="top",
                len=0.2583, thickness=40
            )
        )
    )

    fig.write_html(output_filename)
    print(f"✅ Saved combined figure → {output_filename}")


# Now call for Pizza and Clock:
make_figure6(
    base_dirs_pizza,
    "figure_6_pizza.html",
    "Pizza-transformer: R² vs depth (1–4 layers)"
)
make_figure6(
    base_dirs_clock,
    "figure_6_clock.html",
    "Clock-transformer: R² vs depth (1–4 layers)"
)

# ─── Figure 6 Final: legend in 11th cell (row=2,col=4) ───

# Helper to get ordinal suffix
def ordinal_suffix(n):
    if 10 <= n % 100 <= 20:
        return 'th'
    if n % 10 == 1:
        return 'st'
    if n % 10 == 2:
        return 'nd'
    if n % 10 == 3:
        return 'rd'
    return 'th'

# 1) Gather runs and global hyperparams
all_runs = [("Pizza", base_dirs_pizza), ("Clock", base_dirs_clock)]
global_lrs = set()
global_wds = set()
run_results = []
for tag, dirs in all_runs:
    for bd in dirs:
        name = os.path.basename(bd.rstrip("/"))
        layer_dict = extract_results(bd)
        run_results.append((tag, name, layer_dict))
        for entries in layer_dict.values():
            global_lrs |= {e['lr'] for e in entries}
            global_wds |= {e['wd'] for e in entries}
# Sort global hyperparameters
global_lrs = sorted(global_lrs)
global_wds = sorted(global_wds)

# Precompute tick positions and labels at exact powers of ten
exponents_lrs = sorted({int(round(math.log10(v))) for v in global_lrs})
tickvals_x = [10**e for e in exponents_lrs]
ticktext_x = [f"10<sup>{e}</sup>" for e in exponents_lrs]

exponents_wds = sorted({int(round(math.log10(v))) for v in global_wds})
tickvals_y = [10**e for e in exponents_wds]
ticktext_y = [f"10<sup>{e}</sup>" for e in exponents_wds]

# 2) Flatten panels (20 heatmaps)
panels = []
for tag, name, layer_dict in run_results:
    for L in sorted(layer_dict):
        panels.append((tag, name, layer_dict, L))

# 3) Build subplot titles with superscript ordinals
rows, cols = 3, 7
subplot_titles = [""] * (rows * cols)
position_list = [(r, c) for r in range(1, rows+1)
                 for c in range(1, cols+1) if not (r == 2 and c == 4)]
for (tag, name, layer_dict, L), (row, col) in zip(panels, position_list):
    L_max = len(layer_dict)
    suffix = ordinal_suffix(L)
    title = f"{tag}: {L}<sup>{suffix}</sup>/{L_max}"
    idx = (row - 1) * cols + (col - 1)
    subplot_titles[idx] = title

# 4) Create figure with custom titles
fig = make_subplots(
    rows=rows, cols=cols,
    subplot_titles=subplot_titles,
    horizontal_spacing=0.005,
    vertical_spacing=0.05
)
for ann in fig.layout.annotations:
    ann.font.size = 24
    ann.y -= 0.01

# 5) Fill panels with heatmaps, overlays, and conditional axis labels
for (tag, name, layer_dict, L), (row, col) in zip(panels, position_list):
    entries = layer_dict[L]
    Z = np.full((len(global_wds), len(global_lrs)), np.nan)
    for e in entries:
        i = global_wds.index(e['wd'])
        j = global_lrs.index(e['lr'])
        Z[i, j] = e['avg_r2']

    fig.add_trace(
        go.Heatmap(
            z=Z, x=global_lrs, y=global_wds,
            coloraxis="coloraxis", showscale=False
        ),
        row=row, col=col
    )

    for e in entries:
        at, ai = e['avg_test_acc'], e['avg_inj_acc']
        if at > threshold and ai > threshold:
            sym, col_sym = symbol_map['both']
        elif at > threshold:
            sym, col_sym = symbol_map['test_only']
        else:
            sym, col_sym = symbol_map['neither']
        fig.add_trace(
            go.Scatter(
                x=[e['lr']], y=[e['wd']],
                mode="text", text=[sym],
                textfont=dict(color=col_sym, size=14),
                showlegend=False
            ),
            row=row, col=col
        )

    # Axis formatting: only outer row/col show tick labels
    if row == rows:
        fig.update_xaxes(
            title_text='Learning rate', title_font=dict(size=22), title_standoff=10,
            type='log', tickvals=tickvals_x, ticktext=ticktext_x,
            tickangle=90, tickfont=dict(size=24), showticklabels=True,
            row=row, col=col
        )
    else:
        fig.update_xaxes(
            type='log', tickvals=tickvals_x, ticktext=ticktext_x,
            tickangle=90, tickfont=dict(size=22), showticklabels=False,
            row=row, col=col
        )

    if col == 1:
        fig.update_yaxes(
            title_text='L2 regularization', title_font=dict(size=22), title_standoff=10,
            type='log', tickvals=tickvals_y, ticktext=ticktext_y,
            tickfont=dict(size=22), showticklabels=True,
            row=row, col=col
        )
    else:
        fig.update_yaxes(
            type='log', tickvals=tickvals_y, ticktext=ticktext_y,
            tickfont=dict(size=22), showticklabels=False,
            row=row, col=col
        )

# 6) Add visible ticks (black hyphens) on all axes
fig.update_xaxes(
    ticks='outside', ticklen=5, tickcolor='black', tickwidth=1
)
fig.update_yaxes(
    ticks='outside', ticklen=5, tickcolor='black', tickwidth=1
)

# 7) Hide axes in legend slot
fig.update_xaxes(visible=False, row=2, col=4)
fig.update_yaxes(visible=False, row=2, col=4)

# Manual legend
legend_items = [
    ("✓ both fit & test 100%", "green"),
    ("● test 100%, fit < 100%", "purple"),
    ("✗ neither 100%", "red"),
]
x0 = (4 - 0.5) / cols
y0 = (1 - (2 - 0.5) / rows) + 0.05
dy = 0.05
for i, (label, color) in enumerate(legend_items):
    fig.add_annotation(
        x=x0, y=y0 - i*dy,
        xref='paper', yref='paper',
        text=f"<span style='color:{color}; font-size:20px'>{label}</span>",
        showarrow=False, align='left'
    )

# 8) Final layout
fig.update_layout(
    width=1600, height=850,
    paper_bgcolor='white', plot_bgcolor='white',
    title_text='Pizza and Clock R² across training hyperparameters for 1, 2, 3, 4 layer models',
    title_y=0.95,
    font=dict(size=24), title_font_size=24,
    coloraxis=dict(
        colorscale='plasma',
        colorbar=dict(
            title='R²', x=0.995, xanchor='left',
            y=0.5, yanchor='middle', len=0.9, thickness=20,
            tickfont=dict(size=15), title_font=dict(size=22)
        )
    )
)

# 9) Write HTML
fig.write_html('figure_6_final.html')
print("✅ Saved figure_6_final.html with legend in cell 11")


to_latex_exp = lambda val: f"10<sup>{int(round(math.log10(val)))}</sup>"
# symbol_map and threshold as defined before:
symbol_map = {'both': ('✓', 'green'), 'test_only': ('●', 'purple'), 'neither': ('✗', 'red')}
threshold = 0.999

# Paths for Figure 7
base_dirs_7 = [
    "/neurips_2025/transformer_r2_heatmap_k=50_2000/59_50_nn_16_fits_attn-co=0.0_top-k",
    "/neurips_2025/transformer_r2_heatmap_k=50_2000/59_50_nn_16_fits_attn-co=1.0_top-k"
]

# Extract results for both paths
results_7 = [extract_results(d) for d in base_dirs_7]
# Determine the two layers (common across both)
layers = sorted(results_7[0].keys())[:2]

# Create 1×4 subplot figure
titles = [f"Layer {L} ({'pizza' if idx_dir==0 else 'clock'})"
          for L in layers for idx_dir in [0,1]]
fig7 = make_subplots(
    rows=1, cols=4,
    subplot_titles=titles,
    horizontal_spacing=0.05,
    vertical_spacing=0.145
)

# Fill subplots
for i, layer in enumerate(layers, start=1):
    for j, layer_dict in enumerate(results_7, start=1):
        # determine position in single row
        row = 1
        col = j + (i - 1) * 2

        entries = layer_dict[layer]
        # Collect hyperparameters
        lrs = sorted({e['lr'] for e in entries})
        wds = sorted({e['wd'] for e in entries})
        # Build heatmap matrix for % sin_cos
        mat = np.full((len(wds), len(lrs)), np.nan)
        for e in entries:
            ii = wds.index(e['wd'])
            jj = lrs.index(e['lr'])
            mat[ii, jj] = e['perc_sin_cos']

        # Add heatmap with clamped color scale
        fig7.add_trace(
            go.Heatmap(
                x=lrs,
                y=wds,
                z=mat,
                zmin=0,
                zmax=1.0,
                coloraxis='coloraxis'
            ),
            row=row, col=col
        )

        # Overlay symbols based on accuracy logic
        for e in entries:
            at, ai = e.get('avg_test_acc', np.nan), e.get('avg_inj_acc', np.nan)
            if at > threshold and ai > threshold:
                key = 'both'
            elif at > threshold:
                key = 'test_only'
            else:
                key = 'neither'
            sym, col_sym = symbol_map[key]
            fig7.add_trace(
                go.Scatter(
                    x=[e['lr']],
                    y=[e['wd']],
                    mode='text',
                    text=[sym],
                    textfont=dict(color=col_sym, size=24),
                    showlegend=False
                ),
                row=row, col=col
            )

        # Configure axes with log scale and LaTeX ticks
        tickvals_x = lrs
        ticktext_x = collapse_duplicate_labels(lrs, to_latex_exp)
        fig7.update_xaxes(
            title_text='Learning rate',
            title_font=dict(size=18),
            type='log',
            tickvals=tickvals_x,
            ticktext=ticktext_x,
            tickangle=90,
            tickfont=dict(family='Arial', size=22),
            row=row, col=col
        )

        tickvals_y = wds
        ticktext_y = collapse_duplicate_labels(wds, to_latex_exp)

        fig7.update_yaxes(
            title_text='L2 regularization' if j == 1 else None,
            title_font=dict(size=18),
            type='log',
            tickvals=tickvals_y,
            ticktext=ticktext_y,
            tickfont=dict(family='Arial', size=17),
            row=row, col=col
        )

# Shared layout with % sin_cos colorbar and increased global font
fig7.update_layout(
    title_text='% Neurons with their best R² first order sinusoids (2048 neurons pizza vs clock)',
    font=dict(family='Arial', size=18),
    coloraxis=dict(
        colorscale='plasma',
        cmin=0,
        cmax=1.0,
        colorbar=dict(
            title='% fit w/ 1-order sines',
            title_font=dict(size=14),
            tickfont=dict(size=14)
        )
    ),
    paper_bgcolor='white',
    plot_bgcolor='white',
    width=1600,   # increase width for 4 panels
    height=400
)
# Increase subplot‐title font size
for ann in fig7.layout.annotations:
    ann.font.size = 24   # ← pick whatever size works best for you

# Write HTML
fig7.write_html('figure_7.html')
print('Saved Figure 7 → figure_7.html')


# 1) Your four histogram dirs
base_dir_histograms = [
    "/logn_histogram/"
    "transformer_r2_heatmap_k=50_2001/66_50_nn_8_fits_attn-co=0.0_top-k_layers=1"
    "/p=66_bs=66_k=50_nn=1024_lr=0.00075_wd=0.0001_epochs=2001_training_set_size=3300",
    "/logn_histogram/"
    "transformer_r2_heatmap_k=50_2001/66_50_nn_8_fits_attn-co=0.0_top-k_layers=4"
    "/p=66_bs=66_k=50_nn=1024_lr=0.00075_wd=0.0001_epochs=2001_training_set_size=3300",
    "/logn_histogram/"
    "transformer_r2_heatmap_k=50_2001/66_50_nn_8_fits_attn-co=1.0_top-k_layers=1"
    "/p=66_bs=66_k=50_nn=1024_lr=0.00075_wd=0.0001_epochs=2001_training_set_size=3300",
    "/logn_histogram/"
    "transformer_r2_heatmap_k=50_2001/66_50_nn_8_fits_attn-co=1.0_top-k_layers=4"
    "/p=66_bs=66_k=50_nn=1024_lr=0.00075_wd=0.0001_epochs=2001_training_set_size=3300"
]

layer_re = re.compile(r"all_preactivations_L(\d+)_")
def map_freq(f):
    r = f % 33
    return r if r != 0 else 33

# assume base_dir_histograms is defined as before
figs9 = []
for bd in base_dir_histograms:
    uniq_counts  = np.zeros(33, dtype=int)
    total_counts = np.zeros(33, dtype=int)
    sum_counts   = np.zeros(33, dtype=int)
    event_counts = np.zeros(33, dtype=int)

    for root, _, files in os.walk(bd):
        for fn in files:
            if not fn.endswith(".json"):
                continue
            m = layer_re.search(fn)
            if not m:
                continue
            layer = int(m.group(1))
            thresh = 3 if layer == 1 else 11

            data = json.load(open(os.path.join(root, fn)))
            seen = set()
            for k, v in data.items():
                if not isinstance(v, dict) or "count" not in v:
                    continue
                cnt = v["count"]
                i, j = map(int, k.split(","))
                ii, jj = map_freq(i)-1, map_freq(j)-1

                # unique logic unchanged
                if cnt > thresh:
                    seen.update([i, j])

                # total counts (always)
                total_counts[ii] += cnt
                if i != j:
                    total_counts[jj] += cnt

                # average logic: skip if layer1 and cnt<4
                if not (layer == 1 and cnt < 4):
                    sum_counts[ii]   += cnt
                    event_counts[ii] += 1
                    if i != j:
                        sum_counts[jj]   += cnt
                        event_counts[jj] += 1

            for u in seen:
                uniq_counts[map_freq(u)-1] += 1

    # compute average safely
    avg_counts = np.divide(
        sum_counts, event_counts,
        out=np.zeros_like(sum_counts, dtype=float),
        where=event_counts > 0
    )

    # build 1×3 bar chart
    x33 = list(range(1, 34))
    fig = make_subplots(
        rows=1, cols=3,
        subplot_titles=["Unique freqs (>thresh)", "Total raw counts", "Average count"]
    )
    fig.add_trace(go.Bar(x=x33, y=uniq_counts),  row=1, col=1)
    fig.add_trace(go.Bar(x=x33, y=total_counts), row=1, col=2)
    fig.add_trace(go.Bar(x=x33, y=avg_counts),   row=1, col=3)
    fig.update_layout(title_text=bd, showlegend=False, margin=dict(t=60))
    figs9.append(fig)

# overwrite figure_9.html with the new 1×3 layout
with open("figure_9.html", "w") as f:
    f.write("""
<html><head>
  <script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
</head><body>
""")
    for fig in figs9:
        f.write(fig.to_html(include_plotlyjs=False, full_html=False))
    f.write("</body></html>")

print("Saved Figure 9 → figure_9.html")


# ─── Figure 9b: per-layer breakdown with averages ────────────────────────────
figs9b = []

for bd in base_dir_histograms:
    # discover all layers in this directory
    layers = sorted({
        int(layer_re.search(fn).group(1))
        for root, _, files in os.walk(bd)
        for fn in files
        if layer_re.search(fn)
    })
    if not layers:
        print(f"⚠️ Skipping {bd}: no JSONs found.")
        continue

    # make 3×L subplots: unique / total / average
    fig = make_subplots(
        rows=3, cols=len(layers),
        subplot_titles=(
            [f"L{l} unique" for l in layers] +
            [f"L{l} total"  for l in layers] +
            [f"L{l} avg"    for l in layers]
        ),
        horizontal_spacing=0.04,
        vertical_spacing=0.10
    )

    for ci, layer in enumerate(layers, start=1):
        uniq   = np.zeros(33, dtype=int)
        tot    = np.zeros(33, dtype=int)
        events = np.zeros(33, dtype=int)
        thresh = 3 if layer == 1 else 11

        for root, _, files in os.walk(bd):
            for fn in files:
                if not fn.endswith(".json"):
                    continue
                m = layer_re.search(fn)
                if not m or int(m.group(1)) != layer:
                    continue
                data = json.load(open(os.path.join(root, fn)))
                seen = set()
                for k, v in data.items():
                    if not isinstance(v, dict) or "count" not in v:
                        continue
                    cnt = v["count"]
                    i, j = map(int, k.split(","))
                    ii = map_freq(i) - 1
                    jj = map_freq(j) - 1

                    if cnt > thresh:
                        seen.update([i, j])

                    tot[ii]    += cnt
                    events[ii] += 1
                    if i != j:
                        tot[jj]    += cnt
                        events[jj] += 1

                for u in seen:
                    uniq[map_freq(u) - 1] += 1

        avg = np.divide(tot, events, out=np.zeros_like(tot, dtype=float), where=events>0)

        # plot into columns=ci, rows=1..3
        fig.add_trace(go.Bar(x=x33, y=uniq),  row=1, col=ci)
        fig.add_trace(go.Bar(x=x33, y=tot),   row=2, col=ci)
        fig.add_trace(go.Bar(x=x33, y=avg),   row=3, col=ci)

    fig.update_layout(
        title_text=f"Figure 9b: {bd}",
        showlegend=False,
        width=300 * len(layers),
        height=900,
        margin=dict(t=80)
    )
    figs9b.append((bd, fig))

# write figure_9_b.html with headings
with open("figure_9_b.html", "w") as f:
    f.write("<html><head>"
            "<script src='https://cdn.plot.ly/plotly-latest.min.js'></script>"
            "</head><body>")
    for bd, fig in figs9b:
        f.write(f"<h2>{bd}</h2>")
        f.write(fig.to_html(include_plotlyjs=False, full_html=False))
    f.write("</body></html>")
print("✅ Saved Figure 9b → figure_9_b.html")




# Figure_9_final.html below


pizza_dirs = base_dir_histograms[:2]
clock_dirs = base_dir_histograms[2:]

# Build panel specs: (tag, dir, layer, total_layers)
panels = []
for tag, dirs in [("Pizza", pizza_dirs), ("Clock", clock_dirs)]:
    for bd in dirs:
        layers = sorted({
            int(layer_re.search(fn).group(1))
            for root, _, files in os.walk(bd)
            for fn in files if layer_re.search(fn)
        })
        L_max = len(layers)
        for L in layers:
            panels.append((tag, bd, L, L_max))

# 2×5 grid
fig9_final = make_subplots(
    rows=2, cols=5,
    subplot_titles=[f"{tag}: {L}/{L_max}" for tag,_,L,L_max in panels],
    horizontal_spacing=0.01,
    vertical_spacing=0.08
)

for idx, (tag, bd, L, L_max) in enumerate(panels, start=1):
    row = 1 if idx <= 5 else 2
    col = idx if idx <= 5 else idx - 5

    # compute unique counts
    uniq_counts = np.zeros(33, dtype=int)
    thresh = 3 if L == 1 else 11
    for root, _, files in os.walk(bd):
        for fn in files:
            if not fn.endswith(".json"):
                continue
            m = layer_re.search(fn)
            if not m or int(m.group(1)) != L:
                continue
            data = json.load(open(os.path.join(root, fn)))
            seen = set()
            for k, v in data.items():
                if isinstance(v, dict) and "count" in v:
                    cnt = v["count"]
                    i, j = map(int, k.split(","))
                    if cnt > thresh:
                        seen.update([i, j])
            for u in seen:
                uniq_counts[map_freq(u) - 1] += 1

    fig9_final.add_trace(
        go.Bar(x=list(range(1,34)), y=uniq_counts, showlegend=False),
        row=row, col=col
    )

    # y-axis: only first column
    fig9_final.update_yaxes(
        title_text='count' if col == 1 else None,
        showticklabels=(col == 1),
        tickfont=dict(size=15),
        row=row, col=col
    )

    # x-axis: only bottom row
    fig9_final.update_xaxes(
        title_text='Frequency' if row == 2 else None,
        showticklabels=(row == 2),
        tickfont=dict(size=15),
        tickmode='array',
        tickvals=[1, 10, 20, 30],
        ticktext=['1', '10', '20', '30'] if row == 2 else ['','','',''],
        ticks='outside',
        row=row, col=col
    )

# match subplot‐title font and shift down slightly
for ann in fig9_final.layout.annotations:
    ann.font.size = 16
    ann.y -= 0.0851

fig9_final.update_layout(
    width=800,
    height=320,
    paper_bgcolor='white',
    plot_bgcolor='white',
    font=dict(size=16),
    title_text='Frequency counts in 1 & 4 layer models mod 66 = (2x3x11)'
)

fig9_final.write_html('figure_9_final.html')
print("Saved figure_9_final.html")