# %% Setup libs
from collections import defaultdict

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import yaml
from scipy.stats import t

# %% Setup env
with open("./scripts/paper/utils/models.yml", encoding="utf-8") as f:
    models_info = yaml.safe_load(f)

df = pd.read_csv("./scripts/paper/utils/f1_scores.csv", index_col=0)


# %% fig_03_size_comparison


# ------------------------- 1. params & filter ----------------------
params = {**models_info["reranker_params"], **models_info["embedding_params"]}
df = df.loc[df.index.intersection(params)]

# ------------------------- 2. metadata ----------------------------
df["params"] = df.index.map(params)
df["type"] = np.where(df.index.isin(models_info["reranker_params"]), "rerankers", "embeddings")

# ------------------------- 3. bucketize ---------------------------
edges = [0, 100e6, 400e6, 1e9, np.inf]
labels = ["small", "medium", "large", "xl"]
df["bucket"] = pd.cut(df["params"], edges, labels=labels)

# ------------------------- 4. long format -------------------------
f1_cols = df.columns.difference(["params", "type", "bucket"])
scores = df[["type", "bucket", "params", *list(f1_cols)]].melt(id_vars=["type", "bucket", "params"], value_name="f1")

# ------------------------- 5. aggregate per bucket ----------------
g = scores.groupby(["type", "bucket"], observed=True).agg(
    mean_f1=("f1", "mean"), std_f1=("f1", "std"), n=("f1", "count"), params_mean=("params", "mean")
)

g["sem"] = g["std_f1"] / np.sqrt(g["n"])
g["ci95"] = t.ppf(0.975, g["n"] - 1) * g["sem"]

# ------------------------- 6. plotting ----------------------------
fig, ax = plt.subplots(figsize=(8, 5))
for tname, style in zip(["rerankers", "embeddings"], ["-o", "-s"], strict=False):
    sub = g.loc[tname].reset_index()
    ax.errorbar(
        sub["params_mean"] / 1e6,  # millions
        sub["mean_f1"],
        yerr=sub["ci95"],
        fmt=style,
        capsize=3,
        label=tname if tname == "rerankers" else "embedding models",
    )

ax.set_xscale("log")
ax.set_xlabel("parameters [millions, log scale]")
ax.set_ylabel("F1 score")
ax.grid(alpha=0.3, which="both")
ax.set_facecolor("none")
fig = plt.gcf()
fig.patch.set_facecolor("none")
ax.legend()
plt.tight_layout()

plt.savefig("./paper/figs/fig_03_size_comparison.pdf", format="pdf", bbox_inches="tight")
