# %%
from collections import defaultdict

import matplotlib.pyplot as plt
import mlflow
import numpy as np
import pandas as pd
import scienceplots
import yaml

from ml_utils.proxies import set_proxies

# %% Setup env
with open("./scripts/paper/utils/models.yml", encoding="utf-8") as f:
    models_info = yaml.safe_load(f)

URI = "azureml://db17de6f-6cb8-4996-849f-3fdf0d10a4b9.workspace.westeurope.api.azureml.ms/mlflow/v2.0/subscriptions/dec849c6-3664-4372-9569-749eb6820434/resourceGroups/rg-VLProd/providers/Microsoft.MachineLearningServices/workspaces/mlw-deeplearning1h7w"
JOB_NAME = "250801203355_aaai26_time"
JOB_NAME = "250803171753_aaai26_time"

# %% Retrieve eval data
with set_proxies():
    mlflow.set_tracking_uri(URI)
    run = mlflow.get_run(JOB_NAME)
metrics = run.data.metrics

# %%
df = pd.DataFrame.from_dict(metrics, orient="index")
df.columns = ["time"]
df = df.loc[models_info["model_to_type"].keys()]


df_f1 = pd.read_csv("./scripts/paper/utils/f1_scores.csv", index_col=0)

df["f1_score"] = df_f1.mean(axis=1)
df["time_inverse"] = 1 / df["time"]
df = df.iloc[3:, :]  # skip base models

# %% fig_04_performance_v_latency

# ---------- marker assignment ----------
marker_map = {
    "nli_cross_encoder": "X",
    "reranker": "p",
    "embedding_model": "*",
}

# ---------- figure ----------
plt.style.use(["science", "no-latex"])
fig, ax = plt.subplots(figsize=(7, 5))


for name, row in df.iterrows():
    mtype = models_info["model_to_type"].get(name)
    ax.scatter(
        row["time_inverse"],
        row["f1_score"],
        marker=marker_map[mtype],
        s=100,
        alpha=0.8,
        edgecolor="black",
        linewidth=0.4,
        label=mtype,
        zorder=3,
        color="black",
    )

# ---------- legend (deduplicated) ----------
# remove duplicate legend entries
handles, labels = ax.get_legend_handles_labels()
unique = {lab.replace("_", " "): h for h, lab in zip(handles, labels, strict=False)}
ax.legend(
    unique.values(),
    unique.keys(),
    title="Model",
    loc="lower center",
    bbox_to_anchor=(0.5, 1.05),
    ncol=len(unique),
    fontsize="small",
    title_fontsize="small",
    frameon=False,
)

# ---------- quadrant lines ----------
split_x = df["time_inverse"].median()
split_y = df["f1_score"].median()
ax.axvline(split_x, linestyle="--", linewidth=1, color="gray", zorder=1)
ax.axhline(split_y, linestyle="--", linewidth=1, color="gray", zorder=1)

# ---------- highlight best quadrant ----------
plt.draw()  # ensures limits are computed
xmax = ax.get_xlim()[1]
ymax = ax.get_ylim()[1]
rect = plt.Rectangle((split_x, split_y), xmax - split_x, ymax - split_y, color="tab:green", alpha=0.15, zorder=0)
ax.add_patch(rect)

# ---------- cosmetics ----------
ax.set_xlabel("1 / wall time")
ax.set_ylabel("F1 score")
ax.grid(alpha=0.3, linestyle=":")

ax.set_facecolor("none")
fig = plt.gcf()
fig.patch.set_facecolor("none")

plt.tight_layout()

plt.savefig("./paper/figs/fig_04_performance_v_latency.pdf", format="pdf", bbox_inches="tight")


# %%
best_mask = (df["time_inverse"] > split_x) & (df["f1_score"] > split_y)
best_models = df.index[best_mask].tolist()

print("Models in the top-right (best) quadrant:")
for m in best_models:
    print(f"{m}  |  type: {models_info['model_to_type'].get(m, 'Unknown')}")
