{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "%cd ../..\n",
    "\n",
    "LATEX_PATH = \"../latex/Distillation-MI-ICLR\""
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import os.path\n",
    "\n",
    "import pandas as pd\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "\n",
    "from molDistill.utils.notebooks import *\n",
    "\n",
    "MODELS_TO_EVAL = [\n",
    "    ZINC_MODEL,\n",
    "    SMALL_KERNEL,\n",
    "    LARGE_KERNEL,\n",
    "]\n",
    "DATASETS = df_metadata.index.tolist()[:-3]\n",
    "DATASETS.remove(\"ToxCast\")\n",
    "\n",
    "len(DATASETS)"
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Mean Performances Classif\n"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "DATASETS = df_metadata[df_metadata.task_type == \"cls\"].index.tolist()\n",
    "\n",
    "DATASET_GROUP = [[\"Distribution\", \"HTS\", \"Absorption\", \" \"], [\"Metabolism\"], [\"Tox\",]]\n",
    "\n",
    "\n",
    "df_base= get_all_results(MODELS_TO_EVAL, \"downstream_results\", DATASETS,)\n",
    "\n",
    "df, order = aggregate_results_with_ci(df_base)\n",
    "\n",
    "\n",
    "for i, datasets in enumerate(DATASET_GROUP):\n",
    "    df_group = df[[col for col in df.columns if (col[0] in datasets)]]\n",
    "    style,latex = style_df_ci(df_group, order[::-1])\n",
    "    table_path = f\"{LATEX_PATH}/tables/molecules/kern_cls_{i}.tex\"\n",
    "    latex = add_hline(latex, 1)\n",
    "    latex = add_hline(latex, -1)\n",
    "    with open(table_path, \"w\") as f:\n",
    "        f.write(latex)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Mean Performances Reg"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "DATASETS = df_metadata[df_metadata.task_type == \"reg\"].index.tolist()\n",
    "\n",
    "DATASET_GROUP = [[\" \", \"Absorption\", \"Tox\"], [\"Distribution\", \"Excretion\"]]\n",
    "\n",
    "\n",
    "df_base= get_all_results(MODELS_TO_EVAL, \"downstream_results\", DATASETS)\n",
    "df, order = aggregate_results_with_ci(df_base)\n",
    "\n",
    "\n",
    "for i, datasets in enumerate(DATASET_GROUP):\n",
    "    df_group = df[[col for col in df.columns if (col[0] in datasets)]]\n",
    "    style,latex = style_df_ci(df_group, order[::-1])\n",
    "    table_path = f\"{LATEX_PATH}/tables/molecules/kern_reg_{i}.tex\"\n",
    "    latex = add_hline(latex, 1)\n",
    "    latex = add_hline(latex, -1)\n",
    "    with open(table_path, \"w\") as f:\n",
    "        f.write(latex)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "df"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "# All dataset Plot"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "df_base"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "DATASETS = df_metadata[df_metadata.task_type == \"cls\"].index.tolist()[:-3]\n",
    "DATASETS.remove(\"ToxCast\")\n",
    "\n",
    "\n",
    "\n",
    "df_base= get_all_results(MODELS_TO_EVAL, \"downstream_results\", DATASETS,)\n",
    "\n",
    "df_base.reset_index(inplace=True, drop=True)\n",
    "step = df_base.embedder.value_counts().max()\n",
    "df_base[\"id\"] = df_base.index%step\n",
    "\n",
    "df_base[\"short_dataset\"] = df_base.dataset.apply(lambda x: df_metadata.loc[x].short_name)\n",
    "\n",
    "n_clusters = {\n",
    "    \"2-layers-kernel\": 2,\n",
    "    \"5-layers-kernel\": 5,\n",
    "    \"student-250k\": 3,\n",
    "}\n",
    "\n",
    "df_base[\"n_cluster\"] = df_base.embedder.apply(lambda x: n_clusters[x])\n",
    "\n",
    "g = sns.catplot(\n",
    "    data=df_base.dropna(),\n",
    "    col=\"short_dataset\",\n",
    "    y=\"metric_test\",\n",
    "    x=\"n_cluster\",\n",
    "    hue = \"embedder\",\n",
    "    kind=\"point\",\n",
    "    palette=\"husl\",\n",
    "    height=1.1,\n",
    "    aspect=1.5,\n",
    "    col_wrap=5,\n",
    "    sharey=False,\n",
    "    alpha = 0.7,\n",
    "    legend=False,\n",
    "    errorbar=\"ci\",\n",
    "    native_scale=True,\n",
    ")\n",
    "g.map(sns.lineplot, \"n_cluster\", \"metric_test\", errorbar=None, color=\"black\", alpha = 0.3, linewidth=2.5, )\n",
    "g.map(sns.pointplot, \"n_cluster\", \"metric_test\", \"embedder\",palette=\"husl\", errorbar=None, alpha = 1, legend=False,\n",
    "    native_scale=True,)\n",
    "\n",
    "g.set_titles(col_template=\"{col_name}\", row_template=\"Test performance\")\n",
    "\n",
    "g.set_ylabels(\"AUROC\")\n",
    "# Rotate x-ticks\n",
    "g.tick_params(axis = 'y', labelsize=8)\n",
    "\n",
    "g.figure.subplots_adjust(wspace=0.8, hspace=0.4)\n",
    "\n",
    "\n",
    "\n",
    "plt.savefig(f\"{LATEX_PATH}/figures/molecules/kernel_point_cls.pdf\", bbox_inches=\"tight\")\n"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "DATASETS = df_metadata[df_metadata.task_type == \"reg\"].index.tolist()\n",
    "\n",
    "\n",
    "\n",
    "df_base= get_all_results(MODELS_TO_EVAL, \"downstream_results\", DATASETS,)\n",
    "\n",
    "df_base.reset_index(inplace=True, drop=True)\n",
    "step = df_base.embedder.value_counts().max()\n",
    "df_base[\"id\"] = df_base.index%step\n",
    "\n",
    "df_base[\"short_dataset\"] = df_base.dataset.apply(lambda x: df_metadata.loc[x].short_name)\n",
    "\n",
    "n_clusters = {\n",
    "    \"2-layers-kernel\": 2,\n",
    "    \"5-layers-kernel\": 5,\n",
    "    \"student-250k\": 3,\n",
    "}\n",
    "\n",
    "df_base[\"n_cluster\"] = df_base.embedder.apply(lambda x: n_clusters[x])\n",
    "\n",
    "g = sns.catplot(\n",
    "    data=df_base.dropna(),\n",
    "    col=\"short_dataset\",\n",
    "    y=\"metric_test\",\n",
    "    x=\"n_cluster\",\n",
    "    hue = \"embedder\",\n",
    "    kind=\"point\",\n",
    "    palette=\"husl\",\n",
    "    height=1.1,\n",
    "    aspect=1.5,\n",
    "    col_wrap=5,\n",
    "    sharey=False,\n",
    "    alpha = 0.7,\n",
    "    legend=False,\n",
    "    errorbar=\"ci\",\n",
    "    native_scale=True,\n",
    ")\n",
    "g.map(sns.lineplot, \"n_cluster\", \"metric_test\", errorbar=None, color=\"black\", alpha = 0.3, linewidth=2.5, )\n",
    "g.map(sns.pointplot, \"n_cluster\", \"metric_test\", \"embedder\",palette=\"husl\", errorbar=None, alpha = 1, legend=False,\n",
    "    native_scale=True,)\n",
    "\n",
    "g.set_titles(col_template=\"{col_name}\", row_template=\"Test performance\")\n",
    "\n",
    "g.set_ylabels(\"$R^2$\")\n",
    "# Rotate x-ticks\n",
    "g.tick_params(axis = 'y', labelsize=8)\n",
    "\n",
    "g.figure.subplots_adjust(wspace=0.8, hspace=0.4)\n",
    "\n",
    "\n",
    "\n",
    "plt.savefig(f\"{LATEX_PATH}/figures/molecules/kernel_point_reg.pdf\", bbox_inches=\"tight\")\n"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "DATASETS = df_metadata.index.tolist()\n",
    "DATASETS.remove(\"ToxCast\")\n",
    "\n",
    "\n",
    "df_base= get_all_results(MODELS_TO_EVAL, \"downstream_results\", DATASETS,)\n",
    "\n",
    "df_base.reset_index(inplace=True, drop=True)\n",
    "step = df_base.embedder.value_counts().max()\n",
    "df_base[\"id\"] = df_base.index%step\n",
    "\n",
    "df_base[\"short_dataset\"] = df_base.dataset.apply(lambda x: df_metadata.loc[x].short_name)\n",
    "\n",
    "n_clusters = {\n",
    "    \"2-layers-kernel\": 2,\n",
    "    \"5-layers-kernel\": 5,\n",
    "    \"student-250k\": 3,\n",
    "}\n",
    "\n",
    "df_base[\"n_cluster\"] = df_base.embedder.apply(lambda x: n_clusters[x])\n",
    "\n",
    "g = sns.catplot(\n",
    "    data=df_base.dropna(),\n",
    "    col=\"short_dataset\",\n",
    "    y=\"metric_test\",\n",
    "    x=\"n_cluster\",\n",
    "    hue = \"embedder\",\n",
    "    kind=\"point\",\n",
    "    palette=\"husl\",\n",
    "    height=1.1,\n",
    "    aspect=1.5,\n",
    "    col_wrap=4,\n",
    "    sharey=False,\n",
    "    alpha = 0.7,\n",
    "    legend=False,\n",
    "    errorbar=\"ci\",\n",
    "    native_scale=True,\n",
    ")\n",
    "g.map(sns.lineplot, \"n_cluster\", \"metric_test\", errorbar=None, color=\"black\", alpha = 0.3, linewidth=2.5, )\n",
    "g.map(sns.pointplot, \"n_cluster\", \"metric_test\", \"embedder\",palette=\"husl\", errorbar=None, alpha = 1, legend=False,\n",
    "    native_scale=True,)\n",
    "\n",
    "g.set_titles(col_template=\"{col_name}\", row_template=\"Test performance\")\n",
    "\n",
    "g.set_ylabels(\"\")\n",
    "# Rotate x-ticks\n",
    "g.tick_params(axis = 'y', labelsize=8)\n",
    "\n",
    "g.figure.subplots_adjust(wspace=0.8, hspace=0.4)\n",
    "g.figure.supylabel(\"AUROC/$R^2$\")\n",
    "\n",
    "\n",
    "plt.savefig(f\"{LATEX_PATH}/figures/molecules/kernel_point.pdf\", bbox_inches=\"tight\")\n"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [],
   "metadata": {
    "collapsed": false
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
