import os
import re
import json
import numpy as np
from glob import glob
from collections import defaultdict
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.io as pio

base_dirs = [
    "multilayer_heatmaps_logn/one_embed_1-freqs_59-",
    "multilayer_heatmaps_2/one_embed_2-freqs_59-",
    "multilayer_heatmaps_2/one_embed_3-freqs_59-",
    "multilayer_heatmaps_2/one_embed_4-freqs_59-",
    "multilayer_heatmaps_logn/one_embed_5-freqs_59-"
]

# --- Regex Patterns ---
folder_pattern   = re.compile(
    r"freq_distribution_mlp=(\w+)_p=(\d+)_bs=(\d+)_k=(\d+)_nn=(\d+)_wd=([0-9.e-]+)_lr=([0-9.e-]+)"
)
freq_file_regex  = re.compile(r"freq_distribution_layer_(\d+)_top-k[_=](\d+)_seed_(\d+)\.json")
recon_file_regex = re.compile(r"reconstruction_metrics_top-k[=_](\d+)_seed_(\d+)\.json")

# --- Data Containers ---
metrics_by_config = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
detected_configs  = {}
high_acc_counts   = defaultdict(lambda: defaultdict(int))
total_run_counts  = defaultdict(lambda: defaultdict(int))
threshold_acc     = 99.99925
overlay_threshold = 0.9998
symbol_map = {
    'both':      ('✓', 'green'),
    'test_only': ('●', 'magenta'),
    'neither':   ('✗', 'red')
}

# --- 1) Process Each Base Directory ---
for base_dir in base_dirs:
    base_name = os.path.basename(base_dir.rstrip("/"))
    m_base = re.search(r"(no_embed|one_embed|two_embed)_(\d+)-freqs_59-", base_name)
    if not m_base:
        print(f"Skipping unmatched base_dir: {base_dir}")
        continue
    mlp_class_lower = m_base.group(1)
    expected_layers = int(m_base.group(2))

    for folder in os.listdir(base_dir):
        full_path = os.path.join(base_dir, folder)
        if not os.path.isdir(full_path):
            continue
        m = folder_pattern.match(folder)
        if not m:
            continue

        # --- Extract Hyperparameters ---
        folder_mlp, p, bs, k, nn, wd_str, lr_str = m.groups()
        p, bs, k, nn = int(p), int(bs), int(k), int(nn)
        wd, lr = float(wd_str), float(lr_str)
        key = (lr, wd)
        config_key = (bs, nn, k, mlp_class_lower, expected_layers)

        # Register config
        detected_configs.setdefault(config_key, {
            "mlp": mlp_class_lower,
            "p": p,
            "k": k,
            "expected_layers": expected_layers,
            "available_layers": set()
        })

        # --- A) Frequency Distributions per Layer ---
        layer_freqs_by_seed = defaultdict(dict)
        for jf in glob(os.path.join(full_path, "freq_distribution_layer_*_top-k_*_seed_*.json")):
            m2 = freq_file_regex.search(jf)
            if not m2 or int(m2.group(2)) != 1:
                continue
            layer_idx = int(m2.group(1))
            seed      = int(m2.group(3))
            try:
                data = json.load(open(jf))
                layer_freqs_by_seed[layer_idx][seed] = data
            except Exception as e:
                print(f"Error reading {jf}: {e}")

        detected_configs[config_key]["available_layers"].update(layer_freqs_by_seed.keys())

        # Compute avg freq count per layer
        for layer_idx, seed_dict in layer_freqs_by_seed.items():
            counts = [len(d) for d in seed_dict.values()]
            if counts:
                metrics_by_config[config_key][f"layer_{layer_idx}_freqs"][key].append(np.mean(counts))

        # --- B) Compare Layer 1 vs 2 Shared/Extra Frequencies ---
        shared_vals = []; extra_vals = []
        shared_max_vals = []; extra_max_vals = []
        if 1 in layer_freqs_by_seed and 2 in layer_freqs_by_seed:
            for sd in set(layer_freqs_by_seed[1]) & set(layer_freqs_by_seed[2]):
                d1 = layer_freqs_by_seed[1][sd]
                d2 = layer_freqs_by_seed[2][sd]
                k1 = set(map(int, d1.keys()))
                k2 = set(map(int, d2.keys()))
                shared = k1 & k2
                extra  = k2 - k1
                shared_counts = [int(d2[str(f)][0]/59)+1 for f in shared]
                extra_counts  = [int(d2[str(f)][0]/59)+1 for f in extra]
                if shared_counts:
                    shared_vals.append(np.mean(shared_counts))
                    shared_max_vals.append(np.max(shared_counts))
                if extra_counts:
                    extra_vals.append(np.mean(extra_counts))
                    extra_max_vals.append(np.max(extra_counts))
            if shared_vals:
                metrics_by_config[config_key]["layer_2_shared_freqs"][key].append(np.mean(shared_vals))
                metrics_by_config[config_key]["layer_2_shared_freqs_max"][key].append(np.max(shared_max_vals))
            if extra_vals:
                metrics_by_config[config_key]["layer_2_extra_freqs"][key].append(np.mean(extra_vals))
                metrics_by_config[config_key]["layer_2_extra_freqs_max"][key].append(np.max(extra_max_vals))

        # --- C) Reconstruction Metrics & Per-Layer R² ---
        count_h, count_t = 0, 0
        for jf in glob(os.path.join(full_path, "reconstruction_metrics_top-k*seed_*.json")):
            m3 = recon_file_regex.search(jf)
            if not m3 or int(m3.group(1)) != 1:
                continue
            try:
                data = json.load(open(jf))
            except Exception as e:
                print(f"Error reading {jf}: {e}")
                continue

            sa = data.get("stored_fits", {}).get("accuracy", np.nan)
            ma = data.get("model",      {}).get("accuracy", np.nan)
            if sa is not None:
                count_t += 1
                if sa > threshold_acc:
                    count_h += 1
                metrics_by_config[config_key]["stored_acc"][key].append(sa)
            if ma is not None:
                metrics_by_config[config_key]["model_acc"][key].append(ma)

            metrics_by_config[config_key]["stored_loss"][key].append(
                data.get("stored_fits", {}).get("cross_entropy_loss")
            )
            metrics_by_config[config_key]["model_loss"][key].append(
                data.get("model",      {}).get("cross_entropy_loss")
            )

            for k2, v2 in data.items():
                if k2.isdigit():
                    layer_i = int(k2)
                    metrics_by_config[config_key][f"layer_{layer_i}_r2"][key].append(v2)

        high_acc_counts[config_key][key]  += count_h
        total_run_counts[config_key][key] += count_t

# --- 2) Heatmap Helper & Axes ---
lr_vals = sorted({lr for cfg in metrics_by_config.values() for met in cfg.values() for lr, _ in met})
wd_vals = sorted({wd for cfg in metrics_by_config.values() for met in cfg.values() for _, wd in met})

def build_heatmap(data_dict):
    mat = np.full((len(wd_vals), len(lr_vals)), np.nan)
    for i, wd in enumerate(wd_vals):
        for j, lr in enumerate(lr_vals):
            val = data_dict.get((lr, wd), None)
            if isinstance(val, list):
                flt = [v for v in val if v is not None]
                if flt:
                    mat[i, j] = np.mean(flt)
            elif val is not None:
                mat[i, j] = val
    return mat

# --- 3) Collect & Render All Complete Plots ---
titles_and_data = []
for config_key, metric_dict in metrics_by_config.items():
    bs, nn, k, mlp, expected_layers = config_key
    info = detected_configs[config_key]

    # High Acc Count & Total Runs
    titles_and_data.append((
        f"High Acc Count >{threshold_acc}% (bs={bs}, nn={nn}, k={k}, mlp={mlp})",
        build_heatmap(high_acc_counts[config_key])
    ))
    titles_and_data.append((
        f"Total Runs (bs={bs}, nn={nn}, k={k}, mlp={mlp})",\
        build_heatmap(total_run_counts[config_key])
    ))

    # Layer Freqs
    for L in sorted(info["available_layers"]):
        key = f"layer_{L}_freqs"
        if key in metric_dict:
            titles_and_data.append((
                f"Layer {L} Avg. Freqs (bs={bs}, nn={nn}, k={k}, mlp={mlp})",
                build_heatmap(metric_dict[key])
            ))

    # Reconstruction Metrics: Accuracy & Loss
    for mname, label in [
        ("stored_acc",  "Stored-Fits Accuracy"),
        ("model_acc",   "Model Accuracy"),
        ("stored_loss", "Stored-Fits Loss"),
        ("model_loss",  "Model Loss")
    ]:
        if mname in metric_dict:
            titles_and_data.append((
                f"{label} (bs={bs}, nn={nn}, k={k}, mlp={mlp})",
                build_heatmap(metric_dict[mname])
            ))

    # Shared / Extra neuron counts for layer 2
    if "layer_2_shared_freqs" in metric_dict and "layer_2_extra_freqs" in metric_dict:
        s = build_heatmap(metric_dict["layer_2_shared_freqs"])
        e = build_heatmap(metric_dict["layer_2_extra_freqs"])
        vmax, vmin = np.nanmax([s, e]), 0
        titles_and_data.append((
            f"Layer 2 Shared Key Freqs Avg Neuron Count (bs={bs}, nn={nn}, k={k}, mlp={mlp})",
            (s, vmin, vmax)
        ))
        titles_and_data.append((
            f"Layer 2 Extra Key Freqs Avg Neuron Count (bs={bs}, nn={nn}, k={k}, mlp={mlp})",
            (e, vmin, vmax)
        ))
    if "layer_2_shared_freqs_max" in metric_dict and "layer_2_extra_freqs_max" in metric_dict:
        smax = build_heatmap(metric_dict["layer_2_shared_freqs_max"])
        emax = build_heatmap(metric_dict["layer_2_extra_freqs_max"])
        vmax, vmin = np.nanmax([smax, emax]), 0
        titles_and_data.append((
            f"Layer 2 Shared Key Freqs MAX Neuron Count (bs={bs}, nn={nn}, k={k}, mlp={mlp})",
            (smax, vmin, vmax)
        ))
        titles_and_data.append((
            f"Layer 2 Extra Key Freqs MAX Neuron Count (bs={bs}, nn={nn}, k={k}, mlp={mlp})",
            (emax, vmin, vmax)
        ))

html_chunks = []
for i in range(0, len(titles_and_data), 2):
    pair = titles_and_data[i:i+2]
    fig = make_subplots(rows=1, cols=len(pair), subplot_titles=[t[0] for t in pair])
    for j, (_, Z) in enumerate(pair, start=1):
        if isinstance(Z, tuple):
            mat, vmin, vmax = Z
        else:
            mat, vmin, vmax = Z, None, None
        fig.add_trace(
            go.Heatmap(
                z=mat,
                x=[f"{v:.2e}" for v in lr_vals],
                y=[f"{v:.2e}" for v in wd_vals],
                colorscale="viridis",
                zmin=vmin,
                zmax=vmax,
                colorbar=dict(
                    title=dict(text="Value", font=dict(size=24)),
                    tickfont=dict(size=24),
                    x=0.97,           # shift left from 1.0 toward 0.8–0.9 until it sits snug
                    xanchor="left",  # interpret x as the colorbar’s left edge
                    len=0.95          # (optional) shorten it a bit vertically
                )
            ),
            row=1, col=j
        )
        fig.update_xaxes(title_text="Learning Rate", row=1, col=j, title_font=dict(size=24), tickfont=dict(size=24))
        fig.update_yaxes(title_text="Weight Decay", row=1, col=j, title_font=dict(size=24), tickfont=dict(size=24))
    fig.update_layout(
        width=1200,
        height=600,
        showlegend=False,
        font=dict(size=24),
        title_font_size=24
    )
    html_chunks.append(pio.to_html(fig, full_html=False, include_plotlyjs='cdn'))

with open("heatmaps_analyses.html", "w") as f:
    f.write("<html><head><title>All Heatmaps</title></head><body>\n"
            + "\n<hr style='margin:40px 0;'>\n".join(html_chunks)
            + "\n</body></html>")
print("  Complete heatmaps generated → heatmaps_analyse.html")

# --- 4) Specialized Plots by Expected Layer Count ---
special_plots = {2:[], 3:[], 4:[], 5:[]}
for config_key, metric_dict in metrics_by_config.items():
    bs, nn, k, mlp, expected_layers = config_key
    info = detected_configs[config_key]
    if expected_layers not in special_plots:
        continue

    # High Acc Count
    special_plots[expected_layers].append((
        f"High Acc Count >{threshold_acc}% (bs={bs}, nn={nn}, k={k}, mlp={mlp})",
        build_heatmap(high_acc_counts[config_key])
    ))

    # Layer 1 & 2 Freqs
    for L in [1, 2]:
        key = f"layer_{L}_freqs"
        if key in metric_dict:
            special_plots[expected_layers].append((
                f"Layer {L} Avg. Freqs (bs={bs}, nn={nn}, k={k}, mlp={mlp})",
                build_heatmap(metric_dict[key])
            ))

    # Recon Metrics: Stored/Model Acc & Loss
    for mname, label in [
        ("stored_acc",  "Stored-Fits Accuracy"),
        ("model_acc",   "Model Accuracy"),
        ("stored_loss", "Stored-Fits Loss"),
        ("model_loss",  "Model Loss")
    ]:
        if mname in metric_dict:
            special_plots[expected_layers].append((
                f"{label} (bs={bs}, nn={nn}, k={k}, mlp={mlp})",
                build_heatmap(metric_dict[mname])
            ))

os.makedirs("heatmap_plots_here", exist_ok=True)
for n_layers, items in special_plots.items():
    html_chunks = []
    for i in range(0, len(items), 2):
        pair = items[i:i+2]
        fig = make_subplots(rows=1, cols=len(pair), subplot_titles=[p[0] for p in pair])
        for j, (_, Z) in enumerate(pair, start=1):
            fig.add_trace(
                go.Heatmap(
                    z=Z,
                    x=[f"{v:.2e}" for v in lr_vals],
                    y=[f"{v:.2e}" for v in wd_vals],
                    colorscale="viridis",
                    colorbar=dict(
                        title=dict(text="Value", font=dict(size=24)),
                        tickfont=dict(size=24),
                        x=0.97,           # shift left from 1.0 toward 0.8–0.9 until it sits snug
                        xanchor="left",  # interpret x as the colorbar’s left edge
                        len=0.95          
                    )
                ),
                row=1, col=j
            )
            fig.update_xaxes(title_text="Learning Rate", row=1, col=j, title_font=dict(size=24), tickfont=dict(size=24))
            fig.update_yaxes(title_text="L2 Regularization", title_standoff=2, row=1, col=j, title_font=dict(size=24), tickfont=dict(size=24))
        fig.update_layout(
            width=1200,
            height=600,
            showlegend=False,
            font=dict(size=24),
            title_font_size=24
        )
        html_chunks.append(pio.to_html(fig, full_html=False, include_plotlyjs='cdn'))

    out = os.path.join("heatmap_plots_here", f"{n_layers}_layer_heatmaps_logn.html")
    with open(out, "w") as f:
        f.write("<html><head><title>Specialized</title></head><body>\n"
                + "\n<hr style='margin:40px 0;'>\n".join(html_chunks)
                + "\n</body></html>")
    print(f"  Specialized heatmaps for {n_layers} layers → {out}")

# --- 5) Per-Base-Dir R² Heatmaps with ✓/●/✗ Overlays ---
for base_dir in base_dirs:
    base_name = os.path.basename(base_dir.rstrip("/"))
    layer_results = {}
    lr_set, wd_set = set(), set()

    for folder in os.listdir(base_dir):
        full_path = os.path.join(base_dir, folder)
        m = folder_pattern.match(folder)
        if not m or not os.path.isdir(full_path):
            continue
        wd, lr = float(m.group(6)), float(m.group(7))
        lr_set.add(lr); wd_set.add(wd)

        for jf in os.listdir(full_path):
            m2 = recon_file_regex.match(jf)
            if not m2 or int(m2.group(1)) != 1:
                continue
            data = json.load(open(os.path.join(full_path, jf)))
            sa = data.get('stored_fits', {}).get('accuracy', np.nan)
            ma = data.get('model',      {}).get('accuracy', np.nan)
            if sa > 1.0: sa /= 100.0
            if ma > 1.0: ma /= 100.0

            for k2, v2 in data.items():
                if k2.isdigit():
                    L = int(k2)
                    d = layer_results.setdefault(L, {})
                    e = d.setdefault((lr, wd), {"r2": [], "sa": [], "ma": []})
                    e["r2"].append(v2)
                    e["sa"].append(sa)
                    e["ma"].append(ma)

    all_lrs = sorted(lr_set)
    all_wds = sorted(wd_set)
    layers  = sorted(layer_results.keys())
    if not layers:
        continue

    fig = make_subplots(rows=1, cols=len(layers), subplot_titles=[f"Layer {L} out of {len(layers)} R²" for L in layers], vertical_spacing=0.01, horizontal_spacing=0.06)
    fig.update_annotations(font=dict(size=30))  
    for idx, L in enumerate(layers, start=1):
        Z = np.full((len(all_wds), len(all_lrs)), np.nan)
        for (lr, wd), e in layer_results[L].items():
            i, j = all_wds.index(wd), all_lrs.index(lr)
            Z[i, j] = np.nanmean(e["r2"])
        fig.add_trace(go.Heatmap(x=all_lrs, y=all_wds, z=Z, colorscale="viridis", coloraxis='coloraxis'),
                      row=1, col=idx)
        fig.update_xaxes(title_text="Learning rate", type="log", row=1, col=idx, title_font=dict(size=24), tickfont=dict(size=24), tickformat=".0e",tickangle=90)
        
        fig.update_yaxes(title_text="L2 regularization", title_standoff=2, type="log", row=1, col=idx, title_font=dict(size=24), tickfont=dict(size=24), tickformat=".0e")

        for (lr, wd), e in layer_results[L].items():
            sa_m = np.nanmean(e["sa"]); ma_m = np.nanmean(e["ma"])
            if sa_m == 1.0 and ma_m == 1.0:
                key = 'both'
            elif ma_m == 1.0 and sa_m <= overlay_threshold:
                key = 'test_only'
            else:
                key = 'neither'
            sym, col = symbol_map[key]
            fig.add_trace(
                go.Scatter(x=[lr], y=[wd], mode='text', text=[sym],
                           textfont=dict(color=col, size=18), showlegend=False, cliponaxis=True),
                row=1, col=idx
            )

    fig.update_layout(
    title_text=f"{base_name} Layers R² w/ Overlays",
    height=700,
    font=dict(size=24),
    title_font_size=24,
    paper_bgcolor='white',
    plot_bgcolor='white',
    coloraxis=dict(
        colorbar=dict(
            title=dict(text="R²", font=dict(size=24)),
            tickfont=dict(size=24),

            # shift the bar left from the right‐edge of the figure…
            x=0.996,           # try 0.92, 0.90, etc.
            xanchor="left",

            # optional: make it a bit shorter so it sits neatly
            len=0.8
        )
    ),
)
    outpath = os.path.join("heatmap_plots_here", f"{base_name}_layers_r2_logn.html")
    fig.write_html(outpath)
    print(f"  Saved R² overlays → {outpath}")

# --- 6) COMBINED FIGURE FOR one_embed_1…one_embed_4 WITH CUSTOM POSITIONS ---
from plotly.subplots import make_subplots

# 6.A) gather the first 4 embeds exactly as before
combined_dirs = base_dirs[:4]
all_results   = []
global_lrs    = set()
global_wds    = set()

for base_dir in combined_dirs:
    name    = os.path.basename(base_dir.rstrip("/"))
    layers  = {}    # layers[L] = { (lr,wd): {"r2":[], "sa":[], "ma":[]} }
    lr_set, wd_set = set(), set()

    for folder in os.listdir(base_dir):
        m = folder_pattern.match(folder)
        if not m: 
            continue
        wd_val, lr_val = float(m.group(6)), float(m.group(7))
        wd_set.add(wd_val); lr_set.add(lr_val)

        full_path = os.path.join(base_dir, folder)
        for jf in os.listdir(full_path):
            m2 = recon_file_regex.match(jf)
            if not m2 or int(m2.group(1))!=1:
                continue
            data = json.load(open(os.path.join(full_path, jf)))
            sa = data["stored_fits"].get("accuracy", np.nan)
            ma = data["model"].get("accuracy",      np.nan)
            if sa>1: sa /= 100
            if ma>1: ma /= 100

            for k2, v2 in data.items():
                if k2.isdigit():
                    L = int(k2)
                    slot = layers.setdefault(L, {})\
                                 .setdefault((lr_val, wd_val), {"r2":[], "sa":[], "ma":[]})
                    slot["r2"].append(v2)
                    slot["sa"].append(sa)
                    slot["ma"].append(ma)

    all_results.append((name, layers))
    global_lrs |= lr_set
    global_wds |= wd_set

global_lrs = sorted(global_lrs)
global_wds = sorted(global_wds)

# 6.B) where each embed’s layers should live
layout_map = {
    1: (1, [1]),        # one_embed_1 → row 1, col 1
    2: (1, [3,4]),      # one_embed_2 → row 1, cols 3–4
    3: (2, [1,2,3]),    # one_embed_3 → row 2, cols 1–3
    4: (3, [1,2,3,4])   # one_embed_4 → row 3, cols 1–4
}

# 6.C) build a flat subplot_titles list of length 12
titles = [""] * 12
for embed_idx, (_, layers) in enumerate(all_results, start=1):
    row, cols = layout_map[embed_idx]
    sorted_L = sorted(layers)
    for j, L in enumerate(sorted_L):
        col = cols[j]
        idx = (row - 1)*4 + (col - 1)
        titles[idx] = f"Layer {L} out of {len(sorted_L)} R²"

# 6.D) create the 3×4 grid with those titles
fig = make_subplots(
    rows=3, cols=4,
    subplot_titles=titles,
    horizontal_spacing=0.05,
    vertical_spacing=0.1125
)
fig.for_each_annotation(lambda ann: ann.update(font=dict(size=18)))


# 6.E) fill each subplot with a heatmap + overlays
for embed_idx, (_, layers) in enumerate(all_results, start=1):
    row, cols = layout_map[embed_idx]
    for j, L in enumerate(sorted(layers)):
        col = cols[j]

        # build Z-matrix
        Z = np.full((len(global_wds), len(global_lrs)), np.nan)
        for (lr_val, wd_val), e in layers[L].items():
            i = global_wds.index(wd_val)
            k = global_lrs.index(lr_val)
            Z[i, k] = max(np.nanmean(e["r2"]), 0)


        # add heatmap
        fig.add_trace(
            go.Heatmap(
                z=Z,
                x=global_lrs,
                y=global_wds,
                colorscale="viridis",
                coloraxis="coloraxis"
            ),
            row=row, col=col
        )

        # add ✓/●/✗ overlays
        for (lr_val, wd_val), e in layers[L].items():
            sa_m = np.nanmean(e["sa"])
            ma_m = np.nanmean(e["ma"])
            if sa_m == 1.0 and ma_m == 1.0:
                sym, col_sym = symbol_map['both']
            elif ma_m == 1.0 and sa_m <= overlay_threshold:
                sym, col_sym = symbol_map['test_only']
            else:
                sym, col_sym = symbol_map['neither']

            fig.add_trace(
                go.Scatter(
                    x=[lr_val], y=[wd_val],
                    mode="text", text=[sym],
                    textfont=dict(color=col_sym, size=18),
                    showlegend=False, cliponaxis=True
                ),
                row=row, col=col
            )

        # log axis + tilted labels
        fig.update_xaxes(
            type="log", tickformat=".0e", tickangle=90,
            row=row, col=col
        )
        fig.update_yaxes(
            type="log", tickformat=".0e",
            row=row, col=col
        )
# 6.G) Manual legend in row=1, col=2
# — 1) hide the axes in the empty slot (row=1,col=2) —
fig.update_xaxes(visible=False, row=1, col=2)
fig.update_yaxes(visible=False, row=1, col=2)

# — 2) draw a manual legend in paper‐coords over col2,row1 —
legend_items = [
    ("Fit and test accuracy 100% (✓)", "green"),
    ("Test 100% and fit accuracy < 100% (●)",   "magenta"),
    ("Neither accuracy 100% (✗)",     "red")
]

# approx center of col2 in paper‐coords
x0 = 0.24
# approx y start near top of row1
y0 = 0.9525   
dy = 0.05   # vertical spacing between lines

for i, (label, color) in enumerate(legend_items):
    fig.add_annotation(
        x=x0, y=y0 - i*dy,
        xref="paper", yref="paper",
        text=f"<span style='color:{color}; font-size:19px'>{label}</span>",
        showarrow=False,
        align="left"
    )


fig.update_xaxes(
    title_text="Learning rate",
    title_font=dict(size=17),
    title_standoff=0,
    tickfont=dict(size=17))

# — give every y‐axis the same title —
fig.update_yaxes(
    title_text="L2 regularization",
    title_standoff=2,
    title_font=dict(size=17))

# 6.F) final layout tweaks & save
fig.update_layout(
    width=1600, height=900,
    paper_bgcolor='white', plot_bgcolor='white',
    font=dict(size=18), title_font_size=24,
    title_text="1, 2, 3 and 4 layer MLPs - showing R² of our fit as a function of depth",
    coloraxis=dict(
    colorbar=dict(
        title=dict(text="R²", font=dict(size=18)),
        tickfont=dict(size=18),
        # move into the empty subplot at row=2,col=4:
        x=0.87375,            # left edge of col 4
        xanchor="left",
        y=0.6292,            # top edge of row 2
        yanchor="top",
        len=0.2583,          # height = one row
        thickness=40         # adjust bar width as you like
    )
)
)

fig.write_html("figure_4.html")
print("  Saved combined figure → figure_4.html")
