{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "682ef10d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import wandb\n",
    "from tqdm.notebook import tqdm\n",
    "import pickle\n",
    "from os.path import exists\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import math\n",
    "import ast\n",
    "import scipy as sp\n",
    "import scipy.stats as sps\n",
    "\n",
    "from matplotlib.ticker import MaxNLocator\n",
    "#...\n",
    "\n",
    "font = {'family' : 'times',\n",
    "        'size'   : 14}\n",
    "\n",
    "matplotlib.rc('font', **font)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d06d4e0c",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Experiment:\n",
    "    def __init__(self, run):\n",
    "        self.name = run.name\n",
    "        self.config = run.config\n",
    "        self.summary = run.summary\n",
    "        self.history = run.history()\n",
    "        self.tags = run.tags\n",
    "        self.run = run\n",
    "        \n",
    "    def get_id(self):\n",
    "        return (self.config['formula'],self.config['mol_idx'])\n",
    "        \n",
    "    def get_history(self):\n",
    "        return np.array(list(self.history['additional_steps'])).cumsum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d004a6d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def fetch(project):\n",
    "    api = wandb.Api()\n",
    "    entity = \"bogp\"\n",
    "    hdata = []\n",
    "    runs = api.runs(entity + \"/\" + project)\n",
    "    for run in tqdm(runs):\n",
    "        try:\n",
    "            hdata.append(Experiment(run))\n",
    "        except:\n",
    "            pass\n",
    "    return hdata"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28e2b924",
   "metadata": {},
   "outputs": [],
   "source": [
    "raw = fetch(\"scale_master\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2c7a4097",
   "metadata": {},
   "source": [
    "# Width"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "048221ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "exps = {}\n",
    "for exp in raw:\n",
    "    if exp.run.group == \"bayes_wide\":\n",
    "        print(exp.name)\n",
    "        exps[exp.config[\"num_particles\"]] = exp\n",
    "exps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d25b9c9a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def calc(exps, p):\n",
    "    d = ast.literal_eval(exps[p].history.orig_dist0[11])\n",
    "    d_swag = ast.literal_eval(exps[p].history.max_dist0[len(exps[p].history.max_dist0) - 1])\n",
    "    misclass = 0\n",
    "    misclass_swag = 0\n",
    "    total = 0\n",
    "    for c in range(10):\n",
    "        misclass += sum(d[c]) - d[c][c]\n",
    "        total += sum(d[c])\n",
    "        misclass_swag += sum(d_swag[c]) - d_swag[c][c]\n",
    "        # print(\"Orig\", c, sps.entropy([x / sum(d[c]) for x in d[c]]))\n",
    "        # print(\"Swag\", c, sps.entropy([x / sum(d_swag[c]) for x in d_swag[c]]))\n",
    "    print(\"original misclass\", misclass, \"mswag misclass\", misclass_swag)\n",
    "    return 1 - (misclass/total), 1- (misclass_swag/total)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a799c16c",
   "metadata": {},
   "outputs": [],
   "source": [
    "ps = [1, 2, 4, 8, 16, 32]\n",
    "orig = []\n",
    "mswag = []\n",
    "for p in ps:\n",
    "    m1, m2 = calc(exps, p)\n",
    "    orig += [m1]\n",
    "    mswag += [m2]\n",
    "plt.plot(ps, orig, label='Standard', marker='s', linestyle=\"--\" )\n",
    "plt.plot(ps, mswag, label=\"Multi-Swag\", marker='o', linestyle=\":\")\n",
    "plt.xlabel(\"Particles\")\n",
    "plt.ylabel(\"Accuracy\")\n",
    "plt.title(\"Standard Training vs. Multi-Swag on MNIST\")\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c5dd40cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "params = []\n",
    "for p in ps:\n",
    "    params += [exps[p].config[\"num_params\"]]\n",
    "params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3d526059",
   "metadata": {},
   "outputs": [],
   "source": [
    "orig, mswag"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "18693adb",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.DataFrame({\n",
    "    \"parameters\": params,\n",
    "    \"original accuracy\": orig,\n",
    "    \"particles\": ps,\n",
    "    \"mswag accuracy\": mswag,\n",
    "})\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b6c0c3e",
   "metadata": {},
   "outputs": [],
   "source": [
    "df.to_latex(buf=\"table_width.tex\", index=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "62b73776",
   "metadata": {},
   "source": [
    "# Depth"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d7e68d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "bayes5 = {}\n",
    "for exp in raw:\n",
    "    if exp.run.group == \"bayes_deep\":\n",
    "        print(exp.name)\n",
    "        bayes5[exp.config[\"num_particles\"]] = exp\n",
    "bayes5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51f0f196",
   "metadata": {},
   "outputs": [],
   "source": [
    "def calc2(exps, p):\n",
    "    d = ast.literal_eval(exps[p].history.orig_dist0[11])\n",
    "    d_swag = ast.literal_eval(exps[p].history.max_dist0[len(exps[p].history.max_dist0) - 1])\n",
    "    misclass = 0\n",
    "    misclass_swag = 0\n",
    "    total = 0\n",
    "    for c in range(10):\n",
    "        misclass += sum(d[c]) - d[c][c]\n",
    "        total += sum(d[c])\n",
    "        misclass_swag += sum(d_swag[c]) - d_swag[c][c]\n",
    "        # print(\"Orig\", c, sps.entropy([x / sum(d[c]) for x in d[c]]))\n",
    "        # print(\"Swag\", c, sps.entropy([x / sum(d_swag[c]) for x in d_swag[c]]))\n",
    "    print(\"original misclass\", misclass, \"mswag misclass\", misclass_swag)\n",
    "    return 1 - (misclass/total), 1 - (misclass_swag/total)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bf77267e",
   "metadata": {},
   "outputs": [],
   "source": [
    "params = []\n",
    "for p in ps:\n",
    "    params += [bayes5[p].config[\"num_params\"]]\n",
    "orig = []\n",
    "mswag = []\n",
    "for p in ps:\n",
    "    m1, m2 = calc2(bayes5, p)\n",
    "    orig += [m1]\n",
    "    mswag += [m2]\n",
    "    \n",
    "df = pd.DataFrame({\n",
    "    \"parameters\": params,\n",
    "    \"original accuracy\": orig,\n",
    "    \"particles\": ps,\n",
    "    \"mswag accuracy\": mswag,\n",
    "})\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0294b9ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "df.to_latex(buf=\"table_depth.tex\", index=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ca0ede91",
   "metadata": {},
   "source": [
    "# Plot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "65c425b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "ps = [1, 2, 4, 8, 16, 32]\n",
    "orig = []\n",
    "mswag = []\n",
    "for p in ps:\n",
    "    m1, m2 = calc2(exps, p)\n",
    "    orig += [m1]\n",
    "    mswag += [m2]\n",
    "fig, ax = plt.subplots()\n",
    "ax.plot(ps, mswag, label=\"Multi-Swag\", marker='o', linestyle=\":\")\n",
    "ax.plot(ps, orig, label='Standard', marker='s', linestyle=\"--\" )\n",
    "\n",
    "def foo(x):\n",
    "    print(x)\n",
    "    return exps[x].config[\"num_params\"]\n",
    "\n",
    "params_to_p = {}\n",
    "for p in ps:\n",
    "    params_to_p[exps[p].config[\"num_params\"]] = p\n",
    "\n",
    "def foo_inv(x):\n",
    "    return params_to_p[x]\n",
    "    \n",
    "secax = ax.secondary_xaxis('top', functions=(foo, foo_inv))\n",
    "\n",
    "\n",
    "ax.set_xlabel(\"Particles\")\n",
    "ax.set_ylabel(\"Acurracy\")\n",
    "ax.set_title(\"Standard Training vs. Multi-Swag on MNIST\")\n",
    "ax.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b768c12",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c88198c3",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e351feb6",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
