{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "%cd ../..\n",
    "\n",
    "LATEX_PATH = \"../latex/Distillation-MI-ICLR\""
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import os.path\n",
    "\n",
    "import pandas as pd\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "\n",
    "from molDistill.utils.notebooks import *\n",
    "\n",
    "MODELS_TO_EVAL = [\n",
    "    \"ChemBertMLM-10M\",\n",
    "    \"ChemBertMTR-77M\",\n",
    "    \"ChemGPT-1.2B\",\n",
    "    \"GraphMVP\",\n",
    "    \"GROVER\",\n",
    "    \"GraphLog\",\n",
    "    \"GraphCL\",\n",
    "    \"InfoGraph\",\n",
    "    \"FRAD_QM9\",\n",
    "    \"MolR_gat\",\n",
    "    \"ThreeDInfomax\",\n",
    "    STUDENT_MODEL,\n",
    "    ZINC_MODEL,\n",
    "    L2_MODEL,\n",
    "    COS_MODEL,\n",
    "]\n",
    "\n",
    "DATASETS = df_metadata.index.tolist()\n",
    "DATASETS.remove(\"ToxCast\")"
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Mean Performances"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "df_base= get_all_results(MODELS_TO_EVAL + [ZINC_MODEL], \"downstream_results\", DATASETS,)\n",
    "\n",
    "df, order = aggregate_results_with_ci(df_base)\n",
    "\n",
    "\n",
    "order.remove(\"student-250k\")\n",
    "order.remove(\"student-2M\")\n",
    "order.remove(\"L2\")\n",
    "order.remove(\"Cosine\")\n",
    "order = order[::-1] + [\"L2\", \"Cosine\", \"student-250k\", \"student-2M\"]"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "style,latex = style_df_ci(df, order)\n",
    "\n",
    "table_path = f\"{LATEX_PATH}/tables/molecules/all_raw.tex\"\n",
    "with open(table_path, \"w\") as f:\n",
    "    f.write(latex)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "style"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Rankings"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "df_base= get_all_results(MODELS_TO_EVAL, \"downstream_results\", DATASETS)\n",
    "df_base.reset_index(inplace=True, drop=True)\n",
    "\n",
    "step = df_base.embedder.value_counts().max()\n",
    "df_base[\"id\"] = df_base.index%step\n",
    "df_ranked = get_ranked_df(df_base)\n",
    "df_ranked"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "style, latex = style_df_ranked(df_ranked, order)\n",
    "\n",
    "\n",
    "style"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "col_format = \"r|\"\n",
    "over_cols = None\n",
    "for col in style.columns:\n",
    "    if col == \"Avg\":\n",
    "        col_format += \"|\"\n",
    "    col_format += \"c\"\n",
    "\n",
    "latex = style.to_latex(\n",
    "    column_format=col_format,\n",
    "    multicol_align=\"|c|\",\n",
    "    siunitx=True,\n",
    ")\n",
    "latex = add_hline(latex, 1)\n",
    "latex = add_hline(latex, -4)\n",
    "latex = add_hline(latex, -3)\n",
    "latex = add_hline(latex, -2)\n",
    "\n",
    "table_path = f\"{LATEX_PATH}/tables/molecules/all_ranks.tex\"\n",
    "with open(table_path, \"w\") as f:\n",
    "    f.write(latex)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Heatmap"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "MODELS_TO_EVAL = [\n",
    "    \"ChemBertMLM-10M\",\n",
    "    \"ChemBertMTR-77M\",\n",
    "    \"ChemGPT-1.2B\",\n",
    "    \"GraphMVP\",\n",
    "    \"GROVER\",\n",
    "    \"GraphLog\",\n",
    "    \"GraphCL\",\n",
    "    \"InfoGraph\",\n",
    "    \"FRAD_QM9\",\n",
    "    \"MolR_gat\",\n",
    "    \"ThreeDInfomax\",\n",
    "    STUDENT_MODEL,\n",
    "]\n",
    "\n",
    "DATASETS = df_metadata.index.tolist()\n",
    "DATASETS.remove(\"ToxCast\")"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "df_base= get_all_results(MODELS_TO_EVAL, \"downstream_results\", DATASETS)\n",
    "df, order = aggregate_results_with_ci(df_base)\n",
    "\n",
    "df_base.reset_index(inplace=True, drop=True)\n",
    "\n",
    "step = df_base.embedder.value_counts().max()\n",
    "df_base[\"id\"] = df_base.index%step"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "order.remove(\"student-2M\")\n",
    "order = order[::-1] + [\"student-2M\"]"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "ranked_df = get_ranked_df(df_base)\n",
    "\n",
    "ranked_df.columns = [\"embedder\"] + [df_metadata.loc[c, \"short_name\"] for c in ranked_df.columns[1:]]\n",
    "ranked_df"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "from matplotlib.patches import Rectangle\n",
    "plt.rc('text', usetex=True)\n",
    "plt.rc('text.latex', preamble=r'\\usepackage{lmodern}')\n",
    "\n",
    "\n",
    "\n",
    "REG_DATASETS = df_metadata[df_metadata[\"task_type\"] == \"reg\"].short_name.tolist()\n",
    "REG_DATASET = [d for d in REG_DATASETS if d in ranked_df.columns]\n",
    "CLS_DATASET = df_metadata[df_metadata[\"task_type\"] == \"cls\"].short_name.tolist()\n",
    "CLS_DATASET = [d for d in CLS_DATASET if d in ranked_df.columns]\n",
    "\n",
    "\n",
    "df_plot = ranked_df.set_index(\"embedder\").loc[order[::-1]].transpose()\n",
    "\n",
    "df_plot.loc[\"Average (reg)\"] = df_plot.loc[REG_DATASETS].mean()\n",
    "df_plot.loc[\"Average (cls)\"] = df_plot.loc[CLS_DATASET].mean()\n",
    "\n",
    "order_dataset = REG_DATASETS + [\"Average (reg)\"] + CLS_DATASET + [\"Average (cls)\"]\n",
    "df_plot = df_plot.loc[order_dataset]\n",
    "\n",
    "df_plot.columns = [x.replace(\"_\", \" \").split(\"-\")[0].replace(\"student\", \"\\\\textbf{Student-2M}\") for x in df_plot.columns]\n",
    "\n",
    "def highlight_value(data):\n",
    "    if str(np.round(data,1))[-1] != \"0\":\n",
    "        return r'\\textbf{\\underline{' + str(np.round(data,1)) + '}}'\n",
    "    else:\n",
    "        return r'\\textbf{\\underline{' + str(int(data)) + '}}'\n",
    "\n",
    "mask_min = (df_plot.transpose() == df_plot.min(axis=1)).transpose()"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "bold_values = np.array(\n",
    "    [\n",
    "        highlight_value(data) for data in df_plot.to_numpy().ravel()\n",
    "    ]\n",
    ").reshape(\n",
    "    np.shape(df_plot)\n",
    ")\n",
    "common_kwargs = {\n",
    "    \"cbar\": False,\n",
    "    \"vmin\": 1.4,\n",
    "    \"vmax\": 9,\n",
    "    \"annot_kws\": {\"color\": \"white\", \"fontsize\": 9},\n",
    "}"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "def create_heatmap(ax, df_plot, mask_min, bold_values, cmap_name = \"viridis_r\", desat = 0.0):\n",
    "    cmap = sns.color_palette(cmap_name, as_cmap=False)\n",
    "    cmap_not_min = sns.color_palette(cmap_name, as_cmap=False, desat=desat)\n",
    "    sns.heatmap(df_plot, mask = mask_min, annot=True, ax = ax,cmap=cmap_not_min, **common_kwargs)\n",
    "    sns.heatmap(\n",
    "        df_plot,\n",
    "        mask = ~mask_min,\n",
    "        annot= bold_values,\n",
    "        fmt='',\n",
    "        ax = ax,\n",
    "        cmap=cmap,\n",
    "        **common_kwargs\n",
    "    )\n",
    "\n",
    "    ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha = \"right\")\n",
    "    ax.set_yticklabels(ax.get_yticklabels(), rotation=0, ha = \"right\")\n",
    "    ax.set_xlabel(\"\")\n",
    "    ax.set_ylabel(\"\")\n",
    "\n",
    "    # reduce y ticks\n",
    "    ax.tick_params(axis='y', labelsize=12)\n",
    "    ax.tick_params(axis='x', labelsize=12)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "DATASETS_TO_PLOT = [REG_DATASETS, [\"Average (reg)\"], CLS_DATASET, [\"Average (cls)\"]]\n",
    "\n",
    "fig,axes = plt.subplots(\n",
    "    len(DATASETS_TO_PLOT),\n",
    "    1,\n",
    "    figsize=(3.5, 8.5),\n",
    "    sharex=True,\n",
    "    gridspec_kw={'height_ratios': [len(d) for d in DATASETS_TO_PLOT]}\n",
    ")\n",
    "axes = axes.flatten()\n",
    "plt.subplots_adjust(hspace=0.02)\n",
    "\n",
    "for i,dataset_to_plot in enumerate(DATASETS_TO_PLOT):\n",
    "    filter = [df_plot.index.get_loc(c) for c in dataset_to_plot]\n",
    "    create_heatmap(axes[i], df_plot.loc[dataset_to_plot], mask_min.loc[dataset_to_plot], bold_values[filter], cmap_name=\"flare\", desat=0.7)\n",
    "\n",
    "\n",
    "plt.savefig(f\"{LATEX_PATH}/figures/molecules/hmp_rankings.pdf\", bbox_inches=\"tight\",)\n"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [],
   "metadata": {
    "collapsed": false
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
