{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "72d526b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.gridspec as gridspec\n",
    "import seaborn as sns\n",
    "\n",
    "# -------------------- Settings and base data\n",
    "model_order = [\"GraphSAGE\", \"GIN\", \"GAT\", \"GCN\"]\n",
    "model_order2 = [\"GCN\", \"GAT\", \"GIN\", \"GraphSAGE\"]\n",
    "\n",
    "cell_text = {\n",
    "    (0, 0): {\"title\": \"GPT\\nvs\\nRandom\",       \"caption_left\": \"Graph Properties\"},\n",
    "    (0, 1): {\"title\": \"GPT\\nvs\\nGround Truth\"},\n",
    "    (1, 0): {\"title\": \"GPT\\nvs\\nRandom\",       \"caption_left\": \"Embeddings\"},\n",
    "    (1, 1): {\"title\": \"GPT\\nvs\\nGround Truth\"},\n",
    "}\n",
    "\n",
    "title_pos_right   = (0.986, 0.95)\n",
    "title_pos_default = (0.02, 0.95)\n",
    "\n",
    "title_pos_by_cell = {(0, 1): title_pos_right}\n",
    "title_align_by_cell = {(0, 1): \"right\"}\n",
    "\n",
    "TITLE_FONTSIZE = 3\n",
    "CAPTION_FONTSIZE = 3\n",
    "TITLE_WEIGHT = \"normal\"\n",
    "CAPTION_WEIGHT = \"normal\"\n",
    "TITLE_BBOX = dict(boxstyle=\"round,pad=0.2\", fc=\"white\", ec=\"none\", alpha=0.8)\n",
    "CAPTION_BBOX = None\n",
    "\n",
    "plt.rcParams.update({\n",
    "    \"font.family\": \"serif\",\n",
    "    \"font.size\": 3,\n",
    "    \"axes.labelsize\": 3,\n",
    "    \"xtick.labelsize\": 3,\n",
    "    \"ytick.labelsize\": 3,\n",
    "    \"legend.fontsize\": 3,\n",
    "    \"legend.title_fontsize\": 3\n",
    "})\n",
    "\n",
    "palette = {\n",
    "    \"GIN\": \"#4F91C6\",\n",
    "    \"GAT\": \"#FF9E4D\",\n",
    "    \"GraphSAGE\": \"#5EBF5E\",\n",
    "    \"GCN\": \"#E15B5B\",\n",
    "}\n",
    "row1_files = [\n",
    "    (base_path / \"graphproperties_gpt_vs_random.csv\", \"\"),\n",
    "    (base_path / \"graphproperties_gpt_vs_groundtruth.csv\", \"\"),\n",
    "]\n",
    "\n",
    "row2_files = [\n",
    "    (base_path / \"embedding_gpt_vs_random.csv\", \"GPT VS Random\"),\n",
    "    (base_path / \"embeddings_gpt_vs_groundtruth.csv\", \"GPT VS Ground truth\"),\n",
    "]\n",
    "# Inset rectangle inside KDE axes: [left, bottom, width, height] (0–1 coords)\n",
    "inset_rect = [0.36, 0.67, 0.25, 0.3]\n",
    "\n",
    "# Limits for inset axes depending on subplot position\n",
    "inset_limits_by_position = {\n",
    "    (0, 0): dict(ylim=(0, 0.08), yticks=[0.01, 0.04, 0.07]),\n",
    "    (0, 1): dict(ylim=(0, 0.008), yticks=[0.001, 0.004, 0.008]),\n",
    "    (1, 0): dict(ylim=(0, 0.08), yticks=[0.01, 0.04, 0.07]),\n",
    "    (1, 1): dict(ylim=(0, 0.08), yticks=[0.01, 0.04, 0.07]),\n",
    "}\n",
    "\n",
    "# -------------------- Helper: add text annotations (titles, etc.)\n",
    "def add_texts(ax_kde, ax_box, row_idx, col_idx):\n",
    "    info = cell_text.get((row_idx, col_idx), {})\n",
    "    title_txt = info.get(\"title\", \"\")\n",
    "    title_xy = title_pos_by_cell.get((row_idx, col_idx), title_pos_default)\n",
    "    ha_align = title_align_by_cell.get((row_idx, col_idx), \"left\")\n",
    "\n",
    "    if title_txt:\n",
    "        ax_kde.text(\n",
    "            title_xy[0], title_xy[1], title_txt,\n",
    "            transform=ax_kde.transAxes,\n",
    "            ha=ha_align, va=\"top\",\n",
    "            fontsize=TITLE_FONTSIZE, fontweight=TITLE_WEIGHT,\n",
    "            bbox=TITLE_BBOX\n",
    "        )\n",
    "def draw_left_break_at_top(ax, d=0.018, lw=0.8, color='k', y=0.985, gap=0.02):\n",
    "    \"\"\"\n",
    "    Draw two small diagonal slashes near the top-left of the axis (axes coords),\n",
    "    to visually indicate a y-axis break.\n",
    "    \"\"\"\n",
    "    kwargs = dict(color=color, clip_on=False, linewidth=lw, transform=ax.transAxes)\n",
    "    # two parallel slashes near the top-left\n",
    "    ax.plot((-d, +d), (y - d, y + d), **kwargs)\n",
    "    ax.plot((-d, +d), (y - d - gap, y + d - gap), **kwargs)\n",
    "\n",
    "def add_extra_top_tick_label(ax, label=\"240\", ticklen=0.015, x=0.0, y=1.0, lw=0.6):\n",
    "    \"\"\"\n",
    "    Add a cosmetic extra tick and label at the very top of the left y-axis.\n",
    "    This does NOT change data limits; it's just drawn in axes coordinates.\n",
    "    \"\"\"\n",
    "    # small horizontal tick at top-left (extends outward a bit)\n",
    "    ax.plot((x - ticklen, x), (y, y), transform=ax.transAxes, color='k', linewidth=lw, clip_on=False)\n",
    "    # label slightly to the left of the tick\n",
    "    ax.text(x - ticklen - 0.012, y, label, transform=ax.transAxes,\n",
    "            va='center', ha='right', fontsize=ax.yaxis.get_ticklabels()[0].get_fontsize())\n",
    "\n",
    "# -------------------- Helper: draw visual break only on the left y-axis spine (single axis)\n",
    "def draw_yaxis_visual_break(ax, low=(0, 30), high=(200, 240),d=0.018, lw=0.8, color='k', n_lines=2, gap=0.018,\n",
    "                            pos=1):\n",
    "\n",
    "    y0, y1 = ax.get_ylim()\n",
    "    # weighted position instead of strict midpoint\n",
    "    y_break = low[1] + pos * (high[0] - low[1])   #  pos=0.5 was the old midpoint\n",
    "    y_frac  = (y_break - y0) / (y1 - y0)\n",
    "\n",
    "    kwargs = dict(color=color, clip_on=False, linewidth=lw, transform=ax.transAxes)\n",
    "    for i in range(n_lines):\n",
    "        offs = (i - (n_lines - 1) / 2.0) * gap\n",
    "        ax.plot((-d, +d), (y_frac - d + offs, y_frac + d + offs), **kwargs)\n",
    "\n",
    "\n",
    "# -------------------- Normal column (no y-axis visual break)\n",
    "def plot_column(ax_kde, ax_box, csv_path, xlabel, row_idx, col_idx):\n",
    "    df = pd.read_csv(csv_path)\n",
    "    df.columns = df.columns.str.strip()\n",
    "    df[\"model_name\"] = df[\"Name\"].str.extract(r\"^(GCN|GraphSAGE|GAT|GIN)\")\n",
    "\n",
    "    df_kde = df[[\"model_name\", \"final_val_accuracy\"]].dropna()\n",
    "    df_bar = df[[\"model_name\", \"final_val_accuracy_std\"]].dropna()\n",
    "    mean_std = df_bar.groupby(\"model_name\")[\"final_val_accuracy_std\"].mean().reset_index()\n",
    "\n",
    "    # KDE\n",
    "    for model, group in df_kde.groupby(\"model_name\"):\n",
    "        sns.kdeplot(\n",
    "            data=group, x=\"final_val_accuracy\",\n",
    "            label=model, fill=False, alpha=0.9, linewidth=1, bw_adjust=0.3,\n",
    "            color=palette.get(model, \"gray\"), common_norm=True, ax=ax_kde\n",
    "        )\n",
    "    ax_kde.set_xlabel(\"\")\n",
    "    if col_idx == 0:\n",
    "        ax_kde.set_ylabel(\"Density\")\n",
    "    else:\n",
    "        ax_kde.set_ylabel(\"\")\n",
    "    ax_kde.tick_params(axis=\"y\", length=1, pad=1)\n",
    "    ax_kde.tick_params(axis=\"x\", which=\"both\", bottom=False, top=False, labelbottom=False)\n",
    "    ax_kde.grid(False)\n",
    "    ax_kde.tick_params(axis=\"y\", which=\"both\", labelleft=True)\n",
    "    for spine in ax_kde.spines.values():\n",
    "        spine.set_linewidth(0.3)\n",
    "\n",
    "    # Boxplot\n",
    "    sns.boxplot(\n",
    "        data=df_kde, y=\"model_name\", x=\"final_val_accuracy\",\n",
    "        orient=\"h\", palette=palette, width=0.5, linewidth=0.5,\n",
    "        order=model_order, ax=ax_box,\n",
    "        flierprops=dict(marker='o', markersize=3, alpha=0.6)\n",
    "    )\n",
    "    ax_box.set_xlabel(\"\")\n",
    "    ax_box.tick_params(axis=\"y\", length=1, pad=1)\n",
    "    ax_box.tick_params(axis=\"x\", length=1, pad=1)\n",
    "    ax_box.set_ylabel(\"\")\n",
    "    ax_box.invert_yaxis()\n",
    "    ax_box.grid(False)\n",
    "    ax_box.set_yticklabels(ax_box.get_yticklabels(), rotation=30)\n",
    "    for spine in ax_box.spines.values():\n",
    "        spine.set_linewidth(0.3)\n",
    "\n",
    "    # Inset bar inside KDE\n",
    "    ax_inset = ax_kde.inset_axes(inset_rect)\n",
    "    sns.barplot(\n",
    "        data=mean_std, x=\"model_name\", y=\"final_val_accuracy_std\",\n",
    "        palette=palette, order=model_order2, width=0.5, ax=ax_inset\n",
    "    )\n",
    "    ax_inset.set_xlabel(\"Models\")\n",
    "    ax_inset.set_ylabel(\"\")\n",
    "    ax_inset.set_title(\"\")\n",
    "    for spine in ax_inset.spines.values():\n",
    "        spine.set_linewidth(0.3)\n",
    "    ax_inset.set_xticklabels([])\n",
    "    ax_inset.tick_params(axis=\"y\", length=1, pad=1)\n",
    "    ax_inset.tick_params(axis=\"x\", length=1, pad=1)\n",
    "    ax_inset.set_yticklabels(ax_inset.get_yticklabels(), rotation=45)\n",
    "    for spine in ax_inset.spines.values():\n",
    "        spine.set_linewidth(0.3)\n",
    "\n",
    "    # Apply manual y-axis limits for inset if needed\n",
    "    lims = inset_limits_by_position.get((row_idx, col_idx))\n",
    "    if lims:\n",
    "        ax_inset.set_ylim(*lims[\"ylim\"])\n",
    "        ax_inset.set_yticks(lims[\"yticks\"])\n",
    "\n",
    "    add_texts(ax_kde, ax_box, row_idx, col_idx)\n",
    "\n",
    "# -------------------- Column with y-axis *visual* break (top-right only, curve intact)\n",
    "def plot_column_with_visual_break(ax_kde, ax_box, csv_path, xlabel, row_idx, col_idx,\n",
    "                                  low=(0,30), high=(200,240)):\n",
    "    df = pd.read_csv(csv_path)\n",
    "    df.columns = df.columns.str.strip()\n",
    "    df[\"model_name\"] = df[\"Name\"].str.extract(r\"^(GCN|GraphSAGE|GAT|GIN)\")\n",
    "\n",
    "    df_kde = df[[\"model_name\", \"final_val_accuracy\"]].dropna()\n",
    "    df_bar = df[[\"model_name\", \"final_val_accuracy_std\"]].dropna()\n",
    "    mean_std = df_bar.groupby(\"model_name\")[\"final_val_accuracy_std\"].mean().reset_index()\n",
    "\n",
    "    # KDE (single axis; do NOT split/clip the plot)\n",
    "    for model, group in df_kde.groupby(\"model_name\"):\n",
    "        sns.kdeplot(\n",
    "            data=group, x=\"final_val_accuracy\",\n",
    "            label=model, fill=False, alpha=0.9, linewidth=1, bw_adjust=0.3,\n",
    "            color=palette.get(model, \"gray\"), common_norm=True, ax=ax_kde\n",
    "        )\n",
    "\n",
    "    # Keep the curve intact; just show ticks for two ranges and draw a visual break\n",
    "    ax_kde.set_ylim(0, 240)  # full span so the curve isn't cut\n",
    "    ax_kde.set_yticks([0, 10, 20, 30, 240])\n",
    "    ax_kde.set_ylabel(\"\")\n",
    "    ax_kde.tick_params(axis=\"x\", which=\"both\", bottom=False, top=False, labelbottom=False)\n",
    "    ax_kde.grid(False)\n",
    "    for s in ax_kde.spines.values():\n",
    "        s.set_linewidth(0.3)\n",
    "\n",
    "    draw_yaxis_visual_break(ax_kde, low=(0,30), high=(200,240), d=0.018, lw=0.8, color='k', n_lines=2, gap=0.02, pos=1)\n",
    "\n",
    "    # Boxplot (shared x)\n",
    "    sns.boxplot(\n",
    "        data=df_kde, y=\"model_name\", x=\"final_val_accuracy\",\n",
    "        orient=\"h\", palette=palette, width=0.5, linewidth=0.5,\n",
    "        order=model_order, ax=ax_box,\n",
    "        flierprops=dict(marker='o', markersize=3, alpha=0.6)\n",
    "    )\n",
    "    ax_box.set_xlabel(\"\")\n",
    "    ax_box.set_ylabel(\"\")\n",
    "    ax_box.invert_yaxis()\n",
    "    ax_box.grid(False)\n",
    "    ax_box.set_yticklabels(ax_box.get_yticklabels(), rotation=30)\n",
    "    for s in ax_box.spines.values():\n",
    "        s.set_linewidth(0.3)\n",
    "\n",
    "    # Inset inside KDE\n",
    "    ax_inset = ax_kde.inset_axes(inset_rect)\n",
    "    sns.barplot(\n",
    "        data=mean_std, x=\"model_name\", y=\"final_val_accuracy_std\",\n",
    "        palette=palette, order=model_order2, width=0.5, ax=ax_inset\n",
    "    )\n",
    "    ax_inset.set_xlabel(\"Models\"); ax_inset.set_ylabel(\"\"); ax_inset.set_title(\"\")\n",
    "    for s in ax_inset.spines.values():\n",
    "        s.set_linewidth(0.4)\n",
    "    ax_inset.set_xticklabels([])\n",
    "    ax_inset.tick_params(axis=\"y\", length=1, pad=1)\n",
    "    ax_inset.tick_params(axis=\"x\", length=1, pad=1)\n",
    "    ax_inset.set_yticklabels(ax_inset.get_yticklabels(), rotation=45)\n",
    "    for s in ax_inset.spines.values():\n",
    "        s.set_linewidth(0.3)\n",
    "\n",
    "    lims = inset_limits_by_position.get((row_idx, col_idx))\n",
    "    if lims:\n",
    "        ax_inset.set_ylim(*lims[\"ylim\"])\n",
    "        ax_inset.set_yticks(lims[\"yticks\"])\n",
    "\n",
    "    add_texts(ax_kde, ax_box, row_idx, col_idx)\n",
    "\n",
    "# -------------------- Figure and layout\n",
    "fig = plt.figure(figsize=(5.5, 3.8), dpi=600)\n",
    "outer = gridspec.GridSpec(nrows=2, ncols=2, height_ratios=[1, 1], wspace=0.19, hspace=0.15)\n",
    "\n",
    "ref_ax_x = None   # shared x reference for KDEs/boxes\n",
    "ref_ax_y = None   # shared y reference for *normal* KDEs (not for the visual-break cell)\n",
    "last_kde_for_legend = None\n",
    "\n",
    "for row_idx, files in enumerate([row1_files, row2_files]):\n",
    "    for col_idx, (csv_path, xlabel) in enumerate(files):\n",
    "\n",
    "        inner = gridspec.GridSpecFromSubplotSpec(\n",
    "            2, 1, subplot_spec=outer[row_idx, col_idx],\n",
    "            height_ratios=[1.3, 1], hspace=0.001\n",
    "        )\n",
    "\n",
    "        # Top axis (KDE)\n",
    "        if ref_ax_x is None:\n",
    "            ax_kde = plt.Subplot(fig, inner[0])\n",
    "            ref_ax_x = ax_kde  # share x across all\n",
    "            ref_ax_y = ax_kde  # share y for normal cells\n",
    "        else:\n",
    "            # For the special top-right cell, DO NOT share y (we need 0–200 + visual break)\n",
    "            if (row_idx, col_idx) == (0, 1):\n",
    "                ax_kde = plt.Subplot(fig, inner[0], sharex=ref_ax_x)  # no sharey\n",
    "            else:\n",
    "                ax_kde = plt.Subplot(fig, inner[0], sharex=ref_ax_x, sharey=ref_ax_y)\n",
    "\n",
    "        # Bottom axis (box)\n",
    "        ax_box = plt.Subplot(fig, inner[1], sharex=ax_kde)\n",
    "\n",
    "        fig.add_subplot(ax_kde)\n",
    "        fig.add_subplot(ax_box)\n",
    "\n",
    "        # Plot either normal or the visual-break variant (only top-right)\n",
    "        if (row_idx, col_idx) == (0, 1):\n",
    "            plot_column_with_visual_break(ax_kde, ax_box, csv_path, xlabel, row_idx, col_idx,\n",
    "                                          low=(0,30), high=(200,240))\n",
    "        else:\n",
    "            plot_column(ax_kde, ax_box, csv_path, xlabel, row_idx, col_idx)\n",
    "\n",
    "        last_kde_for_legend = ax_kde\n",
    "\n",
    "# -------------------- Final shared settings, legend, save\n",
    "# Shared x-axis for all KDEs/boxes\n",
    "ref_ax_x.set_xlim(0.4, 1.04)\n",
    "ref_ax_x.set_xticks([0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0])\n",
    "\n",
    "# Shared y for normal KDEs (kept at 0–20)\n",
    "ref_ax_y.set_ylim(0, 30)\n",
    "ref_ax_y.set_yticks([0, 10, 20, 30])\n",
    "\n",
    "# Global legend\n",
    "handles, labels = last_kde_for_legend.get_legend_handles_labels()\n",
    "fig.legend(\n",
    "    handles, labels,\n",
    "    loc='upper center',\n",
    "    bbox_to_anchor=(0.54, 1.02),\n",
    "    ncol=len(labels),\n",
    "    fontsize=4,\n",
    "    title=\"Model                \",\n",
    "    title_fontsize=4,\n",
    "    frameon=False,\n",
    "    handlelength=3,\n",
    "    handletextpad=2,\n",
    "    borderpad=6,\n",
    "    labelspacing=1\n",
    ")\n",
    "\n",
    "plt.subplots_adjust(top=0.88)\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"combined_2x2_gnn_comparison.png\", bbox_inches=\"tight\", dpi=300)\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f8da06a",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.gridspec as gridspec\n",
    "import seaborn as sns\n",
    "\n",
    "cell_text = {\n",
    "    (0, 0): {\"title\": \"GPT\\nvs\\nRandom\",       \"caption_left\": \"Graph Properties\"},\n",
    "    (0, 1): {\"title\": \"GPT\\nvs\\nGround Truth\"},\n",
    "    (1, 0): {\"title\": \"GPT\\nvs\\nRandom\",       \"caption_left\": \"Embeddings\"},\n",
    "    (1, 1): {\"title\": \"GPT\\nvs\\nGround Truth\"},\n",
    "}\n",
    "\n",
    "title_pos_right   = (0.986, 0.95)\n",
    "title_pos_default = (0.02, 0.95)\n",
    "\n",
    "title_pos_by_cell = {(0, 1): title_pos_right}\n",
    "title_align_by_cell = {(0, 1): \"right\"}\n",
    "\n",
    "TITLE_FONTSIZE = 3\n",
    "CAPTION_FONTSIZE = 3\n",
    "TITLE_WEIGHT = \"normal\"\n",
    "CAPTION_WEIGHT = \"normal\"\n",
    "TITLE_BBOX = dict(boxstyle=\"round,pad=0.2\", fc=\"white\", ec=\"none\", alpha=0.8)\n",
    "CAPTION_BBOX = None\n",
    "\n",
    "plt.rcParams.update({\n",
    "    \"font.family\": \"serif\",\n",
    "    \"font.size\": 3,\n",
    "    \"axes.labelsize\": 3,\n",
    "    \"xtick.labelsize\": 3,\n",
    "    \"ytick.labelsize\": 3,\n",
    "    \"legend.fontsize\": 3,\n",
    "    \"legend.title_fontsize\": 3\n",
    "})\n",
    "\n",
    "palette = {\n",
    "    \"GIN\": \"#4F91C6\",\n",
    "    \"GAT\": \"#FF9E4D\",\n",
    "    \"GraphSAGE\": \"#5EBF5E\",\n",
    "    \"GCN\": \"#E15B5B\",\n",
    "}\n",
    "\n",
    "model_order  = [\"GIN\", \"GAT\", \"GraphSAGE\", \"GCN\"]\n",
    "model_order2 = [\"GIN\", \"GAT\", \"GraphSAGE\", \"GCN\"]\n",
    "\n",
    "row1_files = [\n",
    "    (base_path / \"graphproperties_gpt_vs_random.csv\", \"\"),\n",
    "    (base_path / \"graphproperties_gpt_vs_groundtruth.csv\", \"\"),\n",
    "]\n",
    "\n",
    "row2_files = [\n",
    "    (base_path / \"embedding_gpt_vs_random.csv\", \"GPT VS Random\"),\n",
    "    (base_path / \"embeddings_gpt_vs_groundtruth.csv\", \"GPT VS Ground truth\"),\n",
    "]\n",
    "# Inset rectangle inside KDE axes: [left, bottom, width, height] (0–1 coords)\n",
    "inset_rect = [0.36, 0.67, 0.25, 0.3]\n",
    "\n",
    "inset_limits_by_position = {\n",
    "    (0, 0): dict(ylim=(0, 0.08),  yticks=[0.01, 0.04, 0.07]),\n",
    "    (0, 1): dict(ylim=(0, 0.008), yticks=[0.001, 0.004, 0.008]),\n",
    "    (1, 0): dict(ylim=(0, 0.08),  yticks=[0.01, 0.04, 0.07]),\n",
    "    (1, 1): dict(ylim=(0, 0.08),  yticks=[0.01, 0.04, 0.07]),\n",
    "}\n",
    "\n",
    "shared_density_ylim   = (0, 14)\n",
    "shared_density_yticks = [0, 2, 4, 6, 8, 10, 12]\n",
    "\n",
    "def add_texts(ax_kde, ax_box, row_idx, col_idx):\n",
    "    info = cell_text.get((row_idx, col_idx), {})\n",
    "    title_txt = info.get(\"title\", \"\")\n",
    "    caption_left = info.get(\"caption_left\", \"\")\n",
    "\n",
    "    title_xy = title_pos_by_cell.get((row_idx, col_idx), title_pos_default)\n",
    "    ha_align = title_align_by_cell.get((row_idx, col_idx), \"left\")\n",
    "\n",
    "    if title_txt:\n",
    "        ax_kde.text(\n",
    "            title_xy[0], title_xy[1], title_txt,\n",
    "            transform=ax_kde.transAxes,\n",
    "            ha=ha_align, va=\"top\",\n",
    "            fontsize=TITLE_FONTSIZE, fontweight=TITLE_WEIGHT,\n",
    "            bbox=TITLE_BBOX\n",
    "        )\n",
    "\n",
    "    # Draw left caption\n",
    "    if caption_left:\n",
    "        ax_kde.text(\n",
    "            -0.12, 0.5, caption_left,       \n",
    "            transform=ax_kde.transAxes,\n",
    "            ha=\"right\", va=\"center\",\n",
    "            rotation=90,\n",
    "            fontsize=CAPTION_FONTSIZE, fontweight=CAPTION_WEIGHT,\n",
    "            bbox=CAPTION_BBOX\n",
    "        )\n",
    "\n",
    "\n",
    "def plot_column(ax_kde, ax_box, csv_path, xlabel, row_idx, col_idx):\n",
    "    df = pd.read_csv(csv_path)\n",
    "    df.columns = df.columns.str.strip()\n",
    "    df[\"model_name\"] = df[\"Name\"].str.extract(r\"^(GCN|GraphSAGE|GAT|GIN)\")\n",
    "\n",
    "    df_kde = df[[\"model_name\", \"final_val_accuracy\"]].dropna()\n",
    "    df_bar = df[[\"model_name\", \"final_val_accuracy_std\"]].dropna()\n",
    "    mean_std = df_bar.groupby(\"model_name\")[\"final_val_accuracy_std\"].mean().reset_index()\n",
    "\n",
    "    # KDE (linear y = density)\n",
    "    for model, group in df_kde.groupby(\"model_name\"):\n",
    "        sns.kdeplot(\n",
    "            data=group,\n",
    "            x=\"final_val_accuracy\",\n",
    "            label=model,\n",
    "            fill=False,\n",
    "            alpha=0.9,\n",
    "            linewidth=1,\n",
    "            bw_adjust=0.9,\n",
    "            color=palette.get(model, \"gray\"),\n",
    "            common_norm=True,\n",
    "            ax=ax_kde\n",
    "        )\n",
    "\n",
    "    # Y-axis settings (share for all except top-right (0,1))\n",
    "    ax_kde.set_ylabel(\"Density\")\n",
    "    ax_kde.set_xlabel(\"\")\n",
    "    # Y-axis settings (share for all except top-right (0,1))\n",
    "    if col_idx == 0:\n",
    "        ax_kde.set_ylabel(\"Density\")\n",
    "    else:\n",
    "        ax_kde.set_ylabel(\"\")\n",
    " \n",
    "    ax_kde.set_xlabel(\"\")  \n",
    "    if not (row_idx == 0 and col_idx == 1):\n",
    "        ax_kde.set_ylim(*shared_density_ylim)\n",
    "        ax_kde.set_yticks(shared_density_yticks)\n",
    "    else:\n",
    "        ax_kde.set_ylim(bottom=0)\n",
    "\n",
    "\n",
    "    ax_kde.tick_params(axis=\"y\", length=1, pad=1)\n",
    "    ax_kde.tick_params(axis=\"x\", which=\"both\", bottom=False, top=False, labelbottom=False)\n",
    "    ax_kde.grid(False)\n",
    "    ax_kde.tick_params(axis=\"y\", which=\"both\", labelleft=True)\n",
    "    for spine in ax_kde.spines.values():\n",
    "        spine.set_linewidth(0.3)\n",
    "\n",
    "    # Boxplot (share x)\n",
    "    sns.boxplot(\n",
    "        data=df_kde,\n",
    "        y=\"model_name\",\n",
    "        x=\"final_val_accuracy\",\n",
    "        orient=\"h\",\n",
    "        palette=palette,\n",
    "        width=0.5,\n",
    "        linewidth=0.5,\n",
    "        order=model_order,\n",
    "        ax=ax_box,\n",
    "        flierprops=dict(marker='o', markersize=3, alpha=0.6)\n",
    "    )\n",
    "    if row_idx == 1:   \n",
    "        ax_box.set_xlabel(\"Accuracy\")\n",
    "    else:\n",
    "        ax_box.set_xlabel(\"\")\n",
    "    ax_box.tick_params(axis=\"y\", length=1, pad=1)\n",
    "    ax_box.tick_params(axis=\"x\", length=1, pad=1)\n",
    "    ax_box.set_ylabel(\"\")\n",
    "    ax_box.invert_yaxis()\n",
    "    ax_box.grid(False)\n",
    "    ax_box.set_yticklabels(ax_box.get_yticklabels(), rotation=30)\n",
    "    for spine in ax_box.spines.values():\n",
    "        spine.set_linewidth(0.3)\n",
    "\n",
    "    # Inset bar inside KDE\n",
    "    ax_inset = ax_kde.inset_axes(inset_rect)\n",
    "    sns.barplot(\n",
    "        data=mean_std,\n",
    "        x=\"model_name\",\n",
    "        y=\"final_val_accuracy_std\",\n",
    "        palette=palette,\n",
    "        order=model_order2,\n",
    "        width=0.5,\n",
    "        ax=ax_inset\n",
    "    )\n",
    "    ax_inset.set_xlabel(\"Models\")\n",
    "    ax_inset.set_ylabel(\"\")\n",
    "    ax_inset.set_title(\"\")\n",
    "    for spine in ax_inset.spines.values():\n",
    "        spine.set_linewidth(0.4)\n",
    "    ax_inset.set_xticklabels([])\n",
    "    ax_inset.tick_params(axis=\"y\", length=1, pad=1)\n",
    "    ax_inset.tick_params(axis=\"x\", length=1, pad=1)\n",
    "    ax_inset.set_yticklabels(ax_inset.get_yticklabels(), rotation=45)\n",
    "    for spine in ax_inset.spines.values():\n",
    "        spine.set_linewidth(0.3)\n",
    "\n",
    "    lims = inset_limits_by_position.get((row_idx, col_idx))\n",
    "    if lims:\n",
    "        ax_inset.set_ylim(*lims[\"ylim\"])\n",
    "        ax_inset.set_yticks(lims[\"yticks\"])\n",
    "\n",
    "    add_texts(ax_kde, ax_box, row_idx, col_idx)\n",
    "\n",
    "fig = plt.figure(figsize=(5.5, 4), dpi=600)\n",
    "outer = gridspec.GridSpec(nrows=2, ncols=2, height_ratios=[1, 1], wspace=0.19, hspace=0.15)\n",
    "\n",
    "# Shared x-axis for all KDEs/boxes\n",
    "ref_ax_x.set_xlim(0.4, 1.04)\n",
    "ref_ax_x.set_xticks([0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0])\n",
    "import string\n",
    "\n",
    "N_COLS = 2 \n",
    "\n",
    "def add_panel_label(ax, text, x=-0.1, y=0.995):\n",
    "    ax.text(x, y, f'({text})',\n",
    "            transform=ax.transAxes, ha='left', va='top',\n",
    "            fontsize=4, fontweight='bold', zorder=30,\n",
    "            bbox=dict(facecolor='white', edgecolor='none', alpha=0.9, pad=0.5))\n",
    "\n",
    "    \n",
    "for row_idx, files in enumerate([row1_files, row2_files]):\n",
    "    for col_idx, (csv_path, xlabel) in enumerate(files):\n",
    "\n",
    "        inner = gridspec.GridSpecFromSubplotSpec(\n",
    "            2, 1, subplot_spec=outer[row_idx, col_idx],\n",
    "            height_ratios=[1.3, 1], hspace=0.001\n",
    "        )\n",
    "\n",
    "        # Top axis (KDE)\n",
    "        if ref_ax_x is None:\n",
    "            ax_kde = plt.Subplot(fig, inner[0])\n",
    "            ref_ax_x = ax_kde  # share x across all\n",
    "            ref_ax_y = ax_kde  # share y for normal cells\n",
    "        else:\n",
    "            if (row_idx, col_idx) == (0, 1):\n",
    "                ax_kde = plt.Subplot(fig, inner[0], sharex=ref_ax_x)  # no sharey\n",
    "            else:\n",
    "                ax_kde = plt.Subplot(fig, inner[0], sharex=ref_ax_x, sharey=ref_ax_y)\n",
    "\n",
    "        # Bottom axis (box)\n",
    "        ax_box = plt.Subplot(fig, inner[1], sharex=ax_kde)\n",
    "        fig.add_subplot(ax_kde)\n",
    "        fig.add_subplot(ax_box)\n",
    "\n",
    "        plot_column(ax_kde, ax_box, csv_path, xlabel, row_idx, col_idx)\n",
    "        if col_idx == 0:\n",
    "            add_panel_label(ax_kde, 'a' if row_idx == 0 else 'b')\n",
    "        \n",
    "# Get handles and labels from one of your axes\n",
    "handles, labels = ax_kde.get_legend_handles_labels()\n",
    "\n",
    "fig.legend(\n",
    "    handles, labels,\n",
    "    loc='upper center',\n",
    "    bbox_to_anchor=(0.54, 1.02),\n",
    "    ncol=len(labels),\n",
    "    fontsize=4,\n",
    "    title=\"Model                \",\n",
    "    title_fontsize=4,\n",
    "    frameon=False,\n",
    "    handlelength=3,\n",
    "    handletextpad=2,\n",
    "    borderpad=6,\n",
    "    labelspacing=1\n",
    ")\n",
    "\n",
    "plt.subplots_adjust(top=0.88)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"combined_2x2_gnn_comparison.png\", bbox_inches=\"tight\", dpi=1800)\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d45b2894",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
