{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "%cd ../..\n",
    "\n",
    "LATEX_PATH = \"../latex/Distillation-MI-ICLR\""
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import os.path\n",
    "\n",
    "import pandas as pd\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "\n",
    "from molDistill.utils.notebooks import *\n",
    "\n",
    "MODELS_TO_EVAL = [\n",
    "    \"ChemBertMLM-10M\",\n",
    "    \"ChemBertMTR-77M\",\n",
    "    \"ChemGPT-1.2B\",\n",
    "    \"GraphMVP\",\n",
    "    \"GROVER\",\n",
    "    \"GraphLog\",\n",
    "    \"GraphCL\",\n",
    "    \"InfoGraph\",\n",
    "    \"FRAD_QM9\",\n",
    "    \"MolR_gat\",\n",
    "    \"ThreeDInfomax\",\n",
    "    STUDENT_MODEL,\n",
    "    ZINC_MODEL,\n",
    "    L2_MODEL,\n",
    "    COS_MODEL,\n",
    "]\n",
    "\n",
    "DATASETS = df_metadata[df_metadata.task_type == \"reg\"].index.tolist()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Mean Performances"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "\n",
    "df_base= get_all_results(MODELS_TO_EVAL + [ZINC_MODEL], \"downstream_results\", DATASETS)\n",
    "\n",
    "df, order = aggregate_results_with_ci(df_base)\n",
    "df"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "\n",
    "order.remove(\"student-250k\")\n",
    "order.remove(\"student-2M\")\n",
    "order.remove(\"L2\")\n",
    "order.remove(\"Cosine\")\n",
    "order = order[::-1] + [\"L2\", \"Cosine\", \"student-250k\", \"student-2M\"]"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "LIST_SEP = [\" \", \"Absorption\"]\n",
    "\n",
    "style,latex = style_df_ci(\n",
    "    df[[col for col in df.columns if (col[0] in LIST_SEP)]],\n",
    "    order,\n",
    "    rotate=\"-\"\n",
    ")\n",
    "latex = add_hline(latex, 1)\n",
    "latex = add_hline(latex, -4)\n",
    "latex = add_hline(latex, -3)\n",
    "latex = add_hline(latex, -2)\n",
    "\n",
    "table_path = f\"{LATEX_PATH}/tables/molecules/all_reg1.tex\"\n",
    "with open(table_path, \"w\") as f:\n",
    "    f.write(latex)\n",
    "\n",
    "\n",
    "style,latex = style_df_ci(\n",
    "    df[[col for col in df.columns if (col[0] not in LIST_SEP)]],\n",
    "    order,\n",
    "    rotate=\"-\"\n",
    ")\n",
    "latex = add_hline(latex, 1)\n",
    "latex = add_hline(latex, -4)\n",
    "latex = add_hline(latex, -3)\n",
    "latex = add_hline(latex, -2)\n",
    "\n",
    "table_path = f\"{LATEX_PATH}/tables/molecules/all_reg2.tex\"\n",
    "with open(table_path, \"w\") as f:\n",
    "    f.write(latex)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [],
   "metadata": {
    "collapsed": false
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
