{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "%cd ../..\n",
    "\n",
    "LATEX_PATH = \"../latex/Distillation-MI-ICLR\""
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import os.path\n",
    "\n",
    "import pandas as pd\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "\n",
    "from molDistill.utils.notebooks import *\n",
    "\n",
    "MODELS_TO_EVAL = [\n",
    "    \"ChemBertMTR-77M\",\n",
    "    \"GraphMVP\",\n",
    "    \"GraphLog\",\n",
    "    \"GraphCL\",\n",
    "    \"FRAD_QM9\",\n",
    "    \"ThreeDInfomax\",\n",
    "    STUDENT_MODEL,\n",
    "    ZINC_MODEL\n",
    "]\n",
    "DATASETS = df_metadata[\n",
    "    (df_metadata.task_type == \"reg\")\n",
    "].index.tolist()"
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Rankings"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "df_base= get_all_results(MODELS_TO_EVAL, \"downstream_results\", DATASETS,)\n",
    "df, order = aggregate_results_with_ci(df_base)\n",
    "df_base"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "FIG_SIZE=1\n",
    "fig,axes = plt.subplots(1,len(DATASETS), figsize=(\n",
    "        FIG_SIZE*len(DATASETS)*2.7,\n",
    "        FIG_SIZE*6.\n",
    "    ),\n",
    "    sharey=True\n",
    ")\n",
    "axes = axes.flatten()\n",
    "TEACHERS = [t for t in order if \"{(t)}\" in t]\n",
    "cmap = {\n",
    "    emb: sns.color_palette(\"husl\", df_base.embedder.nunique(), desat=0.15)[i] if not \"student\" in emb else sns.color_palette(\"husl\", df_base.embedder.nunique())[i] for i, emb in enumerate(order)\n",
    "}\n",
    "\n",
    "\n",
    "for i in range(len(DATASETS)):\n",
    "    dataset = DATASETS[i]\n",
    "    df_plt = df_base[df_base.dataset == dataset].set_index(\"embedder\").loc[order[::-1]].reset_index()\n",
    "    #axes[i].axvline(df_plt.groupby('embedder').median().loc[TEACHERS].metric_test.max(), color=cmap[\"student-large\"], linestyle=\"--\", alpha=.7)\n",
    "\n",
    "    sns.barplot(data=df_plt, x=\"metric_test\", y=\"embedder\", ax=axes[i], hue=\"embedder\", palette=cmap, hue_order=order, errorbar=None, fill=True, estimator=\"median\", alpha=.7)\n",
    "    sns.boxplot(data=df_plt, x=\"metric_test\", y=\"embedder\", ax=axes[i], hue=\"embedder\", palette=cmap, hue_order=order, fill=False,fliersize=0, width=.5, linewidth=3.)\n",
    "    if dataset.startswith(\"Clearance\") or dataset.startswith(\"Half\"):\n",
    "        dataset = df_metadata.loc[dataset].short_name.replace(\"Clearance\", \"Clear.\")\n",
    "        axes[i].set_title(dataset, size=24)\n",
    "    else:\n",
    "        axes[i].set_title(dataset.split('_')[0].replace(\"HydrationFreeEnergy\", \"FreeSolv\").replace(\"Lipophilicity\", \"Lipo.\"), size=24)\n",
    "    axes[i].set_xlabel(\"\")\n",
    "    axes[i].set_ylabel(\"\")\n",
    "\n",
    "    axes[i].set_xlim(np.round(max(df_plt.metric_test.quantile(0.1), 0),1))\n",
    "\n",
    "    if i == len(DATASETS)//2:\n",
    "        axes[i].set_xlabel(\"Test $R^2$\")\n",
    "\n",
    "# Reduce xtick size\n",
    "for ax in axes:\n",
    "    ax.tick_params(axis='x', labelsize=14)\n",
    "    ax.tick_params(axis='y', labelsize=24)\n",
    "\n",
    "plt.tight_layout()\n",
    "\n",
    "plt.savefig(f\"{LATEX_PATH}/figures/molecules/reg_boxplot.pdf\", bbox_inches=\"tight\")"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [],
   "metadata": {
    "collapsed": false
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
