{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Ablation Study\n",
    "\n",
    "This collects and shows the data from the ablation study on vectorization and compilation through JAX and Custom TFP."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pd.set_option('display.max_rows', 200)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## JAX"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "expts = ['logreg', 'ffnn', 'mnist', 'embed', 'cifar10', 'lstm']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "expt_dict = {}\n",
    "for expt in expts:\n",
    "    expt_dict[expt] = {}\n",
    "    pickle_name = f'./raw/jaxdp_{expt}_bs_128_priv_True'\n",
    "    try:\n",
    "        with open(pickle_name+'.pkl', 'rb') as f:\n",
    "            expt_dict[expt]['base'] = np.median(pickle.load(f)['timings'])\n",
    "    except: print(f'Failed {expt} xla')\n",
    "    try:\n",
    "        with open(pickle_name+'_novmap.pkl', 'rb') as f:\n",
    "            expt_dict[expt]['nv'] = np.median(pickle.load(f)['timings'])\n",
    "    except: print(f'Failed {expt} xla')\n",
    "    try:\n",
    "        with open(pickle_name+'_nojit.pkl', 'rb') as f:\n",
    "            expt_dict[expt]['nj'] = np.median(pickle.load(f)['timings'])\n",
    "    except: print(f'Failed {expt} xla')\n",
    "    try:\n",
    "        with open(pickle_name+'_nojit_novmap.pkl', 'rb') as f:\n",
    "            expt_dict[expt]['nvj'] = np.median(pickle.load(f)['timings'])\n",
    "    except: print(f'Failed {expt} xla')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "expt_dict"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Custom TFP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "expts = ['logreg', 'ffnn', 'mnist', 'embed', 'lstm', 'cifar10']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tf_expt_dict = {}\n",
    "for expt in expts:\n",
    "    tf_expt_dict[expt] = {}\n",
    "    pickle_name = f'./raw/tf2dp_{expt}_bs_128_priv_True_xla_'\n",
    "    try:\n",
    "        with open(pickle_name+'False.pkl', 'rb') as f:\n",
    "            tf_expt_dict[expt]['base'] = np.median(pickle.load(f)['timings'])\n",
    "    except: print(f'Failed {expt} base')\n",
    "    try:\n",
    "        with open(pickle_name+'True.pkl', 'rb') as f:\n",
    "            tf_expt_dict[expt]['xla'] = np.median(pickle.load(f)['timings'])\n",
    "    except: print(f'Failed {expt} xla')\n",
    "    try:\n",
    "        with open(pickle_name+'False_novmap.pkl', 'rb') as f:\n",
    "            tf_expt_dict[expt]['nv'] = np.median(pickle.load(f)['timings'])\n",
    "    except: print(f'Failed {expt} nv')\n",
    "    try:\n",
    "        with open(pickle_name+'False_nojit_novmap.pkl', 'rb') as f:\n",
    "            tf_expt_dict[expt]['nvj'] = np.median(pickle.load(f)['timings'])\n",
    "    except: print(f'Failed {expt} nvj')\n",
    "    try:\n",
    "        with open(pickle_name+'True_novmap.pkl', 'rb') as f:\n",
    "            tf_expt_dict[expt]['xla_nv'] = np.median(pickle.load(f)['timings'])\n",
    "    except: print(f'Failed {expt} xla_nv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tf_expt_dict"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
