#!/usr/bin/env python
"""
Creates a 4 × 4 grid with a lower-triangular block of heat-maps
(row 1→1 plot, row 2→2, …, row 4→4) sharing a single colour-bar,
and saves it to heatmap_plots_here/layer_freqs_grid.pdf
"""
import os, json, numpy as np
from glob import glob
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.io as pio        # needs kaleido ≥ 0.2.1 for PDF export
pio.kaleido.scope.mathjax = None
# ------------------------------------------------------------------
# 1. directories & helper ------------------------------------------
BASE_DIRS = [
    "multilayer_heatmaps_2/one_embed_1-freqs_59-",
    "multilayer_heatmaps_2/one_embed_2-freqs_59-",
    "multilayer_heatmaps_2/one_embed_3-freqs_59-",
    "multilayer_heatmaps_2/one_embed_4-freqs_59-",
]
FREQ_FILE = "freq_distribution_layer_{L}_top-k_1_seed_*.json"

def average_freq_matrix(layer_dir: str, layer_idx: int):
    """Return a 2-D ‘weight-decay × lr’ matrix of avg. key counts for one layer."""
    lr_set, wd_set, tmp = set(), set(), {}
    for run in os.listdir(layer_dir):
        full = os.path.join(layer_dir, run)
        if not os.path.isdir(full):
            continue
        *_, wd_str, lr_str = run.split("_")
        wd, lr = float(wd_str.split("=")[1]), float(lr_str.split("=")[1])
        lr_set.add(lr); wd_set.add(wd)

        hits = []
        for jf in glob(os.path.join(full, FREQ_FILE.format(L=layer_idx))):
            with open(jf) as f:
                d = json.load(f)
            hits.append(len(d))
        if hits:
            tmp[(wd, lr)] = np.mean(hits)

    lrs, wds = sorted(lr_set), sorted(wd_set)
    Z = np.full((len(wds), len(lrs)), np.nan)
    for (wd, lr), v in tmp.items():
        Z[wds.index(wd), lrs.index(lr)] = v
    return Z, lrs, wds

# ------------------------------------------------------------------
# 2. build triangular specs & subplot titles -----------------------
specs, titles = [], []
for r in range(1, 5):                       # rows 1‥4
    specs.append([{"type": "heatmap"}]*r + [None]*(4-r))
    titles += ["—"]*r                       # dummy; we’ll overwrite per-trace

fig = make_subplots(
    rows=4, cols=4,
    specs=specs,
    horizontal_spacing=0.07,
    vertical_spacing=0.08,
    subplot_titles=titles
)
fig.for_each_annotation(lambda a: a.update(font=dict(size=18)))

# ------------------------------------------------------------------
# 3. add heat-maps --------------------------------------------------
all_Zs = []
for row_idx, base in enumerate(BASE_DIRS, start=1):
    for layer in range(1, row_idx+1):
        Z, xs, ys = average_freq_matrix(base, layer)
        # get rid of the smallest learning rate and smallest weight decay rows and columns so the color scheme can show
        # that in areas where the network learns the number of frequencies found in deep nets is slightly lower than shallow nets
        # (since the boundaries of where the hyperparameters work cause it to not learn)
        if Z.shape[1] > 1:
            Z = Z[:, 1:]
            xs = xs[1:]
        if Z.shape[0] > 1:
            Z = Z[1:, :]
            ys = ys[1:]
        all_Zs.append(Z)                    # collect for global vmin/vmax

vmin = np.nanmin([np.nanmin(z) for z in all_Zs])
vmax = np.nanmax([np.nanmax(z) for z in all_Zs])

heat_i = 0
for row_idx, base in enumerate(BASE_DIRS, start=1):
    base_lbl = os.path.basename(base.rstrip("-"))
    for layer in range(1, row_idx+1):
        heat_i += 1
        Z, xs, ys = average_freq_matrix(base, layer)
        if Z.shape[1] > 1:
            Z = Z[:, 1:]
            xs = xs[1:]
        if Z.shape[0] > 1:
            Z = Z[1:, :]
            ys = ys[1:]
        col_idx = layer                     # triangular: col = layer number
        fig.add_trace(
            go.Heatmap(
                z=Z,
                x=[f"{v:.0e}" for v in xs],
                y=[f"{v:.0e}" for v in ys],
                colorscale="viridis",
                zmin=vmin, zmax=vmax,
                coloraxis="coloraxis"
            ),
            row=row_idx, col=col_idx
        )
        # pretty labels
        title = f"layer {layer} average frequencies"
        fig.layout.annotations[heat_i-1].update(text=title)
        fig.update_xaxes(title_text="Learning rate",  tickangle=90,
                         row=row_idx, col=col_idx)
        fig.update_yaxes(title_text="Weight decay",   row=row_idx, col=col_idx)

# ------------------------------------------------------------------
# 4. overall aesthetics & save -------------------------------------
fig.update_layout(
    height=1700, width=1600,
    title_text="Layer-wise average frequency heat-maps (1 → 4 layers)",
    font=dict(size=18),
    coloraxis=dict(
        colorbar=dict(
            title="Average number frequencies", titleside="right",
            len=0.82, y=0.5, thickness=25
        )
    ),
    showlegend=False,
    plot_bgcolor="white",
    margin=dict(l=40, r=80, t=80, b=40)
)

out_dir = "refactored/heatmap_plots_here"
os.makedirs(out_dir, exist_ok=True)
out_pdf = os.path.join(out_dir, "layer_freqs_grid.pdf")
pio.write_image(fig, out_pdf)
print(f"✓ saved → {out_pdf}")

# ------------------------------------------------------------------
# 5. bar chart of adding layers ------------------------------------

# x‑axis: number of layers in the network
layer_counts = list(range(1, len(BASE_DIRS) + 1))
y_vals = []
chosen_lrs = []
chosen_wds = []

for row_idx, base in enumerate(BASE_DIRS, start=1):
    # take the “final” layer of each network
    Z, lrs, wds = average_freq_matrix(base, row_idx)

    # drop smallest LR (first column) and smallest WD (first row)
    if Z.shape[1] > 1:
        Z = Z[:, 1:]
        lrs = lrs[1:]
    if Z.shape[0] > 1:
        Z = Z[1:, :]
        wds = wds[1:]

    # pick the entry at (second‑largest LR, fourth‑largest WD)
    if len(lrs) < 2 or len(wds) < 4:
        raise ValueError(f"Not enough hyperparameter values: got {len(lrs)} LRs, {len(wds)} WDs")
    i_lr = -2   # second‑largest LR
    i_wd = -4   # fourth‑largest WD

    lr_val = lrs[i_lr]
    wd_val = wds[i_wd]
    val = Z[i_wd, i_lr]

    chosen_lrs.append(lr_val)
    chosen_wds.append(wd_val)
    y_vals.append(val)

# PRINT the raw hyperparams + values
print("At (2nd‑largest LR, 4th‑largest WD) for each network:")
for n_layers, lr, wd, val in zip(layer_counts, chosen_lrs, chosen_wds, y_vals):
    print(f"  {n_layers} layer(s): LR = {lr:.0e}, WD = {wd:.0e} → avg freqs = {val:.3f}")

# build and save a Plotly bar chart
bar_fig = go.Figure()
bar_fig.add_trace(go.Bar(x=layer_counts, y=y_vals))
bar_fig.update_xaxes(tick0=0, dtick=1)
bar_fig.update_layout(
    height=1000, width=1000,
    title_text="(a+b) mod59: average freqs at LR = 1e-03, WD = 3e-04 vs. # layers",
    xaxis_title="# layers",
    yaxis_title="Average number frequencies",
    font=dict(size=18),
    plot_bgcolor="white",
    margin=dict(l=40, r=40, t=60, b=40)
)

bar_out = os.path.join(out_dir, "bar_chart_of_adding_layers.pdf")
pio.write_image(bar_fig, bar_out)
print(f"✓ saved → {bar_out}")