{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Wotj2L12Cs9c"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "import pickle\n",
        "import matplotlib.pyplot as plt\n",
        "import scipy.stats as st\n",
        "\n",
        "class UCB_weak:\n",
        "\n",
        "    def __init__(self,arms,R,x, b = 2):\n",
        "        self.num_arms=len(arms)\n",
        "        self.true_means=[0 for arm in arms]\n",
        "        self.count_gt=[0 for arm in arms]\n",
        "        self.count_c=[0 for arm in arms]\n",
        "        self.R=R\n",
        "        self.b = b\n",
        "        self.x=x\n",
        "        self.bounds=[R for arm in arms]\n",
        "\n",
        "    def choose_arm(self):\n",
        "        max_val = np.max(self.bounds)\n",
        "        arm = np.random.choice(np.where(self.bounds == max_val)[0])\n",
        "        return arm\n",
        "\n",
        "    def update_params(self,reward,arm,t):\n",
        "        if self.count_gt[arm]<=self.x*self.count_c[arm]:\n",
        "            self.count_gt[arm]+=1\n",
        "            delta_g = reward[0]-self.true_means[arm]\n",
        "            self.true_means[arm] += delta_g/self.count_gt[arm]\n",
        "        else:\n",
        "            self.count_c[arm]+=1\n",
        "        for arm in range(self.num_arms):\n",
        "          if self.count_gt[arm] == 0:\n",
        "              self.bounds[arm] = self.R\n",
        "          else:\n",
        "            term = np.log(1+t)\n",
        "            self.bounds[arm] = min(self.true_means[arm]+self.R*np.sqrt(self.b*term/self.count_gt[arm]), self.R)\n",
        "\n",
        "class UCBV_weak:\n",
        "\n",
        "    def __init__(self,arms,R, x, b=3):\n",
        "        self.num_arms=len(arms)\n",
        "        self.sum_rew_gt=[0 for arm in arms]\n",
        "        self.count_gt=[0 for arm in arms]\n",
        "        self.sum_sq_rew_gt = [0 for arm in arms]\n",
        "        self.count_c=[0 for arm in arms]\n",
        "        self.R=R\n",
        "        self.b = b\n",
        "        self.x=x\n",
        "        self.bounds=[R for arm in arms]\n",
        "\n",
        "    def choose_arm(self):\n",
        "        max_val = np.max(self.bounds)\n",
        "        arm = np.random.choice(np.where(self.bounds == max_val)[0])\n",
        "        return arm\n",
        "\n",
        "    def update_params(self,reward,arm, t):\n",
        "        if self.count_gt[arm]<=self.x*self.count_c[arm]:\n",
        "            self.count_gt[arm]+=1\n",
        "            self.sum_rew_gt[arm] += reward[0]\n",
        "            self.sum_sq_rew_gt[arm] += reward[0]**2\n",
        "        else:\n",
        "            self.count_c[arm]+=1\n",
        "\n",
        "        for arm1 in range(self.num_arms):\n",
        "          if self.count_gt[arm1] == 0:\n",
        "              self.bounds[arm1] = self.R\n",
        "          else:\n",
        "            term = np.log(1+t)\n",
        "            mean = self.sum_rew_gt[arm1]/self.count_gt[arm1]\n",
        "            var = self.sum_sq_rew_gt[arm1]/self.count_gt[arm1]-mean**2\n",
        "            if var < 0:\n",
        "              var = 0\n",
        "            bound = np.sqrt(2*var*term/self.count_gt[arm1])+self.b*self.R*term/self.count_gt[arm1]\n",
        "            self.bounds[arm1] = min(mean+bound, self.R)\n",
        "\n",
        "class CUUCB:\n",
        "\n",
        "    def __init__(self,arms,R, x, gamma = 1, b = 3):\n",
        "        self.num_arms=len(arms)\n",
        "        self.sum_rew_gt=[0 for arm in arms]\n",
        "        self.sum_rew_c=[0 for arm in arms]\n",
        "        self.sum_rew_cgt=[0 for arm in arms]\n",
        "        self.count_gt=[0 for arm in arms]\n",
        "        self.count_c=[0 for arm in arms]\n",
        "        self.sum_sq_rew_gt = [0 for arm in arms]\n",
        "        self.sum_sq_rew_c = [0 for arm in arms]\n",
        "        self.sum_sq_rew_cgt = [0 for arm in arms]\n",
        "        self.sum_prod_rew = [0 for arm in arms]\n",
        "        self.R=R\n",
        "        self.x=x\n",
        "        self.b = b\n",
        "        self.gamma = gamma\n",
        "        self.bounds=[R for arm in arms]\n",
        "\n",
        "    def choose_arm(self):\n",
        "        max_val = np.max(self.bounds)\n",
        "        arm = np.random.choice(np.where(self.bounds == max_val)[0])\n",
        "        return arm\n",
        "\n",
        "    def update_params(self,reward,arm, t):\n",
        "        if self.count_gt[arm]<=self.x*self.count_c[arm]:\n",
        "            self.count_gt[arm]+=1\n",
        "            self.sum_rew_gt[arm] += reward[0]\n",
        "            self.sum_sq_rew_gt[arm] += reward[0]**2\n",
        "            self.sum_rew_cgt[arm] += reward[1]\n",
        "            self.sum_sq_rew_cgt[arm] += reward[1]**2\n",
        "            self.sum_prod_rew[arm] += reward[0]*reward[1]\n",
        "        else:\n",
        "            self.count_c[arm]+=1\n",
        "            self.sum_rew_c[arm] += reward[1]\n",
        "            self.sum_sq_rew_c[arm] += reward[1]**2\n",
        "\n",
        "        for arm1 in range(self.num_arms):\n",
        "          if self.count_gt[arm1] == 0 or self.count_c[arm1] == 0:\n",
        "            self.bounds[arm1] = self.R\n",
        "          else:\n",
        "            term = np.log(1+t)\n",
        "            x_t = self.count_gt[arm1]/self.count_c[arm1]\n",
        "            alpha_t = x_t/(1+x_t)\n",
        "            mean_c = self.sum_rew_c[arm1]/self.count_c[arm1]\n",
        "            mean_gt = self.sum_rew_gt[arm1]/self.count_gt[arm1]\n",
        "            mean_cgt = self.sum_rew_cgt[arm1]/self.count_gt[arm1]\n",
        "            var_c = self.sum_sq_rew_c[arm1]/self.count_c[arm1]-mean_c**2\n",
        "            var_gt = self.sum_sq_rew_gt[arm1]/self.count_gt[arm1]-mean_gt**2\n",
        "            var_cgt = self.sum_sq_rew_cgt[arm1]/self.count_gt[arm1]-mean_cgt**2\n",
        "            cov = self.sum_prod_rew[arm1]/self.count_gt[arm1]-mean_gt*mean_cgt\n",
        "            var_c_hat = alpha_t*var_c + (1-alpha_t)*var_cgt\n",
        "            if var_c_hat == 0:\n",
        "              ratio = np.sign(cov)*self.gamma\n",
        "            else:\n",
        "              ratio = cov/var_c_hat\n",
        "            if ratio>=self.gamma:\n",
        "              ratio = self.gamma\n",
        "            if ratio<=-self.gamma:\n",
        "              ratio = -self.gamma\n",
        "            lam = ratio/(1+x_t)\n",
        "            var = var_gt - 2*lam*cov + (1+x_t)*var_c_hat*lam**2\n",
        "            mean_est = mean_gt - lam*(mean_cgt-mean_c)\n",
        "            if var < 0:\n",
        "              var = 0\n",
        "            bound=np.sqrt(2*var*term/self.count_gt[arm1]) + self.b*self.R*term/self.count_gt[arm1]\n",
        "            self.bounds[arm1] = min(mean_est+bound, self.R)\n",
        "\n",
        "def sample(arm,dataset):\n",
        "  N = len(dataset)\n",
        "  index=np.random.choice(N)\n",
        "  r_gt = dataset[index, arm, 0]\n",
        "  r_coarse = dataset[index, arm, 1]\n",
        "  return [r_gt,r_coarse]\n",
        "\n",
        "\n",
        "def runAlg(alg,new_dataset,num_arms,T):\n",
        "\n",
        "    regret=[]\n",
        "    true_means=[np.mean(new_dataset[:,i,0]) for i in range(num_arms)]\n",
        "    synthetic_means=[np.mean(new_dataset[:,i,1]) for i in range(num_arms)]\n",
        "    best_mean = np.max(true_means)\n",
        "    #  run algorithm\n",
        "    for t in range(T):\n",
        "        arm=alg.choose_arm()\n",
        "        regret.append(best_mean-true_means[arm])\n",
        "        alg.update_params(sample(arm,new_dataset),arm, t)\n",
        "\n",
        "    return regret\n",
        "\n",
        "\n",
        "\n",
        "# Set global Matplotlib parameters for ICML formatting\n",
        "plt.rcParams.update({\n",
        "    \"axes.labelsize\": 10,                # Label font size\n",
        "    \"font.size\": 10,                     # General font size\n",
        "    \"legend.fontsize\": 8,                # Legend font size\n",
        "    \"xtick.labelsize\": 8,                # X-axis tick font size\n",
        "    \"ytick.labelsize\": 8,                # Y-axis tick font size\n",
        "    \"figure.dpi\": 300,                   # Ensure high-resolution figures\n",
        "    \"savefig.dpi\": 300,                  # Save high-resolution figures\n",
        "    # \"savefig.format\": \"pdf\",             # Save in PDF format (ICML requirement)\n",
        "    \"figure.figsize\": (8, 3),       # ICML-compatible figure size (in inches)\n",
        "    \"lines.linewidth\": 1.0,              # Line width\n",
        "    \"lines.markersize\": 4,               # Marker size\n",
        "})"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "####Note: corr is the parameter $\\rho$ described in the paper."
      ],
      "metadata": {
        "id": "eQYOpvmxpD1c"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "x = 1/9\n",
        "num_arms=4\n",
        "N = 100000\n",
        "T = 100000\n",
        "num_tries = 40\n",
        "gamma =1.5\n",
        "for corr in [0.4, 0.6, 0.8]:\n",
        "  means_1 = [0.5, 0.7, 0.8, 0.2]\n",
        "  means_2 = np.clip(means_1+0.2*np.random.rand(num_arms)-0.1,0,1)\n",
        "  cov_matrix = np.array([[5,4.5*corr],[4.5*corr,4]])\n",
        "  new_dataset = np.array([ np.clip(st.multivariate_normal.rvs([means_1[i], means_2[i]], cov_matrix, size = N), -4, 6) for i in range(num_arms)])\n",
        "  new_dataset = np.swapaxes(new_dataset,0,1)\n",
        "  R = 10\n",
        "  ucb_regs=[]\n",
        "  ucbv_regs=[]\n",
        "  cuucb_regs=[]\n",
        "\n",
        "  for t in range(num_tries):\n",
        "      alg1=UCB_weak(range(num_arms), R, x, 1)\n",
        "      alg_regret = runAlg(alg1,new_dataset,num_arms,T)\n",
        "      ucb_regs.append(np.array(np.cumsum(alg_regret)))\n",
        "\n",
        "      alg2 = UCBV_weak(range(num_arms), R, x, 0.5)\n",
        "      alg_regret = runAlg(alg2,new_dataset,num_arms,T)\n",
        "      ucbv_regs.append(np.array(np.cumsum(alg_regret)))\n",
        "\n",
        "      alg3=CUUCB(range(num_arms),R, x, gamma, 0.5)\n",
        "      alg_regret = runAlg(alg3,new_dataset,num_arms,T)\n",
        "      cuucb_regs.append(np.array(np.cumsum(alg_regret)))\n",
        "\n",
        "  ucb_regs = np.array(ucb_regs)\n",
        "  ucbv_regs = np.array(ucbv_regs)\n",
        "  cuucb_regs = np.array(cuucb_regs)\n",
        "  ucb_mean = np.mean(ucb_regs,axis=0)\n",
        "  ucbv_mean = np.mean(ucbv_regs,axis=0)\n",
        "  cuucb_mean = np.mean(cuucb_regs,axis=0)\n",
        "  np.savez('data_corr_'+str(corr)+'.npz', ucb = ucb_mean, ucbv = ucbv_mean, cuucb=cuucb_mean)\n",
        "\n",
        "  plt.figure(figsize=(7, 5))\n",
        "  plt.plot(np.mean(ucb_regs,axis=0),'lightcoral', linewidth=1.5, label='UCB')\n",
        "  plt.plot(np.mean(ucbv_regs,axis=0),'darkorange',linewidth=1.5, label='UCB-V')\n",
        "  plt.plot(np.mean(cuucb_regs,axis=0),'dodgerblue', linestyle = '--', linewidth=1.5, label='CUUCB')\n",
        "  plt.legend()\n",
        "  plt.title('$\\\\rho$ = '+str(corr))\n",
        "  plt.xlabel('time')\n",
        "  plt.ylabel('Regret')\n",
        "  plt.ylim(0,17000)\n",
        "  # To set the maximum number of ticks on the X-axis to 4 (nbins=3)\n",
        "  plt.locator_params(axis='x', nbins=5)\n",
        "  # To set the maximum number of ticks on the Y-axis to 5 (nbins=4)\n",
        "  plt.locator_params(axis='y', nbins=4)\n",
        "  filename = 'corr_'+str(corr)+'.png'\n",
        "  plt.savefig(filename, dpi=300, bbox_inches='tight')\n",
        "  plt.close()"
      ],
      "metadata": {
        "id": "1BpG8uZQoXPt"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "corr = 0.8\n",
        "data = np.load('data_corr_'+str(corr)+'.npz')\n",
        "plt.figure(figsize=(7, 5))\n",
        "plt.plot(data['ucb'],'lightcoral', linewidth=1.5, label='UCB')\n",
        "plt.plot(data['ucbv'],'darkorange',linewidth=1.5, label='UCB-V')\n",
        "plt.plot(data['cuucb'],'dodgerblue', linestyle = '--', linewidth=1.5, label='CUUCB')\n",
        "plt.legend()\n",
        "plt.title('$\\\\rho$ = '+str(corr))\n",
        "plt.xlabel('time')\n",
        "plt.ylabel('Regret')\n",
        "plt.ylim(0,15000)\n",
        "# To set the maximum number of ticks on the X-axis to 4 (nbins=3)\n",
        "plt.locator_params(axis='x', nbins=5)\n",
        "# To set the maximum number of ticks on the Y-axis to 5 (nbins=4)\n",
        "plt.locator_params(axis='y', nbins=4)\n",
        "filename = 'corr_'+str(corr)+'.png'\n",
        "plt.savefig(filename, dpi=300, bbox_inches='tight')\n",
        "plt.close()"
      ],
      "metadata": {
        "id": "GhsColwzZKan"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}