from nesim.utils.json_stuff import load_json_as_dict
import matplotlib.pyplot as plt
from nesim.utils.figure.figure_1 import apply_ratan_matplotlib_thing

results = load_json_as_dict("results.json")

def plot_tau_acc(models_dict, other_models_dict = {}, fontsize=17):
    apply_ratan_matplotlib_thing()
    fig, ax = plt.subplots(figsize=(8, 6))

    for model_name, data in models_dict.items():
        tau = data['tau']
        acc = data['acc']
        
        # Convert "Baseline" to a string for labeling purposes, otherwise keep numeric values
        tau_labels = [str(t) if t == "Baseline" else f"$\\tau = {t}$" for t in tau]
        
        ax.plot(tau_labels, acc, marker='o', label=model_name)
    
    for model in other_models_dict:
        ax.axhline(y=other_models_dict[model]["acc"], label=model, c=other_models_dict[model]["color"], linestyle="--")

    yvals = [0.25, 0.5, 0.75, 1]
    ax.set_yticks(yvals)
    ax.set_yticklabels(yvals, fontsize=fontsize)
    ax.set_xticklabels(
        tau_labels,
        fontsize=fontsize
    )
    ax.set_ylim(0.20, 0.85)
    ax.set_xlabel('Tau', fontsize=fontsize)
    ax.set_ylabel('Accuracy', fontsize=fontsize)
    ax.set_title('Tau vs Accuracy for Different Models', fontsize=fontsize)

    # Disable top and right spines
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

    # Add grid with custom styling
    ax.grid(True, linestyle='--', linewidth=0.7, alpha=0.7)

    ax.legend(loc="best")
    plt.tight_layout()
    plt.show()
    fig.savefig("result.pdf")

other_models_dict = {
    "TDANN": {
        "acc": 0.439,
        "color": "red",
    },
    "LLCNN-G": {
        "acc": 0.53,
        "color": "purple",
    }
}
plot_tau_acc(models_dict=results, other_models_dict=other_models_dict)