{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pickle\n",
    "import matplotlib.pyplot as plt\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_single_regret(path, savename):\n",
    "    r1, r2, _, _ = get_files(path, True)\n",
    "    mean_reg1 = np.mean(r1[1:,:], axis=0)\n",
    "    std_reg1 = np.std(r1[1:,:], axis=0)\n",
    "    mean_reg2 = np.mean(r2[1:,:], axis=0)\n",
    "    std_reg2 = np.std(r2[1:,:], axis=0)\n",
    "    plt.rcParams[\"font.family\"] = \"sans-serif\"\n",
    "    plt.errorbar(np.arange(len(std_reg1)), mean_reg1, yerr=std_reg1, errorevery=int(500), capsize=10, linewidth=2.5, label='OFUL')\n",
    "    plt.errorbar(np.arange(len(std_reg2)), mean_reg2, yerr=std_reg2, errorevery=int(500), capsize=10, linewidth=2.5, label='R-OFUL')\n",
    "    plt.grid()\n",
    "    plt.legend(fontsize=15)\n",
    "    plt.xticks(fontsize=12)\n",
    "    plt.yticks(fontsize=12)\n",
    "    plt.xlabel('Rounds', fontsize=15)\n",
    "    plt.ylabel('Cumulative Regret', fontsize=15)\n",
    "    name = savename+'REGRET.png'\n",
    "    plt.savefig(name, bbox_inches = 'tight', pad_inches = 0.05)\n",
    "    plt.show()\n",
    "    plt.close()\n",
    "\n",
    "def plot_single_time(path, savename):\n",
    "    _, _, t1, t2 = get_files(path, False)\n",
    "    time_traj1 = np.cumsum(t1, axis=1)\n",
    "    time_traj2 = np.cumsum(t2, axis=1)\n",
    "    mean_traj1 = np.mean(time_traj1, axis=0)\n",
    "    print(f'Mean time to finish 5000 iterations LinUCB: {mean_traj1[-1]}')\n",
    "    mean_traj2 = np.mean(time_traj2, axis=0)\n",
    "    print(f'Mean time to finish 5000 iterations LinUCB: {mean_traj2[-1]}')\n",
    "    std_traj1 = np.std(time_traj1, axis=0)\n",
    "    std_traj2 = np.std(time_traj2, axis=0)\n",
    "    plt.rcParams[\"font.family\"] = \"sans-serif\"\n",
    "    plt.errorbar(np.arange(len(std_traj1)), mean_traj1, yerr=std_traj1, errorevery=int(200), capsize=10, linewidth=2.5, label='OFUL')\n",
    "    plt.errorbar(np.arange(len(std_traj2)), mean_traj2, yerr=std_traj2, errorevery=int(200), capsize=10, linewidth=2.5, label='R-OFUL')\n",
    "    plt.grid()\n",
    "    plt.legend(fontsize=15)\n",
    "    plt.xticks(fontsize=12)\n",
    "    plt.yticks(fontsize=12)\n",
    "    plt.xlabel('Rounds', fontsize=15)\n",
    "    plt.ylabel('Cumulative Time to compute in seconds', fontsize=15)\n",
    "    name = savename+'TIME.png'\n",
    "    plt.savefig(name, bbox_inches = 'tight', pad_inches = 0.05) \n",
    "    plt.show()\n",
    "    plt.close()\n",
    "\n",
    "def get_files(path, verbose):\n",
    "    files = [os.path.join(path, f) for f in os.listdir(path) if os.path.isfile(os.path.join(path, f))]\n",
    "    for i in range(len(files)):\n",
    "        with open(files[i], 'rb') as f:\n",
    "            reg_dict = pickle.load(f)\n",
    "        if i == 0:\n",
    "            r1 = reg_dict['reg1'].astype(float)\n",
    "            r2 = reg_dict['reg2'].astype(float)\n",
    "            t1 = np.asarray(reg_dict['time_l1']).reshape(r1.shape)\n",
    "            t2 = np.asarray(reg_dict['time_l2']).reshape(r2.shape)\n",
    "        else:\n",
    "            r1 = np.vstack((r1, reg_dict['reg1'].astype(float)))\n",
    "            r2 = np.vstack((r2, reg_dict['reg2'].astype(float)))\n",
    "            temp1 = np.asarray(reg_dict['time_l1']).reshape(reg_dict['reg1'].shape)\n",
    "            t1 = np.vstack((t1, temp1))\n",
    "            temp2 = np.asarray(reg_dict['time_l2']).reshape(reg_dict['reg2'].shape)\n",
    "            t2 = np.vstack((t2, temp2))\n",
    "    if verbose:\n",
    "        print('Total number of arms:', len(reg_dict['arms']))\n",
    "        print('Bound Noise Level:', reg_dict['bound_noise_level'])\n",
    "        print('Reward Noise std:', reg_dict['rew_noise_std'])\n",
    "    return r1, r2, t1, t2\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = './Files/LinBanditRun2'\n",
    "#10 arms\n",
    "save_name = './Results/LinBanditRun_100arms_10armsperround_5bound_8rew'\n",
    "plot_single_regret(path, save_name)\n",
    "plot_single_time(path, save_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = './Files/LinBanditRun3'\n",
    "#5 arms per round\n",
    "save_name = 'LinBanditRun_100arms_5armsperround_5bound_8rew'\n",
    "plot_single_regret(path, save_name)\n",
    "plot_single_time(path, save_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = './Files/LinBanditRun4'\n",
    "# 15 arms per round\n",
    "save_name = 'LinBanditRun_100arms_15armsperround_5bound_8rew'\n",
    "plot_single_regret(path,save_name)\n",
    "plot_single_time(path,save_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
