{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Runtimes Experiment\n",
    "\n",
    "This notebook plots the raw data from the runtimes experiment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "pd.set_option('display.max_rows', 200)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data Preprocessing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "names    = ['JAX',   'TensorFlow 2',   'TensorFlow 1', 'PyTorch', 'JAX', 'Custom TFP', 'TFP', 'Opacus', 'BackPACK', 'PyVacy', 'CRB', 'TensorFlow 2 (XLA)', 'Custom TFP (XLA)', 'TensorFlow 1 (XLA)', 'TFP (XLA)',]\n",
    "private  = [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1]\n",
    "filenames= ['jaxdp', 'tf2dp', 'tf1dp', 'pytorch', 'jaxdp', 'tf2dp', 'tf1dp', 'opacusdp', 'backpackdp', 'pyvacydp', 'owkindp', 'tf2dp', 'tf2dp', 'tf1dp', 'tf1dp',]\n",
    "expts = ['logreg', 'ffnn', 'mnist', 'embed', 'lstm', 'cifar10']\n",
    "batch_sizes = [16, 32, 64, 128, 256]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert len(names) == len(private) == len(filenames)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def expt_iterator():\n",
    "    for expt in expts:\n",
    "        for bs in batch_sizes:\n",
    "            for dpsgd, name, filename in zip(private, names, filenames):\n",
    "                yield expt, bs, name, filename, bool(dpsgd)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "files = []\n",
    "success, errors = 0, 0\n",
    "for expt, bs, name, filename, dpsgd in expt_iterator():\n",
    "    pickle_name = f'./raw/{filename}_{expt}_bs_{bs}_priv_{dpsgd}'\n",
    "    \n",
    "    use_xla = 'xla' in name.lower() or name.lower().startswith('jax')\n",
    "    if filename.startswith('tf'):\n",
    "        pickle_name += f'_xla_{use_xla}'\n",
    "    \n",
    "    try:\n",
    "        with open(pickle_name+'.pkl', 'rb') as f:\n",
    "            d = pickle.load(f)\n",
    "            success += 1\n",
    "    except:\n",
    "        print(f'Failed to load {pickle_name}.pkl')\n",
    "        d = None\n",
    "        errors += 1\n",
    "    files.append((filename, name, expt, bs, dpsgd, use_xla, d))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "success, errors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_list = []\n",
    "for *row, d in files:\n",
    "    d = [np.median(d['timings'])] if d else [0.]\n",
    "    df_list.append(pd.Series(row + d))\n",
    "\n",
    "df = pd.concat(df_list, axis=1).transpose()\n",
    "df.columns = ['Filename', 'Library', 'Experiment', 'Batch Size', 'Private?', 'XLA', 'Median Runtime']\n",
    "df['Median Runtime'] = df['Median Runtime'].astype(float)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df['Order'] = -1\n",
    "for i,name in enumerate(['JAX', 'Custom TFP (XLA)', 'Custom TFP', 'TFP (XLA)', 'TFP', \n",
    "                         'Opacus', 'BackPACK', 'CRB', 'PyVacy', \n",
    "                         'TensorFlow 2 (XLA)', 'TensorFlow 2', 'TensorFlow 1 (XLA)', 'TensorFlow 1', 'PyTorch']):\n",
    "    df.loc[df['Library'] == name, 'Order'] = i\n",
    "assert not (df['Order'] == -1).sum()\n",
    "df = df.sort_values(by=['Batch Size', 'Order'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "expt_to_title = {\n",
    "    'mnist': 'MNIST Convolutional Neural Network',\n",
    "    'lstm': 'LSTM Network',\n",
    "    'embed': 'Embedding Network',\n",
    "    'ffnn': 'Fully Connected Neural Network (FCNN)',\n",
    "    'logreg': 'Logistic Regression',\n",
    "    'cifar10': 'CIFAR10 Convolutional Neural Network'\n",
    "}\n",
    "\n",
    "def get_runtime_plot(expt, ylim=None, figsize=(13, 6)):\n",
    "    f, ax = plt.subplots(2, 1, figsize=figsize, sharey=True)\n",
    "    plot_df = df[df['Experiment'] == expt].copy()\n",
    "    if ylim:\n",
    "        plot_df['Median Runtime'] = np.minimum(plot_df['Median Runtime'], ylim-2)\n",
    "\n",
    "    sns.barplot(x='Library', y='Median Runtime', hue='Batch Size', data=plot_df[plot_df['Private?']], ax=ax[0], palette='muted')\n",
    "    sns.barplot(x='Library', y='Median Runtime', hue='Batch Size', data=plot_df[plot_df['Private?'] != True], ax=ax[1], palette='muted')\n",
    "\n",
    "    for ax_ind, private in enumerate([True, False]):\n",
    "        tmp = df.loc[(df['Experiment'] == expt) & (df['Private?'] == private), 'Median Runtime']\n",
    "        for i, (rect, tim) in enumerate(zip(ax[ax_ind].patches, tmp)):\n",
    "            height = rect.get_height()\n",
    "            if tim > 100.:\n",
    "                annotation = f'{int(tim)}'\n",
    "            elif tim > 0.:\n",
    "                annotation = f'{tim:.2g}'\n",
    "            else:\n",
    "                annotation = ''\n",
    "            ax[ax_ind].annotate(annotation,\n",
    "                                xy=(rect.get_x() + rect.get_width() / 2 - 0.3*rect.get_width(), height),\n",
    "                                xytext=(0, 3),  # 3 points vertical offset\n",
    "                                textcoords=\"offset points\",\n",
    "                                va='bottom', ha='left', \n",
    "                                fontsize=9, rotation=45)\n",
    "\n",
    "\n",
    "\n",
    "    plt.title('')\n",
    "    if expt == 'lstm':\n",
    "        y = 1.18\n",
    "    elif expt == 'cifar10':\n",
    "        y = 1.15\n",
    "    else:\n",
    "        y = 1\n",
    "    ax[0].set_title('Median Runtime for One Private Epoch - '+ expt_to_title[expt], \n",
    "                    y=y)\n",
    "    ax[1].set_title('Median Runtime for One Non-Private Epoch - '+ expt_to_title[expt])\n",
    "    ax[0].set_xlabel('Library')\n",
    "    ax[1].set_xlabel('Library')\n",
    "    ax[0].set_ylabel('Median Runtime (sec)')\n",
    "    ax[1].set_ylabel('Median Runtime (sec)')\n",
    "    if ylim:\n",
    "        ax[0].set_ylim(0, ylim)\n",
    "        ax[1].set_ylim(0, ylim)\n",
    "    # ax[1].set_ylabel('')\n",
    "    ax[0].get_legend().remove()\n",
    "    ax[1].get_legend().remove()\n",
    "    sns.despine()\n",
    "    plt.legend()\n",
    "    f.patch.set_facecolor('white')\n",
    "    f.tight_layout()\n",
    "    return f, ax"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# df.to_csv('median_runtimes_data.csv')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plot Runtimes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "f, ax = get_runtime_plot('logreg', ylim=20, figsize=(11, 5))\n",
    "None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "f.savefig('../../paper/assets/logistic_runtimes.pdf')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "f, ax = get_runtime_plot('ffnn', 20, figsize=(11, 5))\n",
    "None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "f.savefig('../../paper/assets/ffnn_runtimes.pdf')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "f, ax = get_runtime_plot('mnist', 50, figsize=(11, 5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "f.savefig('../../paper/assets/cnn_runtimes.pdf')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "f, ax = get_runtime_plot('cifar10', 175, figsize=(11, 5))\n",
    "None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "f.savefig('../../paper/assets/cifar10_cnn_runtimes.pdf')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "f, ax = get_runtime_plot('embed', 20, figsize=(11, 5))\n",
    "None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "f.savefig('../../paper/assets/embed_runtimes.pdf')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "f, ax = get_runtime_plot('lstm', 250, figsize=(11, 5))\n",
    "None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "f.savefig('../../paper/assets/lstm_runtimes.pdf')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
