{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "cellView": "form",
    "executionInfo": {
     "elapsed": 5,
     "status": "ok",
     "timestamp": 1622693515966,
     "user": {
      "displayName": "Branislav Kveton",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GhjPg1PJKD4ixGqLaZ4X3dpkI6w6dO_eTj8z63Y=s64",
      "userId": "09350298559467540088"
     },
     "user_tz": 0
    },
    "id": "dUrLMhAcZYDY",
    "jupyter": {
     "source_hidden": true
    },
    "outputId": "31bb0c27-489c-4d61-85e8-6002f5252ed5",
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "python 3.9.15\n",
      "matplotlib 3.5.1\n",
      "8 joblib CPUs\n"
     ]
    }
   ],
   "source": [
    "# Imports and defaults\n",
    "from bandit import *\n",
    "import joblib\n",
    "from joblib import Parallel, delayed\n",
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import scipy.stats\n",
    "import time\n",
    "\n",
    "mpl.style.use(\"classic\")\n",
    "mpl.rcParams[\"figure.figsize\"] = [5, 3]\n",
    "\n",
    "mpl.rcParams[\"axes.linewidth\"] = 0.75\n",
    "mpl.rcParams[\"errorbar.capsize\"] = 3\n",
    "mpl.rcParams[\"figure.facecolor\"] = \"w\"\n",
    "mpl.rcParams[\"grid.linewidth\"] = 0.75\n",
    "mpl.rcParams[\"lines.linewidth\"] = 0.75\n",
    "mpl.rcParams[\"patch.linewidth\"] = 0.75\n",
    "mpl.rcParams[\"xtick.major.size\"] = 3\n",
    "mpl.rcParams[\"ytick.major.size\"] = 3\n",
    "\n",
    "mpl.rcParams[\"pdf.fonttype\"] = 42\n",
    "mpl.rcParams[\"ps.fonttype\"] = 42\n",
    "mpl.rcParams[\"font.size\"] = 10\n",
    "mpl.rcParams[\"axes.titlesize\"] = \"medium\"\n",
    "mpl.rcParams[\"legend.fontsize\"] = \"medium\"\n",
    "\n",
    "import platform\n",
    "print(\"python %s\" % platform.python_version())\n",
    "print(\"matplotlib %s\" % mpl.__version__)\n",
    "print(\"%d joblib CPUs\" % joblib.cpu_count())\n",
    "\n",
    "def linestyle2dashes(style):\n",
    "  if style == \"--\":\n",
    "    return (3, 3)\n",
    "  elif style == \":\":\n",
    "    return (0.5, 2.5)\n",
    "  else:\n",
    "    return (None, None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " K | sigma | mu0_max | sigma0 | Gaussian noise | Box noise\n",
      " 5 |  0.50 |    0.50 |   0.50 |          2.175 |     2.326\n",
      " 5 |  0.50 |    0.50 |   1.00 |          0.921 |     0.953\n",
      " 5 |  0.50 |    0.50 |   2.00 |          0.117 |     0.468\n",
      " 5 |  0.50 |    1.00 |   0.50 |          3.540 |     3.930\n",
      " 5 |  0.50 |    1.00 |   1.00 |          1.048 |     0.919\n",
      " 5 |  0.50 |    1.00 |   2.00 |          0.303 |     0.274\n",
      " 5 |  0.50 |    2.00 |   0.50 |          8.932 |     9.011\n",
      " 5 |  0.50 |    2.00 |   1.00 |          1.554 |     1.868\n",
      " 5 |  0.50 |    2.00 |   2.00 |          0.049 |     0.265\n",
      " 5 |  1.00 |    0.50 |   0.50 |          5.886 |     7.233\n",
      " 5 |  1.00 |    0.50 |   1.00 |          2.587 |     1.696\n",
      " 5 |  1.00 |    0.50 |   2.00 |          1.855 |     1.458\n",
      " 5 |  1.00 |    1.00 |   0.50 |         13.982 |    14.505\n",
      " 5 |  1.00 |    1.00 |   1.00 |          3.709 |     3.748\n",
      " 5 |  1.00 |    1.00 |   2.00 |          1.476 |     1.789\n",
      " 5 |  1.00 |    2.00 |   0.50 |         27.357 |    27.840\n",
      " 5 |  1.00 |    2.00 |   1.00 |          7.352 |     7.186\n",
      " 5 |  1.00 |    2.00 |   2.00 |          2.413 |     1.329\n",
      " 5 |  2.00 |    0.50 |   0.50 |         26.844 |    24.412\n",
      " 5 |  2.00 |    0.50 |   1.00 |         10.350 |    11.361\n",
      " 5 |  2.00 |    0.50 |   2.00 |          4.269 |     6.255\n",
      " 5 |  2.00 |    1.00 |   0.50 |         52.433 |    49.125\n",
      " 5 |  2.00 |    1.00 |   1.00 |         16.235 |    13.383\n",
      " 5 |  2.00 |    1.00 |   2.00 |          9.205 |     7.576\n",
      " 5 |  2.00 |    2.00 |   0.50 |         89.494 |    89.404\n",
      " 5 |  2.00 |    2.00 |   1.00 |         30.115 |    30.176\n",
      " 5 |  2.00 |    2.00 |   2.00 |          8.027 |     8.220\n",
      "10 |  0.50 |    0.50 |   0.50 |          6.556 |     6.364\n",
      "10 |  0.50 |    0.50 |   1.00 |          2.592 |     2.197\n",
      "10 |  0.50 |    0.50 |   2.00 |          0.593 |     0.612\n",
      "10 |  0.50 |    1.00 |   0.50 |          9.123 |     8.777\n",
      "10 |  0.50 |    1.00 |   1.00 |          2.251 |     3.144\n",
      "10 |  0.50 |    1.00 |   2.00 |          1.074 |     0.761\n",
      "10 |  0.50 |    2.00 |   0.50 |         20.350 |    20.910\n",
      "10 |  0.50 |    2.00 |   1.00 |          4.046 |     4.715\n",
      "10 |  0.50 |    2.00 |   2.00 |          0.792 |     1.164\n",
      "10 |  1.00 |    0.50 |   0.50 |         20.821 |    20.412\n",
      "10 |  1.00 |    0.50 |   1.00 |         10.846 |    11.245\n",
      "10 |  1.00 |    0.50 |   2.00 |          5.655 |     4.851\n",
      "10 |  1.00 |    1.00 |   0.50 |         33.988 |    33.523\n",
      "10 |  1.00 |    1.00 |   1.00 |         12.929 |    12.434\n",
      "10 |  1.00 |    1.00 |   2.00 |          5.365 |     4.385\n",
      "10 |  1.00 |    2.00 |   0.50 |         60.972 |    60.736\n",
      "10 |  1.00 |    2.00 |   1.00 |         17.973 |    18.418\n",
      "10 |  1.00 |    2.00 |   2.00 |          4.725 |     4.273\n",
      "10 |  2.00 |    0.50 |   0.50 |         66.922 |    67.488\n",
      "10 |  2.00 |    0.50 |   1.00 |         39.105 |    36.767\n",
      "10 |  2.00 |    0.50 |   2.00 |         23.182 |    21.465\n",
      "10 |  2.00 |    1.00 |   0.50 |        111.802 |   110.683\n",
      "10 |  2.00 |    1.00 |   1.00 |         44.549 |    44.353\n",
      "10 |  2.00 |    1.00 |   2.00 |         23.792 |    22.655\n",
      "10 |  2.00 |    2.00 |   0.50 |        197.818 |   198.865\n",
      "10 |  2.00 |    2.00 |   1.00 |         67.340 |    73.345\n",
      "10 |  2.00 |    2.00 |   2.00 |         26.078 |    23.986\n",
      "20 |  0.50 |    0.50 |   0.50 |         16.690 |    15.771\n",
      "20 |  0.50 |    0.50 |   1.00 |          6.266 |     6.065\n",
      "20 |  0.50 |    0.50 |   2.00 |          1.692 |     1.601\n",
      "20 |  0.50 |    1.00 |   0.50 |         19.403 |    20.239\n",
      "20 |  0.50 |    1.00 |   1.00 |          6.418 |     6.547\n",
      "20 |  0.50 |    1.00 |   2.00 |          2.088 |     0.963\n",
      "20 |  0.50 |    2.00 |   0.50 |         43.558 |    44.119\n",
      "20 |  0.50 |    2.00 |   1.00 |          8.774 |     9.100\n",
      "20 |  0.50 |    2.00 |   2.00 |          2.031 |     1.938\n",
      "20 |  1.00 |    0.50 |   0.50 |         53.804 |    57.891\n",
      "20 |  1.00 |    0.50 |   1.00 |         30.042 |    31.745\n",
      "20 |  1.00 |    0.50 |   2.00 |         11.459 |    11.842\n",
      "20 |  1.00 |    1.00 |   0.50 |         74.204 |    72.922\n",
      "20 |  1.00 |    1.00 |   1.00 |         34.868 |    30.103\n",
      "20 |  1.00 |    1.00 |   2.00 |         10.668 |    11.598\n",
      "20 |  1.00 |    2.00 |   0.50 |        127.269 |   127.560\n",
      "20 |  1.00 |    2.00 |   1.00 |         42.864 |    42.579\n",
      "20 |  1.00 |    2.00 |   2.00 |         13.060 |    13.024\n",
      "20 |  2.00 |    0.50 |   0.50 |        145.077 |   145.025\n",
      "20 |  2.00 |    0.50 |   1.00 |        101.352 |   103.956\n",
      "20 |  2.00 |    0.50 |   2.00 |         55.262 |    61.201\n",
      "20 |  2.00 |    1.00 |   0.50 |        233.271 |   229.339\n",
      "20 |  2.00 |    1.00 |   1.00 |        108.185 |   114.871\n",
      "20 |  2.00 |    1.00 |   2.00 |         62.553 |    66.730\n",
      "20 |  2.00 |    2.00 |   0.50 |        408.379 |   406.845\n",
      "20 |  2.00 |    2.00 |   1.00 |        148.668 |   148.601\n",
      "20 |  2.00 |    2.00 |   2.00 |         63.468 |    63.604\n",
      "\n",
      "655.4 seconds\n"
     ]
    }
   ],
   "source": [
    "n = 1000\n",
    "num_runs = 1000\n",
    "num_exp = 81\n",
    "\n",
    "algs = [\n",
    "  [\"UCB1\", {}, \"blue\", \"-\", \"UCB1\"],\n",
    "  [\"BayesUCB\", {}, \"red\", \"-\", \"BayesUCB\"]]\n",
    "\n",
    "start = time.time()\n",
    "\n",
    "results = np.zeros((4 * num_exp, num_runs))\n",
    "result_ndx = 0\n",
    "\n",
    "print(\" K | sigma | mu0_max | sigma0 | Gaussian noise | Box noise\")\n",
    "for K in [5, 10, 20]:\n",
    "  for sigma in [0.5, 1.0, 2.0]:\n",
    "    for mu0_max in [0.5, 1.0, 2.0]:\n",
    "      mu0 = np.zeros(K)\n",
    "      mu0[0] = mu0_max\n",
    "      for sigma0 in [0.5, 1.0, 2.0]:\n",
    "        # Gaussian and box noise\n",
    "        envs = []\n",
    "        envs2 = []\n",
    "        for run in range(num_runs):\n",
    "          mu = mu0 + sigma0 * np.random.randn(K)\n",
    "          envs.append(GaussBandit(mu, sigma))\n",
    "          envs2.append(BoxBandit(mu, sigma))\n",
    "\n",
    "        # algorithm parameters\n",
    "        algs[0][1] = {\"sigma\": sigma}\n",
    "        algs[1][1] = {\"mu0\": mu0, \"sigma0\": sigma0 * np.ones(K), \"sigma\": sigma}\n",
    "\n",
    "        # Gaussian noise\n",
    "        for alg in algs:\n",
    "          alg_class = globals()[alg[0]]\n",
    "          regret, _ = evaluate(alg_class, alg[1], envs, n, printout=False)\n",
    "          regret = regret.sum(axis=0)\n",
    "          results[result_ndx, :] = regret\n",
    "          result_ndx += 1\n",
    "\n",
    "        # box noise\n",
    "        for alg in algs:\n",
    "          alg_class = globals()[alg[0]]\n",
    "          regret, _ = evaluate(alg_class, alg[1], envs2, n, printout=False)\n",
    "          regret = regret.sum(axis=0)\n",
    "          results[result_ndx, :] = regret\n",
    "          result_ndx += 1\n",
    "\n",
    "        print(\"%2d | %5.2f | %7.2f | %6.2f | %14.3f | %9.3f\" % (\n",
    "          K, sigma, mu0_max, sigma0,\n",
    "          results[result_ndx - 4].mean() - results[result_ndx - 3].mean(),\n",
    "          results[result_ndx - 2].mean() - results[result_ndx - 1].mean()))\n",
    "\n",
    "print()\n",
    "print(\"%.1f seconds\" % (time.time() - start))\n",
    "\n",
    "fname = \"Results/ucb1_vs_bayesucb.npy\"\n",
    "np.save(fname, results)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
