{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import math\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import plot_settings as plot_settings\n",
    "from PIL import Image\n",
    "from IPython.display import clear_output\n",
    "clear_output(wait=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_settings.set_latex_settings()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Efficiency"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import math\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import plot_settings as plot_settings\n",
    "from PIL import Image\n",
    "\n",
    "plot_settings.set_latex_settings()\n",
    "\n",
    "shapes = Image.open(r\"images/Picture1.png\")\n",
    "mpi = Image.open(r\"images/Picture2.png\")\n",
    "cars = Image.open(r\"images/Picture3.png\")\n",
    "dsprites = Image.open(r\"images/Picture4.png\")\n",
    "raven = Image.open(r\"images/Picture5.png\")\n",
    "clevr = Image.open(r\"images/Picture6.png\")\n",
    "cub = Image.open(r\"images/Picture7.jpg\")\n",
    "\n",
    "# Define x values (number of attributes)\n",
    "x = np.arange(2, 31)\n",
    "\n",
    "# Compute y values (number of combinations: 2^n - 1)\n",
    "y_comb = [math.comb(n, 2) for n in x]\n",
    "y_const = [1 for _ in x]\n",
    "\n",
    "# Create the plot\n",
    "figure1, ax1 = plt.subplots(1, 1, figsize=(plot_settings.column_width*3/4, plot_settings.column_width*3/4))\n",
    "ax1.plot(x, y_comb, label=r'Pair-wise evaluation', color=\"#a00000\",)\n",
    "ax1.plot(x, y_const, label=r'Orthotopic evaluation', color=\"#298c8c\", )\n",
    "\n",
    "plt.axvline(x = 3, ymin = 0, ymax = 400, linestyle=\":\", color=(0, 0, 0, 0.4))\n",
    "plt.axvline(x = 4, ymin = 0, ymax = 400, linestyle=\":\", color=(0, 0, 0, 0.4))\n",
    "plt.axvline(x = 5, ymin = 0, ymax = 400, linestyle=\":\", color=(0, 0, 0, 0.4))\n",
    "plt.axvline(x = 6, ymin = 0, ymax = 400, linestyle=\":\", color=(0, 0, 0, 0.4))\n",
    "plt.axvline(x = 24, ymin = 0, ymax = 400, linestyle=\":\", color=(0, 0, 0, 0.4))\n",
    "\n",
    "# Log scale for y-axis\n",
    "# plt.yscale('log')\n",
    "plt.xscale('log')\n",
    "ax1.set_xticks([2,3,4,5,6,7,8,9,10,20,30])\n",
    "ax1.set_xticklabels([2,3,4,5,6,7,8,9,10,20,30])\n",
    "ax1.set_yticklabels([0, 1, 100, 200, 300, 400])\n",
    "ax1.tick_params(axis=\"x\", direction=\"in\")\n",
    "ax1.tick_params(axis=\"y\", direction=\"in\")\n",
    "\n",
    "# Labels and title\n",
    "plt.xlabel(\"$P$\", fontsize=16)\n",
    "plt.ylabel(\"Complexity\", fontsize=16)\n",
    "# plt.title(\"Efficiency of the proposed evaluation scheme\", fontsize=14)\n",
    "plt.legend(bbox_to_anchor=(0.5, -0.37), loc=\"lower center\", fontsize=14)\n",
    "\n",
    "\n",
    "ax_image = figure1.add_axes([0.30,0.76,0.08,0.08])\n",
    "ax_image.imshow(clevr)\n",
    "ax_image.set_title(\"CLEVR\", y=-.8, fontsize=14)\n",
    "ax_image.axis('off')\n",
    "\n",
    "ax_image = figure1.add_axes([0.36,0.6,0.08,0.08])\n",
    "ax_image.imshow(shapes)\n",
    "ax_image.set_title(\"Shapes3D\", y=-.8, fontsize=14)\n",
    "ax_image.axis('off')\n",
    "\n",
    "ax_image = figure1.add_axes([0.23,0.52,0.08,0.08])\n",
    "ax_image.imshow(cars)\n",
    "ax_image.set_title(\"Cars3D\", y=-.8, fontsize=14)\n",
    "ax_image.axis('off')\n",
    "\n",
    "\n",
    "\n",
    "ax_image = figure1.add_axes([0.36,0.37,0.08,0.08])\n",
    "ax_image.imshow(mpi)\n",
    "ax_image.set_title(\"MPI3D\", y=-.8, fontsize=14)\n",
    "ax_image.axis('off')\n",
    "\n",
    "\n",
    "ax_image = figure1.add_axes([0.23,0.36,0.08,0.08])\n",
    "ax_image.imshow(raven)\n",
    "ax_image.set_title(\"Raven\", y=-.8, fontsize=14)\n",
    "ax_image.axis('off')\n",
    "\n",
    "ax_image = figure1.add_axes([0.23,0.22,0.08,0.08])\n",
    "ax_image.imshow(dsprites)\n",
    "ax_image.set_title(\"dSprites\", y=-.8, fontsize=14)\n",
    "ax_image.axis('off')\n",
    "\n",
    "ax_image = figure1.add_axes([0.77,0.3,0.08,0.08])\n",
    "ax_image.imshow(cub)\n",
    "ax_image.set_title(\"CUB\", y=-.8, fontsize=14)\n",
    "ax_image.axis('off')\n",
    "\n",
    "# Grid for readability\n",
    "# plt.grid(True, which=\"both\", linestyle=\"--\", linewidth=0.5)\n",
    "\n",
    "# Show the plot\n",
    "figure1.savefig(\"results/efficiency.pgf\", bbox_inches=\"tight\")\n",
    "!./pgf_compiler.sh efficiency"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Similarity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch \n",
    "import math\n",
    "from numpy import dot\n",
    "from numpy.linalg import norm\n",
    "\n",
    "attributes = 6  # Dimension of the tensors\n",
    "num_samples = 100  # Number of random trials per c\n",
    "c_values = np.arange(0, attributes + 1)  # Values of c\n",
    "values = 8\n",
    "figure1, ax1 = plt.subplots(1, 1, figsize=(plot_settings.column_width*3/4, plot_settings.column_width*3/4))\n",
    "\n",
    "colors = {\n",
    "    0: \"#a00000\",              # Deep Purple\n",
    "    1: \"#5B2C6F\",              # Deep Purple\n",
    "    2: \"#2874A6\",  # Steel Blue\n",
    "    3: \"#148F77\",   # Dark Cyan\n",
    "    4: \"#D4AC0D\",            # Amber Gold\n",
    "}\n",
    "codebook_comp = torch.randn(values, 2048)\n",
    "codebook = torch.randn(int(math.pow(values, attributes)), 1024)\n",
    "for r in range(0,5):\n",
    "    if r == 1:\n",
    "        continue\n",
    "    mean_similarities = []\n",
    "    variance_similarities = []\n",
    "    \n",
    "    for c in c_values:\n",
    "        similarities = []\n",
    "        for _ in range(num_samples):\n",
    "            a = np.random.randint(0, values, attributes)\n",
    "            b = np.random.randint(0, values, attributes)\n",
    "            # shared_indices = np.random.choice(attributes, c, replace=False)\n",
    "            for i in range(len(a)):\n",
    "                if i < c:\n",
    "                    b[i] = a[i]\n",
    "                else:\n",
    "                    b[i] = (a[i]+1) % values\n",
    "            def encode(l):\n",
    "                return int(sum([l[i]*math.pow(values, i) for i in range(len(l)-r)]))\n",
    "            if r > 0:\n",
    "                a_hol = codebook[encode(a)]\n",
    "                b_hol = codebook[encode(b)]\n",
    "                a_hol = torch.cat([a_hol] + [codebook_comp[a[-tmp]] for tmp in range(1,r+1)])\n",
    "                b_hol = torch.cat([b_hol] + [codebook_comp[b[-tmp]] for tmp in range(1,r+1)])\n",
    "            else:\n",
    "                a_hol = codebook[encode(a)]\n",
    "                b_hol = codebook[encode(b)]\n",
    "            sim = dot(a_hol, b_hol)/(norm(a_hol)*norm(b_hol))\n",
    "            similarities.append(sim)\n",
    "        mean_similarities.append(np.clip(np.mean(similarities), a_min=0, a_max=1))\n",
    "        variance_similarities.append(np.clip(np.std(similarities), a_min=0, a_max=1))\n",
    "    mean_similarities = np.array(mean_similarities)\n",
    "    variance_similarities = np.array(variance_similarities)\n",
    "    # ax1.plot(c_values, mean_similarities, color=\"#a00000\", label=\"Holistic \")\n",
    "    # lo = np.clip(mean_similarities - variance_similarities, a_min=0, a_max=1) \n",
    "    # up = np.clip(mean_similarities + variance_similarities, a_min=0, a_max=1) \n",
    "    # ax1.fill_between(c_values, lo, up, color=\"#a00000\")\n",
    "    ax1.plot(c_values,mean_similarities, \"d\", ls=\"-\", color=colors[r])\n",
    "\n",
    "ax1.text(4.6, 0.15, \"$n=P-1$\", rotation=78, fontsize=12, color=colors[0], va='center', ha='left')\n",
    "# ax1.text(3.7, 0.15, \"$n=P-2$\", rotation=63, fontsize=12, color=colors[1], va='center', ha='left')\n",
    "ax1.text(2.7, 0.10, \"$n=4$\", rotation=50, fontsize=12, color=colors[2], va='center', ha='left')\n",
    "ax1.text(1.8, 0.08, \"$n=3$\", rotation=40, fontsize=12, color=colors[3], va='center', ha='left')\n",
    "ax1.text(.9, 0.08, \"$n=2$\", rotation=30, fontsize=12, color=colors[4], va='center', ha='left')\n",
    "hol = plt.Line2D([], [], color='k', marker='d', linestyle='-', label='Holistic representation')\n",
    "comp = plt.Line2D([], [], color='k', marker='o', linestyle='-', label='Compositional representation')\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "# concat compositional representations\n",
    "mean_similarities = []\n",
    "variance_similarities = []\n",
    "for c in c_values:\n",
    "    similarities = []\n",
    "\n",
    "    for _ in range(num_samples):\n",
    "        a = np.random.randint(0, values, attributes)\n",
    "        b = np.random.randint(0, values, attributes)\n",
    "        shared_indices = np.random.choice(attributes, c, replace=False)\n",
    "        for i in range(len(a)):\n",
    "            if i in shared_indices:\n",
    "                b[i] = a[i]\n",
    "            else:\n",
    "                b[i] = (a[i]+2) % values\n",
    "        codebook = torch.randn(values, 64)\n",
    "        a_conc_rep = torch.cat([codebook[i,:] for i in a])\n",
    "        b_conc_rep = torch.cat([codebook[i,:] for i in b])\n",
    "        sim = dot(a_conc_rep, b_conc_rep)/(norm(a_conc_rep)*norm(b_conc_rep))\n",
    "        similarities.append(sim)\n",
    "    mean_similarities.append(np.mean(similarities))\n",
    "    variance_similarities.append(np.std(similarities))\n",
    "\n",
    "mean_similarities = np.array(mean_similarities)\n",
    "variance_similarities = np.array(variance_similarities)\n",
    "# comp = ax1.plot(c_values, mean_similarities, \"c\", color=\"#298c8c\", label=\"Concatenative compositional\")\n",
    "# lo = np.clip(mean_similarities - variance_similarities, a_min=0, a_max=1) \n",
    "# up = np.clip(mean_similarities + variance_similarities, a_min=0, a_max=1) \n",
    "# ax1.fill_between(c_values, lo, up, color=\"#298c8c\")\n",
    "ax1.plot(c_values,mean_similarities, linestyle='-', marker=\"o\", color=\"k\", label=\"Compositional representation\")\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "ax1.set_xticks([0,1,2,3,4,5,6])\n",
    "ax1.set_xticklabels([0,1,2,3,\"$\\dots\",\"$P-1$\",\"$P$\"])\n",
    "\n",
    "ax1.axvline(x=0.5, linestyle='--', color='#298c8c', linewidth=0.8)\n",
    "ax1.axvline(x=1.5, linestyle='--', color='#298c8c', linewidth=0.8)\n",
    "ax1.axvline(x=5.5, linestyle='--', color='#298c8c', linewidth=0.8)\n",
    "\n",
    "ax1.text(0.07, 0.5, \"extrapolation\", rotation=90, ha='center', va='center', fontsize=16,                 color='#298c8c', transform=ax1.transAxes)\n",
    "ax1.text(0.21, 0.65, \"comp. generalization\", rotation=90, ha='center', va='center', fontsize=16, color='#298c8c', transform=ax1.transAxes)\n",
    "ax1.text(0.49, 0.7, \"weak comp. \\ngeneralization\", rotation=45, ha='center', va='center', fontsize=16,  color='#298c8c', transform=ax1.transAxes)\n",
    "ax1.text(0.93, 0.5, \"in-distribution\", rotation=90, ha='center', va='center', fontsize=16,               color='#298c8c', transform=ax1.transAxes)\n",
    "\n",
    "# style\n",
    "plt.legend(handles=[hol,comp], bbox_to_anchor=(0.5, -0.35), loc=\"lower center\", fontsize=14)\n",
    "plt.xlim(-0.5,6.5)\n",
    "ax1.tick_params(axis=\"x\", direction=\"in\")\n",
    "ax1.tick_params(axis=\"y\", direction=\"in\")\n",
    "plt.xlabel(\"$c$\", fontsize=16)\n",
    "plt.ylabel(\"Cosine similarity\", fontsize=16)\n",
    "figure1.savefig(\"results/similarity.pgf\", bbox_inches=\"tight\")\n",
    "plt.close()\n",
    "!./pgf_compiler.sh similarity"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Grokking"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "splits = {\n",
    "    \"or_el\": \"Orientation-Elevation\",\n",
    "    \"el_ty\": \"Elevation-Type\",\n",
    "    \"or_ty\": \"Orientation-Type\"\n",
    "}\n",
    "\n",
    "\n",
    "for code, name in splits.items():\n",
    "    file_path = f'grokking_data/cars_{code}.csv'\n",
    "    train_acc = []\n",
    "    wio_acc = []\n",
    "    val_acc = []\n",
    "    test_acc = []\n",
    "    df = pd.read_csv(file_path)\n",
    "    df = df[[\"Step\",\"Grouped runs - train_acc__MAX\", \"Grouped runs - val_acc__MAX\", \"Grouped runs - test_acc__MAX\", \"Grouped runs - wio_acc__MAX\"]]\n",
    "    df.columns = df.columns.str.replace(r'^Grouped runs - ', '', regex=True)\n",
    "    df = df.iloc[::10].reset_index(drop=True)\n",
    "    df1 = df.rolling(window=100, min_periods=1).mean()\n",
    "    df2 = df.rolling(window=1000, min_periods=1).mean()\n",
    "\n",
    "\n",
    "\n",
    "    figure1, ax1 = plt.subplots(1, 1, figsize=(16,5))#, figsize=(plot_settings.column_width*3/4, plot_settings.column_width*3/4))\n",
    "    colors = {\n",
    "        'train': '#5F8B4C',\n",
    "        'val':   '#FFDDAB',\n",
    "        'wio':   '#FF9A9A',\n",
    "        'test':  '#945034' \n",
    "    }\n",
    "    ax1.scatter(df1[\"Step\"], df1[\"train_acc__MAX\"], label='Train Accuracy', color=colors['train'], marker='o', s=1.2)\n",
    "    ax1.scatter(df1[\"Step\"], df1[\"val_acc__MAX\"], label='Validation Accuracy', color=colors['val'], marker='o', s=1.2)\n",
    "    ax1.scatter(df1[\"Step\"], df1[\"wio_acc__MAX\"], label='WIO Accuracy', color=colors['wio'], marker='o', s=1.2)\n",
    "    ax1.scatter(df1[\"Step\"], df1[\"test_acc__MAX\"], label='Test Accuracy', color=colors['test'], marker='o', s=1.2)\n",
    "    ax1.tick_params(axis='both', which='major', labelsize=16)\n",
    "    ax1.tick_params(axis='both', which='minor', labelsize=16)\n",
    "\n",
    "    plt.xlabel('Epoch', fontsize=16)\n",
    "    plt.ylabel('Accuracy', fontsize=16)\n",
    "    plt.legend(markerscale=5, fontsize=16)\n",
    "    plt.savefig(f\"results/grokking_{code}.pdf\", bbox_inches=\"tight\")\n",
    "# figure1.savefig(\"results/grokking.pgf\", bbox_inches=\"tight\")\n",
    "# !./pgf_compiler.sh grokking\n",
    " "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "\n",
    "\n",
    "file_path = f'grokking_data/dsprites.csv'\n",
    "train_acc = []\n",
    "wio_acc = []\n",
    "val_acc = []\n",
    "test_acc = []\n",
    "df = pd.read_csv(file_path)\n",
    "df = df[[\"Step\",\"Grouped runs - train_acc__MAX\", \"Grouped runs - val_acc__MAX\", \"Grouped runs - test_acc__MAX\", \"Grouped runs - wio_acc__MAX\"]]\n",
    "df.columns = df.columns.str.replace(r'^Grouped runs - ', '', regex=True)\n",
    "df1 = df.rolling(window=100, min_periods=1).mean()\n",
    "df2 = df.rolling(window=1000, min_periods=1).mean()\n",
    "\n",
    "\n",
    "\n",
    "figure1, ax1 = plt.subplots(1, 1, figsize=(16,5))#, figsize=(plot_settings.column_width*3/4, plot_settings.column_width*3/4))\n",
    "colors = {\n",
    "    'train': '#5F8B4C',\n",
    "    'val':   '#FFDDAB',\n",
    "    'wio':   '#FF9A9A',\n",
    "    'test':  '#945034' \n",
    "}\n",
    "ax1.scatter(df1[\"Step\"], df1[\"train_acc__MAX\"], label='Train Accuracy', color=colors['train'], marker='o', s=1.2)\n",
    "ax1.scatter(df1[\"Step\"], df1[\"val_acc__MAX\"], label='Validation Accuracy', color=colors['val'], marker='o', s=1.2)\n",
    "ax1.scatter(df1[\"Step\"], df1[\"wio_acc__MAX\"], label='WIO Accuracy', color=colors['wio'], marker='o', s=1.2)\n",
    "ax1.scatter(df1[\"Step\"], df1[\"test_acc__MAX\"], label='Test Accuracy', color=colors['test'], marker='o', s=1.2)\n",
    "ax1.tick_params(axis='both', which='major', labelsize=16)\n",
    "ax1.tick_params(axis='both', which='minor', labelsize=16)\n",
    "\n",
    "plt.xlabel('Epoch', fontsize=16)\n",
    "plt.ylabel('Accuracy', fontsize=16)\n",
    "plt.legend(markerscale=5, fontsize=16)\n",
    "plt.savefig(f\"results/grokking_dsprites.pdf\", bbox_inches=\"tight\")\n",
    " "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "file_path = f'grokking_data/iraven.csv'\n",
    "train_acc = []\n",
    "wio_acc = []\n",
    "val_acc = []\n",
    "test_acc = []\n",
    "df = pd.read_csv(file_path)\n",
    "df = df[[\"Step\",\"Grouped runs - train_acc__MAX\", \"Grouped runs - val_acc__MAX\", \"Grouped runs - test_acc__MAX\", \"Grouped runs - wio_acc__MAX\"]]\n",
    "df.columns = df.columns.str.replace(r'^Grouped runs - ', '', regex=True)\n",
    "df1 = df.rolling(window=100, min_periods=1).mean()\n",
    "df2 = df.rolling(window=1000, min_periods=1).mean()\n",
    "\n",
    "\n",
    "\n",
    "figure1, ax1 = plt.subplots(1, 1, figsize=(16,5))#, figsize=(plot_settings.column_width*3/4, plot_settings.column_width*3/4))\n",
    "colors = {\n",
    "    'train': '#5F8B4C',\n",
    "    'val':   '#FFDDAB',\n",
    "    'wio':   '#FF9A9A',\n",
    "    'test':  '#945034' \n",
    "}\n",
    "ax1.scatter(df1[\"Step\"], df1[\"train_acc__MAX\"], label='Train Accuracy', color=colors['train'], marker='o', s=1.2)\n",
    "ax1.scatter(df1[\"Step\"], df1[\"val_acc__MAX\"], label='Validation Accuracy', color=colors['val'], marker='o', s=1.2)\n",
    "ax1.scatter(df1[\"Step\"], df1[\"wio_acc__MAX\"], label='WIO Accuracy', color=colors['wio'], marker='o', s=1.2)\n",
    "ax1.scatter(df1[\"Step\"], df1[\"test_acc__MAX\"], label='Test Accuracy', color=colors['test'], marker='o', s=1.2)\n",
    "ax1.tick_params(axis='both', which='major', labelsize=16)\n",
    "ax1.tick_params(axis='both', which='minor', labelsize=16)\n",
    "\n",
    "plt.xlabel('Epoch', fontsize=16)\n",
    "plt.ylabel('Accuracy', fontsize=16)\n",
    "plt.legend(markerscale=5, fontsize=16)\n",
    "plt.savefig(f\"results/grokking_iraven.pdf\", bbox_inches=\"tight\")\n",
    " "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Selection metric"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "id_sel=np.load(\"selection/id_sel.npy\" )\n",
    "wio_sel=np.load(\"selection/wio_sel.npy\")\n",
    "ood_sel=np.load(\"selection/ood_sel.npy\")\n",
    "oracle=np.load(\"selection/oracle.npy\" )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mean = np.mean(oracle-id_sel)\n",
    "figure1, ax1 = plt.subplots(1, 1, figsize=(plot_settings.column_width*3/4, plot_settings.column_width*3/4))\n",
    "ax1.hist(oracle-id_sel, bins=80, alpha=0.7, color=\"#FF9A9A\")\n",
    "plt.axvline(x = mean, ymin = 0, ymax = 400, linestyle=\":\", color=\"r\")\n",
    "plt.yscale(\"log\")\n",
    "plt.xlabel(\"Accuracy $\\Delta$\", fontsize=16)\n",
    "plt.ylabel(\"Number of experiments\", fontsize=16)\n",
    "figure1.savefig(\"results/selection_idvswio.pgf\", bbox_inches=\"tight\")\n",
    "ax1.tick_params(axis='both', which='major', labelsize=14)\n",
    "ax1.tick_params(axis='both', which='minor', labelsize=14)\n",
    "ax1.text(0.35, 0.75, f\"$\\mu={mean:.2f}$\\%\", ha='center', va='center', fontsize=16, transform=ax1.transAxes)\n",
    "figure1.savefig(\"results/selection_idvswio.pgf\", bbox_inches=\"tight\")\n",
    "!./pgf_compiler.sh selection_idvswio"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mean = np.mean(oracle-ood_sel)\n",
    "figure1, ax1 = plt.subplots(1, 1, figsize=(plot_settings.column_width*3/4, plot_settings.column_width*3/4))\n",
    "ax1.hist(oracle-ood_sel, bins=80, alpha=0.7, color=\"#FFDDAB\")\n",
    "plt.axvline(x = mean, ymin = 0, ymax = 400, linestyle=\":\", color=\"r\")\n",
    "plt.yscale(\"log\")\n",
    "plt.xlabel(\"Accuracy $\\Delta$\", fontsize=16)\n",
    "plt.ylabel(\"Number of experiments\", fontsize=16)\n",
    "figure1.savefig(\"results/selection_idvswio.pgf\", bbox_inches=\"tight\")\n",
    "ax1.tick_params(axis='both', which='major', labelsize=14)\n",
    "ax1.tick_params(axis='both', which='minor', labelsize=14)\n",
    "ax1.text(0.4, 0.75, f\"$\\mu={mean:.2f}$\\%\", ha='center', va='center', fontsize=16, transform=ax1.transAxes)\n",
    "figure1.savefig(\"results/selection_oodvswio.pgf\", bbox_inches=\"tight\")\n",
    "!./pgf_compiler.sh selection_oodvswio"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mean = np.mean(oracle-wio_sel)\n",
    "figure1, ax1 = plt.subplots(1, 1, figsize=(plot_settings.column_width*3/4, plot_settings.column_width*3/4))\n",
    "ax1.hist(oracle-wio_sel, bins=80, alpha=0.7, color=\"#945034\")\n",
    "plt.axvline(x = mean, ymin = 0, ymax = 400, linestyle=\":\", color=\"r\")\n",
    "plt.yscale(\"log\")\n",
    "plt.xlabel(\"Accuracy $\\Delta$\", fontsize=16)\n",
    "plt.ylabel(\"Number of experiments\", fontsize=16)\n",
    "figure1.savefig(\"results/selection_wiovsoracle.pgf\", bbox_inches=\"tight\")\n",
    "ax1.text(0.35, 0.75, f\"$\\mu={mean:.2f}$\\%\", ha='center', va='center', fontsize=16, transform=ax1.transAxes)\n",
    "# ax1.set_xlim(-110, 110)\n",
    "ax1.tick_params(axis='both', which='major', labelsize=14)\n",
    "ax1.tick_params(axis='both', which='minor', labelsize=14)\n",
    "figure1.savefig(\"results/selection_oraclevswio.pgf\", bbox_inches=\"tight\")\n",
    "!./pgf_compiler.sh selection_oraclevswio"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Pairwise"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "METRICS = [\n",
    "    \"train_acc\",\n",
    "    \"val_acc\",\n",
    "    \"ood_val_0_acc\",\n",
    "    \"test_acc\",\n",
    "]\n",
    "rdf = pd.read_pickle(\"pairwise/mpi3d.pkl\")\n",
    "\n",
    "group_columns = [\"arch\"]\n",
    "res = rdf.groupby(group_columns)[METRICS].agg(['mean', 'sem']).reset_index()\n",
    "\n",
    "res[[(col, 'mean') for col in METRICS] + [(col, 'sem') for col in METRICS]] = (\n",
    "    res[[(col, 'mean') for col in METRICS] + [(col, 'sem') for col in METRICS]].round(2)\n",
    ")\n",
    "print(res.to_latex(index=False,\n",
    "    formatters={\"name\": str.upper},\n",
    "    float_format=\"{:.2f}\".format,\n",
    "))  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(pd.read_pickle(\"pairwise/mpi3d.pkl\"))+len(pd.read_pickle(\"pairwise/shapes3d.pkl\"))+len(pd.read_pickle(\"pairwise/cars3d.pkl\"))+len(pd.read_pickle(\"pairwise/dsprites.pkl\"))+len(pd.read_pickle(\"pairwise/.pkl\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import seaborn as sns\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "METRICS = [\n",
    "    \"train_acc\",\n",
    "    \"val_acc\",\n",
    "    \"ood_val_0_acc\",\n",
    "    \"test_acc\",\n",
    "]\n",
    "for dataset in [\"mpi3d\", \"shapes3d\", \"cars3d\", \"dsprites\", \"iraven\"]:\n",
    "    rdf = pd.read_pickle(f\"pairwise/{dataset}.pkl\")\n",
    "    rdf = rdf[~rdf['arch'].str.contains('prelu', case=False, na=False)]\n",
    "    stems = ['convnext', 'resnet', 'vit', \"swin\", \"densenet\", \"mlp\", \"ed\"]\n",
    "    def assign_stem(model_name):\n",
    "        for stem in stems:\n",
    "            if stem in model_name:\n",
    "                return stem\n",
    "        return 'other'\n",
    "    rdf['stem'] = rdf['arch'].apply(assign_stem)\n",
    "    group_columns = [\"combination\", \"stem\"]\n",
    "    res = rdf.groupby(group_columns)[METRICS].agg(['mean']).reset_index()\n",
    "    res.columns = res.columns.droplevel(1)\n",
    "    res['combination'] = res['combination'].apply(lambda x: f\"({x.replace('_', ', ')})\")\n",
    "\n",
    "\n",
    "    plt.figure(figsize=(9, 6))\n",
    "    sns.stripplot(\n",
    "        x=\"combination\", \n",
    "        y=\"test_acc\", \n",
    "        data=res, \n",
    "        palette=\"muted\",\n",
    "        hue=\"stem\",\n",
    "        size=5,\n",
    "        marker=\"o\",\n",
    "        edgecolor=\"black\",alpha=.75, s=9,linewidth=1.0\n",
    "    )\n",
    "    ax = plt.gca()\n",
    "    ax.tick_params(axis='both', which='major', labelsize=14)\n",
    "    ax.tick_params(axis='both', which='minor', labelsize=14)\n",
    "    # Customizing the plot\n",
    "    plt.xlabel(\"Generative Factors Combination\", fontsize=16)\n",
    "    plt.xticks(rotation=30)\n",
    "    plt.ylabel(\"Test Accuracy (%)\", fontsize=16)\n",
    "    plt.ylim([-10, 110])\n",
    "    plt.legend(loc='best', fontsize=16)\n",
    "    plt.savefig(f\"results/attributewise_{dataset}.pdf\", bbox_inches=\"tight\", )\n",
    "    plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import plot_settings as plot_settings\n",
    "plot_settings.set_latex_settings()\n",
    "\n",
    "METRICS = [\n",
    "    \"train_acc\",\n",
    "    \"val_acc\",\n",
    "    \"ood_val_0_acc\",\n",
    "    \"test_acc\",\n",
    "]\n",
    "models = [\n",
    " 'resnet152',\n",
    " 'resnet101',\n",
    " 'resnet34',\n",
    " 'resnet18',\n",
    " 'densenet161',\n",
    " 'convnext',\n",
    " 'densenet201',\n",
    " 'densenet121',\n",
    " 'convnext',\n",
    " 'resnet50'\n",
    "]\n",
    "\n",
    "for dataset in [\"mpi3d\", \"shapes3d\", \"cars3d\", \"iraven\"]:\n",
    "    figure1, ax1 = plt.subplots(1, 1, figsize=(plot_settings.column_width*3/4, plot_settings.column_width*3/4))\n",
    "    rdf = pd.read_pickle(f\"pairwise/{dataset}.pkl\")\n",
    "    rdf = rdf[~rdf['arch'].str.contains(\"pretrained\", case=False, na=False)]\n",
    "    rdf = rdf[rdf['arch'].str.contains('|'.join(models), case=False, na=False)]\n",
    "    stems = ['resnet', \"densenet\", 'convnext']\n",
    "    def assign_stem(model_name):\n",
    "        for stem in stems:\n",
    "            if stem in model_name:\n",
    "                return stems.index(stem)\n",
    "        return 'other'\n",
    "    rdf['stem'] = rdf['arch'].apply(assign_stem)\n",
    "    rdf[\"prelu\"] = rdf['arch'].str.contains(\"prelu\", case=False, na=False)\n",
    "    res = rdf.groupby(['prelu', 'stem'])[METRICS].agg(['mean', 'sem']).reset_index()\n",
    "    x = list(range(3))\n",
    "    y = [res[np.logical_and(res[\"stem\"] == idx, res[\"prelu\"] == False)][\"test_acc\"][\"mean\"].item() for idx in x]\n",
    "    y_err = [res[np.logical_and(res[\"stem\"] == idx, res[\"prelu\"] == False)][\"test_acc\"][\"sem\"].item() for idx in x]\n",
    "    ax1.errorbar(\n",
    "        x,\n",
    "        y,\n",
    "        yerr=y_err,\n",
    "        fmt='.',\n",
    "        label=\"Standard\",\n",
    "        markersize=8,\n",
    "        capsize=5,\n",
    "        color=\"#FF9A9A\"\n",
    "    )\n",
    "    y = [res[np.logical_and(res[\"stem\"] == idx, res[\"prelu\"] == True)][\"test_acc\"][\"mean\"].item() for idx in x]\n",
    "    y_err = [res[np.logical_and(res[\"stem\"] == idx, res[\"prelu\"] == True)][\"test_acc\"][\"sem\"].item() for idx in x]\n",
    "\n",
    "    ax1.errorbar(\n",
    "        x,\n",
    "        y,\n",
    "        yerr=y_err,\n",
    "        fmt='.',\n",
    "        label=\"PReLU\",\n",
    "        markersize=8,\n",
    "        capsize=5,\n",
    "        color=\"#945034\"\n",
    "    )\n",
    "\n",
    "    ax1.tick_params(axis='both', which='major', labelsize=14)\n",
    "    ax1.tick_params(axis='both', which='minor', labelsize=14)\n",
    "    plt.xticks([0, 1, 2], [\"ResNets\", \"DenseNets\", \"ConvNeXts\"])\n",
    "    plt.ylabel(\"Test Accuracy (\\%)\", fontsize=16)\n",
    "    plt.ylim([-10, 110])\n",
    "    plt.xlim([-.5, 2.5])\n",
    "    plt.legend(loc='best', fontsize=16)\n",
    "\n",
    "    figure1.savefig(f\"results/prelu_{dataset}.pgf\", bbox_inches=\"tight\")\n",
    "\n",
    "!./pgf_compiler.sh prelu"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# rdf = pd.concat([pd.read_pickle(f\"pairwise/shapes3d.pkl\"), pd.read_pickle(f\"pairwise/mpi3d.pkl\"), pd.read_pickle(f\"pairwise/dsprites.pkl\"), pd.read_pickle(f\"pairwise/iraven.pkl\")])\n",
    "\n",
    "result = pd.DataFrame()\n",
    "for dataset in [\"iraven\", \"cars3d\", \"shapes3d\", \"mpi3d\",]:\n",
    "    rdf = pd.read_pickle(f\"pairwise/{dataset}.pkl\")\n",
    "    rdf = rdf[~rdf['arch'].str.contains(\"pretrained\", case=False, na=False)]\n",
    "    rdf = rdf[rdf['arch'].str.contains('|'.join(models), case=False, na=False)]\n",
    "    df = rdf.groupby(['arch'])[METRICS].agg(['mean']).reset_index()\n",
    "    df.columns = df.columns.droplevel(1)\n",
    "    df['base_arch'] = df['arch'].str.replace('_prelu', '', regex=False)\n",
    "    pivot = df.pivot(index='base_arch', columns='arch', values='test_acc')\n",
    "    pivot['delta_test_acc'] = pivot.apply(\n",
    "        lambda row: row.get(f\"{row.name}_prelu\", float('nan')) - row.get(row.name, float('nan')),\n",
    "        axis=1\n",
    "    )\n",
    "    result[dataset] = pivot['delta_test_acc']\n",
    "result['AVG'] = result[['iraven', 'cars3d', 'shapes3d', 'mpi3d']].mean(axis=1)\n",
    "print(result.to_latex(index=True, float_format=\"{:.2f}\".format,))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "result"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Main results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import plot_settings\n",
    "plot_settings.set_latex_settings()\n",
    "\n",
    "METRICS = [\n",
    "    \"train_acc\",\n",
    "    \"val_acc\",\n",
    "    \"ood_val_0_acc\",\n",
    "    \"test_acc\",\n",
    "]\n",
    "pw = pd.concat([pd.read_pickle(f\"pairwise/dsprites.pkl\"), pd.read_pickle(f\"pairwise/iraven.pkl\"), pd.read_pickle(f\"pairwise/shapes3d.pkl\"), pd.read_pickle(f\"pairwise/mpi3d.pkl\"), pd.read_pickle(f\"pairwise/cars3d.pkl\")])\n",
    "ot = pd.concat([pd.read_pickle(f\"main/dsprites.pkl\"), pd.read_pickle(f\"main/iraven.pkl\"), pd.read_pickle(f\"main/shapes3d.pkl\"), pd.read_pickle(f\"main/cars3d.pkl\")])\n",
    "\n",
    "for name, pairwise in zip([\"pairwise\", \"orthotopic\"], [pw, ot]):\n",
    "\n",
    "    if name == \"orthotopic\":\n",
    "        pairwise = pairwise[pairwise[\"c\"]==\"1\"]\n",
    "    pairwise = pairwise[~pairwise['arch'].str.contains('prelu|convnext_tiny', case=False, na=False)]\n",
    "    res = pairwise.groupby([\"arch\"])[METRICS].agg(['mean', 'sem']).reset_index()\n",
    "\n",
    "    def modtoname(arch):\n",
    "        mp = {\n",
    "            'convnext_base': \"CN-base\", \n",
    "            'convnext_small': \"CN-small\",\n",
    "            'densenet121': \"DN-121\",\n",
    "            'densenet121_pretrained': \"DN-121-PT\",\n",
    "            'densenet161': \"DN-161\",\n",
    "            'densenet201': \"DN-201\",\n",
    "            'ed': \"ED\",\n",
    "            'mlp': \"MLP\",\n",
    "            'resnet101': \"RN-101\",\n",
    "            'resnet101_pretrained': \"RN-101-PT\",\n",
    "            'resnet152': \"RN-152\",\n",
    "            'resnet152_pretrained': \"RN-151-PT\",\n",
    "            'resnet18': \"RN-18\",\n",
    "            'resnet34': \"RN-34\",\n",
    "            'resnet50': \"RN-50\",\n",
    "            'swin_base': \"ST-base\",\n",
    "            'swin_tiny': \"ST-tiny\",\n",
    "            'vit': \"ViT\",\n",
    "            'wideresnet': \"WRN\",\n",
    "        }\n",
    "        return mp[arch]\n",
    "\n",
    "    color_dict = {\n",
    "        'convnext': '#FF9A9A',\n",
    "        'resnet':   '#FFDDAB',\n",
    "        'vit':      '#945034',\n",
    "        'densenet': '#7CA982',\n",
    "        'mlp':      '#769ECB',\n",
    "        'ed':       '#C287E8',\n",
    "    }\n",
    "    model_sizes = {\n",
    "        \"mlp\":                406850,\n",
    "        'densenet':       6965131,\n",
    "        'densenet121':       6965131,\n",
    "        'densenet121_pretrained': 6965131,\n",
    "        'densenet161':      26486891,\n",
    "        'densenet201':      18107787,\n",
    "        \"resnet18\":         11175883,\n",
    "        \"resnet34\":         21284043,\n",
    "        \"resnet50\":         24556491,\n",
    "        \"resnet101_pretrained\":        43548619,\n",
    "        \"resnet152_pretrained\":        59192267,\n",
    "        \"resnet101\":        43548619,\n",
    "        \"resnet152\":        59192267,\n",
    "        \"wideresnet\":       67882699,\n",
    "        \"ed\":               22347136,\n",
    "        \"vit\":              86576115,\n",
    "        \"convnext_base\":    87573632,\n",
    "        \"convnext_small\":   49460064,\n",
    "        \"swin_tiny\":        27532469,\n",
    "        \"swin_base\":        86771459,\n",
    "    }\n",
    "    families = ['convnext', 'resnet', 'vit', 'densenet', 'mlp', 'ed']\n",
    "    families_caps = ['ConvNeXt', 'ResNet', 'ViT', 'DenseNet', 'MLP', 'ED']\n",
    "    def get_family(arch):\n",
    "        for fam in families:\n",
    "            if arch.startswith(fam):\n",
    "                return fam\n",
    "            elif arch.startswith(\"wideresnet\"):\n",
    "                return \"resnet\"\n",
    "            elif arch.startswith(\"swin\"):\n",
    "                return \"vit\"\n",
    "        return \"other\"\n",
    "\n",
    "    res[\"family\"] = res[\"arch\"].apply(get_family)\n",
    "    res[\"model_size\"] = res[\"arch\"].map(model_sizes)\n",
    "\n",
    "    figure1, ax1 = plt.subplots(1, 1, figsize=(plot_settings.column_width*3/2, plot_settings.column_width*1/2))\n",
    "    res[\"arch\"] = res[\"arch\"].apply(modtoname)\n",
    "    res.sort_values(by=('test_acc', \"mean\"), inplace=True)\n",
    "    res.reset_index(inplace=True)\n",
    "    res['index1'] = res.index\n",
    "    texts = []\n",
    "    stds = []\n",
    "    for _, row in res.iterrows():\n",
    "        ax1.errorbar(\n",
    "            row[\"index1\"],\n",
    "            row[\"test_acc\"][\"mean\"],\n",
    "            yerr=row[\"test_acc\"][\"sem\"],\n",
    "            fmt='.',\n",
    "            color=color_dict[row[\"family\"].item()],\n",
    "            capsize=3,\n",
    "            markersize=8,\n",
    "            alpha=0.8\n",
    "        )\n",
    "        if name == \"orthotopic\":\n",
    "            stds.append((row[\"arch\"], row[\"test_acc\"][\"sem\"]))\n",
    "        texts.append(ax1.text(\n",
    "            row[\"index1\"].item(),\n",
    "            row[\"test_acc\"][\"mean\"]+10,\n",
    "            f\"{row['test_acc']['mean']:.1f}\\%\",\n",
    "            fontsize=8,\n",
    "            ha='center',\n",
    "            va='center'\n",
    "        ))\n",
    "\n",
    "    handles = [plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=color_dict[f], label=fc, markersize=6)\n",
    "            for f, fc in zip(families, families_caps)]\n",
    "    ax1.legend(handles=handles, loc='upper left', ncol=6, fontsize=8)\n",
    "\n",
    "    # Axes\n",
    "    # plt.xscale(\"log\")\n",
    "    # plt.xlabel(\"Number of Parameters\", fontsize=16)\n",
    "    plt.ylabel(\"Test Accuracy (\\%)\", fontsize=16)\n",
    "    ax1.tick_params(axis='both', which='major', labelsize=10, rotation=25)\n",
    "    ax1.tick_params(axis='both', which='minor', labelsize=14)\n",
    "    ax1.set_xticks(list(range(len(res['arch'].tolist()))), res['arch'].tolist(), ha='right')\n",
    "    plt.ylim([-5, 105])\n",
    "    figure1.savefig(f\"results/overall_{name}_alt.pgf\", bbox_inches=\"tight\")\n",
    "    if name == \"pairwise\": res1 = res\n",
    "!./pgf_compiler.sh overall_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.mean([b for a, b in stds if a.item()!=\"ED\"]) -   5.57544752412038"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "single_level_cols = [col for col, col2 in res.columns if col2 == '']\n",
    "multi_level_means = res.xs('mean', axis=1, level=1)\n",
    "res_sanitized = pd.concat([res[single_level_cols], multi_level_means], axis=1)\n",
    "for _, row in res_sanitized.iterrows():\n",
    "    if \"PT\" in row[(\"arch\",\"\")]:\n",
    "        row[(\"family\",\"\")] = \"pretrained\"\n",
    "        res_sanitized[res_sanitized[(\"arch\",'')] == row[(\"arch\",\"\")]] = row\n",
    "res_ortho = res_sanitized.groupby([(\"family\", '')])[METRICS].agg(['mean', 'sem']).reset_index()\n",
    "res_ortho.sort_values(by=('test_acc', \"mean\"), inplace=True)\n",
    "\n",
    "single_level_cols = [col for col, col2 in res1.columns if col2 == '']\n",
    "multi_level_means = res1.xs('mean', axis=1, level=1)\n",
    "res_sanitized = pd.concat([res1[single_level_cols], multi_level_means], axis=1)\n",
    "for _, row in res_sanitized.iterrows():\n",
    "    if \"PT\" in row[(\"arch\",\"\")]:\n",
    "        row[(\"family\",\"\")] = \"pretrained\"\n",
    "        res_sanitized[res_sanitized[(\"arch\",'')] == row[(\"arch\",\"\")]] = row\n",
    "res_pair = res_sanitized.groupby([(\"family\", '')])[METRICS].agg(['mean', 'sem']).reset_index()\n",
    "res_pair.sort_values(by=('test_acc', \"mean\"), inplace=True)\n",
    "\n",
    "families_caps = {\n",
    "    \"convnext\":'ConvNeXt', \n",
    "    \"resnet\":'ResNet',\n",
    "    \"vit\":'ViT',\n",
    "    \"densenet\":'DenseNet',\n",
    "    \"mlp\":'MLP',\n",
    "    \"ed\": 'ED',\n",
    "    \"pretrained\": \"Pre-trained\"\n",
    "}\n",
    "\n",
    "\n",
    "figure1, ax1 = plt.subplots(1, 1, figsize=(plot_settings.column_width*0.6, plot_settings.column_width*1/2))\n",
    "ax1.bar([families_caps[f] for f in res_ortho[\"family\"]], res_ortho[('test_acc', 'mean')], color='#945034')\n",
    "plt.ylabel(\"Test Accuracy (\\%)\", fontsize=16)\n",
    "ax1.tick_params(axis='both', which='major', labelsize=10, rotation=25)\n",
    "ax1.tick_params(axis='both', which='minor', labelsize=14)\n",
    "plt.ylim([-5, 105])\n",
    "ax1.set_xticks(list(range(len(res_ortho[\"family\"]))), [families_caps[f] for f in res_ortho[\"family\"]], ha='right')\n",
    "figure1.savefig(f\"results/comp_barplot_ortho.pgf\", bbox_inches=\"tight\")\n",
    "\n",
    "figure1, ax1 = plt.subplots(1, 1, figsize=(plot_settings.column_width*0.6, plot_settings.column_width*1/2))\n",
    "ax1.bar([families_caps[f] for f in res_pair[\"family\"]], res_pair[('test_acc', 'mean')], color='#FFDDAB')\n",
    "plt.ylabel(\"Test Accuracy (\\%)\", fontsize=16)\n",
    "ax1.tick_params(axis='both', which='major', labelsize=10, rotation=25)\n",
    "ax1.tick_params(axis='both', which='minor', labelsize=14)\n",
    "ax1.set_xticks(list(range(len(res_pair[\"family\"]))), [families_caps[f] for f in res_pair[\"family\"]], ha='right')\n",
    "plt.ylim([-5, 105])\n",
    "figure1.savefig(f\"results/comp_barplot_pair.pgf\", bbox_inches=\"tight\")\n",
    "!./pgf_compiler.sh comp_barplot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import plot_settings\n",
    "plot_settings.set_latex_settings()\n",
    "\n",
    "METRICS = [\n",
    "    \"train_acc\",\n",
    "    \"val_acc\",\n",
    "    \"ood_val_0_acc\",\n",
    "    \"test_acc\",\n",
    "]\n",
    "families = ['convnext', 'resnet', 'vit', 'densenet', 'mlp', 'ed']\n",
    "color_dict = {\n",
    "    'convnext': '#FF9A9A',\n",
    "    'resnet':   '#FFDDAB',\n",
    "    'vit':      '#945034',\n",
    "    'densenet': '#7CA982',\n",
    "    'mlp':      '#769ECB',\n",
    "    'ed':       '#C287E8',\n",
    "}\n",
    "families_caps = ['ConvNeXt', 'ResNet', 'ViT', 'DenseNet', 'MLP', 'ED']\n",
    "def get_family(arch):\n",
    "    for fam in families:\n",
    "        if arch.startswith(fam):\n",
    "            return fam\n",
    "        elif arch.startswith(\"wideresnet\"):\n",
    "            return \"resnet\"\n",
    "        elif arch.startswith(\"swin\"):\n",
    "            return \"vit\"\n",
    "    return \"other\"\n",
    "\n",
    "# for dataset in [\"dsprites\", \"iraven\", \"cars3d\", \"shapes3d\"]:\n",
    "for dataset in [\"clevr\"]:\n",
    "    data = pd.read_pickle(f\"main/{dataset}.pkl\")\n",
    "    data[\"family\"] = data[\"arch\"].apply(get_family)\n",
    "    data = data.groupby([\"family\", \"c\"])[METRICS].agg(['mean', 'sem']).reset_index()\n",
    "    figure1, ax1 = plt.subplots(1, 1, figsize=(plot_settings.column_width*3/4, plot_settings.column_width*3/4))\n",
    "    for fam in families:\n",
    "        res = data[data[\"family\"] == fam]\n",
    "        if res.empty: continue\n",
    "        # --- add in-dist result by taking val accuracy from last c\n",
    "        a = res.sort_values(by='c').iloc[-1]\n",
    "        a[\"c\"] = int(a[\"c\"].item()) + 1\n",
    "        a[\"test_acc\"] = res.iloc[0][\"val_acc\"]\n",
    "        res = res.append(a).reset_index()\n",
    "        # ---\n",
    "        ax1.plot([int(cc) for cc in res['c']], list(res['test_acc'][\"mean\"]), label=fam, color=color_dict[fam], marker=\"o\")\n",
    "        ax1.fill_between(\n",
    "            [int(cc) for cc in res['c']],\n",
    "            res['test_acc'][\"mean\"] - res['test_acc'][\"sem\"],\n",
    "            res['test_acc'][\"mean\"] + res['test_acc'][\"sem\"],\n",
    "            alpha=0.3,\n",
    "            color=color_dict[fam]\n",
    "        )\n",
    "    handles = [plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=color_dict[f], label=fc, markersize=12)\n",
    "            for f, fc in zip(families, families_caps)]\n",
    "    if dataset == \"dsprites\":\n",
    "        ax1.legend(handles=handles, loc='best',fontsize=16, bbox_to_anchor=(1.5, -0.2), ncol=6)\n",
    "    ax1.set_xticks([int(x) for x in data[\"c\"].unique()]+[int(max(data[\"c\"].unique()))+1])\n",
    "    plt.xlabel(\"$c$\", fontsize=16)\n",
    "    plt.ylabel(\"Test Accuracy (\\%)\", fontsize=16)\n",
    "    ax1.tick_params(axis='both', which='major', labelsize=14)\n",
    "    ax1.tick_params(axis='both', which='minor', labelsize=14)\n",
    "    plt.ylim([-5, 105])\n",
    "    figure1.savefig(f\"results/neurips_main_{dataset}.pgf\", bbox_inches=\"tight\")\n",
    "\n",
    "!./pgf_compiler.sh neurips_main"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Extended main results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "METRICS = [\n",
    "    \"train_acc\",\n",
    "    \"val_acc\",\n",
    "    \"ood_val_0_acc\",\n",
    "    \"test_acc\",\n",
    "]\n",
    "\n",
    "rdf = pd.read_pickle(\"main/clevr.pkl\")\n",
    "print(rdf[\"c\"].unique())\n",
    "rdf = rdf[rdf[\"c\"]==\"3\"]\n",
    "group_columns = [\"arch\"]\n",
    "res = rdf.groupby(group_columns)[METRICS].agg(['mean', 'sem']).reset_index()\n",
    "\n",
    "res[[(col, 'mean') for col in METRICS] + [(col, 'sem') for col in METRICS]] = (\n",
    "    res[[(col, 'mean') for col in METRICS] + [(col, 'sem') for col in METRICS]].round(2)\n",
    ")\n",
    "# print(res.to_latex(index=False,\n",
    "#     formatters={\"name\": str.upper},\n",
    "#     float_format=\"{:.2f}\".format,\n",
    "# ))  \n",
    "\n",
    "def format_model_name(arch):\n",
    "    name = arch.replace(\"_pretrained\", \"\")\n",
    "    pretrained = r\"\\cmark\" if \"pretrained\" in arch else r\"\\xmark\"\n",
    "    pretty_map = {\n",
    "        \"resnet18\": \"ResNet-18\",\n",
    "        \"resnet50\": \"ResNet-50\",\n",
    "        \"resnet101\": \"ResNet-101\",\n",
    "        \"resnet152\": \"ResNet-152\",\n",
    "        \"densenet121\": \"DenseNet-121\",\n",
    "        \"densenet161\": \"DenseNet-161\",\n",
    "        \"densenet201\": \"DenseNet-201\",\n",
    "        \"convnext_tiny\": \"ConvNeXt-Tiny\",\n",
    "        \"convnext_small\": \"ConvNeXt-Small\",\n",
    "        \"convnext_base\": \"ConvNeXt-Base\",\n",
    "        \"swin_tiny\": \"Swin-Tiny\",\n",
    "        \"swin_base\": \"Swin-Base\",\n",
    "        \"wideresnet\": \"WideResNet\",\n",
    "        \"ed\": \"ED\",\n",
    "        \"mlp\": \"MLP\"\n",
    "    }\n",
    "    pretty_name = pretty_map.get(name, name)\n",
    "    return pretty_name, pretrained\n",
    "\n",
    "# Generate LaTeX rows\n",
    "latex_rows = []\n",
    "for _, row in res.iterrows():\n",
    "    model, pretrained = format_model_name(row[\"arch\"].item())\n",
    "    values = \" & \".join([f\"{row[col]:.2f}\" for col in res.columns[1:]])\n",
    "    latex_rows.append(f\"{model} & {pretrained} & {values} \\\\\\\\\")\n",
    "\n",
    "# Join all rows\n",
    "latex_table_body = \"\\n\".join(latex_rows)\n",
    "print(latex_table_body)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## FPE vs Linear"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import plot_settings\n",
    "plot_settings.set_latex_settings()\n",
    "\n",
    "for dataset in [\"dsprites\", \"iraven\", \"shapes3d\", \"cars3d\", \"mpi3d\"]:\n",
    "    # FPE\n",
    "    df = pd.read_csv(f\"fpe_linear/{dataset}_fpe.csv\")\n",
    "    df.drop(list(df.filter(regex='MAX|MIN|Step')), axis=1, inplace=True)\n",
    "    figure1, ax1 = plt.subplots(1, 1, figsize=(plot_settings.column_width*6/4, plot_settings.column_width*3/4))\n",
    "    data = np.array([df[name] for name in list(df)])\n",
    "    mean = np.mean(data, axis=0)\n",
    "    std = np.std(data, axis=0)\n",
    "    ax1.plot(list(range(data.shape[1])), mean, color='#FF9A9A', label=\"FPE\")\n",
    "    ax1.fill_between(\n",
    "        list(range(data.shape[1])),\n",
    "        np.clip(mean - std, 0, 100),\n",
    "        np.clip(mean + std, 0, 100),\n",
    "        alpha=0.3,\n",
    "        color='#FF9A9A'\n",
    "    )\n",
    "    # linear\n",
    "    df = pd.read_csv(f\"fpe_linear/{dataset}_linear.csv\")\n",
    "    df.drop(list(df.filter(regex='MAX|MIN|Step')), axis=1, inplace=True)\n",
    "    clean = []\n",
    "    for c in df.columns:\n",
    "        clean.append(df[c].dropna())\n",
    "    data = np.array(clean)\n",
    "    mean = np.mean(data, axis=0)\n",
    "    std = np.std(data, axis=0)\n",
    "    ax1.plot(list(range(data.shape[1])), mean, color='#945034', label=\"Linear\")\n",
    "    ax1.fill_between(\n",
    "        list(range(data.shape[1])),\n",
    "        np.clip(mean - std, 0, 100),\n",
    "        np.clip(mean + std, 0, 100),\n",
    "        alpha=0.3,\n",
    "        color='#945034'\n",
    "    )\n",
    "    plt.xlabel(\"Epoch\", fontsize=16)\n",
    "    plt.ylabel(\"Test Accuracy (\\%)\", fontsize=16)\n",
    "    ax1.tick_params(axis='both', which='major', labelsize=14)\n",
    "    ax1.tick_params(axis='both', which='minor', labelsize=14)\n",
    "    plt.ylim([-5, 105])\n",
    "    ax1.legend(loc='best', fontsize=16)\n",
    "    figure1.savefig(f\"results/fpe_linear_{dataset}.pgf\", bbox_inches=\"tight\")\n",
    "!./pgf_compiler.sh fpe_linear_"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## AIN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "METRICS = [\n",
    "    \"train_acc\",\n",
    "    \"val_acc\",\n",
    "    \"ood_val_0_acc\",\n",
    "    \"test_acc\",\n",
    "]\n",
    "for data in [\"shapes3d\", \"mpi3d\", \"dsprites\", \"iraven\", \"cars3d\", \"clevr\"]:\n",
    "\n",
    "    ain = pd.read_pickle(f\"ain/{data}.pkl\")\n",
    "    oth = pd.read_pickle(f\"main/{data}.pkl\")\n",
    "\n",
    "    oth = oth[oth[\"c\"] == \"1\"]\n",
    "\n",
    "    group_columns = [\"arch\"]\n",
    "\n",
    "    res_ain = ain.groupby(group_columns)[METRICS].agg(['mean', 'sem']).reset_index()\n",
    "    res_ain[[(col, 'mean') for col in METRICS] + [(col, 'sem') for col in METRICS]] = (\n",
    "        res_ain[[(col, 'mean') for col in METRICS] + [(col, 'sem') for col in METRICS]].round(2)\n",
    "    )\n",
    "    res_oth = oth.groupby(group_columns)[METRICS].agg(['mean', 'sem']).reset_index()\n",
    "    res_oth[[(col, 'mean') for col in METRICS] + [(col, 'sem') for col in METRICS]] = (\n",
    "        res_oth[[(col, 'mean') for col in METRICS] + [(col, 'sem') for col in METRICS]].round(2)\n",
    "    )\n",
    "    res_oth = res_oth[np.logical_or(res_oth[\"arch\"] == \"resnet18\", res_oth[\"arch\"] == \"ed\")]\n",
    "\n",
    "    print(res_ain)\n",
    "    print(res_oth)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.colors\n",
    "import numpy as np\n",
    "from adjustText import adjust_text\n",
    "import plot_settings\n",
    "plot_settings.set_latex_settings()\n",
    "\n",
    "METRICS = [\n",
    "    \"train_acc\",\n",
    "    \"val_acc\",\n",
    "    \"ood_val_0_acc\",\n",
    "    \"test_acc\",\n",
    "]\n",
    "\n",
    "\n",
    "ain = pd.concat([pd.read_pickle(f\"ain/{dataset}.pkl\") for dataset in [\"shapes3d\", \"mpi3d\", \"dsprites\", \"iraven\", \"cars3d\"]])\n",
    "oth = pd.concat([pd.read_pickle(f\"main/{dataset}.pkl\") for dataset in [\"shapes3d\", \"mpi3d\", \"dsprites\", \"iraven\", \"cars3d\"]])\n",
    "oth = oth[oth[\"c\"] == \"1\"]\n",
    "clevr = pd.read_pickle(f\"main/clevr.pkl\")\n",
    "clevr = clevr[clevr[\"c\"] == \"1\"]\n",
    "data = pd.concat([ain, oth, clevr])\n",
    "res = data.groupby([\"arch\"])[METRICS].agg(['mean', 'sem']).reset_index()\n",
    "\n",
    "\n",
    "def modtoname(arch):\n",
    "    mp = {\n",
    "        'convnext_base': \"CN-base\", \n",
    "        'convnext_small': \"CN-small\",\n",
    "        'densenet121': \"DN-121\",\n",
    "        'convnext_tiny': \"CN-tiny\",\n",
    "        'densenet121_pretrained': \"DN-121-PT\",\n",
    "        'densenet161': \"DN-161\",\n",
    "        'densenet201': \"DN-201\",\n",
    "        'ed': \"ED\",\n",
    "        'mlp': \"MLP\",\n",
    "        'resnet101': \"RN-101\",\n",
    "        'resnet101_pretrained': \"RN-101-PT\",\n",
    "        'resnet152': \"RN-152\",\n",
    "        'resnet152_pretrained': \"RN-151-PT\",\n",
    "        'resnet18': \"RN-18\",\n",
    "        'resnet34': \"RN-34\",\n",
    "        'resnet50': \"RN-50\",\n",
    "        'swin_base': \"ST-base\",\n",
    "        'swin_tiny': \"ST-tiny\",\n",
    "        'vit': \"ViT\",\n",
    "        'wideresnet': \"WRN\",\n",
    "        \"split\": \"AIN\"\n",
    "    }\n",
    "    return mp[arch]\n",
    "\n",
    "color_dict = {\n",
    "    'convnext': '#FF9A9A',\n",
    "    'resnet':   '#FFDDAB',\n",
    "    'vit':      '#945034',\n",
    "    'densenet': '#7CA982',\n",
    "    'mlp':      '#769ECB',\n",
    "    'split': \"#E63946\",\n",
    "    'ed':       '#C287E8',\n",
    "}\n",
    "model_sizes = {\n",
    "    \"mlp\":                406850,\n",
    "    'densenet':       6965131,\n",
    "    'densenet121':       6965131,\n",
    "    'densenet121_pretrained': 6965131,\n",
    "    'densenet161':      26486891,\n",
    "    'densenet201':      18107787,\n",
    "    \"resnet18\":         11175883,\n",
    "    \"resnet34\":         21284043,\n",
    "    \"resnet50\":         24556491,\n",
    "    \"resnet101_pretrained\":        43548619,\n",
    "    \"resnet152_pretrained\":        59192267,\n",
    "    \"resnet101\":        43548619,\n",
    "    \"resnet152\":        59192267,\n",
    "    \"split\":            11175883 +  11175883*0.032*4.16,\n",
    "    \"wideresnet\":       67882699,\n",
    "    \"ed\":               11175883 +  11175883*4.16,\n",
    "    \"vit\":              86576115,\n",
    "    \"convnext_base\":    87573632,\n",
    "    \"convnext_small\":   49460064,\n",
    "    \"convnext_tiny\":    28600064,\n",
    "    \"swin_tiny\":        27532469,\n",
    "    \"swin_base\":        86771459,\n",
    "}\n",
    "families = ['convnext', 'resnet', 'vit', 'densenet', 'mlp', 'ed', \"split\"]\n",
    "families_caps = ['ConvNeXt', 'ResNet', 'ViT', 'DenseNet', 'MLP', 'ED', \"AIN\"]\n",
    "def get_family(arch):\n",
    "    for fam in families:\n",
    "        if arch.startswith(fam):\n",
    "            return fam\n",
    "        elif arch.startswith(\"wideresnet\"):\n",
    "            return \"resnet\"\n",
    "        elif arch.startswith(\"swin\"):\n",
    "            return \"vit\"\n",
    "    return \"other\"\n",
    "\n",
    "res[\"family\"] = res[\"arch\"].apply(get_family)\n",
    "res[\"model_size\"] = res[\"arch\"].map(model_sizes)\n",
    "\n",
    "figure1, ax1 = plt.subplots(1, 1, figsize=(plot_settings.column_width*3/4, plot_settings.column_width*3/4))\n",
    "res[\"arch\"] = res[\"arch\"].apply(modtoname)\n",
    "texts = []\n",
    "for _, row in res.iterrows():\n",
    "    ax1.errorbar(\n",
    "        row[\"model_size\"],\n",
    "        row[\"test_acc\"][\"mean\"],\n",
    "        yerr=row[\"test_acc\"][\"sem\"],\n",
    "        fmt='.',\n",
    "        color=color_dict[row[\"family\"].item()],\n",
    "        capsize=3,\n",
    "        markersize=8,\n",
    "        alpha=0.8\n",
    "    )\n",
    "#     texts.append(ax1.text(\n",
    "#         row[\"model_size\"].item() * 1.01,\n",
    "#         row[\"test_acc\"][\"mean\"],\n",
    "#         row[\"arch\"].item(),\n",
    "#         fontsize=9,\n",
    "#         verticalalignment='center'\n",
    "#     ))\n",
    "\n",
    "# adjust_text(texts, expand=(1,1), arrowprops=dict(arrowstyle='-', color='k', lw=1))\n",
    "# Legend\n",
    "handles = [plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=color_dict[f], label=fc, markersize=6)\n",
    "           for f, fc in zip(families, families_caps)]\n",
    "ax1.legend(handles=handles, loc='best', ncol=3)\n",
    "\n",
    "# Axes\n",
    "# plt.xscale(\"log\")\n",
    "plt.xlabel(\"Number of Parameters\", fontsize=16)\n",
    "plt.ylabel(\"Test Accuracy (\\%)\", fontsize=16)\n",
    "ax1.tick_params(axis='both', which='major', labelsize=14)\n",
    "ax1.tick_params(axis='both', which='minor', labelsize=14)\n",
    "plt.ylim([-5, 105])\n",
    "figure1.savefig(\"results/pareto.pgf\", bbox_inches=\"tight\")\n",
    "!./pgf_compiler.sh pareto"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "visual",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
