# %% Setup libs
from collections import defaultdict

import matplotlib.pyplot as plt
import mlflow
import numpy as np
import pandas as pd
import scienceplots
import seaborn as sns
import yaml
from scipy.stats import t

# %% Setup env
with open("./scripts/paper/utils/dss.yml", encoding="utf-8") as f:
    ds_info = yaml.safe_load(f)["dataset_info"]

with open("./scripts/paper/utils/models.yml", encoding="utf-8") as f:
    models_info = yaml.safe_load(f)


# %% Retrieve eval data
df_f1 = pd.read_csv("./scripts/paper/utils/f1_scores.csv", index_col=0)
df_acc = pd.read_csv("./scripts/paper/utils/acc_scores.csv", index_col=0)


# %% tbl_03_main_results


# --- Group columns by task ---
task_map = {k: v["task"] for k, v in ds_info.items() if k in df_f1.columns}
df_f1 = df_f1.rename(columns=task_map)

# --- Aggregate by task ---
df_mu = df_f1.T.groupby(df_f1.columns, sort=False).mean().T
df_std = df_f1.T.groupby(df_f1.columns, sort=False).std().T

df_mu["avg_f1"] = df_f1.mean(axis=1)
df_std["avg_f1"] = df_f1.std(axis=1)


# --- Acc scores ---
df_mu["avg_acc"] = df_acc.mean(axis=1)
df_std["avg_acc"] = df_acc.std(axis=1)

# --- LaTeX formatting ---
latex_body = df_mu.copy()
for col in latex_body.columns:
    latex_body[col] = [f"{mu:.2f} ({std:.2f})" for mu, std in zip(df_mu[col], df_std[col], strict=False)]

for model, row in latex_body.iterrows():
    cells = " & ".join(row)
    print(f"{model:<26} & {cells} \\\\")


# %% fig_02_nli_performance_comparison

# prepare data
df_plot = (
    df_f1.loc[models_info["in_house_nli_models"]]
    .assign(model_type=np.resize(["base", "large", "large (triplet)"], len(models_info["in_house_nli_models"])))
    .melt(var_name="task", value_name="f1", id_vars="model_type")
)

# consistent hue mapping
task_order = df_plot["task"].unique()
palette = dict(zip(task_order, sns.color_palette("Set2", len(task_order)), strict=False))

plt.style.use(["science", "no-latex"])
fig, ax = plt.subplots(figsize=(6, 4))

# raw points
sns.stripplot(
    data=df_plot,
    x="model_type",
    y="f1",
    hue="task",
    hue_order=task_order,
    jitter=0.25,
    dodge=False,
    alpha=0.7,
    linewidth=0.5,
    edgecolor="gray",
    palette=palette,
    ax=ax,
)

# group-wise medians
med = df_plot.groupby(["model_type", "task"], sort=False)["f1"].median().reset_index()

sns.scatterplot(
    data=med,
    x="model_type",
    y="f1",
    hue="task",
    hue_order=task_order,
    palette=palette,
    marker="D",
    s=120,
    edgecolor="black",
    legend=False,
    zorder=3,
    ax=ax,
)

ax.set_ylabel("F1 score")
ax.set_xlabel("")
ax.legend(title="Task", bbox_to_anchor=(1.02, 1), loc="upper left", frameon=False)
ax.set_facecolor("none")
fig = plt.gcf()
fig.patch.set_facecolor("none")
sns.despine(trim=True)
plt.tight_layout()

plt.savefig("./paper/figs/fig_02_nli_performance_comparison.pdf", format="pdf", bbox_inches="tight")
