{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "%cd ../..\n",
    "\n",
    "LATEX_PATH = \"../latex/Distillation-MI-ICLR\""
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import os.path\n",
    "\n",
    "import pandas as pd\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "\n",
    "from molDistill.utils.notebooks import *\n",
    "\n",
    "MODELS_TO_EVAL = [\n",
    "    STUDENT_MODEL,\n",
    "    SINGLE_TEACHER_BERT,\n",
    "    SINGLE_TEACHER_TDINFO,\n",
    "    TWO_TEACHER,\n",
    "]\n",
    "DATASETS = df_metadata.index.tolist()[:-3]\n",
    "DATASETS.remove(\"ToxCast\")\n",
    "\n",
    "len(DATASETS)"
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Mean Performances Classif\n"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "DATASETS = df_metadata[df_metadata.task_type == \"cls\"].index.tolist()\n",
    "\n",
    "DATASET_GROUP = [[\"Distribution\", \"HTS\", \"Absorption\", \" \"], [\"Metabolism\"], [\"Tox\",]]\n",
    "\n",
    "\n",
    "df_base= get_all_results(MODELS_TO_EVAL, \"downstream_results\", DATASETS)\n",
    "df, order = aggregate_results_with_ci(df_base)\n",
    "\n",
    "\n",
    "for i, datasets in enumerate(DATASET_GROUP):\n",
    "    df_group = df[[col for col in df.columns if (col[0] in datasets)]]\n",
    "    style,latex = style_df_ci(df_group, order[::-1])\n",
    "    table_path = f\"{LATEX_PATH}/tables/molecules/sgl_cls_{i}.tex\"\n",
    "    latex = add_hline(latex, 1)\n",
    "    latex = add_hline(latex, -1)\n",
    "    with open(table_path, \"w\") as f:\n",
    "        f.write(latex)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Mean Performances Reg"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "DATASETS = df_metadata[df_metadata.task_type == \"reg\"].index.tolist()\n",
    "\n",
    "DATASET_GROUP = [[\" \", \"Absorption\", \"Tox\"], [\"Distribution\", \"Excretion\"]]\n",
    "\n",
    "\n",
    "df_base= get_all_results(MODELS_TO_EVAL, \"downstream_results\", DATASETS)\n",
    "df, order = aggregate_results_with_ci(df_base)\n",
    "\n",
    "\n",
    "for i, datasets in enumerate(DATASET_GROUP):\n",
    "    df_group = df[[col for col in df.columns if (col[0] in datasets)]]\n",
    "    style,latex = style_df_ci(df_group, order[::-1])\n",
    "    table_path = f\"{LATEX_PATH}/tables/molecules/sgl_reg_{i}.tex\"\n",
    "    latex = add_hline(latex, 1)\n",
    "    latex = add_hline(latex, -1)\n",
    "    with open(table_path, \"w\") as f:\n",
    "        f.write(latex)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Figures"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "DATASETS = df_metadata.index.tolist()[:-3]\n",
    "DATASETS.remove(\"ToxCast\")\n",
    "\n",
    "df_base= get_all_results(MODELS_TO_EVAL, \"downstream_results\", DATASETS).reset_index(drop=True)\n",
    "df_base[\"short_dataset\"] = df_base.dataset.apply(lambda x: df_metadata.loc[x].short_name)\n",
    "\n",
    "g = sns.catplot(\n",
    "    data=df_base.dropna(),\n",
    "    col=\"short_dataset\",\n",
    "    y=\"metric_test\",\n",
    "    x=\"embedder\",\n",
    "    hue = \"embedder\",\n",
    "    kind=\"point\",\n",
    "    palette=\"husl\",\n",
    "    height=1.3,\n",
    "    aspect=0.9,\n",
    "    col_wrap=8,\n",
    "    sharey=False,\n",
    "    alpha = 0.,\n",
    "    legend=False,\n",
    "    errorbar=None,\n",
    "    order=order,\n",
    "    hue_order = order,\n",
    ")\n",
    "g.map(sns.lineplot, \"embedder\", \"metric_test\", errorbar=None, color=\"black\", alpha = 0.3, linewidth=2.5)\n",
    "g.map(sns.pointplot, \"embedder\", \"metric_test\", \"embedder\", order=order,palette=\"husl\", errorbar=None, alpha = 1, legend=False, hue_order = order)\n",
    "\n",
    "g.set_titles(col_template=\"{col_name}\", row_template=\"Test performance\")\n",
    "g.set_xlabels(\"\")\n",
    "g.set_ylabels(\"Test perf.\")\n",
    "# Rotate x-ticks\n",
    "g.tick_params(axis = 'x',rotation=90)\n",
    "g.tick_params(axis = 'y', labelsize=8)\n",
    "\n",
    "g.figure.subplots_adjust(wspace=0.8, hspace=0.5)\n",
    "\n",
    "for ax, dataset in zip(g.axes, df_base.short_dataset.unique()):\n",
    "    ax.set_ylim(\n",
    "        df_base[df_base.short_dataset == dataset].groupby([\"embedder\"]).mean().metric_test.min() - 0.02,\n",
    "        df_base[df_base.short_dataset == dataset].groupby([\"embedder\"]).mean().metric_test.max() + 0.02\n",
    "    )\n",
    "\n",
    "plt.savefig(f\"{LATEX_PATH}/figures/molecules/multi_vs_single_facetgrid.pdf\", bbox_inches=\"tight\")"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "DATASETS = df_metadata[df_metadata.task_type == \"reg\"].index.tolist()\n",
    "\n",
    "df_base= get_all_results(MODELS_TO_EVAL, \"downstream_results\", DATASETS).reset_index(drop=True)\n",
    "df_base[\"short_dataset\"] = df_base.dataset.apply(lambda x: df_metadata.loc[x].short_name)\n",
    "\n",
    "g = sns.catplot(\n",
    "    data=df_base.dropna(),\n",
    "    col=\"short_dataset\",\n",
    "    y=\"metric_test\",\n",
    "    x=\"embedder\",\n",
    "    hue = \"embedder\",\n",
    "    kind=\"point\",\n",
    "    palette=\"husl\",\n",
    "    height=1.1,\n",
    "    aspect=1.5,\n",
    "    col_wrap=5,\n",
    "    sharey=False,\n",
    "    alpha = 0.,\n",
    "    legend=False,\n",
    "    errorbar=None,\n",
    "    order=order,\n",
    "    hue_order = order,\n",
    ")\n",
    "g.map(sns.lineplot, \"embedder\", \"metric_test\", errorbar=None, color=\"black\", alpha = 0.3, linewidth=2.5)\n",
    "g.map(sns.pointplot, \"embedder\", \"metric_test\", \"embedder\", order=order,palette=\"husl\", errorbar=None, alpha = 1, legend=False, hue_order = order)\n",
    "\n",
    "g.set_titles(col_template=\"{col_name}\", row_template=\"Test performance\")\n",
    "g.set_xlabels(\"\")\n",
    "g.set_ylabels(\"$R^2$\")\n",
    "# Rotate x-ticks\n",
    "g.tick_params(axis = 'x',rotation=90)\n",
    "g.tick_params(axis = 'y', labelsize=8)\n",
    "\n",
    "g.figure.subplots_adjust(wspace=0.8, hspace=0.4)\n",
    "for ax, dataset in zip(g.axes, df_base.short_dataset.unique()):\n",
    "    ax.set_ylim(\n",
    "        df_base[df_base.short_dataset == dataset].groupby([\"embedder\"]).mean().metric_test.min() - 0.02,\n",
    "        df_base[df_base.short_dataset == dataset].groupby([\"embedder\"]).mean().metric_test.max() + 0.02\n",
    "    )\n",
    "\n",
    "\n",
    "plt.savefig(f\"{LATEX_PATH}/figures/molecules/multi_vs_single_facetgrid_reg.pdf\", bbox_inches=\"tight\")\n"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [],
   "metadata": {
    "collapsed": false
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
