{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pickle\n",
    "import matplotlib.pyplot as plt\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_single_regret(path, savename):\n",
    "    r1, r2, _, _, avg, ub, lb = get_files(path, True)\n",
    "    plt.figure(figsize = (14,6))\n",
    "    plt.subplot(1,2,1)\n",
    "    arms =  np.arange(1,1+len(avg),1)\n",
    "    plt.plot(arms,avg,'ro',label='Mean')\n",
    "    plt.plot(arms,ub,'gs',label='Upper bound')\n",
    "    plt.plot(arms,lb,'bs',label='Lower bound')\n",
    "    plt.vlines(arms, ymin=ub, ymax=lb, color = 'k')\n",
    "    plt.xticks(arms)\n",
    "    plt.xlabel('Arms', fontsize=14)\n",
    "    plt.ylabel('Reward', fontsize=14)\n",
    "    plt.legend()\n",
    "    plt.plot()\n",
    "    plt.grid()\n",
    "   \n",
    "    plt.subplot(1,2,2)\n",
    "    mean_reg1 = np.mean(r1[1:,:], axis=0)\n",
    "    std_reg1 = np.std(r1[1:,:], axis=0)\n",
    "    mean_reg2 = np.mean(r2[1:,:], axis=0)\n",
    "    std_reg2 = np.std(r2[1:,:], axis=0)\n",
    "    plt.rcParams[\"font.family\"] = \"sans-serif\"\n",
    "    plt.errorbar(np.arange(len(std_reg1)), mean_reg1, yerr=std_reg1, errorevery=int(500), capsize=10, linewidth=2.5, label='OFUL')\n",
    "    plt.errorbar(np.arange(len(std_reg2)), mean_reg2, yerr=std_reg2, errorevery=int(500), capsize=10, linewidth=2.5, label='R-OFUL')\n",
    "    plt.grid()\n",
    "    plt.legend(fontsize=15)\n",
    "    plt.xticks(fontsize=12)\n",
    "    plt.yticks(fontsize=12)\n",
    "    plt.xlabel('Rounds', fontsize=15)\n",
    "    plt.ylabel('Cumulative Regret', fontsize=15)\n",
    "    name = savename+'REGRET.png'\n",
    "    plt.savefig(name, bbox_inches = 'tight', pad_inches = 0.05)\n",
    "    plt.show()\n",
    "    plt.close()\n",
    "\n",
    "def plot_single_time(path, savename):\n",
    "    _, _, t1, t2, _, _, _ = get_files(path, False)\n",
    "    time_traj1 = np.cumsum(t1, axis=1)\n",
    "    time_traj2 = np.cumsum(t2, axis=1)\n",
    "    mean_traj1 = np.mean(time_traj1, axis=0)\n",
    "    print(f'Mean time to finish 5000 iterations LinUCB: {mean_traj1[-1]}')\n",
    "    mean_traj2 = np.mean(time_traj2, axis=0)\n",
    "    print(f'Mean time to finish 5000 iterations LinUCB: {mean_traj2[-1]}')\n",
    "    std_traj1 = np.std(time_traj1, axis=0)\n",
    "    std_traj2 = np.std(time_traj2, axis=0)\n",
    "    plt.rcParams[\"font.family\"] = \"sans-serif\"\n",
    "    plt.errorbar(np.arange(len(std_traj1)), mean_traj1, yerr=std_traj1, errorevery=int(500), capsize=10, linewidth=2.5, label='OFUL')\n",
    "    plt.errorbar(np.arange(len(std_traj2)), mean_traj2, yerr=std_traj2, errorevery=int(500), capsize=10, linewidth=2.5, label='R-OFUL')\n",
    "    plt.grid()\n",
    "    plt.legend(fontsize=15)\n",
    "    plt.xticks(fontsize=12)\n",
    "    plt.yticks(fontsize=12)\n",
    "    plt.xlabel('Rounds', fontsize=15)\n",
    "    plt.ylabel('Cumulative Time to compute in seconds', fontsize=15)\n",
    "    name = savename+'TIME.png'\n",
    "    plt.savefig(name, bbox_inches = 'tight', pad_inches = 0.05) \n",
    "    plt.show()\n",
    "    plt.close()\n",
    "\n",
    "def get_files(path, verbose):\n",
    "    files = [os.path.join(path, f) for f in os.listdir(path) if os.path.isfile(os.path.join(path, f))]\n",
    "    for i in range(len(files)):\n",
    "        with open(files[i], 'rb') as f:\n",
    "            reg_dict = pickle.load(f)\n",
    "        if i == 0:\n",
    "            r1 = reg_dict['reg1'].astype(float)\n",
    "            r2 = reg_dict['reg2'].astype(float)\n",
    "            t1 = np.asarray(reg_dict['time_l1']).reshape(r1.shape)\n",
    "            t2 = np.asarray(reg_dict['time_l2']).reshape(r2.shape)\n",
    "        else:\n",
    "            r1 = np.vstack((r1, reg_dict['reg1'].astype(float)))\n",
    "            r2 = np.vstack((r2, reg_dict['reg2'].astype(float)))\n",
    "            temp1 = np.asarray(reg_dict['time_l1']).reshape(reg_dict['reg1'].shape)\n",
    "            t1 = np.vstack((t1, temp1))\n",
    "            temp2 = np.asarray(reg_dict['time_l2']).reshape(reg_dict['reg2'].shape)\n",
    "            t2 = np.vstack((t2, temp2))\n",
    "    theta_z, theta_u, arms_z, arms_u = reg_dict['theta_z'], reg_dict['theta_u'], reg_dict['arms_z'], reg_dict['arms_u']\n",
    "    print(len(arms_z), len(arms_u), theta_z.shape, theta_u.shape)\n",
    "    avg = [np.dot(theta_u.T, arms_u[i]) for i in range(len(arms_z))]\n",
    "    ub, lb = reg_dict['ub'], reg_dict['lb']\n",
    "    avg = np.asarray(avg).reshape(len(arms_z))\n",
    "    ub, lb = np.asarray(ub).reshape(len(arms_z)), np.asarray(lb).reshape(len(arms_z))\n",
    "    if verbose:\n",
    "        print('Total number of arms:', len(reg_dict['arms_z']))\n",
    "    return r1, r2, t1, t2, avg, ub, lb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = './Files/LinBanditRun4'\n",
    "#10 arms\n",
    "#save_name = 'Setting2'\n",
    "plot_single_regret(path, save_name)\n",
    "plot_single_time(path, save_name)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
