{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9011718a",
   "metadata": {},
   "outputs": [],
   "source": [
    "from env import Bandit\n",
    "from baseline import *\n",
    "from VarDE import *\n",
    "from plotexp import plot_experiment_results\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from tqdm import tqdm\n",
    "import random\n",
    "random.seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1abd8444",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Experiment setup\n",
    "means = [1] + [0]*4\n",
    "stds = [2] + [1]*4\n",
    "true_best = 0\n",
    "T = 1000\n",
    "env = Bandit(distribution='gaussian', means=means, stds=stds, seed=42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "474ed703",
   "metadata": {},
   "outputs": [],
   "source": [
    "agent = UGapE(env, T=T, a=8.0)\n",
    "print(agent.run())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b74b523b",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rcParams.update({\n",
    "    \"text.usetex\": True,\n",
    "    \"font.family\": \"serif\",\n",
    "\n",
    "    # Embed fonts as TrueType\n",
    "    \"pdf.fonttype\": 42,\n",
    "    \"ps.fonttype\": 42,\n",
    "\n",
    "    # sizes\n",
    "    \"axes.titlesize\": 22,\n",
    "    \"axes.labelsize\": 17,\n",
    "    \"legend.fontsize\": 17,\n",
    "    \"xtick.labelsize\": 17,\n",
    "    \"ytick.labelsize\": 17,\n",
    "\n",
    "    # Times-like fonts, matching paper\n",
    "    \"text.latex.preamble\": r\"\"\"\n",
    "        \\usepackage{newtxtext}\n",
    "        \\usepackage{newtxmath}\n",
    "    \"\"\",\n",
    "})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "586eda37",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plotting results\n",
    "n_arr = np.array(agent.n_history)\n",
    "plt.figure(figsize=(9, 4))\n",
    "plt.plot(n_arr[:, 0], label=r'Arm 0 ($\\mu=1, \\sigma=2$)')\n",
    "for k in range(1, env.K):\n",
    "    plt.plot(n_arr[:, k], label=f'Arm {k} ($\\mu=0, \\sigma=1$)')\n",
    "plt.title(r'\\textsc{UGapE} - Number of Pulls per Arm Over Time')\n",
    "plt.xlabel('Time Steps')\n",
    "plt.ylabel('Number of Pulls')\n",
    "plt.legend(ncol=1)\n",
    "plt.grid(True, linestyle='--', alpha=0.4)\n",
    "plt.tight_layout()\n",
    "plt.savefig('results/ugape_pulls_per_arm.pdf')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "varde",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
