{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import timeit\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "def cal_inf(G, S_A):\n",
    "    inf_A = np.sum((1 - np.prod(1 - G[S_A], axis=0)))\n",
    "    return inf_A\n",
    "\n",
    "def greedy(G, k_A):\n",
    "    S_A = []\n",
    "    for _ in range(k_A):\n",
    "        inf_out = np.zeros(n_s)\n",
    "        for i in range(n_s):\n",
    "            if i not in S_A:\n",
    "                inf_out[i] = cal_inf(G, S_A+[i])\n",
    "#                 print('node '+str(i)+', inf '+str(inf_out[i]))\n",
    "        i_max = np.argmax(inf_out)\n",
    "        S_A.append(i_max)\n",
    "#         print('select node '+str(i_max))\n",
    "#         print('inf '+str(inf_out.max()))\n",
    "    return S_A, inf_out.max()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_s = 10\n",
    "n_t = 20\n",
    "k_A = 5\n",
    "T = 20000\n",
    "N_exp = 10\n",
    "C1 = 3\n",
    "Bv = 3*np.sqrt(2*n_t)/2\n",
    "explore_ratio = 0.01"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "reg_cucb = np.zeros((N_exp,T))\n",
    "reg_bcucb = np.zeros((N_exp,T))\n",
    "reg_sescb = np.zeros((N_exp,T))\n",
    "reg_escb = np.zeros((N_exp,T))\n",
    "np.random.seed(0)\n",
    "G = np.random.rand(n_s,n_t)*0.01 + 0.05\n",
    "mu_init = (np.random.rand(n_s,n_t) < G).astype(float)\n",
    "mu_hat = np.zeros((n_s,n_t))\n",
    "for exp in range(N_exp):\n",
    "    np.random.seed(exp)\n",
    "    S_A_opt, opt = greedy(G, k_A)\n",
    "    print('exp '+str(exp)+', opt: '+str(opt))\n",
    "    np.random.seed(exp)\n",
    "#     ESCB\n",
    "    start = timeit.default_timer()\n",
    "    print('ESCB')\n",
    "    L = n_s\n",
    "    K = k_A\n",
    "    comb_list = rSubset(np.arange(L), K)\n",
    "    comb_len = len(comb_list)\n",
    "    rho = np.zeros(comb_len)\n",
    "    UCB = np.zeros(comb_len)\n",
    "    T_i = np.ones(L)\n",
    "    mu_hat = mu_init\n",
    "    T_mu = np.ones((n_s,n_t))\n",
    "    inf_t = np.zeros(T)\n",
    "    for t in range(T):\n",
    "        for i in range(comb_len):\n",
    "            comb_tmp = list(comb_list[i])\n",
    "            T_min = np.min(T_i[comb_tmp])\n",
    "            a1 = n_t * np.sum(1 / T_i[comb_tmp])\n",
    "            delta = 2 * np.log(t+3) + 2*(n_s*n_t + 2) * np.log(np.log(t+3)) + 1\n",
    "            rho[i] = np.sqrt(delta * a1)\n",
    "            UCB[i] = cal_inf(mu_hat, comb_tmp) + explore_ratio*rho[i]\n",
    "        S_A = list(comb_list[np.argmax(UCB)])\n",
    "        T_i[S_A] += 1\n",
    "        for i in S_A:\n",
    "            T_mu[i] += 1\n",
    "            X_i = (np.random.rand(n_t) < G[i]).astype(int)\n",
    "            mu_hat[i] = mu_hat[i] + (X_i - mu_hat[i]) / T_mu[i]\n",
    "        inf_t[t] = cal_inf(G, S_A)\n",
    "    rho_ESCB = rho\n",
    "    reg_escb[exp] = opt - inf_t\n",
    "    stop = timeit.default_timer()\n",
    "    print('ESCB Time: ', stop - start)  \n",
    "\n",
    "#         BCUCB\n",
    "    start = timeit.default_timer()\n",
    "    print('BCUCB')\n",
    "    mu_hat = mu_init\n",
    "    var_hat = np.zeros((n_s,n_t))\n",
    "    rho = np.zeros((n_s,n_t))\n",
    "    T_i = np.ones((n_s,n_t))\n",
    "    inf_t = np.zeros(T)\n",
    "    for t in range(T):\n",
    "        rho = np.sqrt(6 * var_hat * np.log(t+1) / (T_i)) + (9 * np.log(t+1) / (T_i))\n",
    "        UCB = np.clip(mu_hat + explore_ratio * rho, a_min=0, a_max=1)\n",
    "        G_t = UCB\n",
    "        S_A, _ = greedy(G_t, k_A)\n",
    "        for i in S_A:\n",
    "            T_i[i] += 1\n",
    "            X_i = (np.random.rand(n_t) < G[i]).astype(int)\n",
    "            var_hat[i] = (T_i[i] - 1) / (T_i[i]**2) * ((X_i - mu_hat[i])**2) + (T_i[i]-1)/T_i[i]*var_hat[i]\n",
    "            mu_hat[i] = mu_hat[i] + (X_i - mu_hat[i]) / T_i[i]\n",
    "        inf_t[t] = cal_inf(G, S_A)\n",
    "    reg_bcucb[exp] = opt - inf_t\n",
    "    stop = timeit.default_timer()\n",
    "    print('BCUCB Time: ', stop - start)      \n",
    "    \n",
    "    # SESCB\n",
    "    start = timeit.default_timer()\n",
    "    print('SESCB')\n",
    "    L = n_s\n",
    "    K = k_A\n",
    "    comb_list = rSubset(np.arange(L), K)\n",
    "    comb_len = len(comb_list)\n",
    "    rho = np.zeros(comb_len)\n",
    "    UCB = np.zeros(comb_len)\n",
    "    T_i = np.ones(L)\n",
    "    mu_hat = mu_init\n",
    "    T_mu = np.ones((n_s,n_t))\n",
    "    inf_t = np.zeros(T)\n",
    "    for t in range(T):\n",
    "        for i in range(comb_len):\n",
    "            comb_tmp = list(comb_list[i])\n",
    "            T_min = np.min(T_i[comb_tmp])\n",
    "            a1 = n_t * np.sum(C1 / T_i[comb_tmp])\n",
    "            a2 = n_t * 8 * C1 * np.sqrt(np.sum(np.log(2*comb_len*(t+1)) / (T_i[comb_tmp]**2)))\n",
    "            a3 = 8 * C1 * np.log(2*comb_len*(t+1)) / T_min\n",
    "            rho[i] = Bv * np.sqrt(a1 + np.max((a2, a3)))\n",
    "            UCB[i] = cal_inf(mu_hat, comb_tmp) + explore_ratio*rho[i]\n",
    "        S_A = list(comb_list[np.argmax(UCB)])\n",
    "        T_i[S_A] += 1\n",
    "        for i in S_A:\n",
    "            T_mu[i] += 1\n",
    "            X_i = (np.random.rand(n_t) < G[i]).astype(int)\n",
    "            mu_hat[i] = mu_hat[i] + (X_i - mu_hat[i]) / T_mu[i]\n",
    "        inf_t[t] = cal_inf(G, S_A)\n",
    "    rho_SESCB = rho\n",
    "    reg_sescb[exp] = opt - inf_t\n",
    "    stop = timeit.default_timer()\n",
    "    print('SESCB Time: ', stop - start) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ft = 14\n",
    "t = np.arange(10000)\n",
    "X1 = reg_escb[0:10].cumsum(axis=1)\n",
    "X2 = reg_bcucb[0:10].cumsum(axis=1)\n",
    "X3 = reg_sescb[0:10].cumsum(axis=1)\n",
    "mu1 = X1.mean(axis=0)\n",
    "sigma1 = X1.std(axis=0)\n",
    "mu2 = X2.mean(axis=0)\n",
    "sigma2 = X2.std(axis=0)\n",
    "mu3 = X3.mean(axis=0)\n",
    "sigma3 = X3.std(axis=0)\n",
    "\n",
    "fig, ax = plt.subplots(1)\n",
    "\n",
    "ax.plot(t, mu1, lw=3, label='ESCB', color='green', linestyle = '-')\n",
    "ax.plot(t, mu2, lw=3, label='BCUCB-T', color='red', linestyle = '--')\n",
    "ax.plot(t, mu3, lw=3, label='SESCB', color='blue', linestyle = '-.')\n",
    "ax.fill_between(t, mu1+sigma1, mu1-sigma1, facecolor='green', alpha=0.1)\n",
    "ax.fill_between(t, mu2+sigma2, mu2-sigma2, facecolor='red', alpha=0.1)\n",
    "ax.fill_between(t, mu3+sigma3, mu3-sigma3, facecolor='blue', alpha=0.1)\n",
    "\n",
    "ax.legend(loc='upper left',fontsize=ft) \n",
    "ax.set_xlabel('Round',fontsize=ft) \n",
    "ax.set_ylabel('Cumulative Regret',fontsize=ft)\n",
    "fig.savefig('PMC.pdf')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
