{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "fc614513",
   "metadata": {},
   "source": [
    "# Fixed Confidence Best Arm Identification in Bayesian Settings\n",
    "\n",
    "This code is the official implementation of second experiment (Table 2) of 'Fixed Confidence Best Arm Identification in Bayesian Settings.'\n",
    "\n",
    "To proceed, simply press 'shift+enter' for each cell. \n",
    "\n",
    "### Requirements\n",
    "The following cell includes all the required packages for this code."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "1e84c8de-84c7-43b2-91d0-ccfea06e73a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from scipy.stats import norm\n",
    "from scipy.integrate import quad\n",
    "from scipy.optimize import minimize, minimize_scalar\n",
    "import random\n",
    "import time"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e1c87975",
   "metadata": {},
   "source": [
    "## Prior Distribution class\n",
    "\n",
    "This class is for implementing prior distribution $H$ and computing prior-dependent constants such as $L(H)$ efficiently. \n",
    "*prior\\_distribution* takes three main parameters:\n",
    "- *prior\\_mean*: corresponds to $(m_i)_{i=1}^k$, the mean vector of the prior distribution.\n",
    "- *prior\\_std*: corrsponds to $(\\xi_i)_{i=1}^k$, the vector of standard deviations.\n",
    "- *instance\\_std*: corresponds to $(\\sigma_i)_{i=1}^k$, the standard deviations of the reward distributions. \n",
    "\n",
    "We will mainly use two methods from this distribution:\n",
    "- sample\\_instance(): Sample $(\\mu_i)_{i=1}^k$ from the prior distribution $H$. \n",
    "- get_Delta_0($\\delta$): Compute $\\Delta_0$, which is defined in Algorithm 1. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "df8fb3fd-5d63-43f8-8bf0-03f28ea32cb2",
   "metadata": {},
   "outputs": [],
   "source": [
    "class prior_distribution:\n",
    "  def __init__(self, prior_mean, prior_std, instance_std):\n",
    "    self.K=np.size(prior_mean)\n",
    "    self.prior_mean = prior_mean\n",
    "    self.prior_std = prior_std\n",
    "    self.prior_cov_mat = np.diag(self.prior_std**2)\n",
    "    self.instance_std = instance_std\n",
    "    self._get_Lij()\n",
    "\n",
    "  def sample_instance(self):\n",
    "    return np.random.multivariate_normal(self.prior_mean, self.prior_cov_mat)\n",
    "\n",
    "  def _get_Lij(self):\n",
    "    self.Lij_whole = np.zeros((self.K, self.K)) #whole list of Lij\n",
    "    for i in range(self.K):\n",
    "      for j in range(self.K):\n",
    "        if i==j:\n",
    "          self.Lij_whole[i][j]=0\n",
    "        else:\n",
    "          integrand= lambda x: self._integrand(x,i,j)\n",
    "          self.Lij_whole[i][j]=quad(integrand, -np.inf, np.inf)[0]\n",
    "    self.Lij=np.sum(self.Lij_whole)\n",
    "\n",
    "  def _integrand(self,x,i,j):\n",
    "    product=1\n",
    "    for s in range(self.K):\n",
    "      if (s==i or s==j):\n",
    "        product=product*norm.pdf(self._standardize(x,s))\n",
    "      else:\n",
    "        product=product*norm.cdf(self._standardize(x,s))\n",
    "    return product\n",
    "\n",
    "  def _standardize(self,x,i):\n",
    "    return (x-self.prior_mean[i])/self.prior_std[i]\n",
    "\n",
    "  def get_Delta_0(self, delta):\n",
    "    return delta/self.Lij\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3e59fab7",
   "metadata": {},
   "source": [
    "# Best-Arm-Identification Algorithm\n",
    "## 0) Parent Class\n",
    "Each BAI algorithm will have the following four methonds as its main method:\n",
    "- \\_\\_init\\_\\_: for initialization\n",
    "- sample(): A sampling rule $(A_t)_t$, which determines the arm to draw at round $t$ based on the previous history.\n",
    "- stopping\\_criterion(): when to stop the sampling\n",
    "- update(): Update the information after each sampling.\n",
    "- recommendation(): A decision rule $J$, which determines the arm the forecaster recommends based on his sampling history\n",
    "\n",
    "All algorithms take only two inputs:\n",
    "- prior_dist: corresponds to $H$ in our main paper. It will be *prior\\_distribution* instance.\n",
    "- delta: the confidence level $\\delta$ in our main paper. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "f70c76f4-a927-434c-9fae-7419ff69ae1e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from ast import Pass\n",
    "class BAI_Algorithm:                    #Parent class for all algorithms\n",
    "  def __init__(self, prior_dist, delta):\n",
    "    Pass\n",
    "\n",
    "  def sample(self):\n",
    "    raise NotImplementedError\n",
    "\n",
    "  def stopping_criterion(self): # False for continue, True for stop\n",
    "    raise NotImplementedError\n",
    "\n",
    "  def update(self):\n",
    "    raise NotImplementedError\n",
    "\n",
    "  def recommendation(self):\n",
    "    raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a150b521",
   "metadata": {},
   "source": [
    "## 1) Our Algorithm (Algorithm 1, Successive Elimination)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "b5ad9471-b281-46dd-8e51-97b47e7277e9",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Elimination(BAI_Algorithm):\n",
    "  def __init__(self, prior_dist, delta):\n",
    "    self.prior_dist = prior_dist\n",
    "    self.K= prior_dist.K\n",
    "    self.delta=delta\n",
    "    self.Delta0=prior_dist.get_Delta_0(delta)\n",
    "    self.survived_arms = list(range(self.K))\n",
    "    self.need_elimination = False\n",
    "    self.Delta_safe = np.inf\n",
    "\n",
    "    #self.hist=[ [] for _ in range(self.K) ] # list of empty lists\n",
    "    self.n_list = np.zeros(self.K)\n",
    "    self.m_hat = np.zeros(self.K)\n",
    "\n",
    "\n",
    "  def sample(self):\n",
    "#    n_list = [len(self.hist[i]) for i in self.survived_arms]\n",
    "    chosen = np.argmin(self.n_list[self.survived_arms])\n",
    "    if chosen==len(self.survived_arms)-1:\n",
    "      self.need_elimination=True\n",
    "    return self.survived_arms[chosen]\n",
    "\n",
    "  def recommendation(self):\n",
    "    if len(self.survived_arms)==1:\n",
    "      return self.survived_arms[0]\n",
    "    elif self.Delta_safe<self.Delta0:\n",
    "      return random.choice(self.survived_arms)\n",
    "    else:\n",
    "      print('Error: recommendation should be done only after stopping criterion')\n",
    "      return -1\n",
    "\n",
    "  def stopping_criterion(self):\n",
    "    if self.need_elimination:\n",
    "      self.elim()\n",
    "\n",
    "    if len(self.survived_arms)==1:                # Note that these two parameters, self,survived_arms and self.Delta_safe only changes when elim() has been called.\n",
    "      return True\n",
    "    elif self.Delta_safe<self.Delta0:\n",
    "      return True\n",
    "    else:\n",
    "      return False\n",
    "\n",
    "\n",
    "  def elim(self):\n",
    "    ucbs=np.zeros(len(self.survived_arms))\n",
    "    lcbs=np.zeros(len(self.survived_arms))\n",
    "    next_survived=[]\n",
    "\n",
    "    for i in range(len(self.survived_arms)):                      # Compute UCB and LCB\n",
    "      ucbs[i], lcbs[i] = self.conf_bounds(self.m_hat,self.survived_arms[i])\n",
    "    lcbmax=np.max(lcbs)                               # Maximum LCB\n",
    "    ucbmax=np.max(ucbs)                               # Maximum UCB\n",
    "    for i in range(len(self.survived_arms)):                       # Include all arms which satisfies UCB>LCBMAX to a new basket\n",
    "      if ucbs[i]>lcbmax:\n",
    "        next_survived.append(self.survived_arms[i])\n",
    "\n",
    "    self.Delta_safe=ucbmax-lcbmax\n",
    "    self.survived_arms=next_survived\n",
    "    self.need_elimination = False\n",
    "\n",
    "  def conf_bounds(self,m_hat,i):\n",
    "    m_hat=self.m_hat[i]\n",
    "    n=self.n_list[i]\n",
    "    width=np.sqrt(\n",
    "        2*self.prior_dist.instance_std[i]**2\n",
    "        *np.log(12*self.K*n**2/self.delta**2/np.pi**2)/n\n",
    "    )\n",
    "    return m_hat+width, m_hat-width\n",
    "\n",
    "  def update(self, a, reward):\n",
    "    self.m_hat[a]=(self.m_hat[a]*self.n_list[a]+reward)/(self.n_list[a]+1)\n",
    "    self.n_list[a]=self.n_list[a]+1\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "322dcb1c",
   "metadata": {},
   "source": [
    "## 2) Modified Algorithm (Algorithm 2, NoElim)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "537d7170-81e0-4e59-9021-fbdda73750eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "class NoElimination(BAI_Algorithm):\n",
    "  def __init__(self, prior_dist, delta):\n",
    "    self.prior_dist = prior_dist\n",
    "    self.K= prior_dist.K\n",
    "    self.delta=delta\n",
    "    self.Delta0=prior_dist.get_Delta_0(delta)\n",
    "    self.survived_arms = list(range(self.K))\n",
    "    self.true_survived = list(range(self.K))\n",
    "    self.need_elimination = False\n",
    "    self.Delta_safe = np.inf\n",
    "\n",
    "    #self.hist=[ [] for _ in range(self.K) ] # list of empty lists\n",
    "    self.n_list = np.zeros(self.K)\n",
    "    self.m_hat = np.zeros(self.K)\n",
    "\n",
    "\n",
    "  def sample(self):\n",
    "#    n_list = [len(self.hist[i]) for i in self.survived_arms]\n",
    "    chosen = np.argmin(self.n_list[self.survived_arms])\n",
    "    if chosen==len(self.survived_arms)-1:\n",
    "      self.need_elimination=True\n",
    "    return self.survived_arms[chosen]\n",
    "\n",
    "  def recommendation(self):\n",
    "      return np.argmax(self.m_hat)\n",
    "\n",
    "  def stopping_criterion(self):\n",
    "    if self.need_elimination:\n",
    "      self.elim()\n",
    "\n",
    "    if len(self.true_survived)==1:                # Note that these two parameters, self,survived_arms and self.Delta_safe only changes when elim() has been called.\n",
    "      return True\n",
    "    elif self.Delta_safe<self.Delta0:\n",
    "      return True\n",
    "    else:\n",
    "      return False\n",
    "\n",
    "\n",
    "  def elim(self):\n",
    "    ucbs=np.zeros(len(self.survived_arms))\n",
    "    lcbs=np.zeros(len(self.survived_arms))\n",
    "    next_true_survived=[]\n",
    "\n",
    "    for i in range(len(self.survived_arms)):                      # Compute UCB and LCB\n",
    "      ucbs[i], lcbs[i] = self.conf_bounds(self.m_hat,self.survived_arms[i])\n",
    "    lcbmax=np.max(lcbs)                               # Maximum LCB\n",
    "    ucbmax=np.max(ucbs)                               # Maximum UCB\n",
    "    for i in range(len(self.survived_arms)):                       # Include all arms which satisfies UCB>LCBMAX to a new basket\n",
    "      if ucbs[i]>lcbmax:\n",
    "        next_true_survived.append(self.survived_arms[i])\n",
    "\n",
    "    self.Delta_safe=ucbmax-lcbmax\n",
    "    self.true_survived=next_true_survived\n",
    "    self.need_elimination = False\n",
    "\n",
    "  def conf_bounds(self,m_hat,i):\n",
    "    m_hat=self.m_hat[i]\n",
    "    n=self.n_list[i]\n",
    "    width=np.sqrt(\n",
    "        2*self.prior_dist.instance_std[i]**2\n",
    "        *np.log(12*self.K*n**2/self.delta**2/np.pi**2)/n\n",
    "    )\n",
    "    return m_hat+width, m_hat-width\n",
    "\n",
    "  def update(self, a, reward):\n",
    "    self.m_hat[a]=(self.m_hat[a]*self.n_list[a]+reward)/(self.n_list[a]+1)\n",
    "    self.n_list[a]=self.n_list[a]+1\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c3b35b89",
   "metadata": {},
   "source": [
    "# General Experiment Design\n",
    "A class designed to effectively manage the experimental environment. This is mainly for guaranteeing same instance samples $(\\mu_i)_{i=1}^k$ for all algorithms. It takes the following major inputs:\n",
    "- prior\\_dist: corresponds to $H$ in our main paper. It will be *prior\\_distribution* instance.\n",
    "- $\\delta$: Error probability. \n",
    "- algolist: list of algorithms for this experiment. \n",
    "\n",
    "Methods:\n",
    "- single\\_experiment(mean, algoname): run single experiment for *algoname* with given mean $(\\mu_i)_{i=1}^k$. Returns stopping time and whether the prediction is correct or wrong. \n",
    "- monte\\_carlo\\_experiment(num\\_of\\_exp): repeat experiment *num\\_of\\_exp* times, but for each repetition all algorithms share same $(\\mu_i)_{i=1}^k$. Output is expected stopping time and success rate of each algorithm. \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "cd4e562f-ace5-4a05-bbdd-c04618d5b2ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Experiment:\n",
    "  def __init__(self, prior_dist, delta, algolist):\n",
    "    self.delta=delta\n",
    "    self.algolist=algolist                                  # Strings of algorithm names, such as ['Elim', 'TTTS']\n",
    "    self.prior_dist=prior_dist\n",
    "    self.K=self.prior_dist.K\n",
    "    self.stopping_time_hist = []\n",
    "    self.success_hist = []\n",
    "    self.time_spent=[]\n",
    "\n",
    "\n",
    "  def monte_carlo_experiment(self, num_of_exp):\n",
    "    exp_stopping_time=np.zeros(len(self.algolist))\n",
    "    success_rate=np.zeros(len(self.algolist))\n",
    "    for i in range(num_of_exp):\n",
    "      mean=self.prior_dist.sample_instance()\n",
    "      for alg in range(len(self.algolist)):\n",
    "        algoname=self.algolist[alg]\n",
    "        start_time=time.time()\n",
    "        sample_stopping_time, sample_success = self.single_experiment(mean, algoname)\n",
    "        elapsed=time.time()-start_time\n",
    "        self.stopping_time_hist.append(sample_stopping_time)\n",
    "        self.success_hist.append(sample_success)\n",
    "        self.time_spent.append(elapsed)\n",
    "        exp_stopping_time[alg] =exp_stopping_time[alg]+sample_stopping_time/num_of_exp\n",
    "        success_rate[alg]      =success_rate[alg]+sample_success/num_of_exp\n",
    "    return exp_stopping_time, success_rate\n",
    "\n",
    "\n",
    "\n",
    "  def single_experiment(self, mean, algoname):\n",
    "    print(\"Starting an experiment - mean \")\n",
    "    print(mean)\n",
    "    alg=BAI_Algorithm(self.prior_dist,self.delta)\n",
    "    if algoname=='Elim':\n",
    "      print(\"Algo: Elim\")\n",
    "      alg = Elimination(self.prior_dist, self.delta)\n",
    "    elif algoname=='TTTS':\n",
    "      print(\"Algo: TTTS\")\n",
    "      alg = TTTS(self.prior_dist, self.delta)\n",
    "    elif algoname=='TTUCB':\n",
    "      print(\"Algo: TTUCB\")\n",
    "      alg = TTUCB(self.prior_dist, self.delta)\n",
    "    elif algoname=='NoElim':\n",
    "      print(\"Algo: NoElim\")\n",
    "      alg = NoElimination(self.prior_dist, self.delta)\n",
    "    answer=np.argmax(mean)\n",
    "    std=self.prior_dist.instance_std\n",
    "    stopping_time=0\n",
    "    while not alg.stopping_criterion():\n",
    "      a=alg.sample()\n",
    "      reward = np.random.normal(mean[a],std[a])\n",
    "      alg.update(a, reward)\n",
    "      stopping_time=stopping_time+1\n",
    "    print('Final stopping time: %d'%stopping_time)\n",
    "    if answer==alg.recommendation():\n",
    "      return stopping_time,1\n",
    "    else:\n",
    "      return stopping_time,0\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bf9d6bac",
   "metadata": {},
   "source": [
    "# Start of the Main Code\n",
    "Parameters\n",
    "- $K$: number of arms.\n",
    "- *prior\\_mean*: corresponds to $(m_i)_{i=1}^k$, the mean vector of the prior distribution.\n",
    "- *prior\\_std*: corrsponds to $(\\xi_i)_{i=1}^k$, the vector of standard deviations.\n",
    "- *instance\\_std*: corresponds to $(\\sigma_i)_{i=1}^k$, the standard deviations of the reward distributions. \n",
    "- $\\delta$: error probability."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "d8e9c4a3-21a4-4c86-aa65-bbfd9154407b",
   "metadata": {},
   "outputs": [],
   "source": [
    "K=10\n",
    "prior_mean=np.zeros(K)\n",
    "prior_std=np.ones(K)\n",
    "instance_std=np.ones(K)\n",
    "delta=0.01"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a719dec9",
   "metadata": {},
   "source": [
    "Create prior distribution $H$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a29c2f9d",
   "metadata": {},
   "outputs": [],
   "source": [
    "prior_dist=prior_distribution(prior_mean, prior_std,instance_std)\n",
    "prior_dist.get_Delta_0(delta)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "90d97c0f",
   "metadata": {},
   "source": [
    "Experiment instance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "94fdad53-7721-45f6-81c7-474ec07fb01d",
   "metadata": {},
   "outputs": [],
   "source": [
    "myExp=Experiment(prior_dist, delta, ['NoElim','Elim'])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a3520ec0",
   "metadata": {},
   "source": [
    "Run the experiment 1000 times.\n",
    "- res: average stopping time variable\n",
    "\n",
    "- ult: success rate variable- ult: success rate variable"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ac465cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "res, ult= myExp.monte_carlo_experiment(1000)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c20b0357",
   "metadata": {},
   "source": [
    "### Stopping time variables\n",
    "\n",
    "- ElimStop: Stopping time history for our algorithm\n",
    "- NoElimStop: Stopping time history for NoElim\n",
    "\n",
    "### Computation time variables\n",
    "\n",
    "- ElimCompTime: Computation time history for our algorithm\n",
    "- NoElimComp: Computation time history for NoElim\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4cc91ff5",
   "metadata": {},
   "outputs": [],
   "source": [
    "lnow=1000\n",
    "ElimCompTime = [myExp.time_spent[2*i+1] for i in range(lnow)]\n",
    "NoElimCompTime = [myExp.time_spent[2*i] for i in range(lnow)]\n",
    "print(np.mean(NoElimCompTime))\n",
    "print(np.mean(ElimCompTime))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "78a87d58",
   "metadata": {},
   "outputs": [],
   "source": [
    "ElimStop = [myExp.stopping_time_hist[2*i+1] for i in range(lnow)]\n",
    "NoElimStop = [myExp.stopping_time_hist[2*i] for i in range(lnow)]\n",
    "np.max(ElimStop)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12bb84bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(np.max(NoElimStop))\n",
    "print(np.max(ElimStop))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
