from nesim.utils.json_stuff import load_json_as_dict, dict_to_json
import matplotlib.pyplot as plt
from nesim.utils.figure.figure_1 import apply_ratan_matplotlib_thing
import numpy as np

apply_ratan_matplotlib_thing()

results = load_json_as_dict("results.json")

model_names = list(results.keys())

extracted_smoothness_data = {}

plot_data = {
    # "LLCNN-G": 0.7931034482758621,
    # "LLCNN-MH": 0.9683908045977012,
    # "LLCNN-S": 1.0258620689655173,
    # "TDANN": 0.8160919540229886,
    # "Resnet-18": 0.22126436781609188
}

for model_name in model_names:
    smoothness_values_single_model = np.array(list(results[model_name].values()))
    smoothness_values_single_model = smoothness_values_single_model[~np.isnan(smoothness_values_single_model)]
    mean_smoothness = sum(smoothness_values_single_model)/len(smoothness_values_single_model)
    print(f"{model_name}: {mean_smoothness}")

    if model_name.startswith("baseline"):
        label = "baseline"
    elif model_name.startswith("all_topo"):
        tau = float(model_name.split("_")[3])
        if tau > 1:
            tau = int(tau)
        label = f"All topo\n$\\tau$ = {tau}"
    elif model_name.startswith("end_topo"):
        tau = float(model_name.split("_")[3])
        if tau > 1:
            tau = int(tau)
        label = f"End topo\n$\\tau$ = {tau}"
    else:
        raise ValueError('invalid model name!')

    plot_data[label] = mean_smoothness
    plot_data[model_name] = mean_smoothness

dict_to_json(
    dictionary=plot_data,
    filename="smoothness_values.json"
)