import os
import re
import json
from collections import defaultdict
import numpy as np
import plotly.graph_objs as go
import plotly.io as pio
pio.kaleido.scope.mathjax = None
base_dirs = [
    "/multilayer_figure_1_one_hot/no_embed_1-freqs_59-/freq_distribution_mlp=no_embed_1_p=59_bs=59_k=59_nn=1024_wd=1e-05_lr=0.00075",
    "/multilayer_figure_1_one_hot/no_embed_1-freqs_60-/freq_distribution_mlp=no_embed_1_p=60_bs=60_k=59_nn=1024_wd=1e-05_lr=0.00075",
    "/multilayer_figure_1_one_hot/no_embed_1-freqs_61-/freq_distribution_mlp=no_embed_1_p=61_bs=61_k=59_nn=1024_wd=1e-05_lr=0.00075",
    "/multilayer_figure_1_one_hot/no_embed_1-freqs_62-/freq_distribution_mlp=no_embed_1_p=62_bs=62_k=59_nn=1024_wd=1e-05_lr=0.00075",
    "/multilayer_figure_1_one_hot/no_embed_1-freqs_63-/freq_distribution_mlp=no_embed_1_p=63_bs=63_k=59_nn=1024_wd=1e-05_lr=0.00075",
    "/multilayer_figure_1_one_hot/no_embed_1-freqs_64-/freq_distribution_mlp=no_embed_1_p=64_bs=64_k=59_nn=1024_wd=1e-05_lr=0.00075",
    "/multilayer_figure_1_one_hot/no_embed_1-freqs_65-/freq_distribution_mlp=no_embed_1_p=65_bs=65_k=59_nn=1024_wd=1e-05_lr=0.00075",
    "/multilayer_figure_1_one_hot/no_embed_1-freqs_66-/freq_distribution_mlp=no_embed_1_p=66_bs=66_k=59_nn=1024_wd=1e-05_lr=0.00075",
    "/multilayer_figure_1_one_hot/no_embed_2-freqs_59-/freq_distribution_mlp=no_embed_2_p=59_bs=59_k=59_nn=1024_wd=1e-05_lr=0.00075",
    "/multilayer_figure_1_one_hot/no_embed_2-freqs_60-/freq_distribution_mlp=no_embed_2_p=60_bs=60_k=59_nn=1024_wd=1e-05_lr=0.00075",
    "/multilayer_figure_1_one_hot/no_embed_2-freqs_61-/freq_distribution_mlp=no_embed_2_p=61_bs=61_k=59_nn=1024_wd=1e-05_lr=0.00075",
    "/multilayer_figure_1_one_hot/no_embed_2-freqs_62-/freq_distribution_mlp=no_embed_2_p=62_bs=62_k=59_nn=1024_wd=1e-05_lr=0.00075",
    "/multilayer_figure_1_one_hot/no_embed_2-freqs_63-/freq_distribution_mlp=no_embed_2_p=63_bs=63_k=59_nn=1024_wd=1e-05_lr=0.00075",
    "/multilayer_figure_1_one_hot/no_embed_2-freqs_64-/freq_distribution_mlp=no_embed_2_p=64_bs=64_k=59_nn=1024_wd=1e-05_lr=0.00075",
    "/multilayer_figure_1_one_hot/no_embed_2-freqs_65-/freq_distribution_mlp=no_embed_2_p=65_bs=65_k=59_nn=1024_wd=1e-05_lr=0.00075",
    "/multilayer_figure_1_one_hot/no_embed_2-freqs_66-/freq_distribution_mlp=no_embed_2_p=66_bs=66_k=59_nn=1024_wd=1e-05_lr=0.00075",
    "/multilayer_figure_1_one_hot/no_embed_3-freqs_59-/freq_distribution_mlp=no_embed_3_p=59_bs=59_k=59_nn=1024_wd=1e-05_lr=0.00075",
    "/multilayer_figure_1_one_hot/no_embed_3-freqs_60-/freq_distribution_mlp=no_embed_3_p=60_bs=60_k=59_nn=1024_wd=1e-05_lr=0.00075",
    "/multilayer_figure_1_one_hot/no_embed_3-freqs_61-/freq_distribution_mlp=no_embed_3_p=61_bs=61_k=59_nn=1024_wd=1e-05_lr=0.00075",
    "/multilayer_figure_1_one_hot/no_embed_3-freqs_62-/freq_distribution_mlp=no_embed_3_p=62_bs=62_k=59_nn=1024_wd=1e-05_lr=0.00075",
    "/multilayer_figure_1_one_hot/no_embed_3-freqs_63-/freq_distribution_mlp=no_embed_3_p=63_bs=63_k=59_nn=1024_wd=1e-05_lr=0.00075",
    "/multilayer_figure_1_one_hot/no_embed_3-freqs_64-/freq_distribution_mlp=no_embed_3_p=64_bs=64_k=59_nn=1024_wd=1e-05_lr=0.00075",
    "/multilayer_figure_1_one_hot/no_embed_3-freqs_65-/freq_distribution_mlp=no_embed_3_p=65_bs=65_k=59_nn=1024_wd=1e-05_lr=0.00075",
    "/multilayer_figure_1_one_hot/no_embed_3-freqs_66-/freq_distribution_mlp=no_embed_3_p=66_bs=66_k=59_nn=1024_wd=1e-05_lr=0.00075",
    "/multilayer_figure_1_one_hot/no_embed_4-freqs_59-/freq_distribution_mlp=no_embed_4_p=59_bs=59_k=59_nn=1024_wd=1e-05_lr=0.00075",
    "/multilayer_figure_1_one_hot/no_embed_4-freqs_60-/freq_distribution_mlp=no_embed_4_p=60_bs=60_k=59_nn=1024_wd=1e-05_lr=0.00075",
    "/multilayer_figure_1_one_hot/no_embed_4-freqs_61-/freq_distribution_mlp=no_embed_4_p=61_bs=61_k=59_nn=1024_wd=1e-05_lr=0.00075",
    "/multilayer_figure_1_one_hot/no_embed_4-freqs_62-/freq_distribution_mlp=no_embed_4_p=62_bs=62_k=59_nn=1024_wd=1e-05_lr=0.00075",
    "/multilayer_figure_1_one_hot/no_embed_4-freqs_63-/freq_distribution_mlp=no_embed_4_p=63_bs=63_k=59_nn=1024_wd=1e-05_lr=0.00075",
    "/multilayer_figure_1_one_hot/no_embed_4-freqs_64-/freq_distribution_mlp=no_embed_4_p=64_bs=64_k=59_nn=1024_wd=1e-05_lr=0.00075",
    "/multilayer_figure_1_one_hot/no_embed_4-freqs_65-/freq_distribution_mlp=no_embed_4_p=65_bs=65_k=59_nn=1024_wd=1e-05_lr=0.00075",
    "/multilayer_figure_1_one_hot/no_embed_4-freqs_66-/freq_distribution_mlp=no_embed_4_p=66_bs=66_k=59_nn=1024_wd=1e-05_lr=0.00075"
]

more_base_dirs = [
    "multilayer-one_embed_1-freqs_59-",
    "multilayer-one_embed_1-freqs_60-",
    "multilayer-one_embed_1-freqs_61-",
    "multilayer-one_embed_1-freqs_62-",
    "multilayer-one_embed_1-freqs_63-",
    "multilayer-one_embed_1-freqs_64-",
    "multilayer-one_embed_1-freqs_65-",
    "multilayer-one_embed_1-freqs_66-",
]

# Combine all directories
all_dirs = base_dirs + more_base_dirs

# Friendly names for plot legend
pretty_model_names = {
    "no_embed_1": "one-hot 1 layer",
    "no_embed_2": "one-hot 2 layers",
    "no_embed_3": "one-hot 3 layers",
    "no_embed_4": "one-hot 4 layers",
    "one_embed_1": "1 embed 1 layer",  
}

# Pattern now strips optional 'multilayer-' prefix
model_pattern = re.compile(r"/(?:multilayer-)?([^/]+)-freqs_\d+-")
p_pattern = re.compile(r"p=(\d+)")
layer_pattern = re.compile(r"layer_(\d+)")

data = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))

# Parse paths
for base_dir in all_dirs:
    model_match = model_pattern.search(base_dir)
    if not model_match:
        print(f"⚠️ Skipping invalid path (no model match): {base_dir}")
        continue
    model = model_match.group(1)

    for root, dirs, files in os.walk(base_dir):
        for fname in files:
            if not fname.startswith("freq_distribution") or not fname.endswith(".json"):
                continue

            layer_match = layer_pattern.search(fname)
            if not layer_match:
                continue
            layer = int(layer_match.group(1))

            p_match = p_pattern.search(root)
            if not p_match:
                continue
            p = int(p_match.group(1))

            fpath = os.path.join(root, fname)
            try:
                with open(fpath, "r") as f:
                    json_data = json.load(f)
                freq_count = len(json_data)
                data[layer][model][p].append(freq_count)
            except Exception as e:
                print(f" Failed to read {fpath}: {e}")

# Build averaged data
averages = defaultdict(lambda: defaultdict(lambda: ([], [])))
for layer in sorted(data.keys()):
    for model in sorted(data[layer].keys()):
        for p in sorted(data[layer][model].keys()):
            counts = data[layer][model][p]
            if counts:
                avg = np.mean(counts)
                averages[layer][model][0].append(p)
                averages[layer][model][1].append(avg)

# Plot figures for layer 1 and 2
figs = []
for layer in [1, 2]:
    fig = go.Figure()
    for model in sorted(averages[layer].keys()):
        p_vals, counts = averages[layer][model]
        friendly_name = pretty_model_names.get(model, model)
        fig.add_trace(
            go.Scatter(x=p_vals, y=counts, mode="lines+markers", name=friendly_name)
        )

    fig.update_layout(
        plot_bgcolor='white',
        paper_bgcolor='white',
        title_text="Average # of frequencies in trained model",
        title_font_size=32,
        xaxis_title="n",
        xaxis_title_font_size=32,
        xaxis_tickfont_size=24,
        yaxis_title="avg # frequencies",
        yaxis_title_font_size=32,
        yaxis_tickfont_size=32,
        legend_font_size=18,
        width=800,
        height=400,
        margin=dict(t=48, l=0, r=0, b=0)  
    )
    # turn on grey, dotted grids
    fig.update_xaxes(
        showgrid=True,
        gridcolor='grey',
        gridwidth=1,
        griddash='dot'
    )
    fig.update_yaxes(
        showgrid=True,
        gridcolor='grey',
        gridwidth=1,
        griddash='dot'
    )
    figs.append(fig)

figs[0].write_image("figure1.pdf", format="pdf", engine="kaleido")
print("made figure1.pdf")
