{
  "cells": [
    {
      "cell_type": "code",
      "source": [
        "import numpy as np\n",
        "import pickle\n",
        "import matplotlib.pyplot as plt\n",
        "import scipy.stats as st\n",
        "\n",
        "class UCB_strong:\n",
        "\n",
        "    def __init__(self,arms,R, x):\n",
        "        self.num_arms=len(arms)\n",
        "        self.means=[0 for arm in arms]\n",
        "        self.R=R\n",
        "        self.pull_counts=[0 for arm in arms]\n",
        "        self.delta=0.01\n",
        "        self.eps = 0.1\n",
        "        self.bounds=[R for arm in arms]\n",
        "\n",
        "    def choose_arm(self):\n",
        "        max_val = np.max(self.bounds)\n",
        "        arm = np.random.choice(np.where(self.bounds == max_val)[0])\n",
        "        return arm\n",
        "\n",
        "    def update_params(self,reward,arm, t):\n",
        "        self.pull_counts[arm]+=1\n",
        "        self.means[arm]=(self.means[arm]*(self.pull_counts[arm]-1)+reward[0])/(self.pull_counts[arm])\n",
        "        for arm1 in range(self.num_arms):\n",
        "          if self.pull_counts[arm1] != 0:\n",
        "            term = np.log(1+t)\n",
        "            self.bounds[arm1] = min(self.means[arm]+self.R*np.sqrt(2*term/self.pull_counts[arm]), self.R)\n",
        "\n",
        "class UCBV_strong:\n",
        "    def __init__(self,arms,R, x):\n",
        "        self.num_arms=len(arms)\n",
        "        self.means=[0 for arm in arms]\n",
        "        self.R=R\n",
        "        self.pull_counts=[0 for arm in arms]\n",
        "        self.sq_rew_avg = [0 for arm in arms]\n",
        "        self.bounds=[R for arm in arms]\n",
        "\n",
        "    def choose_arm(self):\n",
        "        max_val = np.max(self.bounds)\n",
        "        arm = np.random.choice(np.where(self.bounds == max_val)[0])\n",
        "        return arm\n",
        "\n",
        "    def update_params(self,reward,arm, t):\n",
        "        self.pull_counts[arm]+=1\n",
        "        self.means[arm] += (reward[0]-self.means[arm])/(self.pull_counts[arm])\n",
        "        self.sq_rew_avg[arm] += (reward[0]**2-self.sq_rew_avg[arm])/(self.pull_counts[arm])\n",
        "        for arm1 in range(self.num_arms):\n",
        "          if self.pull_counts[arm1] != 0:\n",
        "            term = np.log(1+t)\n",
        "            var = self.sq_rew_avg[arm1]-self.means[arm1]**2\n",
        "            if var < 0:\n",
        "              var = 0\n",
        "            bound = np.sqrt(2*var*term/self.pull_counts[arm1])+3*self.R*term/self.pull_counts[arm1]\n",
        "            self.bounds[arm1] = min(self.means[arm1]+bound, self.R)#, self.means[arm1]+2*self.R*np.sqrt(term/self.pull_counts[arm1]),\n",
        "\n",
        "class UCB_weak:\n",
        "\n",
        "    def __init__(self,arms,R,x):\n",
        "        self.num_arms=len(arms)\n",
        "        self.true_means=[0 for arm in arms]\n",
        "        self.count_gt=[0 for arm in arms]\n",
        "        self.count_c=[0 for arm in arms]\n",
        "        self.R=R\n",
        "        self.x=x\n",
        "        self.bounds=[R for arm in arms]\n",
        "\n",
        "    def choose_arm(self):\n",
        "        max_val = np.max(self.bounds)\n",
        "        arm = np.random.choice(np.where(self.bounds == max_val)[0])\n",
        "        return arm\n",
        "\n",
        "    def update_params(self,reward,arm,t):\n",
        "        if self.count_gt[arm]<=self.x*self.count_c[arm]:\n",
        "            self.count_gt[arm]+=1\n",
        "            delta_g = reward[0]-self.true_means[arm]\n",
        "            self.true_means[arm] += delta_g/self.count_gt[arm]\n",
        "        else:\n",
        "            self.count_c[arm]+=1\n",
        "        for arm in range(self.num_arms):\n",
        "          if self.count_gt[arm] == 0:\n",
        "              self.bounds[arm] = self.R\n",
        "          else:\n",
        "            term = np.log(1+t)\n",
        "            self.bounds[arm] = min(self.true_means[arm]+2*self.R*np.sqrt(term/self.count_gt[arm]), self.R)\n",
        "\n",
        "class UCBV_coarse:\n",
        "\n",
        "    def __init__(self,arms,R,x, b = 3):\n",
        "        self.num_arms=len(arms)\n",
        "        self.sum_rew=[0 for arm in arms]\n",
        "        self.sum_sq_rew = [0 for arm in arms]\n",
        "        self.count=[0 for arm in arms]\n",
        "        self.R=R\n",
        "        self.b = b\n",
        "        self.bounds=[R for arm in arms]\n",
        "\n",
        "    def choose_arm(self):\n",
        "        max_val = np.max(self.bounds)\n",
        "        arm = np.random.choice(np.where(self.bounds == max_val)[0])\n",
        "        return arm\n",
        "\n",
        "    def update_params(self,reward,arm, t):\n",
        "        self.count[arm]+=1\n",
        "        self.sum_rew[arm] += reward[1]\n",
        "        self.sum_sq_rew[arm] += reward[1]**2\n",
        "\n",
        "        for arm1 in range(self.num_arms):\n",
        "          if self.count[arm1] == 0:\n",
        "              self.bounds[arm1] = self.R\n",
        "          else:\n",
        "            term = np.log(1+t)\n",
        "            mean = self.sum_rew[arm1]/self.count[arm1]\n",
        "            var = self.sum_sq_rew[arm1]/self.count[arm1]-mean**2\n",
        "            if var < 0:\n",
        "              var = 0\n",
        "            bound = np.sqrt(2*var*term/self.count[arm1])+self.b*self.R*term/self.count[arm1]\n",
        "            self.bounds[arm1] = min(mean+bound, self.R)\n",
        "\n",
        "class UCBV_weak:\n",
        "\n",
        "    def __init__(self,arms,R,x,b = 3):\n",
        "        self.num_arms=len(arms)\n",
        "        self.sum_rew_gt=[0 for arm in arms]\n",
        "        self.count_gt=[0 for arm in arms]\n",
        "        self.sum_sq_rew_gt = [0 for arm in arms]\n",
        "        self.count_c=[0 for arm in arms]\n",
        "        self.R=R\n",
        "        self.b = b\n",
        "        self.x=x\n",
        "        self.bounds=[R for arm in arms]\n",
        "\n",
        "    def choose_arm(self):\n",
        "        max_val = np.max(self.bounds)\n",
        "        arm = np.random.choice(np.where(self.bounds == max_val)[0])\n",
        "        return arm\n",
        "\n",
        "    def update_params(self,reward,arm, t):\n",
        "        if self.count_gt[arm]<=self.x*self.count_c[arm]:\n",
        "            self.count_gt[arm]+=1\n",
        "            self.sum_rew_gt[arm] += reward[0]\n",
        "            self.sum_sq_rew_gt[arm] += reward[0]**2\n",
        "        else:\n",
        "            self.count_c[arm]+=1\n",
        "\n",
        "        for arm1 in range(self.num_arms):\n",
        "          if self.count_gt[arm1] == 0:\n",
        "              self.bounds[arm1] = self.R\n",
        "          else:\n",
        "            term = np.log(1+t)\n",
        "            mean = self.sum_rew_gt[arm1]/self.count_gt[arm1]\n",
        "            var = self.sum_sq_rew_gt[arm1]/self.count_gt[arm1]-mean**2\n",
        "            if var < 0:\n",
        "              var = 0\n",
        "            bound = np.sqrt(2*var*term/self.count_gt[arm1])+self.b*self.R*term/self.count_gt[arm1]\n",
        "            self.bounds[arm1] = min(mean+bound, self.R)\n",
        "\n",
        "class CUUCB:\n",
        "\n",
        "    def __init__(self,arms,R,x, gamma = 1, b = 3):\n",
        "        self.num_arms=len(arms)\n",
        "        self.sum_rew_gt=[0 for arm in arms]\n",
        "        self.sum_rew_c=[0 for arm in arms]\n",
        "        self.sum_rew_cgt=[0 for arm in arms]\n",
        "        self.count_gt=[0 for arm in arms]\n",
        "        self.count_c=[0 for arm in arms]\n",
        "        self.sum_sq_rew_gt = [0 for arm in arms]\n",
        "        self.sum_sq_rew_c = [0 for arm in arms]\n",
        "        self.sum_sq_rew_cgt = [0 for arm in arms]\n",
        "        self.sum_prod_rew = [0 for arm in arms]\n",
        "        self.R=R\n",
        "        self.x=x\n",
        "        self.b = b\n",
        "        self.gamma = gamma\n",
        "        self.bounds=[R for arm in arms]\n",
        "\n",
        "    def choose_arm(self):\n",
        "        max_val = np.max(self.bounds)\n",
        "        arm = np.random.choice(np.where(self.bounds == max_val)[0])\n",
        "        return arm\n",
        "\n",
        "    def update_params(self,reward,arm, t):\n",
        "        if self.count_gt[arm]<=self.x*self.count_c[arm]:\n",
        "            self.count_gt[arm]+=1\n",
        "            self.sum_rew_gt[arm] += reward[0]\n",
        "            self.sum_sq_rew_gt[arm] += reward[0]**2\n",
        "            self.sum_rew_cgt[arm] += reward[1]\n",
        "            self.sum_sq_rew_cgt[arm] += reward[1]**2\n",
        "            self.sum_prod_rew[arm] += reward[0]*reward[1]\n",
        "        else:\n",
        "            self.count_c[arm]+=1\n",
        "            self.sum_rew_c[arm] += reward[1]\n",
        "            self.sum_sq_rew_c[arm] += reward[1]**2\n",
        "\n",
        "        for arm1 in range(self.num_arms):\n",
        "          if self.count_gt[arm1] == 0 or self.count_c[arm1] == 0:\n",
        "            self.bounds[arm1] = self.R\n",
        "          else:\n",
        "            term = np.log(1+t)\n",
        "            x_t = self.count_gt[arm1]/self.count_c[arm1]\n",
        "            alpha_t = x_t/(1+x_t)\n",
        "            mean_c = self.sum_rew_c[arm1]/self.count_c[arm1]\n",
        "            mean_gt = self.sum_rew_gt[arm1]/self.count_gt[arm1]\n",
        "            mean_cgt = self.sum_rew_cgt[arm1]/self.count_gt[arm1]\n",
        "            var_c = self.sum_sq_rew_c[arm1]/self.count_c[arm1]-mean_c**2\n",
        "            var_gt = self.sum_sq_rew_gt[arm1]/self.count_gt[arm1]-mean_gt**2\n",
        "            var_cgt = self.sum_sq_rew_cgt[arm1]/self.count_gt[arm1]-mean_cgt**2\n",
        "            cov = self.sum_prod_rew[arm1]/self.count_gt[arm1]-mean_gt*mean_cgt\n",
        "            var_c_hat = alpha_t*var_c + (1-alpha_t)*var_cgt\n",
        "            if var_c_hat == 0:\n",
        "              ratio = np.sign(cov)*self.gamma\n",
        "            else:\n",
        "              ratio = cov/var_c_hat\n",
        "            if ratio>=self.gamma:\n",
        "              ratio = self.gamma\n",
        "            if ratio<=-self.gamma:\n",
        "              ratio = -self.gamma\n",
        "            lam = ratio/(1+x_t)\n",
        "            var = var_gt - 2*lam*cov + (1+x_t)*var_c_hat*lam**2\n",
        "            mean_est = mean_gt - lam*(mean_cgt-mean_c)\n",
        "            if var < 0:\n",
        "              var = 0\n",
        "            bound=np.sqrt(2*var*term/self.count_gt[arm1]) + self.b*self.R*term/self.count_gt[arm1]\n",
        "            self.bounds[arm1] = min(mean_est+bound, self.R)\n",
        "\n",
        "def sample(arm,dataset):\n",
        "  N = len(dataset)\n",
        "  index=np.random.choice(N)\n",
        "  r_gt = dataset[index, arm, 0]\n",
        "  r_coarse = dataset[index, arm, 1]\n",
        "  return [r_gt,r_coarse]\n",
        "\n",
        "\n",
        "def runAlg(alg,new_dataset,num_arms,T):\n",
        "\n",
        "    regret=[]\n",
        "    true_means=[np.mean(new_dataset[:,i,0]) for i in range(num_arms)]\n",
        "    synthetic_means=[np.mean(new_dataset[:,i,1]) for i in range(num_arms)]\n",
        "    best_mean = np.max(true_means)\n",
        "    #  run algorithm\n",
        "    for t in range(T):\n",
        "        arm=alg.choose_arm()\n",
        "        regret.append(best_mean-true_means[arm])\n",
        "        alg.update_params(sample(arm,new_dataset),arm, t)\n",
        "\n",
        "    return regret\n",
        "\n",
        "\n",
        "\n",
        "# Set global Matplotlib parameters for ICML formatting\n",
        "plt.rcParams.update({\n",
        "    \"axes.labelsize\": 10,                # Label font size\n",
        "    \"font.size\": 10,                     # General font size\n",
        "    \"legend.fontsize\": 8,                # Legend font size\n",
        "    \"xtick.labelsize\": 8,                # X-axis tick font size\n",
        "    \"ytick.labelsize\": 8,                # Y-axis tick font size\n",
        "    \"figure.dpi\": 300,                   # Ensure high-resolution figures\n",
        "    \"savefig.dpi\": 300,                  # Save high-resolution figures\n",
        "    # \"savefig.format\": \"pdf\",             # Save in PDF format (ICML requirement)\n",
        "    \"figure.figsize\": (8, 3),       # ICML-compatible figure size (in inches)\n",
        "    \"lines.linewidth\": 1.0,              # Line width\n",
        "    \"lines.markersize\": 4,               # Marker size\n",
        "})"
      ],
      "metadata": {
        "id": "L_YdrioC7GRd"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "with open('gemini2.5vs1.5_LLMrankings.pickle', 'rb') as handle:\n",
        "  dataset= pickle.load(handle)\n",
        "dataset = (6-dataset)/5\n",
        "print(len(dataset))\n",
        "new_dataset = dataset"
      ],
      "metadata": {
        "id": "WIrBDEaVae8M"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "with open('llama3-1bvs3b_LLMdataset.pickle', 'rb') as handle:\n",
        "  dataset= pickle.load(handle)\n",
        "dataset = (7-dataset)/6\n",
        "print(len(dataset))\n",
        "new_dataset = dataset"
      ],
      "metadata": {
        "id": "BHBaNgmbjFnX"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "print(new_dataset.mean(axis=0))"
      ],
      "metadata": {
        "id": "5UgCtfw5UFbm"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "for i in range(6):\n",
        "  print(np.cov(new_dataset[:,i,0], new_dataset[:,i,1]))"
      ],
      "metadata": {
        "id": "vvSfMimwT82I"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "for i in range(6):\n",
        "  print(np.corrcoef(new_dataset[:,i,0], new_dataset[:,i,1]))"
      ],
      "metadata": {
        "id": "ss20EhoT2yXk"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "#LLM benchmarking\n",
        "num_arms=6\n",
        "x = 1/9\n",
        "gamma = 1.5\n",
        "num_tries=40\n",
        "T= 50000\n",
        "R= 1\n",
        "\n",
        "for x in [1/9, 1/3, 1]:\n",
        "  ucbv_regs=[]\n",
        "  cuucb_regs=[]\n",
        "  for t in range(num_tries):\n",
        "\n",
        "      alg2 = UCBV_weak(range(num_arms),R, x, 0.5)\n",
        "      alg_regret = runAlg(alg2,new_dataset,num_arms,T)\n",
        "      ucbv_regs.append(np.array(np.cumsum(alg_regret)))\n",
        "\n",
        "      alg3=CUUCB(range(num_arms),R,x, gamma, 0.5)\n",
        "      alg_regret = runAlg(alg3,new_dataset,num_arms,T)\n",
        "      cuucb_regs.append(np.array(np.cumsum(alg_regret)))\n",
        "\n",
        "  ucbv_regs = np.array(ucbv_regs)\n",
        "  cuucb_regs = np.array(cuucb_regs)\n",
        "  ucbv_mean = np.mean(ucbv_regs,axis=0)\n",
        "  cuucb_mean = np.mean(cuucb_regs,axis=0)\n",
        "  np.savez('LLM_benchmarking_alpha'+str(round(x/(1+x), 2))+'.npz', ucbv = ucbv_mean, cuucb=cuucb_mean)\n",
        "  plt.figure(figsize=(7, 5))\n",
        "  plt.plot(np.mean(ucbv_regs,axis=0),'darkorange',linewidth=1.5, label='UCB-V')\n",
        "  plt.plot(np.mean(cuucb_regs,axis=0),'dodgerblue', linestyle = '--', linewidth=1.5, label='CUUCB')\n",
        "  plt.legend()\n",
        "  plt.xlabel('time')\n",
        "  plt.ylabel('Regret')\n",
        "  # To set the maximum number of ticks on the X-axis to 4 (nbins=3)\n",
        "  plt.locator_params(axis='x', nbins=5)\n",
        "  # To set the maximum number of ticks on the Y-axis to 5 (nbins=4)\n",
        "  plt.locator_params(axis='y', nbins=4)\n",
        "  filename = 'LLM_benchmarking_alpha' + str(round(x/(1+x), 2)) + '.png'\n",
        "  plt.savefig(filename, dpi=300, bbox_inches='tight')\n",
        "  plt.close()"
      ],
      "metadata": {
        "id": "gMAPO7YanrC8"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "#Misalignment\n",
        "num_arms=6\n",
        "x = 1/9\n",
        "gamma = 1.5\n",
        "num_tries=40\n",
        "T= 120000\n",
        "R= 1\n",
        "\n",
        "for x in [3/17]:\n",
        "  # ucbv_regs=[]\n",
        "  cuucb_regs=[]\n",
        "  for t in range(num_tries):\n",
        "    # run UCBV_coarse as well to obtain linear regret due to mis-alignment\n",
        "    alg3=CUUCB(range(num_arms),R,x, gamma, 0.5)\n",
        "    alg_regret = runAlg(alg3,new_dataset,num_arms,T)\n",
        "    cuucb_regs.append(np.array(np.cumsum(alg_regret)))\n",
        "\n",
        "  # ucbv_regs = np.array(ucbv_regs)\n",
        "  cuucb_regs = np.array(cuucb_regs)\n",
        "  # ucbv_mean = np.mean(ucbv_regs,axis=0)\n",
        "  cuucb_mean = np.mean(cuucb_regs,axis=0)\n",
        "  np.savez('cuucb_alpha_'+str(round(x/(1+x),2))+'.npz', cuucb=cuucb_mean)\n",
        "  plt.figure(figsize=(7, 5))\n",
        "  plt.plot(np.mean(ucbv_regs,axis=0),'darkorange',linewidth=1.5, label='UCB-V')\n",
        "  plt.plot(np.mean(cuucb_regs,axis=0),'dodgerblue', linestyle = '--', linewidth=1.5, label='CUUCB')\n",
        "  plt.legend()\n",
        "  plt.xlabel('time')\n",
        "  plt.ylabel('Regret')\n",
        "  # To set the maximum number of ticks on the X-axis to 4 (nbins=3)\n",
        "  plt.locator_params(axis='x', nbins=5)\n",
        "  # To set the maximum number of ticks on the Y-axis to 5 (nbins=4)\n",
        "  plt.locator_params(axis='y', nbins=4)\n",
        "  filename = 'misalignment.png'\n",
        "  plt.savefig(filename, dpi=300, bbox_inches='tight')\n",
        "  plt.close()"
      ],
      "metadata": {
        "id": "OxSgXHpxZHQM"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "data = np.load('misalignment.npz')\n",
        "data_0 = np.load('cuucb_alpha_0.05.npz')['cuucb']\n",
        "data_2 = np.load('cuucb_alpha_0.15.npz')['cuucb']\n",
        "plt.figure(figsize=(7, 5))\n",
        "plt.plot(data['ucbv'][:100000],'darkorange',linewidth=1.5, label='UCB-V')\n",
        "\n",
        "plt.plot(data_0[:100000],'lightcoral', linewidth=1.5, linestyle = '--', label='CUUCB ($\\\\alpha$ = 0.05)')\n",
        "plt.plot(data['cuucb'][:100000],'dodgerblue', linestyle = '--', linewidth=1.5, label='CUUCB ($\\\\alpha$ = 0.1)')\n",
        "plt.plot(data_2[:100000],'purple', linewidth=1.5,linestyle = '--', label='CUUCB ($\\\\alpha$ = 0.15)')\n",
        "plt.legend()\n",
        "plt.xlabel('time')\n",
        "plt.ylabel('Regret')\n",
        "# To set the maximum number of ticks on the X-axis to 4 (nbins=3)\n",
        "plt.locator_params(axis='x', nbins=5)\n",
        "# To set the maximum number of ticks on the Y-axis to 5 (nbins=4)\n",
        "plt.locator_params(axis='y', nbins=4)\n",
        "filename = 'misalignment.png'\n",
        "plt.savefig(filename, dpi=300, bbox_inches='tight')\n",
        "plt.close()"
      ],
      "metadata": {
        "id": "cNOoB13Wx9pr"
      },
      "execution_count": null,
      "outputs": []
    }
  ],
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}