{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Ask for help Only when it's available: Training Deep Ensemble model\n",
    "\n",
    "- Train a deep ensemble to build a reward distribution (which k-of-n will sample for it later) using MNIST regression.\n",
    "\n",
    "- convert MNIST labels into a reward vector using the following equations:\n",
    "\n",
    "    - $R(label) = \\frac{label+1}{10}$ if right (reward)\n",
    "    \n",
    "    - $R(label)= \\frac{-1}{9*10} * lable$ if wrong (risk)\n",
    "    \n",
    "    - $R(help)$= 0.05 if help is available else $R(help)$= -0.1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "! pip install torchvision"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 128
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 26422,
     "status": "ok",
     "timestamp": 1596084097174,
     "user": {
      "displayName": "Montaser Fathelrhman Hussen Mohammedala",
      "photoUrl": "",
      "userId": "10501124642310264932"
     },
     "user_tz": 360
    },
    "id": "t2egOflEdhcs",
    "outputId": "ec1f00a0-e940-4405-af8e-80740c51d79d"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import time\n",
    "import torchvision\n",
    "import numpy as np\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torchvision import transforms as transforms\n",
    "from IPython import display"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "Gh17GqWZdhc8"
   },
   "source": [
    "## datasets and pre-processing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 232,
     "referenced_widgets": [
      "b92fb5fc578a4e1d8a089194b816fd70",
      "eba51d6745224b078e42d9ada9abc430",
      "b2bcb67e6be2458a8733c357438499b7",
      "46f887125eb3459088b9b6e3298d4e65",
      "9be8fab804f64d38af5923b716a00c2f",
      "1c8e407422fd418c8ae6ba29aa8040d2",
      "e42491b54d494b79b82626173868e59d",
      "75734eadca3d4cceb0ff63292740a073"
     ]
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 81759,
     "status": "ok",
     "timestamp": 1596110097016,
     "user": {
      "displayName": "Montaser Fathelrhman Hussen Mohammedala",
      "photoUrl": "",
      "userId": "10501124642310264932"
     },
     "user_tz": 360
    },
    "id": "jWzGtDtCdhdM",
    "outputId": "8de11aa3-0b79-4936-a28d-a0dc55aef559"
   },
   "outputs": [],
   "source": [
    "transform = transforms.ToTensor()\n",
    "# Load and transform data\n",
    "mnist_train = torchvision.datasets.MNIST('datasets', train=True, download=True, transform=transform)\n",
    "trainloader = torch.utils.data.DataLoader(mnist_train, shuffle=False)\n",
    "\n",
    "def action_to_reward (a, risk=9, pluse_value=1):\n",
    "    '''\n",
    "    convert each label into a reward vector has size equal to number of actions (one-hot encoding)\n",
    "    '''\n",
    "    new_y = torch.zeros((a.shape[0], 10))\n",
    "    for i in range (a.shape[0]):\n",
    "        new_y[i] = -(1 / risk) * torch.arange(10)\n",
    "        new_y[i, a[i]] = a[i] + pluse_value\n",
    "    return new_y / (9 + pluse_value)\n",
    "\n",
    "training_set = np.zeros((2 * len(trainloader), 795))\n",
    "l=0\n",
    "for i, data in enumerate (trainloader):\n",
    "    img, label = data\n",
    "    training_set[l, : 784] = img.view(-1).numpy()\n",
    "    training_set[l ,784:-1] = action_to_reward(label).numpy()\n",
    "    training_set[l, -1] = 0.05\n",
    "\n",
    "    training_set[l+1, : 784] = img.view(-1).numpy()\n",
    "    training_set[l+1 ,784:-1] = action_to_reward(label).numpy()\n",
    "    training_set[l+1, -1] = -0.1\n",
    "\n",
    "    l+=2 \n",
    "\n",
    "ss = torch.ones(11).cuda()\n",
    "for i in range (10):\n",
    "    ss[i] = 1/(i+1)**2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 661,
     "status": "ok",
     "timestamp": 1596084194065,
     "user": {
      "displayName": "Montaser Fathelrhman Hussen Mohammedala",
      "photoUrl": "",
      "userId": "10501124642310264932"
     },
     "user_tz": 360
    },
    "id": "bbZ71ILOdhdq"
   },
   "outputs": [],
   "source": [
    "device = \"cuda:0\"\n",
    "n_esambles = 100\n",
    "n_epochs = 100\n",
    "batch_size = 512\n",
    "l2 = 0\n",
    "training_loss_retrained = np.zeros((n_esambles, n_epochs))\n",
    "times = []\n",
    "for m in range (n_esambles):\n",
    "    t = time.perf_counter()\n",
    "\n",
    "    model = nn.Sequential(\n",
    "        nn.Conv2d(1, 64, (4, 4)),\n",
    "        nn.MaxPool2d((2, 2)),\n",
    "        nn.ReLU(),\n",
    "\n",
    "        nn.Conv2d(64, 16, (4, 4)),\n",
    "        nn.MaxPool2d((2, 2)),\n",
    "        nn.ReLU(),\n",
    "\n",
    "        nn.Flatten(),\n",
    "        #Add a single bit which will determine if help is available (0) or not (1)\n",
    "        nn.Linear(256+1, 50),\n",
    "        nn.ReLU(),\n",
    "        \n",
    "        nn.Linear(50, 15),\n",
    "        nn.ReLU(),\n",
    "        \n",
    "        nn.Linear(15, 11),\n",
    "    ).to(device)\n",
    "\n",
    "    opt = torch.optim.Adam(params = model.parameters(), lr=1.6e-3, weight_decay=l2)\n",
    "    loss_fun = nn.MSELoss(size_average=False, reduce=False)\n",
    "\n",
    "    for ep in range (n_epochs):\n",
    "        ep_loss, n_batches = 0 , 0\n",
    "        np.random.shuffle(training_set)\n",
    "        for batch in range (0, training_set.shape[0] , batch_size):\n",
    "            n_batches += 1\n",
    "            x = torch.tensor(training_set[batch : batch + batch_size:, :784], device=device, dtype=torch.float)\n",
    "            x = x.view(x.shape[0], 1, 28, 28)\n",
    "            new_x = torch.zeros((x.shape[0], 257), device=device, dtype=torch.float)\n",
    "            y = torch.tensor(training_set[batch : batch + batch_size, 784:], device=device, dtype=torch.float)\n",
    "            new_x[:, :-1] = model[:7](x)\n",
    "            neg = np.where(y[:, -1].cpu().numpy() == -0.1)\n",
    "            new_x[neg,  -1] = 1.0\n",
    "            \n",
    "            #loss re-weight\n",
    "            loss = torch.mean(loss_fun(model[7:](new_x), y)*ss)\n",
    "            ep_loss += loss.item()\n",
    "\n",
    "            opt.zero_grad()\n",
    "            loss.backward()\n",
    "            opt.step()\n",
    "        print(\"model: %i , [EPOCH]: %i, [train LOSS]: %.6f\" % (m, ep+1, ep_loss / n_batches))\n",
    "        display.clear_output(wait=True)\n",
    "        training_loss_retrained[m, ep] = ep_loss / n_batches\n",
    "\n",
    "    torch.save(model, \"models/Ask for Help Only When it is Available/ensemble_model_{}\".format(m))\n",
    "    times.append( time.perf_counter() - t)\n",
    "np.save(\"models/Ask for Help Only When it is Available/training_loss\", training_loss_retrained)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "acc = 0\n",
    "sec = 0\n",
    "y_pre = np.zeros(training_set.shape[0])\n",
    "y_i = np.argmax(training_set[:, 784:], axis=1)\n",
    "for b in range (0, training_set.shape[0], batch_size):\n",
    "    x_i = torch.tensor(training_set[b : b + batch_size:, :784], device=device, dtype=torch.float)\n",
    "\n",
    "    new_x_i = torch.zeros((x_i.shape[0], 257), device=device, dtype=torch.float)\n",
    "    new_x_i[:, :-1] = model[:7](x_i.view(x_i.shape[0],1, 28,28 ))\n",
    "    new_x_i[np.where(training_set[b : b + batch_size, -1]==-0.1), -1] = 1.0\n",
    "    y_pre[b:batch_size+b] = np.argmax(model[7:](new_x_i).detach().cpu().numpy(), axis=1)\n",
    "\n",
    "for i in range (training_set.shape[0]):\n",
    "    if y_pre[i]==y_i[i]:\n",
    "        acc += 1\n",
    "    elif y_pre[i]==y_i[i]-1 or y_pre[i]==y_i[i]+1:\n",
    "        sec+=1\n",
    "print('Accuracy', 100*acc/training_set.shape[0], 100*sec/training_set.shape[0])"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "k-of-n-bit.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "1c8e407422fd418c8ae6ba29aa8040d2": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "46f887125eb3459088b9b6e3298d4e65": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_75734eadca3d4cceb0ff63292740a073",
      "placeholder": "​",
      "style": "IPY_MODEL_e42491b54d494b79b82626173868e59d",
      "value": " 561758208/? [00:43&lt;00:00, 20475028.56it/s]"
     }
    },
    "75734eadca3d4cceb0ff63292740a073": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "9be8fab804f64d38af5923b716a00c2f": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": "initial"
     }
    },
    "b2bcb67e6be2458a8733c357438499b7": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "info",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_1c8e407422fd418c8ae6ba29aa8040d2",
      "max": 1,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_9be8fab804f64d38af5923b716a00c2f",
      "value": 1
     }
    },
    "b92fb5fc578a4e1d8a089194b816fd70": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_b2bcb67e6be2458a8733c357438499b7",
       "IPY_MODEL_46f887125eb3459088b9b6e3298d4e65"
      ],
      "layout": "IPY_MODEL_eba51d6745224b078e42d9ada9abc430"
     }
    },
    "e42491b54d494b79b82626173868e59d": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "eba51d6745224b078e42d9ada9abc430": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    }
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
