{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "cellView": "form",
    "executionInfo": {
     "elapsed": 5,
     "status": "ok",
     "timestamp": 1622693515966,
     "user": {
      "displayName": "Branislav Kveton",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GhjPg1PJKD4ixGqLaZ4X3dpkI6w6dO_eTj8z63Y=s64",
      "userId": "09350298559467540088"
     },
     "user_tz": 0
    },
    "id": "dUrLMhAcZYDY",
    "jupyter": {
     "source_hidden": true
    },
    "outputId": "31bb0c27-489c-4d61-85e8-6002f5252ed5",
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "python 3.9.15\n",
      "matplotlib 3.5.1\n",
      "8 joblib CPUs\n"
     ]
    }
   ],
   "source": [
    "# Imports and defaults\n",
    "from bandit import *\n",
    "import joblib\n",
    "from joblib import Parallel, delayed\n",
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import scipy.stats\n",
    "import time\n",
    "\n",
    "mpl.style.use(\"classic\")\n",
    "mpl.rcParams[\"figure.figsize\"] = [5, 3]\n",
    "\n",
    "mpl.rcParams[\"axes.linewidth\"] = 0.75\n",
    "mpl.rcParams[\"errorbar.capsize\"] = 3\n",
    "mpl.rcParams[\"figure.facecolor\"] = \"w\"\n",
    "mpl.rcParams[\"grid.linewidth\"] = 0.75\n",
    "mpl.rcParams[\"lines.linewidth\"] = 0.75\n",
    "mpl.rcParams[\"patch.linewidth\"] = 0.75\n",
    "mpl.rcParams[\"xtick.major.size\"] = 3\n",
    "mpl.rcParams[\"ytick.major.size\"] = 3\n",
    "\n",
    "mpl.rcParams[\"pdf.fonttype\"] = 42\n",
    "mpl.rcParams[\"ps.fonttype\"] = 42\n",
    "mpl.rcParams[\"font.size\"] = 10\n",
    "mpl.rcParams[\"axes.titlesize\"] = \"medium\"\n",
    "mpl.rcParams[\"legend.fontsize\"] = \"medium\"\n",
    "\n",
    "import platform\n",
    "print(\"python %s\" % platform.python_version())\n",
    "print(\"matplotlib %s\" % mpl.__version__)\n",
    "print(\"%d joblib CPUs\" % joblib.cpu_count())\n",
    "\n",
    "def linestyle2dashes(style):\n",
    "  if style == \"--\":\n",
    "    return (3, 3)\n",
    "  elif style == \":\":\n",
    "    return (0.5, 2.5)\n",
    "  else:\n",
    "    return (None, None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "mu0_max | sigma0 |   UCB1 | Log bound | BayesUCB | Log bound | Sqrt bound\n",
      "   1.00 |   1.00 |   87.8 |     743.8 |     75.3 |     727.3 |     1345.5\n",
      "   1.00 |   0.90 |   91.2 |     812.5 |     77.7 |     794.0 |     1330.7\n",
      "   1.00 |   0.80 |   95.3 |     847.1 |     79.2 |     825.6 |     1312.5\n",
      "   1.00 |   0.70 |   99.2 |     920.9 |     79.8 |     895.5 |     1289.5\n",
      "   1.00 |   0.60 |  104.0 |     974.5 |     79.2 |     943.4 |     1259.5\n",
      "   1.00 |   0.50 |  107.5 |    1031.0 |     74.2 |     990.3 |     1218.9\n",
      "   1.00 |   0.40 |  110.1 |     957.2 |     59.4 |     897.8 |     1160.8\n",
      "   1.00 |   0.30 |  112.2 |     787.7 |     30.5 |     687.4 |     1071.6\n",
      "   1.00 |   0.20 |  109.9 |     560.1 |      1.3 |     337.1 |      918.9\n",
      "   1.00 |   0.10 |  106.7 |     508.5 |      0.0 |       3.7 |      615.8\n",
      "   1.00 |   0.05 |  105.6 |     499.4 |      0.0 |       0.0 |      351.0\n",
      "   1.50 |   1.00 |   83.4 |     662.5 |     69.0 |     644.4 |     1345.5\n",
      "   2.00 |   1.00 |   77.3 |     520.2 |     59.3 |     499.5 |     1345.5\n",
      "   2.50 |   1.00 |   70.5 |     411.5 |     47.7 |     387.6 |     1345.5\n",
      "   3.00 |   1.00 |   63.7 |     302.7 |     35.6 |     275.1 |     1345.5\n",
      "   3.50 |   1.00 |   59.0 |     226.8 |     24.1 |     195.1 |     1345.5\n",
      "   4.00 |   1.00 |   55.7 |     167.7 |     14.1 |     131.7 |     1345.5\n",
      "   4.50 |   1.00 |   53.9 |     132.6 |      7.0 |      92.2 |     1345.5\n",
      "   5.00 |   1.00 |   53.9 |     113.1 |      3.0 |      68.6 |     1345.5\n"
     ]
    }
   ],
   "source": [
    "def mab_bound(mu0, sigma0, sigma, n, envs, kind=\"log\"):\n",
    "  K = mu0.size\n",
    "  gap_cap = np.log(n) / n\n",
    "\n",
    "  bound = np.zeros(len(envs))\n",
    "  if \"log\" in kind:\n",
    "    for env_ndx, env in enumerate(envs):\n",
    "      best_arm = np.argmax(env.mu)\n",
    "      for i in range(K):\n",
    "        if i != best_arm:\n",
    "          gap = np.maximum(env.mu[best_arm] - env.mu[i], gap_cap)\n",
    "          if \"ucb1\" in kind:\n",
    "            bound[env_ndx] += 8 * np.square(sigma) * np.log(1 / delta) / gap\n",
    "          else:\n",
    "            bound[env_ndx] += max(8 * np.square(sigma) * np.log(1 / delta) / gap -\n",
    "              np.square(sigma) * gap / np.square(sigma0), 0)\n",
    "  elif \"sqrt\" in kind:\n",
    "    c = np.square(sigma) * K / np.square(sigma0)\n",
    "    bound = 4 * np.sqrt(2 * np.square(sigma) * K * np.log(n)) * (np.sqrt(n + c) - np.sqrt(c)) * np.ones(len(envs))\n",
    "\n",
    "  return bound\n",
    "\n",
    "\n",
    "K = 10\n",
    "n = 1000\n",
    "num_runs = 10000\n",
    "\n",
    "mu0 = np.zeros(K)\n",
    "mu0[0] = 1.0\n",
    "sigma0 = 1.0\n",
    "sigma = 1.0\n",
    "delta = 1 / n\n",
    "\n",
    "mu0_max_list = [1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0]\n",
    "sigma0_list = [1.0, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1, 0.05]\n",
    "\n",
    "algs = [\n",
    "  [\"BayesUCB\", {}],\n",
    "  [\"UCB1\", {}],\n",
    "  [\"BayesUCB_log_bound\", {}],\n",
    "  [\"UCB1_log_bound\", {}],\n",
    "  [\"Sqrt_bound\", {}]]\n",
    "\n",
    "print(\"mu0_max | sigma0 |   UCB1 | Log bound | BayesUCB | Log bound | Sqrt bound\")\n",
    "for mu0_max in mu0_max_list:\n",
    "  mu0[0] = mu0_max\n",
    "  for sigma0 in sigma0_list:\n",
    "    if (mu0_max == 1.0) or (sigma0 == 1.0):\n",
    "      envs = []\n",
    "      for run in range(num_runs):\n",
    "        mu = mu0 + sigma0 * np.random.randn(K)\n",
    "        envs.append(GaussBandit(mu, sigma))\n",
    "\n",
    "      # algorithm parameters\n",
    "      algs[0][1] = {\"mu0\": mu0, \"sigma0\": sigma0 * np.ones(K), \"sigma\": sigma}\n",
    "      algs[1][1] = {\"sigma\": sigma}\n",
    "\n",
    "      results = {}\n",
    "      for alg in algs:\n",
    "        if alg[0] == \"UCB1_log_bound\":\n",
    "          regret = mab_bound(mu0, sigma0, sigma, n, envs, \"ucb1-log\")\n",
    "        elif alg[0] == \"BayesUCB_log_bound\":\n",
    "          regret = mab_bound(mu0, sigma0, sigma, n, envs, \"log\")\n",
    "        elif alg[0] == \"Sqrt_bound\":\n",
    "          regret = mab_bound(mu0, sigma0, sigma, n, envs, \"sqrt\")\n",
    "        else:\n",
    "          alg_class = globals()[alg[0]]\n",
    "          regret, _ = evaluate(alg_class, alg[1], envs, n, printout=False)\n",
    "          regret = regret.sum(axis=0)\n",
    "        results[alg[0]] = regret\n",
    "\n",
    "      print(\"%7.2f | %6.2f | %6.1f | %9.1f | %8.1f | %9.1f | %10.1f\" % (\n",
    "        mu0_max, sigma0,\n",
    "        results[\"UCB1\"].mean(),\n",
    "        results[\"UCB1_log_bound\"].mean(),\n",
    "        results[\"BayesUCB\"].mean(),\n",
    "        results[\"BayesUCB_log_bound\"].mean(),\n",
    "        results[\"Sqrt_bound\"].mean()))\n",
    "\n",
    "      fname = \"Results/mab_mu0_max=%.3f_sigma0=%.3f.npy\" % (mu0_max, sigma0)\n",
    "      np.save(fname, results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "theta0_max | sigma0 | BayesUCB | Log bound | Sqrt bound\n",
      "      1.00 |   1.00 |     85.4 |   18246.6 |     1918.2\n",
      "      1.00 |   0.90 |     84.3 |   16460.9 |     1823.3\n",
      "      1.00 |   0.80 |     82.7 |   14267.6 |     1727.5\n",
      "      1.00 |   0.70 |     79.1 |   11687.9 |     1629.8\n",
      "      1.00 |   0.60 |     75.7 |    9478.1 |     1528.5\n",
      "      1.00 |   0.50 |     69.7 |    7921.9 |     1420.3\n",
      "      1.00 |   0.40 |     59.6 |    5168.3 |     1299.2\n",
      "      1.00 |   0.30 |     43.3 |    2628.5 |     1152.8\n",
      "      1.00 |   0.20 |     19.3 |    1107.1 |      952.4\n",
      "      1.00 |   0.10 |      0.4 |     271.8 |      620.5\n",
      "      1.00 |   0.05 |      0.0 |      77.1 |      351.4\n",
      "      2.00 |   1.00 |     58.0 |    8643.2 |     1918.2\n",
      "      3.00 |   1.00 |     33.0 |    3294.1 |     1918.2\n",
      "      4.00 |   1.00 |     12.4 |    1618.1 |     1918.2\n",
      "      5.00 |   1.00 |      3.0 |    1227.7 |     1918.2\n",
      "      6.00 |   1.00 |      0.5 |    1055.2 |     1918.2\n",
      "      7.00 |   1.00 |      0.1 |     954.9 |     1918.2\n",
      "      8.00 |   1.00 |      0.0 |     871.4 |     1918.2\n",
      "      9.00 |   1.00 |      0.0 |     805.7 |     1918.2\n",
      "     10.00 |   1.00 |      0.0 |     749.9 |     1918.2\n"
     ]
    }
   ],
   "source": [
    "def linear_bound(theta0, sigma0, sigma, n, envs, kind=\"log\"):\n",
    "  d = theta0.size\n",
    "  gap_cap = np.log(n) / n\n",
    "\n",
    "  c = np.square(sigma0) / np.square(sigma)\n",
    "  bound = np.zeros(len(envs))\n",
    "  if \"log\" in kind:\n",
    "    for env_ndx, env in enumerate(envs):\n",
    "      sorted_mu = np.sort(env.mu)\n",
    "      minimum_gap = np.maximum(sorted_mu[-1] - sorted_mu[-2], gap_cap)\n",
    "      bound[env_ndx] = 1 / minimum_gap\n",
    "    bound *= 8 * np.square(sigma0) * d * \\\n",
    "      np.log(1 + c * n / d) * np.log(1 / delta) / np.log(1 + c)\n",
    "  elif \"sqrt\" in kind:\n",
    "    bound = 2 * np.sqrt(2 * np.square(sigma0) * d * n *\n",
    "      np.log(1 + c * n / d) * np.log(1 / delta) / np.log(1 + c)) * np.ones(len(envs))\n",
    "\n",
    "  return bound\n",
    "\n",
    "\n",
    "d = 10\n",
    "K = 30\n",
    "n = 1000\n",
    "num_runs = 10000\n",
    "\n",
    "theta0 = - np.ones(d)\n",
    "theta0[0] = 1.0\n",
    "Sigma0 = np.eye(d)\n",
    "sigma = 1.0\n",
    "delta = 1 / n\n",
    "\n",
    "theta0_max_list = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]\n",
    "sigma0_list = [1.0, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1, 0.05]\n",
    "\n",
    "algs = [\n",
    "  [\"LinBayesUCB\", {}],\n",
    "  [\"LinBayesUCB_log_bound\", {}],\n",
    "  [\"Sqrt_bound\", {}]]\n",
    "\n",
    "print(\"theta0_max | sigma0 | BayesUCB | Log bound | Sqrt bound\")\n",
    "for theta0_max in theta0_max_list:\n",
    "  theta0[0] = theta0_max\n",
    "  for sigma0 in sigma0_list:\n",
    "    Sigma0 = sigma0 * np.eye(d)\n",
    "\n",
    "    if (theta0_max == 1.0) or (sigma0 == 1.0):\n",
    "      envs = []\n",
    "      for run in range(num_runs):\n",
    "        # sample model parameter\n",
    "        theta = np.random.multivariate_normal(theta0, Sigma0)\n",
    "        # sample arm features from the positive orthant of a unit sphere\n",
    "        X = np.abs(np.random.randn(K, d))\n",
    "        X /= np.linalg.norm(X, axis=-1)[:, np.newaxis]\n",
    "        X[: d, :] = np.eye(d)  # canonical basis\n",
    "        # initialize bandit environment\n",
    "        envs.append(LinBandit(X, theta, sigma))\n",
    "\n",
    "      # algorithm parameters\n",
    "      algs[0][1] = {\"theta0\": theta0, \"Sigma0\": Sigma0, \"sigma\": sigma}\n",
    "\n",
    "      results = {}\n",
    "      for alg in algs:\n",
    "        if alg[0] == \"LinBayesUCB_log_bound\":\n",
    "          regret = linear_bound(theta0, sigma0, sigma, n, envs, \"log\")\n",
    "        elif alg[0] == \"Sqrt_bound\":\n",
    "          regret = linear_bound(theta0, sigma0, sigma, n, envs, \"sqrt\")\n",
    "        else:\n",
    "          alg_class = globals()[alg[0]]\n",
    "          regret, _ = evaluate(alg_class, alg[1], envs, n, printout=False)\n",
    "          regret = regret.sum(axis=0)\n",
    "        results[alg[0]] = regret\n",
    "\n",
    "      print(\"%10.2f | %6.2f | %8.1f | %9.1f | %10.1f\" % (\n",
    "        theta0_max, sigma0,\n",
    "        results[\"LinBayesUCB\"].mean(),\n",
    "        results[\"LinBayesUCB_log_bound\"].mean(),\n",
    "        results[\"Sqrt_bound\"].mean()))\n",
    "\n",
    "      fname = \"Results/linear_theta0_max=%.3f_sigma0=%.3f.npy\" % (theta0_max, sigma0)\n",
    "      np.save(fname, results)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
