import matplotlib.pyplot as plt
from nesim.utils.json_stuff import load_json_as_dict
from nesim.utils.figure.figure_1 import apply_ratan_matplotlib_thing

colors = { 
    50.0: "#FDD0C7",
    10: "#FEE0D2", 
    5: "#FF9966", 
    1: "#FB6A4A", 
    0.5: "#CB181D", 
    "baseline": "#000000" 
}
# colors = { 
#     50.0: "#66B2A6",  # Stronger pastel green
#     10: "#4DA6D8",    # Rich teal
#     5: "#0099CC",     # Vibrant blue
#     1: "#007ACC",     # Deep pastel blue
#     0.5: "#005DA2",   # Darker blue
#     "baseline": "#888888"  # Neutral grey for baseline
# }

all_layer_names = load_json_as_dict(
    "/home/XXXX-4/repos/nesim/training/imagenet/resnet18/layer_names.json"
)
def plot_effective_dimensionality(data, filename="result.png"):
    apply_ratan_matplotlib_thing()
    fontsize = 18
    
    fig = plt.figure()

    for model_name in data:
        adversarial_accuracies = [
            x["robustness"] for x in data[model_name]["l1"]
        ]
        parameter_counts = [
            x["parameters"] for x in data[model_name]["l1"]
        ]
        plt.plot(
            adversarial_accuracies,
            label = model_name,
            marker = "o"
        )
        plt.xticks(
            range(len(parameter_counts)),
            parameter_counts
        )
    plt.grid()
    plt.legend()
    fig.savefig(
        filename
    )

data = load_json_as_dict("results.json")
plot_effective_dimensionality(data=data, filename="combined_topo.png")
