{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pickle\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_instance(key, avg, ub, lb,save):\n",
    "    plt.figure(figsize = (8,5))\n",
    "    arms =  np.arange(1,1+avg.size,1)\n",
    "    plt.plot(arms,avg,'ro',label='Mean')\n",
    "    plt.plot(arms,ub,'gs',label='Upper bound')\n",
    "    plt.plot(arms,lb,'bs',label='Lower bound')\n",
    "    plt.vlines(arms, ymin=ub, ymax=lb, color = 'k')\n",
    "    plt.xticks(arms)\n",
    "    plt.xlabel('Arms', fontsize=14)\n",
    "    plt.ylabel('Reward', fontsize=14)\n",
    "    plt.legend()\n",
    "    plt.plot()\n",
    "    if save: plt.savefig('./Results/Instance_'+str(key)+'.png')\n",
    "    plt.close()\n",
    "    return None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def filter_data(data,visualize, save):\n",
    "    for key in data.keys():\n",
    "        rew_mat = data[key][0]\n",
    "        dist = (1/sum(data[key][1]))*data[key][1]\n",
    "        \n",
    "        mean = np.zeros(rew_mat.shape[1])\n",
    "        for i in range(len(mean)):\n",
    "            mean[i] = np.average(rew_mat[:,i],weights = dist)\n",
    "        sort_ind = np.argsort(mean)[::-1]\n",
    "        #rew_mat = np.delete(rew_mat, sort_ind[2:-15],1)\n",
    "        ids = []\n",
    "        np.random.seed(200)\n",
    "        #while sort_ind[0] not in ids:\n",
    "        ids = np.random.choice(np.arange(rew_mat.shape[1]),  size=15, replace=False)\n",
    "        del_ids = [i for i in range(rew_mat.shape[1]) if i not in ids]\n",
    "        rew_mat = np.delete(rew_mat, del_ids, 1)\n",
    "        np.random.seed()\n",
    "        for i in range(rew_mat.shape[0]):\n",
    "            sort_ind = np.argsort(rew_mat[i,:])[::-1]\n",
    "            while rew_mat[i,sort_ind[0]]==rew_mat[i,sort_ind[1]]:\n",
    "                rew_mat[i,sort_ind[1]] -= 0.001\n",
    "                sort_ind = np.argsort(rew_mat[i,:])[::-1]\n",
    "        ub, lb = get_ub_lb_new(rew_mat,dist)\n",
    "        \n",
    "        mean = np.zeros(rew_mat.shape[1])\n",
    "        for i in range(len(mean)):\n",
    "            mean[i] = np.average(rew_mat[:,i],weights = dist)\n",
    "            \n",
    "        data[key] = [rew_mat, mean, dist, ub, lb]\n",
    "        \n",
    "    if visualize:\n",
    "\n",
    "        for key in data.keys():\n",
    "            print('Key = ',key)\n",
    "            m = data[key][1]\n",
    "            u = data[key][-2]\n",
    "            l = data[key][-1]\n",
    "            arms = range(1, 1+len(m), 1)\n",
    "            plt.plot(arms,m,'ro',label='Mean')\n",
    "            plt.plot(arms,u,'gs',label='Upper bound')\n",
    "            plt.plot(arms,l,'bs',label='Lower bound')\n",
    "            plt.vlines(arms, ymin=u, ymax=l, color = 'k')\n",
    "            plt.xlabel('Arms', fontsize=14)\n",
    "            plt.ylabel('Reward', fontsize=14)\n",
    "            plt.legend()\n",
    "            plt.grid()\n",
    "            if save: plt.savefig('./Results/Instance_'+str(key)+'.png',bbox_inches='tight', pad_inches=0)\n",
    "            plt.show()\n",
    "            print('Largest Lower Bound = ', max(l))\n",
    "            idx = np.argsort(m)[::-1]\n",
    "            print('Minimum Gap = ',m[idx[0]]-m[idx[1]])\n",
    "            plt.close()\n",
    "            \n",
    "    return data\n",
    "    \n",
    "    \n",
    "def get_ub_lb_new(means,dist):\n",
    "    max_rew_vec = np.zeros(means.shape[0]) # \\mu^*_{z,u}\n",
    "    opt_act = np.zeros(means.shape[0]) # k^*_{z,u}\n",
    "    rew_k = np.zeros(means.shape[1]) # \\mu_z(k)\n",
    "    prob = np.zeros(means.shape[1]) # p_z(k)\n",
    "    \n",
    "    del_min = 10\n",
    "    del_max = 0\n",
    "    \n",
    "    for i in range(means.shape[0]):\n",
    "        max_rew_vec[i] = max(means[i,:])\n",
    "        opt_act[i] = np.argmax(means[i,:])\n",
    "        temp = np.sort(means[i,:])\n",
    "        del_min = min(del_min, temp[-1] - temp[-2])\n",
    "        del_max = max(del_max, temp[-1] - temp[0])\n",
    "        \n",
    "    max_rew = np.average(max_rew_vec, weights = dist) # \\mu_z\n",
    "    \n",
    "    for i in range(means.shape[1]):\n",
    "        temp1,temp2 = 0,0\n",
    "        for j in range(means.shape[0]):\n",
    "            if opt_act[j] == i:\n",
    "                temp1 += max_rew_vec[j]*dist[j]\n",
    "                temp2 += dist[j]\n",
    "        rew_k[i], prob[i] = temp1, temp2\n",
    "    \n",
    "    ub, lb = np.zeros(means.shape[1]), np.zeros(means.shape[1])\n",
    "    \n",
    "    for i in range(means.shape[1]):\n",
    "        ub[i] = max_rew - del_min*(1-prob[i])\n",
    "        \n",
    "        set_greater = []\n",
    "        for j in range(means.shape[1]):\n",
    "            if j != i and rew_k[j]> del_max*prob[j]: set_greater += [j]\n",
    "        acc = rew_k[i]\n",
    "        for j in set_greater:\n",
    "            acc += rew_k[j] - del_max*prob[j] \n",
    "        lb[i] = acc\n",
    "        \n",
    "    return ub,lb               "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "feature = 3\n",
    "with open('./Files/Movielens_data_'+str(feature)+'_hidden.p','rb') as f:\n",
    "    data_raw = pickle.load(f)\n",
    "data = filter_data(data_raw,True, True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
