import matplotlib.pyplot as plt
import numpy as np
from nesim.utils.json_stuff import load_json_as_dict
import numpy as np
import matplotlib.pyplot as plt
from nesim.utils.figure.figure_1 import apply_ratan_matplotlib_thing

colors = {
    "baseline": "#FF5733",  # Red-Orange
    "topo_1": "#28A745",
    "topo_5": "#D98EFF",
    "topo_10": "#007BFF",
    "topo_50": "#8B4513"
}

labels = {
    "baseline": "Baseline",  # Red-Orange
    "topo_1": f"Topographic ($\\tau = {1}$)",
    "topo_5":  f"Topographic ($\\tau = {5}$)",
    "topo_10":  f"Topographic ($\\tau = {10}$)",
    "topo_50":  f"Topographic ($\\tau = {50}$)"
}

def generate_bar_plot(
    data, filename=None, width=15, height=10, fontsize=22, ticks_fontsize=18
):
    # Get the keys from the data dictionary
    keys = list(data.keys())

    # Extract layer names and scores for the first key
    layer_names = list(data[keys[0]].keys())
    first_scores = list(data[keys[0]].values())

    # Create index for x-axis ticks
    index = np.arange(len(layer_names))

    # Set the bar width
    bar_width = 0.8 / len(keys)
    apply_ratan_matplotlib_thing()

    fig = plt.figure(figsize=(width, height))

    # Create the first bar plot
    plt.bar(index, first_scores, bar_width, label=labels[keys[0]], color=colors["baseline"], alpha=0.7)

    # Create additional bar plots for the remaining keys
    for i in range(1, len(keys)):
        scores = list(data[keys[i]].values())
        plt.bar(index + i * bar_width, scores, bar_width, label=labels[keys[i]], alpha=0.7, color = colors[keys[i]])

    # Customize the plot
    # plt.xlabel("Layers", fontsize=fontsize)
    plt.ylabel("Effective Dimensionality", fontsize=fontsize)
    # plt.title(
    #     f"Effective dimensionality of intermediate encodings",
    #     fontsize=fontsize,
    # )
    plt.xticks(
        index + 0.5 * bar_width, layer_names, rotation=90, fontsize=ticks_fontsize
    )
    plt.yticks(
        fontsize=ticks_fontsize
    )
    plt.legend(fontsize=fontsize)
    plt.grid()
    plt.tight_layout()
    # Show the plot or save it to the specified filename
    if filename:
        fig.savefig(filename)
    else:
        plt.show()


run_names = [
    "baseline",
    "topo_1",
    "topo_5",
    "topo_10",
    "topo_50",
]
data = {}
for key in run_names:
    data[key] = load_json_as_dict(filename=f"results/{key}.json")


for key in data:
    for layer in list(data[key].keys()):
        block_index = layer.replace("transformer.", "").split(".")[1]
        data[key][f"Block: {block_index}"] = data[key][layer]
        del data[key][layer]

generate_bar_plot(data=data, filename="effective_dimensionality.pdf", width=20, height=14, ticks_fontsize=24, fontsize=28)
