from nesim.utils.json_stuff import load_json_as_dict, dict_to_json
import matplotlib.pyplot as plt
from nesim.utils.figure.figure_1 import apply_ratan_matplotlib_thing

apply_ratan_matplotlib_thing()

results = load_json_as_dict("results.json")

model_names = list(results.keys())

extracted_smoothness_data = {}

plot_data = {
    "LLCNN-G": 0.7931034482758621,
    "LLCNN-MH": 0.9683908045977012,
    "LLCNN-S": 1.0258620689655173,
    "TDANN": 0.8160919540229886,
    "Resnet-18": 0.22126436781609188
}

for model_name in model_names:
    mean_smoothness = sum(results[model_name].values())/len(results[model_name].values())
    print(f"{model_name}: {mean_smoothness}")

    if model_name.startswith("baseline"):
        label = "baseline"
    elif model_name.startswith("all_topo"):
        tau = float(model_name.split("_")[3])
        if tau > 1:
            tau = int(tau)
        label = f"All topo\n$\\tau$ = {tau}"
    elif model_name.startswith("end_topo"):
        tau = float(model_name.split("_")[3])
        if tau > 1:
            tau = int(tau)
        label = f"End topo\n$\\tau$ = {tau}"
    else:
        raise ValueError('invalid model name!')

    plot_data[label] = mean_smoothness
    plot_data[model_name] = mean_smoothness

dict_to_json(
    dictionary=plot_data,
    filename="smoothness_values.json"
)
    
fig, ax = plt.subplots(figsize=(25, 5))  # Create a figure and axis
ax.bar(
    list(plot_data.keys()),
    list(plot_data.values()),
)

# Hide the top and right spines
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

# Set y-ticks to only 0 and 1
ax.set_yticks([0, 1])

# Increase the font size of x and y ticks
ax.tick_params(axis='both', which='major', labelsize=16)
plt.grid()
# Save the figure
fig.savefig("plot.pdf")