{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Import Dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plot function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def comp_reg(reg_array, avg, ub, lb,label_array, name, per = 5, save = False):\n",
    "    \n",
    "    plt.rcParams[\"font.family\"] = \"sans-serif\"\n",
    "    plt.figure(figsize = (12,4))\n",
    "    plt.subplot(1,2,1)\n",
    "    arms =  np.arange(1,1+avg.size,1)\n",
    "    plt.plot(arms,avg,'ro',label='Mean')\n",
    "    plt.plot(arms,ub,'gs',label='Upper bound')\n",
    "    plt.plot(arms,lb,'bs',label='Lower bound')\n",
    "    plt.vlines(arms, ymin=ub, ymax=lb, color = 'k')\n",
    "    plt.xticks(arms)\n",
    "    plt.xlabel('Arms', fontsize=14)\n",
    "    plt.ylabel('Reward', fontsize=14)\n",
    "    plt.legend(fontsize=15)\n",
    "    plt.grid()\n",
    "    \n",
    "    num = len(reg_array)\n",
    "    reg_median, reg_mean, reg_std, reg_lower, reg_upper = dict(),dict(),dict(),dict(),dict()\n",
    "    for i in range(num):\n",
    "        reg_median[i] = np.median(reg_array[i], axis = 0)\n",
    "        reg_mean[i] = np.mean(np.array(reg_array[i]),axis = 0)\n",
    "        reg_std[i] = np.std(np.array(reg_array[i]), axis=0)\n",
    "        reg_upper[i] = np.percentile(np.array(reg_array[i]), min(100-per,per), axis = 0) \n",
    "        reg_lower[i] = np.percentile(np.array(reg_array[i]), max(per,100-per), axis = 0)\n",
    "    col_mean = ['lightcoral','limegreen','tan','royalblue']\n",
    "    col_med = ['lightcoral','limegreen','tan','royalblue']\n",
    "    col_bet = ['mistyrose','palegreen','blanchedalmond','lavender']\n",
    "    plt.subplot(1,2,2)\n",
    "    for i in [0,1,2,3]:\n",
    "        #plt.plot(np.arange(len(reg_median[i])), reg_mean[i], col_mean[i], linewidth = 3.0, label = label_array[i])\n",
    "        #plt.plot(np.arange(len(reg_median[i])), reg_lower[i], color = col_med[i], alpha = 0.7)\n",
    "        #plt.plot(np.arange(len(reg_median[i])), reg_upper[i], color = col_med[i], alpha = 0.7)\n",
    "        #plt.fill_between(np.arange(len(reg_median[i])), reg_lower[i], reg_upper[i], color = col_bet[i], alpha = 0.5)\n",
    "        plt.errorbar(np.arange(len(reg_std[i])), reg_mean[i], yerr=reg_std[i], errorevery=10000, capsize=10, fmt=col_mean[i], linewidth=3.0,label=label_array[i])\n",
    "    plt.xlabel('Time', fontsize=14)\n",
    "    plt.ylabel('Regret', fontsize=14)\n",
    "#    plt.ylim(-5,50)\n",
    "#    plt.xlim(-5,2000)\n",
    "    plt.legend(loc = 'lower right',fontsize=15)\n",
    "    plt.grid()\n",
    "    if save:\n",
    "        name_png = './Results/'+str(name)+'.png'\n",
    "        plt.savefig(name_png,bbox_inches = 'tight',pad_inches = 0)\n",
    "    \n",
    "    plt.show()\n",
    "    plt.close()\n",
    "    return None"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_files(path,numfiles):\n",
    "    arr = []\n",
    "    label = ['B-UCB','ImprovedUCB','B-KL-UCB','GLUE']\n",
    "    for i in range(1,numfiles+1,1):\n",
    "        with open(path+'_'+str(i)+'.p', 'rb') as fp:\n",
    "                reg = pickle.load(fp)\n",
    "\n",
    "        if i-1 != 0:\n",
    "            for j in range(1,5,1):\n",
    "                arr[j-1] = np.vstack((arr[j-1], reg[j]))\n",
    "                \n",
    "        else:        \n",
    "            arr += [reg[1],reg[2],reg[3],reg[4]]\n",
    "            avg,ub,lb = reg[6],reg[7],reg[8]\n",
    "    return arr,label,avg,ub,lb"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Initializations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Enter name of the file that you want to plot into the name field below.\n",
    "path = './NEWRUN/regfile'\n",
    "numfiles = 10\n",
    "reg_array,label,avg,ub,lb = get_files(path, numfiles)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plot Regret"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "#Set save_name to be the name you want to save the file as.\n",
    "save_name = 'Academic_long'\n",
    "#Set save = True below if you want to save the output file in ./Results\n",
    "comp_reg(reg_array,avg,ub,lb,label,save_name,save = True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pickle\n",
    "\n",
    "def get_ub_lb_new(means,dist):\n",
    "    max_rew_vec = np.zeros(means.shape[0]) # \\mu^*_{z,u}\n",
    "    opt_act = np.zeros(means.shape[0]) # k^*_{z,u}\n",
    "    rew_k = np.zeros(means.shape[1]) # \\mu_z(k)\n",
    "    prob = np.zeros(means.shape[1]) # p_z(k)\n",
    "    \n",
    "    del_min = 10\n",
    "    del_max = 0\n",
    "    \n",
    "    for i in range(means.shape[0]):\n",
    "        max_rew_vec[i] = max(means[i,:])\n",
    "        opt_act[i] = np.argmax(means[i,:])\n",
    "        temp = np.sort(means[i,:])\n",
    "        del_min = min(del_min, temp[-1] - temp[-2])\n",
    "        del_max = max(del_max, temp[-1] - temp[0])\n",
    "        \n",
    "    max_rew = np.average(max_rew_vec, weights = dist) # \\mu_z\n",
    "    \n",
    "    for i in range(means.shape[1]):\n",
    "        temp1,temp2 = 0,0\n",
    "        for j in range(means.shape[0]):\n",
    "            if opt_act[j] == i:\n",
    "                temp1 += max_rew_vec[j]*dist[j]\n",
    "                temp2 += dist[j]\n",
    "        rew_k[i], prob[i] = temp1, temp2\n",
    "    \n",
    "    ub, lb = np.zeros(means.shape[1]), np.zeros(means.shape[1])\n",
    "    \n",
    "    for i in range(means.shape[1]):\n",
    "        ub[i] = max_rew - del_min*(1-prob[i])\n",
    "        \n",
    "        set_greater = []\n",
    "        for j in range(means.shape[1]):\n",
    "            if j != i and rew_k[j]> del_max*prob[j]: set_greater += [j]\n",
    "        acc = rew_k[i]\n",
    "        for j in set_greater:\n",
    "            acc += rew_k[j] - del_max*prob[j] \n",
    "        lb[i] = acc\n",
    "        \n",
    "    return ub,lb\n",
    "\n",
    "\n",
    "def filter_data(data):\n",
    "    for key in data.keys():\n",
    "        rew_mat = data[key][0]\n",
    "        dist = (1/sum(data[key][1]))*data[key][1]\n",
    "        \n",
    "        mean = np.zeros(rew_mat.shape[1])\n",
    "        for i in range(len(mean)):\n",
    "            mean[i] = np.average(rew_mat[:,i],weights = dist)\n",
    "        sort_ind = np.argsort(mean)[::-1]\n",
    "        #rew_mat = np.delete(rew_mat, sort_ind[2:-15],1)\n",
    "        ids = []\n",
    "        np.random.seed(10)\n",
    "        while sort_ind[0] not in ids:\n",
    "            ids = np.random.choice(np.arange(rew_mat.shape[1]),  size=20, replace=False)\n",
    "        del_ids = [i for i in range(rew_mat.shape[1]) if i not in ids]\n",
    "        rew_mat = np.delete(rew_mat, del_ids, 1)\n",
    "        np.random.seed()\n",
    "        for i in range(rew_mat.shape[0]):\n",
    "            sort_ind = np.argsort(rew_mat[i,:])[::-1]\n",
    "            while rew_mat[i,sort_ind[0]]==rew_mat[i,sort_ind[1]]:\n",
    "                rew_mat[i,sort_ind[1]] -= 0.001\n",
    "                sort_ind = np.argsort(rew_mat[i,:])[::-1]\n",
    "        ub, lb = get_ub_lb_new(rew_mat,dist)\n",
    "        \n",
    "        mean = np.zeros(rew_mat.shape[1])\n",
    "        for i in range(len(mean)):\n",
    "            mean[i] = np.average(rew_mat[:,i],weights = dist)\n",
    "            \n",
    "        data[key] = [rew_mat, mean, dist, ub, lb]\n",
    "    return data\n",
    "                \n",
    "feature = 3\n",
    "with open('./Files/Movielens_data_'+str(feature)+'_hidden.p','rb') as f:\n",
    "    data_raw = pickle.load(f)\n",
    "data = filter_data(data_raw)\n",
    "key = 'student'\n",
    "true_means = data[key][0]\n",
    "avg = data[key][1]\n",
    "dist = data[key][2]\n",
    "ub = data[key][3]\n",
    "lb = data[key][4]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "avg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t_avg = np.average(true_means, 0, dist)\n",
    "t_avg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lmax = max(lb)\n",
    "lmax"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pruned_ids = [i for i in range(len(avg)) if ub[i]<lmax]\n",
    "pruned_ids"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "max_rew = max(avg)\n",
    "max_rew"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ub"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "k2 = [i for i in range(len(avg)) if max_rew<=ub[i]]\n",
    "k2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.optimize import brentq\n",
    "def kl(p,q):\n",
    "    return p*np.log(p/q) + (1-p)*np.log((1-p)/(1-q))\n",
    "\n",
    "def func(x, m):\n",
    "        fx = m*np.exp(x*(1-m)) + (1-m)*np.exp(-x*m)\n",
    "        gx = m*(1-m)*(np.exp((1-m)*x) - np.exp(-m*x))/fx - (2/x)*np.log(fx)\n",
    "        return gx\n",
    "    \n",
    "def get_sg(m):\n",
    "    if m > 0.5: \n",
    "        m = 1-m\n",
    "    x = brentq(func, 0.0005, 100, args = m)\n",
    "    sg = (2/x**2)*np.log(m*np.exp((1-m)*x) + (1-m)*np.exp(-m*x))\n",
    "    return sg\n",
    "\n",
    "asym_bucb = [2*0.25/(max_rew -avg[i]) for i in range(len(avg)) if avg[i]!=max_rew]\n",
    "print(len(asym_bucb))\n",
    "asym_improveducb = [2*0.25/(max_rew -avg[i]) for i in range(len(avg)) if avg[i]!=max_rew]\n",
    "print(len(asym_improveducb))\n",
    "asym_bkl = [(max_rew-avg[i])/kl(avg[i],max_rew) for i in range(len(avg)) if avg[i]!=max_rew]\n",
    "print(len(asym_bkl))\n",
    "psi = get_sg(lmax)\n",
    "asym_glue =  [2*(psi)/(max_rew-avg[i]) for i in range(len(avg)) if avg[i]!=max_rew]       \n",
    "print(len(asym_glue))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f'BUCB: {sum(asym_bucb)}')\n",
    "print(f'ImprovedUCB: {sum(asym_improveducb)}')\n",
    "print(f'BKL: {sum(asym_bkl)}')\n",
    "print(f'GLUE: {sum(asym_glue)}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
