{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Fixed Confidence Best Arm Identification in Bayesian Settings\n",
    "\n",
    "This code is the official implementation of 'Fixed Confidence Best Arm Identification in Bayesian Settings.'\n",
    "\n",
    "To proceed, simply press 'shift+enter' for each cell. \n",
    "\n",
    "### Requirements\n",
    "The following cell includes all the required packages for this code."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "executionInfo": {
     "elapsed": 126,
     "status": "ok",
     "timestamp": 1702800713658,
     "user": {
      "displayName": "Kyoungseok Jang",
      "userId": "07027991917961288603"
     },
     "user_tz": 300
    },
    "id": "eaxTLEc5WWBD"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from scipy.stats import norm\n",
    "from scipy.integrate import quad\n",
    "from scipy.optimize import minimize, minimize_scalar\n",
    "import random\n",
    "import time"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prior Distribution class\n",
    "\n",
    "This class is for implementing prior distribution $H$ and computing prior-dependent constants such as $L(H)$ efficiently. \n",
    "*prior\\_distribution* takes three main parameters:\n",
    "- *prior\\_mean*: corresponds to $(m_i)_{i=1}^k$, the mean vector of the prior distribution.\n",
    "- *prior\\_std*: corrsponds to $(\\xi_i)_{i=1}^k$, the vector of standard deviations.\n",
    "- *instance\\_std*: corresponds to $(\\sigma_i)_{i=1}^k$, the standard deviations of the reward distributions. \n",
    "\n",
    "We will mainly use two methods from this distribution:\n",
    "- sample\\_instance(): Sample $(\\mu_i)_{i=1}^k$ from the prior distribution $H$. \n",
    "- get_Delta_0($\\delta$): Compute $\\Delta_0$, which is defined in Algorithm 1. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "executionInfo": {
     "elapsed": 196,
     "status": "ok",
     "timestamp": 1702800715193,
     "user": {
      "displayName": "Kyoungseok Jang",
      "userId": "07027991917961288603"
     },
     "user_tz": 300
    },
    "id": "Xedzt0j2dqK5"
   },
   "outputs": [],
   "source": [
    "class prior_distribution:\n",
    "  def __init__(self, prior_mean, prior_std, instance_std):\n",
    "    self.K=np.size(prior_mean)\n",
    "    self.prior_mean = prior_mean\n",
    "    self.prior_std = prior_std\n",
    "    self.prior_cov_mat = np.diag(self.prior_std**2)\n",
    "    self.instance_std = instance_std\n",
    "    self._get_Lij()\n",
    "\n",
    "  def sample_instance(self):\n",
    "    return np.random.multivariate_normal(self.prior_mean, self.prior_cov_mat)\n",
    "\n",
    "  def _get_Lij(self):\n",
    "    self.Lij_whole = np.zeros((self.K, self.K)) #whole list of Lij\n",
    "    for i in range(self.K):\n",
    "      for j in range(self.K):\n",
    "        if i==j:\n",
    "          self.Lij_whole[i][j]=0\n",
    "        else:\n",
    "          integrand= lambda x: self._integrand(x,i,j)\n",
    "          self.Lij_whole[i][j]=quad(integrand, -np.inf, np.inf)[0]\n",
    "    self.Lij=np.sum(self.Lij_whole)\n",
    "\n",
    "  def _integrand(self,x,i,j):\n",
    "    product=1\n",
    "    for s in range(self.K):\n",
    "      if (s==i or s==j):\n",
    "        product=product*norm.pdf(self._standardize(x,s))\n",
    "      else:\n",
    "        product=product*norm.cdf(self._standardize(x,s))\n",
    "    return product\n",
    "\n",
    "  def _standardize(self,x,i):\n",
    "    return (x-self.prior_mean[i])/self.prior_std[i]\n",
    "\n",
    "  def get_Delta_0(self, delta):\n",
    "    return delta/(4*self.Lij)\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "XnOWbW4-2H6H"
   },
   "source": [
    "# Best-Arm-Identification Algorithm\n",
    "## 0) Parent Class\n",
    "Each BAI algorithm will have the following four methonds as its main method:\n",
    "- \\_\\_init\\_\\_: for initialization\n",
    "- sample(): A sampling rule $(A_t)_t$, which determines the arm to draw at round $t$ based on the previous history.\n",
    "- stopping\\_criterion(): when to stop the sampling\n",
    "- update(): Update the information after each sampling.\n",
    "- recommendation(): A decision rule $J$, which determines the arm the forecaster recommends based on his sampling history\n",
    "\n",
    "All algorithms take only two inputs:\n",
    "- prior_dist: corresponds to $H$ in our main paper. It will be *prior\\_distribution* instance.\n",
    "- delta: the confidence level $\\delta$ in our main paper. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "executionInfo": {
     "elapsed": 153,
     "status": "ok",
     "timestamp": 1702800717566,
     "user": {
      "displayName": "Kyoungseok Jang",
      "userId": "07027991917961288603"
     },
     "user_tz": 300
    },
    "id": "P7X4dR0Qd3YL"
   },
   "outputs": [],
   "source": [
    "from ast import Pass\n",
    "class BAI_Algorithm:                    #Parent class for all algorithms\n",
    "  def __init__(self, prior_dist, delta):\n",
    "    Pass\n",
    "\n",
    "  def sample(self):\n",
    "    raise NotImplementedError\n",
    "\n",
    "  def stopping_criterion(self): # False for continue, True for stop\n",
    "    raise NotImplementedError\n",
    "\n",
    "  def update(self):\n",
    "    raise NotImplementedError\n",
    "\n",
    "  def recommendation(self):\n",
    "    raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1) Top-Two Thompson Sampling\n",
    "\n",
    "is a class that implements the algorithm devised in the paper ['Simple bayesian algorithms for best arm identification'](https://proceedings.mlr.press/v49/russo16.html)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "executionInfo": {
     "elapsed": 151,
     "status": "ok",
     "timestamp": 1702800955389,
     "user": {
      "displayName": "Kyoungseok Jang",
      "userId": "07027991917961288603"
     },
     "user_tz": 300
    },
    "id": "WxtRkL315nsK"
   },
   "outputs": [],
   "source": [
    "class TTTS(BAI_Algorithm):              #Top Two Thompson Sampling - UNDER PROGRESS\n",
    "  def __init__(self, prior_dist, delta):\n",
    "    self.beta = 0.5\n",
    "    self.prior_dist = prior_dist\n",
    "    self.K= prior_dist.K\n",
    "    self.delta = delta\n",
    "    self.C = 1\n",
    "    self.alpha=1\n",
    "    self.hist=[ [] for _ in range(self.K) ] # list of empty lists\n",
    "    self.pos_mean=np.zeros(self.K)\n",
    "    self.pos_std=np.zeros(self.K)\n",
    "    self.n_list=np.zeros(self.K)\n",
    "    self.m_hat = np.zeros(self.K)\n",
    "\n",
    "  def sample(self):\n",
    "    if np.min(self.n_list)==0:\n",
    "      return np.argmin(self.n_list)\n",
    "    alpha=np.zeros(self.K)\n",
    "    index_sample=np.random.binomial(1,self.beta)\n",
    "    alpha=self.posterior_sample()\n",
    "    bestI=np.argmax(alpha)\n",
    "\n",
    "    final_index=bestI         #With probability beta, return best index from the sample\n",
    "\n",
    "    if index_sample==0:       #With prob. 1-beta, repeat sampling until new best index appears\n",
    "      bestJ=bestI\n",
    "      while bestI==bestJ:\n",
    "        alpha=self.posterior_sample()\n",
    "        _mask = np.zeros(self.K, dtype=bool)\n",
    "        _mask[bestI] = True\n",
    "        masked = np.ma.array(alpha, mask=_mask)\n",
    "        bestJ=np.argmax(masked)\n",
    "      final_index=bestJ\n",
    "    return final_index\n",
    "\n",
    "  def posterior_sample(self):               #Function that calculates posterior distribution\n",
    "\n",
    "    sample_result=np.zeros(self.K)\n",
    "    for i in range(self.K):\n",
    "      sample_result[i]=np.random.normal(self.pos_mean[i],self.pos_std[i])\n",
    "    return sample_result\n",
    "\n",
    "\n",
    "  def update(self, a, reward):\n",
    "                          #Posterior update\n",
    "    self.m_hat[a] = (reward+self.n_list[a]*self.m_hat[a])/(self.n_list[a]+1)               #sample mean\n",
    "    self.n_list[a] = self.n_list[a]+1\n",
    "    m_p   = self.prior_dist.prior_mean[a]       #prior mean\n",
    "    var_a = self.prior_dist.instance_std[a]**2  # variance of the arm\n",
    "    var_p = self.prior_dist.prior_std[a]**2     # variance of the prior distribution\n",
    "    self.pos_mean[a] = (m_p*var_p + self.n_list[a]*self.m_hat[a]*var_a)/(self.n_list[a]*var_a + var_p)\n",
    "    self.pos_std[a]  = np.sqrt(var_a*var_p/(self.n_list[a]*var_a + var_p))\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "  def stopping_criterion(self):\n",
    "    T=np.sum(self.n_list)\n",
    "    if np.min(self.n_list)==0:\n",
    "      return False\n",
    "    i_max=np.argmax(self.m_hat)\n",
    "    W = np.zeros(self.K)\n",
    "    for j in range(self.K):\n",
    "      if i_max==j:\n",
    "          W[j]=np.inf\n",
    "      else:\n",
    "        tempi=self.n_list[i_max]/self.prior_dist.instance_std[i_max]**2\n",
    "        tempj=self.n_list[j]/self.prior_dist.instance_std[j]**2\n",
    "        infx=(tempi*self.m_hat[i_max]+tempj*self.m_hat[j])/(tempi+tempj)\n",
    "        W[j]=self.n_list[i_max]*self.Kinf_neg(i_max,infx)+self.n_list[j]*self.Kinf_pos(j,infx)\n",
    "\n",
    "    minWij=np.min(W)\n",
    "    threshold = np.log(self.C*(T**self.alpha) / self.delta) ####################### Need to be adjusted later\n",
    "\n",
    "    if minWij>threshold:\n",
    "      return True\n",
    "    else:\n",
    "      return False\n",
    "\n",
    "\n",
    "\n",
    "  def Kinf_neg(self, i, x):\n",
    "    if x>self.m_hat[i]:\n",
    "      return 0\n",
    "    else:\n",
    "      return self.KL_div(i,x)\n",
    "\n",
    "  def Kinf_pos(self,j,x):\n",
    "    if x<self.m_hat[j]:\n",
    "      return 0\n",
    "    else:\n",
    "      return self.KL_div(j,x)\n",
    "\n",
    "  def KL_div(self,i,x):\n",
    "    return (x-self.m_hat[i])**2/(2*prior_dist.instance_std[i]**2)\n",
    "\n",
    "  def recommendation(self):\n",
    "    return np.argmax(self.m_hat)\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2) Top-Two UCB\n",
    "is a class that implements the algorithm devised in the paper ['Non-asymptotic analysis of a ucb-based top two algorithm'](https://proceedings.neurips.cc/paper_files/paper/2023/hash/d9b564716709357b4bccec9fc9ad04d2-Abstract-Conference.html)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "class TTUCB(BAI_Algorithm):              #Top Two Thompson Sampling - UNDER PROGRESS\n",
    "  def __init__(self, prior_dist, delta):\n",
    "    self.beta = 0.5\n",
    "    self.prior_dist = prior_dist\n",
    "    self.K= prior_dist.K\n",
    "    self.delta = delta\n",
    "    self.n_la = np.zeros((self.K, self.K)) # number of arm 'a'(second) pull when the leader was 'l'(first)\n",
    "    self.n_list=np.zeros(self.K) # number of each arm pull\n",
    "    self.leader_list=np.zeros(self.K)\n",
    "    self.m_hat = np.zeros(self.K)\n",
    "    self.width = np.ones(self.K) # Not precisely width, without log factor\n",
    "    self.UCB = self.m_hat+self.width*np.inf\n",
    "\n",
    "  def sample(self):\n",
    "    if np.min(self.n_list)==0: #Pull each arm at least once\n",
    "        return np.argmin(self.n_list)\n",
    "    bestI=np.argmax(self.m_hat)\n",
    "    leader_ind = np.argmax(self.UCB)\n",
    "    if self.beta*self.leader_list[leader_ind]<self.n_la[leader_ind][leader_ind]: #Pick Challenger\n",
    "        challenger_measure = np.zeros(self.K)\n",
    "        for i in range(self.K):\n",
    "            if i==leader_ind:\n",
    "                challenger_measure[i]=np.inf\n",
    "            else:\n",
    "                challenger_measure[i]=np.max(self.m_hat[leader_ind]-self.m_hat[i], 0)/np.sqrt(1/self.n_list[leader_ind]+1/self.n_list[i])\n",
    "        challenger_ind=np.argmin(challenger_measure)\n",
    "        return challenger_ind\n",
    "    return leader_ind\n",
    "    \n",
    "    \n",
    "\n",
    "\n",
    "\n",
    "  def update(self, a, reward):\n",
    "      leader_ind = np.argmax(self.UCB)\n",
    "      self.n_la[leader_ind][a]=self.n_la[leader_ind][a]+1\n",
    "      self.leader_list[leader_ind]=self.leader_list[leader_ind]+1\n",
    "      self.m_hat[a] = (reward+self.n_list[a]*self.m_hat[a])/(self.n_list[a]+1)               #sample mean\n",
    "      self.n_list[a]=self.n_list[a]+1\n",
    "\n",
    "      T=np.sum(self.n_list)\n",
    "      self.width[a]=1/np.sqrt(self.n_list[a])\n",
    "      self.UCB=self.m_hat+self.width*np.sqrt(4*np.log(T+1))\n",
    "      \n",
    "\n",
    "\n",
    "  def stopping_criterion(self):\n",
    "    T=np.sum(self.n_list)\n",
    "    if np.min(self.n_list)==0:\n",
    "      return False\n",
    "    i_max=np.argmax(self.m_hat)\n",
    "    W = np.zeros(self.K)\n",
    "    minW = np.inf\n",
    "    for j in range(self.K):\n",
    "        if i_max==j:\n",
    "            W[j]=np.inf\n",
    "        else:\n",
    "            W[j]=(self.m_hat[i_max]-self.m_hat[j])/np.sqrt(1/self.n_list[i_max]+1/self.n_list[j])\n",
    "        minW=np.min((minW, W[j]))\n",
    "\n",
    "    C=2*self.C_G(0.5*np.log((self.K-1)/self.delta))+4*np.log(4+np.log((T-1)/2))\n",
    "    threshold = np.sqrt(2*C)\n",
    "    if minW>threshold:\n",
    "      return True\n",
    "    else:\n",
    "      return False\n",
    "\n",
    "\n",
    "  def recommendation(self):\n",
    "    return np.argmax(self.m_hat)\n",
    "\n",
    "  def C_G(self,x):\n",
    "      return x+ np.log(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3) Our Algorithm (Algorithm 1, Successive Elimination)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "executionInfo": {
     "elapsed": 140,
     "status": "ok",
     "timestamp": 1702800721936,
     "user": {
      "displayName": "Kyoungseok Jang",
      "userId": "07027991917961288603"
     },
     "user_tz": 300
    },
    "id": "bCA2xpUzNyfU"
   },
   "outputs": [],
   "source": [
    "class Elimination(BAI_Algorithm):\n",
    "  def __init__(self, prior_dist, delta):\n",
    "    self.prior_dist = prior_dist\n",
    "    self.K= prior_dist.K\n",
    "    self.delta=delta\n",
    "    self.Delta0=prior_dist.get_Delta_0(delta)\n",
    "    self.survived_arms = list(range(self.K))\n",
    "    self.need_elimination = False\n",
    "    self.Delta_safe = np.inf\n",
    "\n",
    "    #self.hist=[ [] for _ in range(self.K) ] # list of empty lists\n",
    "    self.n_list = np.zeros(self.K)\n",
    "    self.m_hat = np.zeros(self.K)\n",
    "\n",
    "\n",
    "  def sample(self):\n",
    "#    n_list = [len(self.hist[i]) for i in self.survived_arms]\n",
    "    chosen = np.argmin(self.n_list[self.survived_arms])\n",
    "    if chosen==len(self.survived_arms)-1:\n",
    "      self.need_elimination=True\n",
    "    return self.survived_arms[chosen]\n",
    "\n",
    "  def recommendation(self):\n",
    "    if len(self.survived_arms)==1:\n",
    "      return self.survived_arms[0]\n",
    "    elif self.Delta_safe<self.Delta0:\n",
    "      return random.choice(self.survived_arms)\n",
    "    else:\n",
    "      print('Error: recommendation should be done only after stopping criterion')\n",
    "      return -1\n",
    "\n",
    "  def stopping_criterion(self):\n",
    "    if self.need_elimination:\n",
    "      self.elim()\n",
    "\n",
    "    if len(self.survived_arms)==1:                # Note that these two parameters, self,survived_arms and self.Delta_safe only changes when elim() has been called.\n",
    "      return True\n",
    "    elif self.Delta_safe<self.Delta0:\n",
    "      return True\n",
    "    else:\n",
    "      return False\n",
    "\n",
    "\n",
    "  def elim(self):\n",
    "    ucbs=np.zeros(len(self.survived_arms))\n",
    "    lcbs=np.zeros(len(self.survived_arms))\n",
    "    next_survived=[]\n",
    "\n",
    "    for i in range(len(self.survived_arms)):                      # Compute UCB and LCB\n",
    "      ucbs[i], lcbs[i] = self.conf_bounds(self.m_hat,self.survived_arms[i])\n",
    "    lcbmax=np.max(lcbs)                               # Maximum LCB\n",
    "    ucbmax=np.max(ucbs)                               # Maximum UCB\n",
    "    for i in range(len(self.survived_arms)):                       # Include all arms which satisfies UCB>LCBMAX to a new basket\n",
    "      if ucbs[i]>lcbmax:\n",
    "        next_survived.append(self.survived_arms[i])\n",
    "\n",
    "    self.Delta_safe=ucbmax-lcbmax\n",
    "    self.survived_arms=next_survived\n",
    "    self.need_elimination = False\n",
    "\n",
    "  def conf_bounds(self,m_hat,i):\n",
    "    m_hat=self.m_hat[i]\n",
    "    n=self.n_list[i]\n",
    "    width=np.sqrt(\n",
    "        2*self.prior_dist.instance_std[i]**2\n",
    "        *np.log(12*self.K*n**2/self.delta**2/np.pi**2)/n\n",
    "    )\n",
    "    return m_hat+width, m_hat-width\n",
    "\n",
    "  def update(self, a, reward):\n",
    "    self.m_hat[a]=(self.m_hat[a]*self.n_list[a]+reward)/(self.n_list[a]+1)\n",
    "    self.n_list[a]=self.n_list[a]+1\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "p-v2O_fU2Fd2"
   },
   "source": [
    "# General Experiment Design\n",
    "A class designed to effectively manage the experimental environment. This is mainly for guaranteeing same instance samples $(\\mu_i)_{i=1}^k$ for all algorithms. It takes the following major inputs:\n",
    "- prior\\_dist: corresponds to $H$ in our main paper. It will be *prior\\_distribution* instance.\n",
    "- $\\delta$: Error probability. \n",
    "- algolist: list of algorithms for this experiment. \n",
    "\n",
    "Methods:\n",
    "- single\\_experiment(mean, algoname): run single experiment for *algoname* with given mean $(\\mu_i)_{i=1}^k$. Returns stopping time and whether the prediction is correct or wrong. \n",
    "- monte\\_carlo\\_experiment(num\\_of\\_exp): repeat experiment *num\\_of\\_exp* times, but for each repetition all algorithms share same $(\\mu_i)_{i=1}^k$. Output is expected stopping time and success rate of each algorithm. \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "executionInfo": {
     "elapsed": 132,
     "status": "ok",
     "timestamp": 1702800724516,
     "user": {
      "displayName": "Kyoungseok Jang",
      "userId": "07027991917961288603"
     },
     "user_tz": 300
    },
    "id": "nY_V_--KlYth"
   },
   "outputs": [],
   "source": [
    "class Experiment:\n",
    "  def __init__(self, prior_dist, delta, algolist):\n",
    "    self.delta=delta\n",
    "    self.algolist=algolist                              # Strings of algorithm names, such as ['Elim', 'TTTS']\n",
    "    self.prior_dist=prior_dist\n",
    "    self.K=self.prior_dist.K                            # Number of Arms\n",
    "    self.stopping_time_hist = []                        # Record of Stopping time\n",
    "    self.success_hist = []             \n",
    "    self.time_spent=[]                                  # Record of time spent (unit: sec)\n",
    "\n",
    "\n",
    "  def monte_carlo_experiment(self, num_of_exp): # running multiple experiment - num_of_exp times\n",
    "    exp_stopping_time=np.zeros(len(self.algolist))\n",
    "    success_rate=np.zeros(len(self.algolist))\n",
    "    for i in range(num_of_exp):\n",
    "      mean=self.prior_dist.sample_instance()\n",
    "      for alg in range(len(self.algolist)):\n",
    "        algoname=self.algolist[alg]\n",
    "        start_time=time.time()\n",
    "        sample_stopping_time, sample_success = self.single_experiment(mean, algoname)\n",
    "        elapsed=time.time()-start_time\n",
    "        self.stopping_time_hist.append(sample_stopping_time)\n",
    "        self.success_hist.append(sample_success)\n",
    "        self.time_spent.append(elapsed)\n",
    "        exp_stopping_time[alg] =exp_stopping_time[alg]+sample_stopping_time/num_of_exp\n",
    "        success_rate[alg]      =success_rate[alg]+sample_success/num_of_exp\n",
    "    return exp_stopping_time, success_rate\n",
    "\n",
    "\n",
    "\n",
    "  def single_experiment(self, mean, algoname):\n",
    "    print(\"Starting an experiment - mean \")\n",
    "    print(mean)\n",
    "    alg=BAI_Algorithm(self.prior_dist,self.delta)\n",
    "    if algoname=='Elim':\n",
    "      print(\"Algo: Elim\")\n",
    "      alg = Elimination(self.prior_dist, self.delta)\n",
    "    elif algoname=='TTTS':\n",
    "      print(\"Algo: TTTS\")\n",
    "      alg = TTTS(self.prior_dist, self.delta)\n",
    "    elif algoname=='TTUCB':\n",
    "      print(\"Algo: TTUCB\")\n",
    "      alg = TTUCB(self.prior_dist, self.delta)\n",
    "    answer=np.argmax(mean)\n",
    "    std=self.prior_dist.instance_std\n",
    "    stopping_time=0\n",
    "    while not alg.stopping_criterion():\n",
    "      a=alg.sample()\n",
    "      reward = np.random.normal(mean[a],std[a])\n",
    "      alg.update(a, reward)\n",
    "      stopping_time=stopping_time+1\n",
    "    print('Final stopping time: %d'%stopping_time)\n",
    "    if answer==alg.recommendation():\n",
    "      return stopping_time,1\n",
    "    else:\n",
    "      return stopping_time,0\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Start of the Main Code\n",
    "Parameters\n",
    "- $K$: number of arms.\n",
    "- *prior\\_mean*: corresponds to $(m_i)_{i=1}^k$, the mean vector of the prior distribution.\n",
    "- *prior\\_std*: corrsponds to $(\\xi_i)_{i=1}^k$, the vector of standard deviations.\n",
    "- *instance\\_std*: corresponds to $(\\sigma_i)_{i=1}^k$, the standard deviations of the reward distributions. \n",
    "- $\\delta$: error probability."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "executionInfo": {
     "elapsed": 144,
     "status": "ok",
     "timestamp": 1702800727426,
     "user": {
      "displayName": "Kyoungseok Jang",
      "userId": "07027991917961288603"
     },
     "user_tz": 300
    },
    "id": "CtsNGPqo9Mr_"
   },
   "outputs": [],
   "source": [
    "K=2\n",
    "prior_mean=np.zeros(K)\n",
    "prior_std=np.ones(K)\n",
    "instance_std=np.ones(K)\n",
    "delta=0.1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Create prior distribution $H$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "prior_dist=prior_distribution(prior_mean, prior_std,instance_std)\n",
    "prior_dist.get_Delta_0(delta)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Experiment instance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "executionInfo": {
     "elapsed": 166,
     "status": "ok",
     "timestamp": 1702800962508,
     "user": {
      "displayName": "Kyoungseok Jang",
      "userId": "07027991917961288603"
     },
     "user_tz": 300
    },
    "id": "MqKIm4yh-Uwq"
   },
   "outputs": [],
   "source": [
    "myExp=Experiment(prior_dist, delta, ['TTTS', 'TTUCB', 'Elim'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Run the experiment 1000 times.\n",
    "- res: average stopping time variable\n",
    "\n",
    "- ult: success rate variable"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "res, ult= myExp.monte_carlo_experiment(1000)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Stopping time variables\n",
    "\n",
    "- ElimStop: Stopping time history for our algorithm\n",
    "- TTTSStop: Stopping time history for TTTS\n",
    "- TTUCBStop: Stopping time history for TTUCB\n",
    "\n",
    "### Computation time variables\n",
    "\n",
    "- ElimCompTime: Computation time history for our algorithm\n",
    "- TTTSComp: Computation time history for TTTS\n",
    "- TTUCBComp: Computation time history for TTUCB\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 141,
     "status": "ok",
     "timestamp": 1702786024399,
     "user": {
      "displayName": "Kyoungseok Jang",
      "userId": "07027991917961288603"
     },
     "user_tz": 300
    },
    "id": "DQlyiFmxwODu",
    "outputId": "102e3eb6-8861-4586-b366-3425dc420200"
   },
   "outputs": [],
   "source": [
    "lnow=1000\n",
    "ElimStop = [myExp.stopping_time_hist[3*i+2] for i in range(lnow)]\n",
    "TTTSStop = [myExp.stopping_time_hist[3*i+1] for i in range(lnow)]\n",
    "TTUCBStop = [myExp.stopping_time_hist[3*i] for i in range(lnow)]\n",
    "ElimCompTime = [myExp.time_spent[3*i+2] for i in range(lnow)]\n",
    "TTTSCompTime = [myExp.time_spent[3*i+1] for i in range(lnow)]\n",
    "TTUCBCompTime = [myExp.time_spent[3*i] for i in range(lnow)]\n",
    "np.savetxt('Elim_K2_1000.txt', ElimStop, delimiter=',')\n",
    "np.savetxt('TTTS_K2_1000.txt', TTTSStop, delimiter=',')\n",
    "np.savetxt('TTUCB_K2_1000.txt', TTUCBStop, delimiter=',')\n",
    "np.savetxt('Elim_Comptime_K2_1000.txt', ElimCompTime, delimiter=',')\n",
    "np.savetxt('TTTS_Comptime_K2_1000.txt', TTTSCompTime, delimiter=',')\n",
    "np.savetxt('TTUCB_Comptime_K2_1000.txt', TTUCBCompTime, delimiter=',')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(np.mean(ElimStop))\n",
    "print(np.mean(TTTSStop))\n",
    "print(np.mean(TTUCBStop))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(np.mean(ElimCompTime))\n",
    "print(np.mean(TTTSCompTime))\n",
    "print(np.mean(TTUCBCompTime))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(np.max(ElimStop))\n",
    "print(np.max(TTTSStop))\n",
    "print(np.max(TTUCBStop))"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "authorship_tag": "ABX9TyM3WMzKUtrcLePPPBMWj71C",
   "name": "",
   "version": ""
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
