import matplotlib.pyplot as plt
import numpy as np
from nesim.utils.json_stuff import load_json_as_dict


def generate_bar_plot(data, filename=None):
    # Extract layer names and scores for each key
    layer_names = list(data["ours"].keys())
    ours_scores = list(data["ours"].values())

    # Assuming "pretrained" and "eshed" keys exist in the data dictionary
    pretrained_scores = list(data.get("pretrained", {}).values())
    eshed_scores = list(data.get("eshed", {}).values())

    # Set the bar width
    bar_width = 0.25

    # Create index for x-axis ticks
    index = np.arange(len(layer_names))

    # Create the bar plots
    plt.bar(index, ours_scores, bar_width, label="ours", color="orange", alpha=0.7)
    plt.bar(
        index + bar_width,
        pretrained_scores,
        bar_width,
        label="pretrained",
        color="green",
        alpha=0.7,
    )
    plt.bar(
        index + 2 * bar_width,
        eshed_scores,
        bar_width,
        label="eshed",
        color="blue",
        alpha=0.7,
    )

    # Customize the plot
    plt.xlabel("Layers")
    plt.ylabel("Effective Dimensionality")
    plt.title("Post-Maxpool effective dimensionality\n of intermediate encodings")
    plt.xticks(index + bar_width, layer_names, rotation=45, fontsize=8)
    plt.legend()
    plt.grid()
    # plt.yscale('log')
    # Show the plot or save it to the specified filename
    if filename:
        plt.savefig(filename)
    else:
        plt.tight_layout()
        plt.show()


data = {
    "ours": load_json_as_dict(filename="results_ours.json"),
    "pretrained": load_json_as_dict(filename="results_pretrained.json"),
    "eshed": load_json_as_dict(filename="results_eshed.json"),
}

generate_bar_plot(data=data, filename="result.pdf")
