{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "%cd ../..\n",
    "\n",
    "LATEX_PATH = \"../latex/Distillation-MI-ICLR\""
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import wandb\n",
    "import seaborn as sns\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "\n",
    "api = wandb.Api()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "RUNS = [\"/mol-distill/56rx297n\", \"/mol-distill/jxk2dqi1\", \"/mol-distill/33msn9sy\"]\n",
    "NAMES = [\"5-layers-kernel\", \"2-layers-kernel\", \"3-layers-kernel\"]\n",
    "KEYS = [\"Sum\", \"GraphMVP\", \"ChemBertMTR-77M\", \"FRAD_QM9\", \"ThreeDInfomax\", \"GraphCL\"]\n",
    "\n",
    "MAX_EPOCH = 450\n",
    "\n",
    "df = pd.DataFrame()\n",
    "for run_id, name in zip(RUNS, NAMES):\n",
    "    run = api.run(run_id)\n",
    "    df_r = run.history()[run.history().epoch <= MAX_EPOCH]\n",
    "    df_r[\"name\"] = name\n",
    "\n",
    "    df_r_processed = pd.DataFrame()\n",
    "    for key in KEYS:\n",
    "        col_name = f\"train_loss_{key}\" if not key == \"Sum\" else \"train_loss\"\n",
    "        col_name_eval = f\"eval_loss_{key}\" if not key == \"Sum\" else \"eval_loss\"\n",
    "\n",
    "        df_r[\"teacher\"] = key\n",
    "\n",
    "        df_r[\"loss\"] = df_r[col_name]\n",
    "        df_r[\"split\"] = \"train\"\n",
    "        df_r_processed = pd.concat([df_r_processed, df_r[[\"loss\", \"split\", \"teacher\", \"epoch\", \"name\"]]])\n",
    "\n",
    "        df_r[\"loss\"] = df_r[col_name_eval]\n",
    "        df_r[\"split\"] = \"val\"\n",
    "        df_r_processed = pd.concat([df_r_processed, df_r[[\"loss\", \"split\", \"teacher\", \"epoch\", \"name\"]]])\n",
    "\n",
    "\n",
    "    df = pd.concat([df, df_r_processed])\n",
    "df"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "fig,axes = plt.subplots(2, len(KEYS)//2, figsize=(len(KEYS)/2*2.7, 4.3), sharex=True)\n",
    "axes = axes.flatten()\n",
    "\n",
    "for i, key in enumerate(KEYS):\n",
    "    sns.lineplot(data=df[df.teacher == key].dropna(), x=\"epoch\", y=\"loss\", hue=\"name\", ax = axes[i], palette=\"husl\", legend=i==len(KEYS)-1, alpha=0.8, style=\"split\")\n",
    "\n",
    "    axes[i].set_ylim(\n",
    "        df[df.teacher == key][\"loss\"].min(),\n",
    "        df[df.teacher == key][\"loss\"].quantile(0.99)\n",
    "    )\n",
    "\n",
    "\n",
    "for ax,name in zip(axes, KEYS):\n",
    "    ax.set_ylabel(\"\")\n",
    "    ax.set_xlabel(\"Epoch\")\n",
    "    ax.set_title(name)\n",
    "    ax.set_xlim(0, 400)\n",
    "axes[0].set_ylabel(\"Train Loss\")\n",
    "\n",
    "# Add in the legen -: train  --: eval\n",
    "axes[-1].legend([\"Train\", \"Eval\"], loc='center left', bbox_to_anchor=(1, 0.5))\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "# Move legend outside\n",
    "plt.legend(loc='center left', bbox_to_anchor=(1.05, 0.8))\n",
    "\n",
    "\n",
    "\n",
    "#plt.tight_layout()\n",
    "\n",
    "plt.savefig(f\"{LATEX_PATH}/figures/molecules/kernel_train_curve.pdf\", bbox_inches=\"tight\")"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "RUNS = [\"/mol-distill/33msn9sy\", \"/mol-distill/wow4guql\", \"/mol-distill/27pq9iwq\", \"/mol-distill/wm0onriy\", \"/mol-distill/r099cum7\", \"/mol-distill/he9vr7df\"]\n",
    "NAMES = [\"GINE-student\", \"GAT-student\", \"GCN-student\", \"TAG-student\", \"SAGE-student\", \"GIN-student\"]\n",
    "KEYS = [\"Sum\", \"ChemBertMTR-77M\", \"FRAD_QM9\"]\n",
    "\n",
    "MAX_EPOCH = 450\n",
    "\n",
    "df = pd.DataFrame()\n",
    "for run_id, name in zip(RUNS, NAMES):\n",
    "    run = api.run(run_id)\n",
    "    df_r = run.history()[run.history().epoch <= MAX_EPOCH]\n",
    "    df_r[\"name\"] = name\n",
    "\n",
    "    df_r_processed = pd.DataFrame()\n",
    "    for key in KEYS:\n",
    "        col_name = f\"train_loss_{key}\" if not key == \"Sum\" else \"train_loss\"\n",
    "        col_name_eval = f\"eval_loss_{key}\" if not key == \"Sum\" else \"eval_loss\"\n",
    "\n",
    "        df_r[\"teacher\"] = key\n",
    "\n",
    "        df_r[\"loss\"] = df_r[col_name]\n",
    "        df_r[\"split\"] = \"train\"\n",
    "        df_r_processed = pd.concat([df_r_processed, df_r[[\"loss\", \"split\", \"teacher\", \"epoch\", \"name\"]]])\n",
    "\n",
    "        df_r[\"loss\"] = df_r[col_name_eval]\n",
    "        df_r[\"split\"] = \"val\"\n",
    "        df_r_processed = pd.concat([df_r_processed, df_r[[\"loss\", \"split\", \"teacher\", \"epoch\", \"name\"]]])\n",
    "\n",
    "\n",
    "    df = pd.concat([df, df_r_processed])"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "df[\"archi\"] = df.name.apply(lambda x: x.split(\"-\")[0])\n",
    "df"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "fig,axes = plt.subplots(2, len(KEYS), figsize=(len(KEYS)*2.7, 4.3), sharex=True)\n",
    "\n",
    "cmap_offset = 4\n",
    "to_iso = [\"GINE\", \"GAT\"]\n",
    "not_to_iso = [n.split(\"-\")[0] for n in NAMES if n.split(\"-\")[0] not in to_iso]\n",
    "\n",
    "\n",
    "def get_color(name, desat=0.7, cmap_offset = 2, cmap = \"icefire\"):\n",
    "    arch = name.split(\"-\")[0]\n",
    "    if arch in to_iso:\n",
    "        idx = to_iso.index(arch)+1+cmap_offset//2\n",
    "        return sns.color_palette(cmap,len(NAMES)+cmap_offset)[-idx]\n",
    "    else:\n",
    "        idx = not_to_iso.index(arch)\n",
    "        return sns.color_palette(cmap,len(NAMES)+cmap_offset, desat = desat)[idx]\n",
    "\n",
    "cmap = {\n",
    "    name: get_color(name)\n",
    "    for name in NAMES\n",
    "}\n",
    "\n",
    "for i, models in enumerate([to_iso, not_to_iso]):\n",
    "    for j, key in enumerate(KEYS):\n",
    "        print(i,j)\n",
    "        sns.lineplot(\n",
    "            data=df[(df.teacher == key) & (df.archi.isin(models))].dropna(),\n",
    "            x=\"epoch\",\n",
    "            y=\"loss\",\n",
    "            hue=\"name\",\n",
    "            ax = axes[i,j],\n",
    "            palette=cmap,\n",
    "            legend=j==len(KEYS)-1,\n",
    "            alpha=1,\n",
    "            style=\"split\"\n",
    "        )\n",
    "\n",
    "        axes[i,j].set_ylim(\n",
    "            df[df.teacher == key][\"loss\"].min(),\n",
    "            df[df.teacher == key][\"loss\"].quantile(0.99)\n",
    "        )\n",
    "\n",
    "for i,models in enumerate([to_iso, not_to_iso]):\n",
    "    for j, key in enumerate(KEYS):\n",
    "        ax = axes[i,j]\n",
    "        ax.set_ylabel(\"\")\n",
    "        ax.set_xlabel(\"Epoch\")\n",
    "        ax.set_title(name)\n",
    "        ax.set_xlim(0, 400)\n",
    "fig.supylabel(\"Train Loss\")\n",
    "\n",
    "# Add in the legen -: train  --: eval\n",
    "axes[-1,-1].legend([\"Train\", \"Eval\"], loc='center left', bbox_to_anchor=(1, 0.5))\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "# Move legend outside\n",
    "axes[-1,-1].legend(loc='center left', bbox_to_anchor=(1.05, 0.5))\n",
    "axes[0,-1].legend(loc='center left', bbox_to_anchor=(1.05, 0.5))\n",
    "\n",
    "\n",
    "\n",
    "#plt.tight_layout()\n",
    "\n",
    "plt.savefig(f\"{LATEX_PATH}/figures/molecules/archi_train_curve.pdf\", bbox_inches=\"tight\")"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [],
   "metadata": {
    "collapsed": false
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
