from nesim.utils.json_stuff import load_json_as_dict, dict_to_json
from nesim.utils.figure.figure_1 import apply_ratan_matplotlib_thing
import matplotlib.pyplot as plt

def plot_accuracy_vs_smoothness_with_tau(data, colors, filename):
    apply_ratan_matplotlib_thing()
    # Extracting smoothness, accuracy values, and tau
    models = list(data.keys())
    smoothness = [data[model]['smoothness'] for model in models]
    accuracy = [data[model]['accuracy'] for model in models]
    tau_values = [data[model].get('tau', 'baseline') for model in models]  # default to 'baseline' if tau is missing
    
    # Creating the plot
    fig = plt.figure(figsize=(10, 6))
    
    # Plot each point with the corresponding tau color
    for i, model in enumerate(models):
        tau = tau_values[i]
        color = colors.get(tau, "#000000")  # Use the color from the dictionary or black as default
        plt.scatter(smoothness[i], accuracy[i], color=color, label=model if tau == 'baseline' else None, s=100)  # Make baseline model's label visible
    
    # Adding labels and title
    plt.xlabel('Smoothness', fontsize=14)
    plt.ylabel('Accuracy', fontsize=14)
    plt.title('Accuracy vs. Smoothness with Tau Values', fontsize=16)
    
    # Adding model names as annotations
    for i, model in enumerate(models):
        plt.annotate(model, (smoothness[i], accuracy[i]), textcoords="offset points", xytext=(0,5), ha='center', fontsize=9)

    # Displaying grid and showing the plot
    # plt.grid(True)
    plt.tight_layout()
    plt.show()
    fig.savefig(filename)


colors = { 
    50.0: "#FEE0D2",
    20.0: "#FEE0D2",
    10: "#FEE0D2", 
    5: "#FF9966", 
    1: "#FB6A4A", 
    0.5: "#CB181D", 
    "baseline": "#000000" 
}
acc_topo_resnet18 = load_json_as_dict(
    filename="/home/XXXX-4/repos/nesim/experiments/imagenet/resnet18/validation_acc/results.json"
)
smoothness_data = load_json_as_dict("../../resnet18/pouya_smoothness/smoothness_values.json")
data = {
    "TDANN":   {
        "smoothness": 0.8160919540229886,
        "accuracy": 0.439
    },
    "LLCNN-G": {
        "smoothness": 0.7931034482758621,
        "accuracy": 0.53
    },
}
our_model_accuracies = {}

for name in acc_topo_resnet18:
    if name.endswith("end_topo"):
        continue
    for acc, tau in zip(acc_topo_resnet18[name]["acc"], acc_topo_resnet18[name]["tau"]):

        if tau == "Baseline":
            label = "baseline"
        else:
            if tau > 1:
                tau = int(tau)
            if name.endswith("end_topo"):
                label = f"End topo\n$\\tau$ = {tau}"
            elif name.endswith("all_topo"):
                label = f"All topo\n$\\tau$ = {tau}"
            else:
                raise ValueError(f"Invalid name: {name}, it should be one of: {acc_topo_resnet18.keys()}")
        
        smoothness = smoothness_data[label]
        data[label] = {
            "smoothness": smoothness,
            "accuracy": acc,
            "tau": tau
        }

resnet50_data = load_json_as_dict("results.json")
plot_accuracy_vs_smoothness_with_tau(
    data=data,
    colors=colors,
    filename="acc_vs_smoothness.pdf",
    resnet50_data=resnet50_data
)

