# %% Setup libs
from collections import defaultdict

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scienceplots
import yaml
from scipy.stats import t

# %% Setup env
with open("./scripts/paper/utils/models.yml", encoding="utf-8") as f:
    models_info = yaml.safe_load(f)

df_f1 = pd.read_csv("./scripts/paper/utils/f1_scores.csv", index_col=0)
df_roc = pd.read_csv("./scripts/paper/utils/roc_scores.csv", index_col=0)


# %% fig_05_nli_vs_clf
df = pd.DataFrame(index=df_f1.index, columns=["nli_score", "clf_score"])
df["nli_score"] = df_roc.mean(axis=1)
df["clf_score"] = df_f1.mean(axis=1)
df = df.iloc[3:, :]  # skip base models

marker_styles = {
    "nli_cross_encoder": "X",
    "reranker": "p",
    "embedding_model": "*",
}

plt.style.use(["science", "no-latex"])
fig, ax = plt.subplots(figsize=(7, 5))

for model, (nli, f1) in df[["nli_score", "clf_score"]].iterrows():
    mtype = models_info["model_to_type"].get(model, "other")
    marker = marker_styles.get(mtype, "x")
    ax.scatter(
        nli,
        f1,
        marker=marker,
        s=100,
        edgecolors="black",
        alpha=0.8,
        label=mtype.replace("_", " "),
        color="black",
    )

# remove duplicate legend entries
handles, labels = ax.get_legend_handles_labels()
unique = {lab: h for h, lab in zip(handles, labels, strict=False)}
ax.legend(
    unique.values(),
    unique.keys(),
    title="Model",
    loc="lower center",
    bbox_to_anchor=(0.5, 1.02),
    ncol=len(unique),
    fontsize="small",
    title_fontsize="small",
)


ax.set_xlabel("NLI score")
ax.set_ylabel("CLF score")
ax.set_facecolor("none")
fig = plt.gcf()
fig.patch.set_facecolor("none")

plt.tight_layout()

plt.savefig("./paper/figs/fig_05_nli_vs_clf.pdf", format="pdf", bbox_inches="tight")
