{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad59dd92",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.transforms import offset_copy\n",
    "\n",
    "import HALL_lib"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8bf87ecc",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rcParams['text.usetex'] = True\n",
    "plt.rcParams['text.latex.preamble'] = r'''\n",
    "\\usepackage{mathtools}\n",
    "% more packages here\n",
    "'''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9b600274",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ============== NEWDATASET ==============\n",
    "\n",
    "# cache_dir = \"cache-newdataset\"\n",
    "# maxResponsesPerPrompt = 500\n",
    "\n",
    "# model_names = {\n",
    "#     0 : \"Llama-8B\",\n",
    "#     1 : \"Phi-14B\",\n",
    "#     2 : \"Mistral-7B\",\n",
    "#     3 : \"Solar-11B\",\n",
    "#     4 : \"Gemma-9B\",\n",
    "#     5 : \"Qwen-15B\",\n",
    "#     6 : \"DeepSeek-7B\",\n",
    "#     7 : \"Yi-9B\"\n",
    "# }\n",
    "\n",
    "# ============== OLDDATASET ==============\n",
    "\n",
    "# cache_dir = \"cache-olddataset-20-22\"\n",
    "# maxResponsesPerPrompt=100\n",
    "\n",
    "# model_names = {\n",
    "#     0: 'Mistral-7B',\n",
    "#     1: 'Gemma-9B',\n",
    "#     2: 'Solar-11B',\n",
    "#     3: 'Phi-14B',\n",
    "#     4: 'Qwen-32B',\n",
    "#     5: 'Gemma-27B',\n",
    "#     6: 'Qwen-15B'\n",
    "# }\n",
    "\n",
    "# best_reg_lambda = 1.9\n",
    "\n",
    "# ============== OLDDATASET ==============\n",
    "\n",
    "embedding_label=\"response_embeddings\"\n",
    "cache_dir = \"cache-icml\"\n",
    "maxResponsesPerPrompt=150\n",
    "\n",
    "model_names = {\n",
    "    0: 'Mistral-7B',\n",
    "    1: 'Gemma-9B',\n",
    "    2: 'Solar-11B',\n",
    "    3: 'Phi-14B',\n",
    "    4: 'Qwen-15B',\n",
    "    5: 'Gemma-27B',\n",
    "    6: 'Qwen-32B',\n",
    "    7: \"DeepSeek-7B\",\n",
    "    8: \"Llama-8B\",\n",
    "    9: \"Yi-9B\"\n",
    "}\n",
    "\n",
    "best_reg_lambda = 1.2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a810ae8c",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_fn = \"../llm-hallucinations/dataset_icml.parquet\"\n",
    "\n",
    "df = HALL_lib.loadParquet(dataset_fn, unifyYears=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7f99912",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_order, model_rank = HALL_lib.build_model_size_order(model_names)\n",
    "\n",
    "false_premise = {pid : fp for pid, fp in df[['prompt_id', \"false_premise\"]].value_counts().keys()} if \"false_premise\" in df.columns else None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8dd570d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_fractions = [x/100 for x in range(5, 101, 5)] # [0.05, 0.1, 0.15, 0.2, 0.4, 0.7, 1.0]\n",
    "\n",
    "model_ids = list(model_names.keys())\n",
    "\n",
    "prompt_ids_by_model = {\n",
    "    mid: sorted(df[df[\"model_id\"] == mid][\"prompt_id\"].unique())\n",
    "    for mid in model_ids\n",
    "}\n",
    "\n",
    "results_lp = HALL_lib.run_full_label_propagation_study(\n",
    "    df=df,\n",
    "    model_ids=model_ids,\n",
    "    prompt_ids_by_model=prompt_ids_by_model,\n",
    "    projector_class=HALL_lib.FisherProjection,\n",
    "    projector_kwargs={\n",
    "        \"lambda_reg\": best_reg_lambda,\n",
    "        \"normalise\": True,\n",
    "        \"normalise_by_trace\": True,\n",
    "    },\n",
    "    train_fractions=train_fractions,\n",
    "    n_iter=10,\n",
    "    test_fraction=1/3,\n",
    "    n_splits=5,\n",
    "    ref_lambda_reg=None,\n",
    "    use_cache=True,\n",
    "    cache_dir=f\"{cache_dir}/LP-trainingsize\",\n",
    "    overwrite_cache=False,\n",
    "    logskip=True,\n",
    "    embedding_label=embedding_label\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bacd9ecb",
   "metadata": {},
   "outputs": [],
   "source": [
    "agg_lc = HALL_lib.aggregate_metric_over_prompts(\n",
    "    results_lp,\n",
    "    metric=\"f1\",\n",
    "    agg_prompts=True,\n",
    "    agg_models=False,\n",
    "    agg_train_frac=False\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "57a07096",
   "metadata": {},
   "source": [
    "## Plots"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8358648e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_learning_curves(\n",
    "    agg_df,\n",
    "    model_names,\n",
    "    metric=\"f1\",\n",
    "    ratio=(3, 2),\n",
    "    scale=3\n",
    "):\n",
    "    fig, ax = plt.subplots(figsize=[scale * x for x in ratio])\n",
    "\n",
    "    for mid, name in model_names.items():\n",
    "        df_m = agg_df[agg_df[\"model_id\"] == mid]\n",
    "        if df_m.empty:\n",
    "            continue\n",
    "\n",
    "        ax.plot(\n",
    "            df_m[\"mean_n_train\"],\n",
    "            df_m[\"metric_mean\"],\n",
    "            marker=\"o\",\n",
    "            label=name\n",
    "        )\n",
    "\n",
    "        ax.fill_between(\n",
    "            df_m[\"mean_n_train\"],\n",
    "            df_m[\"metric_mean\"] - df_m[\"metric_std\"],\n",
    "            df_m[\"metric_mean\"] + df_m[\"metric_std\"],\n",
    "            alpha=0.2\n",
    "        )\n",
    "\n",
    "    ax.set_xlabel(\"Training set size (mean)\")\n",
    "    ax.set_ylabel(metric.upper())\n",
    "    ax.grid(True, alpha=0.4)\n",
    "    ax.legend()\n",
    "    ax.set_title(f\"{metric.upper()} vs training size\")\n",
    "\n",
    "    return fig, ax\n",
    "\n",
    "# ===== ===== ===== ===== ===== ===== ===== =====\n",
    "\n",
    "fig, ax = plot_learning_curves(\n",
    "    agg_lc,\n",
    "    model_names=model_names,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2369a3ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "results_lp.groupby('train_fraction')['n_train'].plot(color='k', alpha=.4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fd704284",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_metric_boxplots(\n",
    "    results_lp,\n",
    "    model_names,\n",
    "    train_fraction,\n",
    "    metric=\"f1\",\n",
    "    ratio=(4, 3),\n",
    "    scale=2,\n",
    "    cmap_name=\"tab10\"\n",
    "):\n",
    "    \n",
    "    def prepare_prompt_level_boxplot_df(df, metric=\"f1\"):\n",
    "        return (\n",
    "            df\n",
    "            .groupby([\"model_id\", \"prompt_id\", \"train_fraction\"])\n",
    "            .agg(\n",
    "                metric_mean=(metric, \"mean\"),\n",
    "                mean_n_train=(\"n_train\", \"mean\"),\n",
    "            )\n",
    "            .reset_index()\n",
    "        )\n",
    "\n",
    "    fig, ax = plt.subplots(figsize=[scale * x for x in ratio])\n",
    "\n",
    "    cmap = plt.get_cmap(cmap_name)\n",
    "\n",
    "    data = []\n",
    "    labels = []\n",
    "    colors = []\n",
    "\n",
    "    df_prompt = prepare_prompt_level_boxplot_df(results_lp, metric=metric)\n",
    "\n",
    "    for i, (mid, name) in enumerate(model_names.items()):\n",
    "        vals = df_prompt[\n",
    "            (df_prompt[\"model_id\"] == mid) &\n",
    "            (df_prompt[\"train_fraction\"] == train_fraction)\n",
    "        ][\"metric_mean\"].values\n",
    "\n",
    "        if len(vals) > 0:\n",
    "            data.append(vals)\n",
    "            labels.append(name)\n",
    "            colors.append(cmap(i % cmap.N))\n",
    "\n",
    "    bp = ax.boxplot(\n",
    "        data,\n",
    "        patch_artist=True,\n",
    "        showfliers=True,\n",
    "        medianprops=dict(color=\"black\", linewidth=1.5),\n",
    "        boxprops=dict(linewidth=1.2),\n",
    "        whiskerprops=dict(linewidth=1.0),\n",
    "        capprops=dict(linewidth=1.0),\n",
    "    )\n",
    "\n",
    "    # ---- apply colours\n",
    "    for box, c in zip(bp[\"boxes\"], colors):\n",
    "        box.set_facecolor(c)\n",
    "        box.set_alpha(0.6)\n",
    "\n",
    "    ax.set_xticklabels(labels, rotation=30, ha=\"right\")\n",
    "    ax.set_ylabel(metric.upper())\n",
    "    ax.set_title(\n",
    "        f\"{metric.upper()} distribution across prompts \"\n",
    "        f\"(train_fraction={train_fraction})\"\n",
    "    )\n",
    "\n",
    "    ax.grid(True, axis=\"y\", alpha=0.4)\n",
    "\n",
    "    return fig, ax\n",
    "\n",
    "# ===== ===== ===== ===== ===== ===== ===== =====\n",
    "\n",
    "plot_metric_boxplots(results_lp, model_names, train_fractions[-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01c5b30a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_metric_boxplots_two_panels(\n",
    "    results_lp,\n",
    "    model_names,\n",
    "    model_order,\n",
    "    train_fraction,\n",
    "    metrics=(\"accuracy\", \"f1\"),\n",
    "    ratio=(6, 3),\n",
    "    scale=2,\n",
    "    width_ratios=[1, 1],\n",
    "    xlims=None,\n",
    "    cmap_name=\"tab10\"\n",
    "):\n",
    "    fig, axes = plt.subplots(\n",
    "        1, 2,\n",
    "        figsize=[scale * x for x in ratio],\n",
    "        sharey=True,\n",
    "        width_ratios=width_ratios\n",
    "    )\n",
    "\n",
    "    cmap = plt.get_cmap(cmap_name)\n",
    "\n",
    "    df_prompt = (\n",
    "        results_lp\n",
    "        .groupby([\"model_id\", \"prompt_id\", \"train_fraction\"])\n",
    "        .agg(\n",
    "            f1_mean=(\"f1\", \"mean\"),\n",
    "            accuracy_mean=(\"accuracy\", \"mean\"),\n",
    "            mean_n_train=(\"n_train\", \"mean\"),\n",
    "        )\n",
    "        .reset_index()\n",
    "    )\n",
    "\n",
    "    for it, metric in enumerate(metrics):\n",
    "        ax = axes[it]\n",
    "        data = []\n",
    "        labels = []\n",
    "        colors = []\n",
    "\n",
    "        for i, mid in enumerate(reversed(model_order)):\n",
    "            name = model_names[mid]\n",
    "\n",
    "            vals = df_prompt[\n",
    "                (df_prompt[\"model_id\"] == mid) &\n",
    "                (df_prompt[\"train_fraction\"] == train_fraction)\n",
    "            ][f\"{metric}_mean\"].values\n",
    "\n",
    "            if len(vals) > 0:\n",
    "                data.append(vals)\n",
    "                labels.append(name)\n",
    "                colors.append(cmap(i % cmap.N))\n",
    "\n",
    "        positions = np.arange(1, len(data) + 1)\n",
    "\n",
    "        # ---- violins (horizontal)\n",
    "        vp = ax.violinplot(\n",
    "            data,\n",
    "            widths=.75,\n",
    "            positions=positions,\n",
    "            vert=False,\n",
    "            showmeans=False,\n",
    "            showextrema=False,\n",
    "        )\n",
    "\n",
    "        for body, c in zip(vp[\"bodies\"], colors):\n",
    "            body.set_facecolor(c)\n",
    "            body.set_alpha(0.35)\n",
    "            body.set_edgecolor(\"none\")\n",
    "\n",
    "        # ---- boxes (horizontal)\n",
    "        bp = ax.boxplot(\n",
    "            data,\n",
    "            positions=positions,\n",
    "            vert=False,\n",
    "            widths=0.25,\n",
    "            showfliers=False,\n",
    "            patch_artist=True,\n",
    "            medianprops=dict(color=\"black\", linewidth=1.5),\n",
    "        )\n",
    "\n",
    "        for box, c in zip(bp[\"boxes\"], colors):\n",
    "            box.set_facecolor(c)\n",
    "            box.set_alpha(0.6)\n",
    "\n",
    "        ax.set_yticks(positions)\n",
    "        ax.set_yticklabels(labels)\n",
    "        # ax.set_xlabel(\"Score\")\n",
    "        ax.set_title(metric.capitalize())\n",
    "        ax.grid(True, axis=\"x\", alpha=0.4)\n",
    "        ax.set_xlim(xlims[it])\n",
    "\n",
    "    # axes[0].set_ylabel(\"Model\")\n",
    "\n",
    "    # fig.suptitle(\n",
    "    #     f\"Prompt-level score distributions (train_fraction={train_fraction})\",\n",
    "    #     y=1.02\n",
    "    # )\n",
    "\n",
    "    fig.tight_layout()\n",
    "    return fig, axes\n",
    "\n",
    "# ===== ===== ===== ===== ===== ===== ===== =====\n",
    "\n",
    "fig, _ = plot_metric_boxplots_two_panels(\n",
    "    results_lp,\n",
    "    model_names=model_names,\n",
    "    model_order=model_order,\n",
    "    train_fraction=train_fractions[-1],\n",
    "    width_ratios=[5, 9],\n",
    "    xlims=[(0.5, 1.0), (.1, 1.0)],\n",
    "    ratio=(3,1),\n",
    "    scale=3\n",
    ")\n",
    "\n",
    "fig.savefig(\"img/LP_statistics.pdf\", bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fd094226",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_metric_heatmap(\n",
    "    agg_df,\n",
    "    model_names,\n",
    "    metric=\"f1\",\n",
    "    ratio=(5, 3),\n",
    "    scale=2\n",
    "):\n",
    "    \n",
    "    def prepare_heatmap_df(agg_df):\n",
    "        return agg_df.pivot(\n",
    "            index=\"model_id\",\n",
    "            columns=\"train_fraction\",\n",
    "            values=\"metric_mean\"\n",
    "        )\n",
    "\n",
    "    heat_df = prepare_heatmap_df(agg_df)\n",
    "    \n",
    "    fig, ax = plt.subplots(figsize=[scale * x for x in ratio])\n",
    "\n",
    "    im = ax.imshow(\n",
    "        heat_df.values,\n",
    "        aspect=\"auto\",\n",
    "        origin=\"lower\"\n",
    "    )\n",
    "\n",
    "    ax.set_yticks(range(len(heat_df.index)))\n",
    "    ax.set_yticklabels([model_names[mid] for mid in heat_df.index])\n",
    "\n",
    "    ax.set_xticks(range(len(heat_df.columns)))\n",
    "    ax.set_xticklabels(\n",
    "        [f\"{int(100*c)}%\" for c in heat_df.columns]\n",
    "    )\n",
    "\n",
    "    ax.set_xlabel(\"Training fraction\")\n",
    "    ax.set_title(metric.upper())\n",
    "\n",
    "    fig.colorbar(im, ax=ax, label=metric.upper())\n",
    "\n",
    "    return fig, ax\n",
    "\n",
    "# ===== ===== ===== ===== ===== ===== ===== =====\n",
    "\n",
    "plot_metric_heatmap(agg_lc, model_names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "78de03f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_prompt_heatmap_per_model(\n",
    "    results_lp,\n",
    "    model_id,\n",
    "    model_name,\n",
    "    metric='f1',\n",
    "    ratio=(5, 4),\n",
    "    scale=2\n",
    "):\n",
    "    \n",
    "    agg_df = HALL_lib.aggregate_metric_over_prompts(\n",
    "        results_lp,\n",
    "        metric=metric,\n",
    "        agg_prompts=False,\n",
    "        agg_models=False,\n",
    "        agg_train_frac=False\n",
    "    )\n",
    "\n",
    "    df_m = agg_df[agg_df[\"model_id\"] == model_id]\n",
    "    \n",
    "    maxn = df_m['mean_n_train'].max()\n",
    "\n",
    "    heat = df_m.pivot(\n",
    "        index=\"prompt_id\",\n",
    "        columns=\"train_fraction\",\n",
    "        values=\"metric_mean\"\n",
    "    )\n",
    "\n",
    "    fig, ax = plt.subplots(figsize=[scale * x for x in ratio])\n",
    "\n",
    "    im = ax.imshow(\n",
    "        heat.values,\n",
    "        aspect=\"auto\",\n",
    "        origin=\"lower\"\n",
    "    )\n",
    "\n",
    "    ax.set_yticks(range(len(heat.index)))\n",
    "    ax.set_yticklabels(heat.index)\n",
    "\n",
    "    ax.set_xticks(range(len(heat.columns)))\n",
    "    ax.set_xticklabels([f\"{int(n*maxn)}\" for n in heat.columns])\n",
    "\n",
    "    ax.set_xlabel(\"Training set size\")\n",
    "    ax.set_title(f\"{model_name}\\nPrompt-level F1 vs training size\")\n",
    "\n",
    "    fig.colorbar(im, ax=ax, label=\"Mean F1\")\n",
    "\n",
    "    return fig, ax\n",
    "\n",
    "# ===== ===== ===== ===== ===== ===== ===== =====\n",
    "\n",
    "plot_prompt_heatmap_per_model(results_lp, 0, model_names[0], metric='f1')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f3564d36",
   "metadata": {},
   "source": [
    "## Compare mean accuracy and std"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eabc1929",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_mean_vs_variability_trajectories(\n",
    "    agg_df,\n",
    "    model_names,\n",
    "    metric_label='metric',\n",
    "    metric_name='F1',\n",
    "    ratio=(4, 3),\n",
    "    scale=2\n",
    "):\n",
    "    fig, ax = plt.subplots(figsize=[scale * x for x in ratio])\n",
    "\n",
    "    for mid, name in model_names.items():\n",
    "        df_m = agg_df[agg_df[\"model_id\"] == mid].sort_values(\"mean_n_train\")\n",
    "\n",
    "        ax.plot(\n",
    "            df_m[f\"{metric_label}_mean\"],\n",
    "            df_m[f\"{metric_label}_std\"],\n",
    "            marker=\"o\",\n",
    "            label=name\n",
    "        )\n",
    "\n",
    "        # Optional: annotate final point\n",
    "        ax.annotate(\n",
    "            name,\n",
    "            (df_m[f\"{metric_label}_mean\"].iloc[-1], df_m[f\"{metric_label}_std\"].iloc[-1]),\n",
    "            fontsize=8\n",
    "        )\n",
    "\n",
    "    ax.set_xlabel(f\"Mean {metric_name} across prompts\")\n",
    "    ax.set_ylabel(f\"{metric_name} standard deviation across prompts\")\n",
    "    # ax.set_title(\"Performance-stability trajectories\")\n",
    "    ax.grid(True, alpha=0.4)\n",
    "\n",
    "    return fig, ax\n",
    "\n",
    "# ===== ===== ===== ===== ===== ===== ===== =====\n",
    "\n",
    "plot_mean_vs_variability_trajectories(agg_lc, model_names)\n",
    "plot_mean_vs_variability_trajectories(agg_lc, model_names, metric_name='Accuracy', metric_label=\"score\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b0b4faa7",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_mean_vs_variability_trajectories_two_panels(\n",
    "    agg_df,\n",
    "    model_names,\n",
    "    metrics=(\n",
    "        (\"metric\", \"F1\"),\n",
    "        (\"score\", \"Accuracy\"),\n",
    "    ),\n",
    "    label_angles_top=None,\n",
    "    label_angles_bottom=None,\n",
    "    label_radius=10,\n",
    "    ratio=(4, 6),\n",
    "    scale=2,\n",
    "    cmap_name=\"tab10\",\n",
    "):\n",
    "    \n",
    "    def alpha_from_fraction(tf, min_alpha=0.15, max_alpha=0.95):\n",
    "        return min_alpha + (max_alpha - min_alpha) * tf\n",
    "\n",
    "    def text_with_display_offset(ax, x, y, text, angle_deg, radius_pts,\n",
    "                                color, fontsize=9):\n",
    "        angle_rad = np.deg2rad(angle_deg)\n",
    "        dx = radius_pts * np.cos(angle_rad)\n",
    "        dy = radius_pts * np.sin(angle_rad)\n",
    "\n",
    "        ha = \"left\" if -90 <= angle_deg <= 90 else \"right\"\n",
    "\n",
    "        text_transform = offset_copy(\n",
    "            ax.transData,\n",
    "            fig=ax.figure,\n",
    "            x=dx,\n",
    "            y=dy,\n",
    "            units=\"points\"\n",
    "        )\n",
    "\n",
    "        ax.text(\n",
    "            x, y, text,\n",
    "            transform=text_transform,\n",
    "            fontsize=fontsize,\n",
    "            ha=ha,\n",
    "            va=\"center\",\n",
    "            color=color,\n",
    "        )\n",
    "\n",
    "    fig, axes = plt.subplots(\n",
    "        2, 1,\n",
    "        figsize=[scale * x for x in ratio],\n",
    "        sharex=True,\n",
    "        gridspec_kw={\"hspace\": 0.15},\n",
    "    )\n",
    "\n",
    "    cmap = plt.get_cmap(cmap_name)\n",
    "\n",
    "    for ax, (metric_label, metric_name) in zip(axes, metrics):\n",
    "        for i, (mid, name) in enumerate(model_names.items()):\n",
    "            df_m = (\n",
    "                agg_df[agg_df[\"model_id\"] == mid]\n",
    "                .sort_values(\"mean_n_train\")\n",
    "            )\n",
    "            if df_m.empty:\n",
    "                continue\n",
    "\n",
    "            color = cmap(i % cmap.N)\n",
    "\n",
    "            x = df_m[f\"{metric_label}_mean\"].values\n",
    "            y = df_m[f\"{metric_label}_std\"].values\n",
    "            tfs = df_m[\"mean_n_train\"].values\n",
    "            tfs = tfs / tfs.max()  # normalise to [0, 1]\n",
    "\n",
    "            # ---- faint trajectory line\n",
    "            ax.plot(\n",
    "                x, y,\n",
    "                color=color,\n",
    "                lw=1,\n",
    "                alpha=0.35,\n",
    "            )\n",
    "\n",
    "            # ---- points with fading alpha\n",
    "            for xi, yi, tf in zip(x, y, tfs):\n",
    "                ax.scatter(\n",
    "                    xi, yi,\n",
    "                    color=color,\n",
    "                    s=35,\n",
    "                    alpha=alpha_from_fraction(tf),\n",
    "                    zorder=3,\n",
    "                )\n",
    "\n",
    "            # ---- best point (star)\n",
    "            best_idx = np.argmax(x)\n",
    "            ax.scatter(\n",
    "                x[best_idx],\n",
    "                y[best_idx],\n",
    "                color=color,\n",
    "                s=140,\n",
    "                marker=\"*\",\n",
    "                edgecolor=\"black\",\n",
    "                linewidth=0.8,\n",
    "                zorder=5,\n",
    "            )\n",
    "\n",
    "            # ---- label to the right of best point\n",
    "            label_angles = (\n",
    "                label_angles_top if metric_label == \"metric\"\n",
    "                else label_angles_bottom\n",
    "            )\n",
    "\n",
    "            angle = label_angles.get(name, 0.0)\n",
    "\n",
    "            text_with_display_offset(\n",
    "                ax,\n",
    "                x[best_idx],\n",
    "                y[best_idx],\n",
    "                name,\n",
    "                angle_deg=angle,\n",
    "                radius_pts=label_radius,\n",
    "                color=color,\n",
    "                fontsize=12,\n",
    "            )\n",
    "\n",
    "        ax.set_ylabel(f\"{metric_name} std across prompts\", size=16)\n",
    "        ax.grid(True, alpha=0.4)\n",
    "\n",
    "        ax.tick_params(axis='both', labelsize=16)\n",
    "\n",
    "    axes[-1].set_xlabel(\"Mean performance across prompts\", size=16)\n",
    "\n",
    "    # fig.suptitle(\n",
    "    #     \"Performance–stability trajectories\\n(fading = increasing training size)\",\n",
    "    #     y=0.97,\n",
    "    #     fontsize=13,\n",
    "    # )\n",
    "\n",
    "    return fig, axes\n",
    "\n",
    "# ===== ===== ===== ===== ===== ===== ===== =====\n",
    "\n",
    "label_angles_f1 = {\n",
    "    'Mistral-7B' : -150,\n",
    "    'Gemma-9B' : -150,\n",
    "    'Solar-11B' : -120,\n",
    "    'Phi-14B' : 0,\n",
    "    'Qwen-32B' : +30,\n",
    "    'Gemma-27B' : -91,\n",
    "    'Qwen-15B' : -45,\n",
    "    'Llama-8B' : +45,\n",
    "    'DeepSeek-7B': +30\n",
    "}\n",
    "\n",
    "label_angles_acc = {\n",
    "    'Mistral-7B' : -95,\n",
    "    'Gemma-9B' : +60,\n",
    "    'Solar-11B' : -165,\n",
    "    'Phi-14B' : 0,\n",
    "    'Qwen-32B' : 0,\n",
    "    'Gemma-27B' : +20,\n",
    "    'Llama-8B' : -95,\n",
    "    'Qwen-15B' : -90,\n",
    "}\n",
    "\n",
    "fig, axes = plot_mean_vs_variability_trajectories_two_panels(\n",
    "    agg_lc,\n",
    "    model_names,\n",
    "    label_angles_top=label_angles_f1,\n",
    "    label_angles_bottom=label_angles_acc,\n",
    "    ratio=(5, 6),\n",
    "    scale=1.5,\n",
    "    label_radius=12\n",
    ")\n",
    "\n",
    "fig.savefig(\"img/LP_performances.pdf\", bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b7ff846",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d613da28",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.14.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
