{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b90200f",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "import math\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy.stats import wasserstein_distance\n",
    "import random\n",
    "import matplotlib as mpl\n",
    "\n",
    "mpl.rcParams['font.family'] = 'serif'\n",
    "mpl.rcParams['font.serif'] = ['Times New Roman']\n",
    "SEP = \",\"\n",
    "COL_MODEL  = \"model_name\"\n",
    "COL_METRIC = \"final_val_accuracy\"\n",
    "\n",
    "GROUPS = [\n",
    "    (\"Graph Properties\", [\n",
    "        (\"GPT vs Random\",      \"graphproperties_gpt_vs_random.csv\"),\n",
    "        (\"GPT vs GroundTruth\", \"graphproperties_gpt_vs_groundtruth.csv\"),\n",
    "    ]),\n",
    "    (\"Embeddings\", [\n",
    "        (\"GPT vs Random\",      \"embedding_gpt_vs_random.csv\"),\n",
    "        (\"GPT vs GroundTruth\", \"embeddings_gpt_vs_groundtruth.csv\"),\n",
    "    ]),\n",
    "]\n",
    "\n",
    "P_GRID = sorted(set(list(range(10, 101, 10)) + [33, 66, 100]))\n",
    "R = 500\n",
    "SEED = 42\n",
    "random.seed(SEED)\n",
    "np.random.seed(SEED)\n",
    "\n",
    "SHOW_LEGEND_TOP_LEFT_ONLY = False\n",
    "\n",
    "FIGSIZE = (13, 12)\n",
    "DPI = 100\n",
    "LEFT, RIGHT, BOTTOM, TOP = 0.07, 0.98, 0.08, 0.96\n",
    "COL_GAP, ROW_GAP = 0.06, 0.06\n",
    "\n",
    "WASSERSTEIN_HEIGHT = 6.0\n",
    "VIOLIN_HEIGHT      = 6.0\n",
    "INTRA_GAP_FRAC     = 0.11\n",
    "CAPTION_LEFT_OFFSET = 0.02\n",
    "\n",
    "palette = {\n",
    "    \"GIN\": \"#4F91C6\",\n",
    "    \"GAT\": \"#FF9E4D\",\n",
    "    \"GraphSAGE\": \"#5EBF5E\",\n",
    "    \"GCN\": \"#E15B5B\",\n",
    "}\n",
    "cell_text = {\n",
    "    (0, 0): {\"title\": \"GPT\\nvs\\nRandom\",       \"caption_left\": \"Graph Properties\"},\n",
    "    (0, 1): {\"title\": \"GPT\\nvs\\nGround Truth\"},\n",
    "    (1, 0): {\"title\": \"GPT\\nvs\\nRandom\",       \"caption_left\": \"Embeddings\"},\n",
    "    (1, 1): {\"title\": \"GPT\\nvs\\nGround Truth\"},\n",
    "}\n",
    "\n",
    "EXCLUDE_FROM_SHARE = {(0, 1)}\n",
    "\n",
    "def canonicalize_model(name: str) -> str:\n",
    "    if not isinstance(name, str):\n",
    "        return str(name)\n",
    "    key = name.strip().lower()\n",
    "    mapping = {\n",
    "        \"gcnnet\": \"GCN\", \"gcn\": \"GCN\",\n",
    "        \"ginnet\": \"GIN\", \"gin\": \"GIN\",\n",
    "        \"gatnet\": \"GAT\", \"gat\": \"GAT\",\n",
    "        \"graphsagenet\": \"GraphSAGE\", \n",
    "    }\n",
    "    return mapping.get(key, name.strip())\n",
    "\n",
    "def cumulative_stats_over_shuffles(values, p_grid, R=500):\n",
    "    n = len(values)\n",
    "    k_grid = [max(1, math.ceil(p/100 * n)) for p in p_grid]\n",
    "    means_all = {p: [] for p in p_grid}\n",
    "    stds_all  = {p: [] for p in p_grid}\n",
    "    wass_all  = {p: [] for p in p_grid}\n",
    "    for _ in range(R):\n",
    "        y = values[np.random.permutation(n)]\n",
    "        prev = None\n",
    "        for i, p in enumerate(p_grid):\n",
    "            k = k_grid[i]\n",
    "            sub = y[:k]\n",
    "            means_all[p].append(float(np.mean(sub)))\n",
    "            stds_all[p].append(float(np.std(sub, ddof=1)) if k > 1 else 0.0)\n",
    "            if i == 0:\n",
    "                wass_all[p].append(np.nan)\n",
    "            else:\n",
    "                wass_all[p].append(float(wasserstein_distance(prev, sub)))\n",
    "            prev = sub\n",
    "    def agg(dct):\n",
    "        out = {}\n",
    "        for p, arr in dct.items():\n",
    "            a = np.array(arr, float)\n",
    "            out[p] = {\n",
    "                \"avg\": float(np.nanmean(a)),\n",
    "                \"p05\": float(np.nanpercentile(a, 5)),\n",
    "                \"p95\": float(np.nanpercentile(a, 95)),\n",
    "            }\n",
    "        return out\n",
    "    return {\"means\": agg(means_all), \"stds\": agg(stds_all), \"wass\": agg(wass_all)}\n",
    "\n",
    "def read_clean_csv(path):\n",
    "    df = pd.read_csv(path, sep=SEP)\n",
    "    need = {COL_MODEL, COL_METRIC}\n",
    "    missing = need - set(df.columns)\n",
    "    if missing:\n",
    "        raise ValueError(f\"Required columns not found in {path}: {missing}\")\n",
    "    df = df[[COL_MODEL, COL_METRIC]].copy()\n",
    "    df[COL_METRIC] = pd.to_numeric(df[COL_METRIC], errors=\"coerce\")\n",
    "    df = df.dropna(subset=[COL_METRIC]).reset_index(drop=True)\n",
    "    df[COL_MODEL] = df[COL_MODEL].apply(canonicalize_model)\n",
    "    models = list(df[COL_MODEL].dropna().unique())\n",
    "    if not models:\n",
    "        raise ValueError(f\"No models found in {path}.\")\n",
    "    return df, models\n",
    "\n",
    "# Collect panels\n",
    "PANELS = [] \n",
    "all_models_union, seen = [], set()\n",
    "for row_title, pair in GROUPS:\n",
    "    for col_title, path in pair:\n",
    "        df, models = read_clean_csv(path)\n",
    "        results = {}\n",
    "        for m, g in df.groupby(COL_MODEL, sort=False):\n",
    "            vals = g[COL_METRIC].values.astype(float)\n",
    "            results[m] = cumulative_stats_over_shuffles(vals, P_GRID, R=R)\n",
    "        PANELS.append((row_title, col_title, df, models, results))\n",
    "        for m in models:\n",
    "            if m not in seen:\n",
    "                all_models_union.append(m)\n",
    "                seen.add(m)\n",
    "\n",
    "PREFERRED_ORDER = [\"GCN\", \"GIN\", \"GAT\", \"GraphSAGE\"]\n",
    "GLOBAL_VIOLIN_ORDER = (\n",
    "    [m for m in PREFERRED_ORDER if m in all_models_union] +\n",
    "    sorted([m for m in all_models_union if m not in PREFERRED_ORDER])\n",
    ")\n",
    "\n",
    "fallback_cmap = plt.get_cmap(\"tab10\")\n",
    "COLOR_MAP = {}\n",
    "fallback_idx = 0\n",
    "for m in all_models_union:\n",
    "    if m in palette:\n",
    "        COLOR_MAP[m] = palette[m]\n",
    "    else:\n",
    "        COLOR_MAP[m] = fallback_cmap(fallback_idx % fallback_cmap.N)\n",
    "        fallback_idx += 1\n",
    "\n",
    "avail_w = (RIGHT - LEFT) - COL_GAP\n",
    "cell_w  = avail_w / 2.0\n",
    "avail_h = (TOP - BOTTOM) - ROW_GAP\n",
    "cell_h  = avail_h / 2.0\n",
    "\n",
    "x_left  = LEFT\n",
    "x_right = LEFT + cell_w + COL_GAP\n",
    "\n",
    "y_bottom_row2 = BOTTOM\n",
    "y_bottom_row1 = BOTTOM + cell_h + ROW_GAP\n",
    "\n",
    "SLOTS = [\n",
    "    (\"Row1-Left\",  [x_left,  y_bottom_row1, cell_w, cell_h]),  # (0,0)\n",
    "    (\"Row1-Right\", [x_right, y_bottom_row1, cell_w, cell_h]),  # (0,1)\n",
    "    (\"Row2-Left\",  [x_left,  y_bottom_row2, cell_w, cell_h]),  # (1,0)\n",
    "    (\"Row2-Right\", [x_right, y_bottom_row2, cell_w, cell_h]),  # (1,1)\n",
    "]\n",
    "\n",
    "PANELS_ORDERED = PANELS\n",
    "ASSIGNMENT = [0, 1, 2, 3]\n",
    "\n",
    "def plot_wasserstein(ax, results, color_map, show_legend=False):\n",
    "    x = P_GRID\n",
    "    ymin, ymax = np.inf, -np.inf\n",
    "    for m, res in results.items():\n",
    "        y  = np.array([res[\"wass\"][p][\"avg\"] for p in x], float)\n",
    "        lo = np.array([res[\"wass\"][p][\"p05\"] for p in x], float)\n",
    "        hi = np.array([res[\"wass\"][p][\"p95\"] for p in x], float)\n",
    "        ax.plot(x, y, marker=\"o\", label=m, color=color_map[m], linewidth=1.6, markersize=4.5)\n",
    "        ax.fill_between(x, lo, hi, alpha=0.15, color=color_map[m])\n",
    "        ymin = min(ymin, np.nanmin(np.minimum(y, lo)))\n",
    "        ymax = max(ymax, np.nanmax(np.maximum(y, hi)))\n",
    "    ax.set_xlabel(\"% of runs\")\n",
    "    ax.set_ylabel(\"Density\")\n",
    "    ax.grid(False)\n",
    "    if show_legend:\n",
    "        ax.legend(ncol=2, fontsize=9.5, frameon=False)\n",
    "    ax.set_xlim(20, 100)\n",
    "    return ymin, ymax\n",
    "\n",
    "def plot_violin(ax, df, models, global_order):\n",
    "    # Build consistent model order (only those present in this panel)\n",
    "    order = [m for m in global_order if m in models]\n",
    "    data  = [df.loc[df[COL_MODEL] == m, COL_METRIC].values for m in order]\n",
    "\n",
    "    keep = [i for i, arr in enumerate(data) if len(arr) > 0]\n",
    "    order = [order[i] for i in keep]\n",
    "    data  = [data[i]  for i in keep]\n",
    "\n",
    "    if not order:\n",
    "        ax.cla()\n",
    "        ax.text(0.5, 0.5, \"No models found\", ha=\"center\", va=\"center\", transform=ax.transAxes)\n",
    "        return ax.get_ylim()\n",
    "\n",
    "    positions = np.arange(1, len(order) + 1)\n",
    "    parts = ax.violinplot(\n",
    "        data,\n",
    "        positions=positions,\n",
    "        vert=True,          \n",
    "        showmeans=True,\n",
    "        showmedians=True,\n",
    "        widths=0.9\n",
    "    )\n",
    "\n",
    "    # Color each violin using your palette\n",
    "    for i, body in enumerate(parts[\"bodies\"]):\n",
    "        m = order[i]\n",
    "        color = COLOR_MAP.get(m, \"#888888\")\n",
    "        body.set_facecolor(color)\n",
    "        body.set_edgecolor(color)\n",
    "        body.set_alpha(0.9)\n",
    "        body.set_linewidth(1.0)\n",
    "\n",
    "    for key in (\"cmeans\", \"cmedians\", \"cmins\", \"cmaxes\", \"cbars\"):\n",
    "        artist = parts.get(key)\n",
    "        if artist is not None:\n",
    "            try:\n",
    "                artist.set_color(\"#333333\")\n",
    "                artist.set_linewidth(1.0)\n",
    "            except Exception:\n",
    "                pass\n",
    "\n",
    "    ax.set_xticks(positions)\n",
    "    ax.set_xticklabels(order, rotation=25, ha=\"right\", fontsize=9.5)\n",
    "\n",
    "    # Cosmetics\n",
    "    ax.set_ylabel(\"Accuracy\")\n",
    "    ax.grid(False)\n",
    "    ax.set_ylim(0.4, 1.0)  \n",
    "\n",
    "    return ax.get_ylim()\n",
    "\n",
    "\n",
    "def annotate_wasserstein_corner(ax, text, fontsize=9.5):\n",
    "    ax.text(\n",
    "        0.98, 0.98, text,\n",
    "        transform=ax.transAxes,\n",
    "        ha=\"right\", va=\"top\",\n",
    "        fontsize=fontsize,\n",
    "        bbox=dict(boxstyle=\"round,pad=0.2\", facecolor=\"white\", alpha=0.65, edgecolor=\"none\")\n",
    "    )\n",
    "\n",
    "fig = plt.figure(figsize=FIGSIZE, dpi=DPI)\n",
    "\n",
    "ratio_sum = WASSERSTEIN_HEIGHT + VIOLIN_HEIGHT\n",
    "cell_intra_gap = INTRA_GAP_FRAC * cell_h\n",
    "h_w = max(cell_h * (WASSERSTEIN_HEIGHT / ratio_sum) - (cell_intra_gap / 2.0), 0.02)\n",
    "h_v = max(cell_h * (VIOLIN_HEIGHT      / ratio_sum) - (cell_intra_gap / 2.0), 0.02)\n",
    "\n",
    "# Store axes WITH their (row_idx, col_idx) so we can exclude (0,1)\n",
    "wasser_axes_by_row = {0: [], 1: []}         \n",
    "wasser_y_extents_by_row = {0: [], 1: []}     \n",
    "violin_axes_by_row = {0: [], 1: []}       \n",
    "violin_y_extents_by_row = {0: [], 1: []}\n",
    "\n",
    "for slot_idx, (slot_name, rect) in enumerate(SLOTS):\n",
    "    panel_idx = ASSIGNMENT[slot_idx]\n",
    "    row_title, col_title, df, models, results = PANELS_ORDERED[panel_idx]\n",
    "\n",
    "    row_idx = 0 if slot_idx < 2 else 1\n",
    "    col_idx = 0 if (slot_idx % 2 == 0) else 1\n",
    "\n",
    "    left, bottom, width, height = rect\n",
    "    ax_v = fig.add_axes([left, bottom, width, h_v])                        # violin (bottom)\n",
    "    ax_w = fig.add_axes([left, bottom + h_v + cell_intra_gap, width, h_w]) # Wasserstein (top)\n",
    "\n",
    "    # Wasserstein\n",
    "    ymin, ymax = plot_wasserstein(\n",
    "        ax=ax_w,\n",
    "        results=results,\n",
    "        color_map=COLOR_MAP,\n",
    "        show_legend=(slot_idx == 0 and SHOW_LEGEND_TOP_LEFT_ONLY)\n",
    "    )\n",
    "  \n",
    "    if col_idx == 0:\n",
    "        ax_w.set_xlim(20, max(P_GRID))\n",
    "\n",
    "    wasser_axes_by_row[row_idx].append((ax_w, (row_idx, col_idx)))\n",
    "    wasser_y_extents_by_row[row_idx].append(((ymin, ymax), (row_idx, col_idx)))\n",
    "\n",
    "    ctext = cell_text.get((row_idx, col_idx), {})\n",
    "    if \"title\" in ctext:\n",
    "        annotate_wasserstein_corner(ax_w, ctext[\"title\"], fontsize=9.5)\n",
    "\n",
    "    if col_idx == 0 and \"caption_left\" in ctext:\n",
    "        row_center_y = bottom + height * 0.8\n",
    "        fig.text(\n",
    "            max(0.01, -0.9),\n",
    "            row_center_y,\n",
    "            ctext[\"caption_left\"],\n",
    "            ha=\"right\", va=\"center\",\n",
    "            rotation=90,           \n",
    "            fontsize=10\n",
    "        )\n",
    "      \n",
    "\n",
    "    # Violin\n",
    "    v_ylim = plot_violin(ax_v, df, models=models, global_order=GLOBAL_VIOLIN_ORDER)\n",
    "    violin_axes_by_row[row_idx].append((ax_v, (row_idx, col_idx)))\n",
    "    violin_y_extents_by_row[row_idx].append((v_ylim, (row_idx, col_idx)))\n",
    "\n",
    "kept_axes = []\n",
    "kept_exts = []\n",
    "\n",
    "for r in [0, 1]:\n",
    "\n",
    "    for (ax, pos), (ext, pos2) in zip(wasser_axes_by_row[r], wasser_y_extents_by_row[r]):\n",
    "        if pos not in EXCLUDE_FROM_SHARE:\n",
    "            kept_axes.append(ax)\n",
    "            kept_exts.append(ext)\n",
    "\n",
    "if kept_axes and kept_exts:\n",
    "    ymins = [e[0] for e in kept_exts if np.isfinite(e[0])]\n",
    "    ymaxs = [e[1] for e in kept_exts if np.isfinite(e[1])]\n",
    "    if ymins and ymaxs:\n",
    "        common_ylim = (min(ymins), max(ymaxs))\n",
    "        for ax in kept_axes:\n",
    "            ax.set_ylim(common_ylim)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "for r in [0, 1]:\n",
    "    kept_v_axes = [(ax, pos) for (ax, pos) in violin_axes_by_row[r]\n",
    "                   if pos not in EXCLUDE_FROM_SHARE]\n",
    "    if kept_v_axes:\n",
    "        for ax, _ in kept_v_axes:\n",
    "            ax.set_ylim(0.4, 1.0)  \n",
    "\n",
    "            \n",
    "import matplotlib.patches as mpatches\n",
    "\n",
    "legend_handles = [\n",
    "    mpatches.Patch(color=palette[\"GIN\"],        label=\"GIN\"),\n",
    "    mpatches.Patch(color=palette[\"GAT\"],        label=\"GAT\"),\n",
    "    mpatches.Patch(color=palette[\"GraphSAGE\"],  label=\"GraphSAGE\"),\n",
    "    mpatches.Patch(color=palette[\"GCN\"],        label=\"GCN\"),\n",
    "]\n",
    "\n",
    "# Add one legend above all plots\n",
    "fig.legend(\n",
    "    handles=legend_handles,\n",
    "    loc=\"upper center\",\n",
    "    bbox_to_anchor=(0.53, 1), \n",
    "    ncol=4,\n",
    "    frameon=False,\n",
    "    fontsize=10\n",
    ")\n",
    "out_path = \"manual_layout_pairs_2x2.png\"\n",
    "plt.savefig(out_path, dpi = 100, bbox_inches=\"tight\")\n",
    "print(f\"Saved: {out_path}\")\n",
    "\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "af03de31",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
