{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Runtimes Experiment - Confidence Intervals\n",
    "\n",
    "This notebook is nearly identical to `runtimes.ipynb`, except the data it consumes was run for 100 epochs and this notebook will plot confidence intervals along with the bars."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pd.set_option('display.max_rows', 200)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "names    = ['JAX',   'TensorFlow 2',   'TensorFlow 1', 'PyTorch', 'JAX', 'Custom TFP', 'TFP', 'Opacus', 'BackPACK', 'PyVacy', 'CRB', 'TensorFlow 2 (XLA)', 'Custom TFP (XLA)', 'TensorFlow 1 (XLA)', 'TFP (XLA)',]\n",
    "private  = [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1]\n",
    "filenames= ['jaxdp', 'tf2dp', 'tf1dp', 'pytorch', 'jaxdp', 'tf2dp', 'tf1dp', 'opacusdp', 'backpackdp', 'pyvacydp', 'owkindp', 'tf2dp', 'tf2dp', 'tf1dp', 'tf1dp',]\n",
    "expts = ['logreg', 'ffnn', 'mnist', 'embed', 'lstm']\n",
    "batch_sizes = [16, 32, 64, 128, 256]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(names), len(private), len(filenames)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def expt_iterator():\n",
    "    for expt in expts:\n",
    "        for bs in batch_sizes:\n",
    "            for dpsgd, name, filename in zip(private, names, filenames):\n",
    "                yield expt, bs, name, filename, bool(dpsgd)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "files = []\n",
    "success, errors = 0, 0\n",
    "for expt, bs, name, filename, dpsgd in expt_iterator():\n",
    "    pickle_name = f'./raw/{filename}_{expt}_bs_{bs}_priv_{dpsgd}'\n",
    "    \n",
    "    use_xla = 'xla' in name.lower() or name.lower().startswith('jax')\n",
    "    if filename.startswith('tf'):\n",
    "        pickle_name += f'_xla_{use_xla}'\n",
    "    \n",
    "    try:\n",
    "        with open(pickle_name+'.pkl', 'rb') as f:\n",
    "            d = pickle.load(f)\n",
    "            success += 1\n",
    "    except:\n",
    "        print(f'Failed to load {pickle_name}.pkl')\n",
    "        d = None\n",
    "        errors += 1\n",
    "    files.append((filename, name, expt, bs, dpsgd, use_xla, d))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "success, errors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# df_list = []\n",
    "# for *row, d in files:\n",
    "#     # d = [np.median(d['timings'])] if d else [0.]\n",
    "#     d = [np.mean(d['timings'][1:])] if d else [0.]\n",
    "#     df_list.append(pd.Series(row + d))\n",
    "\n",
    "# df = pd.concat(df_list, axis=1).transpose()\n",
    "# df.columns = ['Filename', 'Library', 'Experiment', 'Batch Size', 'Private?', 'XLA', 'Runtime']\n",
    "# df['Runtime'] = df['Runtime'].astype(float)\n",
    "# old_df = df.copy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_list = []\n",
    "for *row, d in files:\n",
    "    if d:\n",
    "        assert len(d['timings']) == 102\n",
    "        for timing in d['timings'][1:]:\n",
    "            df_list.append(pd.Series(row + [timing]))\n",
    "    else:\n",
    "        df_list.append(pd.Series(row + [0.]))\n",
    "    \n",
    "\n",
    "df = pd.concat(df_list, axis=1).transpose()\n",
    "df.columns = ['Filename', 'Library', 'Experiment', 'Batch Size', 'Private?', 'XLA', 'Runtime']\n",
    "df['Runtime'] = df['Runtime'].astype(float)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df['Order'] = -1\n",
    "for i,name in enumerate(['JAX', 'Custom TFP (XLA)', 'Custom TFP', 'TFP (XLA)', 'TFP', \n",
    "                         'Opacus', 'BackPACK', 'CRB', 'PyVacy', \n",
    "                         'TensorFlow 2', 'TensorFlow 2 (XLA)', 'TensorFlow 1', 'TensorFlow 1 (XLA)', 'PyTorch']):\n",
    "    df.loc[df['Library'] == name, 'Order'] = i\n",
    "assert not (df['Order'] == -1).sum()\n",
    "df = df.sort_values(by=['Batch Size', 'Order'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "means = df.groupby(['Filename', 'Library', 'Experiment', 'Batch Size', 'Private?', 'XLA', 'Order']).agg('mean').reset_index()\n",
    "means.columns = ['Filename', 'Library', 'Experiment', 'Batch Size', 'Private?', 'XLA', 'Order', 'Runtime']\n",
    "means = means.sort_values(by=['Batch Size', 'Order'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "expt_to_title = {\n",
    "    'mnist': 'Convolutional Neural Network (CNN)',\n",
    "    'lstm': 'LSTM Network',\n",
    "    'embed': 'Embedding Network',\n",
    "    'ffnn': 'Fully Connected Neural Network (FCNN)',\n",
    "    'logreg': 'Logistic Regression',\n",
    "}\n",
    "\n",
    "def get_runtime_plot(expt, ylim=None, figsize=(13, 6)):\n",
    "    f, ax = plt.subplots(2, 1, figsize=figsize, sharey=True)\n",
    "    plot_df = df[df['Experiment'] == expt].copy()\n",
    "    if ylim:\n",
    "        plot_df['Runtime'] = np.minimum(plot_df['Runtime'], ylim-2)\n",
    "\n",
    "    sns.barplot(x='Library', y='Runtime', hue='Batch Size', ci='sd',\n",
    "                data=plot_df[plot_df['Private?']], ax=ax[0], palette='muted')\n",
    "    sns.barplot(x='Library', y='Runtime', hue='Batch Size', ci='sd',\n",
    "                data=plot_df[plot_df['Private?'] != True], ax=ax[1], palette='muted')\n",
    "\n",
    "    for ax_ind, private in enumerate([True, False]):\n",
    "        tmp = means.loc[(means['Experiment'] == expt) & (means['Private?'] == private), 'Runtime']\n",
    "        for i, (rect, tim) in enumerate(zip(ax[ax_ind].patches, tmp)):\n",
    "            height = rect.get_height()\n",
    "            if tim > 100.:\n",
    "                annotation = f'{int(tim)}'\n",
    "            elif tim > 0.:\n",
    "                annotation = f'{tim:.2g}'\n",
    "            else:\n",
    "                annotation = ''\n",
    "            ax[ax_ind].annotate(annotation,\n",
    "                                xy=(rect.get_x() + rect.get_width() / 2 - 0.3*rect.get_width(), height),\n",
    "                                xytext=(0, 3),  # 3 points vertical offset\n",
    "                                textcoords=\"offset points\",\n",
    "                                va='bottom', ha='left', \n",
    "                                fontsize=9, rotation=45)\n",
    "\n",
    "\n",
    "\n",
    "    plt.title('')\n",
    "    if expt == 'lstm':\n",
    "        y = 1.18\n",
    "    # elif expt == 'embed':\n",
    "    #     y = 1.1\n",
    "    else:\n",
    "        y = 1\n",
    "    ax[0].set_title('Mean Runtime for One Private Epoch - '+ expt_to_title[expt], \n",
    "                    y=y)\n",
    "    ax[1].set_title('Mean Runtime for One Non-Private Epoch - '+ expt_to_title[expt])\n",
    "    ax[0].set_xlabel('Library')\n",
    "    ax[1].set_xlabel('Library')\n",
    "    ax[0].set_ylabel('Runtime (sec)')\n",
    "    ax[1].set_ylabel('Runtime (sec)')\n",
    "    if ylim:\n",
    "        ax[0].set_ylim(0, ylim)\n",
    "        ax[1].set_ylim(0, ylim)\n",
    "    # ax[1].set_ylabel('')\n",
    "    ax[0].get_legend().remove()\n",
    "    ax[1].get_legend().remove()\n",
    "    sns.despine()\n",
    "    plt.legend()\n",
    "    f.patch.set_facecolor('white')\n",
    "    f.tight_layout()\n",
    "    return f, ax"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# table with one batch size\n",
    "\n",
    "# Ccheck x-axis\n",
    "f, ax = get_runtime_plot('logreg', ylim=20, figsize=(11, 5))\n",
    "None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# f.savefig('../../mlsys/assets/logistic_runtimes.pdf')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "f, ax = get_runtime_plot('ffnn', 20, figsize=(11, 5))\n",
    "None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# f.savefig('../../mlsys/assets/ffnn_runtimes.pdf')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "f, ax = get_runtime_plot('mnist', 50, figsize=(11, 5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# f.savefig('../../mlsys/assets/cnn_runtimes.pdf')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "f, ax = get_runtime_plot('embed', 20, figsize=(11, 5))\n",
    "None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# f.savefig('../../mlsys/assets/embed_runtimes.pdf')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "f, ax = get_runtime_plot('lstm', 250, figsize=(11, 5))\n",
    "None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# f.savefig('../../mlsys/assets/lstm_runtimes.pdf')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df.to_csv('arxiv_paper_data.csv')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
