{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "13dnqIY7DfmL"
   },
   "source": [
    "# How caution depends on the extent of training data: Training Deep Ensemble model\n",
    "\n",
    "- Train a deep ensemble to build a reward distribution (which k-of-n will sample for it later) using MNIST regression.\n",
    "\n",
    "- convert MNIST labels into a reward vector using the following equations : \n",
    "\n",
    "    - $R(label) = \\mathcal{N}(1, 0.1) $ if right label\n",
    "    \n",
    "    - $R(label)= \\mathcal{N}(0, 0.1) $ if wrong label\n",
    "    \n",
    "    - $R(help)= \\mathcal{N}(0.25, 0.1)$ if label == help arm "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## import lib"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "! pip install torchvision"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 35
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 2689,
     "status": "ok",
     "timestamp": 1596511981556,
     "user": {
      "displayName": "Montaser Fathelrhman Hussen Mohammedala",
      "photoUrl": "",
      "userId": "10501124642310264932"
     },
     "user_tz": 360
    },
    "id": "Q0mjeLjlDfmN",
    "outputId": "fb132c79-5c61-4f3c-cf32-37fbe942ce2d"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import time\n",
    "import torchvision\n",
    "import numpy as np\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torchvision import transforms as transforms\n",
    "from IPython import display"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Hyper-parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_esambles = 50 #number of models in Ensamble\n",
    "n_epochs = 1000   #number of epochs to train each model\n",
    "batch_size = 128 # batch size\n",
    "learning_rate = 1.6e-3 #learning rate\n",
    "output_models_dir = \"models/How Caution Depends on the Extent of Training Data/\" # directory path where you want to save models\n",
    "device = \"cuda:0\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "6XdUVnlVcqzf"
   },
   "source": [
    "## Load dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "gBrKLSo6cuAK"
   },
   "source": [
    "## Data pre-processing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 1882,
     "status": "ok",
     "timestamp": 1596511981557,
     "user": {
      "displayName": "Montaser Fathelrhman Hussen Mohammedala",
      "photoUrl": "",
      "userId": "10501124642310264932"
     },
     "user_tz": 360
    },
    "id": "e9ZYDaeqDfmX"
   },
   "outputs": [],
   "source": [
    "transform = transforms.ToTensor()\n",
    "mnist_train = torchvision.datasets.MNIST('datasets', train=True, download=True, transform=transform)\n",
    "trainloader = torch.utils.data.DataLoader(mnist_train, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def action_to_reward (a, n_labels=11):\n",
    "    '''\n",
    "    Convert MNIST label into a reward vector\n",
    "    Args:\n",
    "    a: (tensor) has shape (number of samples, ) MNIST labels\n",
    "    return:\n",
    "    new_y: (tensor) has shape (number of samples, 11) reward vector\n",
    "    '''\n",
    "    new_y = torch.zeros((a.shape[0], 11))\n",
    "    new_y = 1.0 * (a == torch.arange(n_labels).reshape(1, n_labels)).float() + torch.from_numpy(np.random.normal(0, scale=0.1, size=n_labels))\n",
    "    new_y[:, -1] += 0.25 * torch.ones(a.shape[0])\n",
    "    return new_y\n",
    "\n",
    "training_set = np.zeros((len(trainloader) , 795))\n",
    "for i, data in enumerate (trainloader):\n",
    "    img, label = data\n",
    "    training_set[i, : 784] = img.view(-1).numpy()\n",
    "    training_set[i ,784:] = action_to_reward(label).numpy()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1% training set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_samples = 600\n",
    "training_set = np.zeros((n_samples , 795))\n",
    "counts = np.zeros(10)\n",
    "ll=0\n",
    "for i, data in enumerate (trainloader):\n",
    "    if counts.all() == n_samples//10:\n",
    "        break\n",
    "    img, label = data\n",
    "    if counts[label] < n_samples//10:\n",
    "        counts[label] +=1\n",
    "        training_set[ll, :784] = img.view(-1).numpy()\n",
    "        training_set[ll, 784:] = action_to_reward(label).numpy()\n",
    "        ll+=1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "McBI9nyIDfmn"
   },
   "source": [
    "## Train Ensemble"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "training_loss = np.zeros((n_esambles, n_epochs)) #training loss during training\n",
    "times = [] #time to train each model\n",
    "loss_fun = nn.MSELoss() #Loss function \n",
    "\n",
    "for m in range (n_esambles):\n",
    "    t = time.perf_counter() # reset timer for each model\n",
    "    model = nn.Sequential(\n",
    "        nn.Conv2d(1, 64, (4, 4)),\n",
    "        nn.MaxPool2d((2, 2)),\n",
    "        nn.ReLU(),\n",
    "\n",
    "        nn.Conv2d(64, 16, (4, 4)),\n",
    "        nn.MaxPool2d((2, 2)),\n",
    "        nn.ReLU(),\n",
    "\n",
    "        nn.Flatten(),\n",
    "\n",
    "        nn.Linear(256, 50),\n",
    "        nn.ReLU(),\n",
    "        \n",
    "        nn.Linear(50, 15),\n",
    "        nn.ReLU(),\n",
    "        \n",
    "        nn.Linear(15, 11),\n",
    "    ).to(device)\n",
    "    \n",
    "    opt = torch.optim.Adam(params = model.parameters(), lr=learning_rate)\n",
    "\n",
    "    for ep in range (n_epochs):\n",
    "        ep_loss, n_batches = 0, 0\n",
    "        np.random.shuffle(training_set)\n",
    "        for batch in range (0, training_set.shape[0] , batch_size):\n",
    "            n_batches += 1\n",
    "            x = torch.tensor(training_set[batch : batch + batch_size:, :784], device=device, dtype=torch.float)\n",
    "            x = x.view(x.shape[0], 1, 28, 28)\n",
    "            y = torch.tensor(training_set[batch : batch + batch_size, 784:], device=device, dtype=torch.float)\n",
    "            \n",
    "            loss = loss_fun(model(x), y) \n",
    "            ep_loss += loss.item()\n",
    "\n",
    "            opt.zero_grad()\n",
    "            loss.backward()\n",
    "            opt.step()\n",
    "        print(\"model: %i ,[EPOCH]: %i, [training loss]: %.6f\" % (m, ep+1, ep_loss / n_batches))\n",
    "        display.clear_output(wait=True)\n",
    "        training_loss[m, ep] = ep_loss / n_batches\n",
    "\n",
    "    torch.save(model, output_models_dir + \"ensemble_model_{}\".format(m)) # save each model by model number\n",
    "    times.append(time.perf_counter() - t)\n",
    "np.save(output_models_dir + \"training_loss\", training_loss) # save training loss for all models in the ensamble"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "F_pXMoh4c1EV"
   },
   "source": [
    "## Calculate training time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 35
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 746,
     "status": "ok",
     "timestamp": 1596515428771,
     "user": {
      "displayName": "Montaser Fathelrhman Hussen Mohammedala",
      "photoUrl": "",
      "userId": "10501124642310264932"
     },
     "user_tz": 360
    },
    "id": "hRS_d22DNr21",
    "outputId": "2bf24f18-bb2b-41a8-943b-1f465a4ffd2b"
   },
   "outputs": [],
   "source": [
    "print(\"average training time for a single model is {} min\".format(np.round(np.mean(times)/60, 3)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Calculate accuracy of the last model "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 35
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 945,
     "status": "ok",
     "timestamp": 1596518041151,
     "user": {
      "displayName": "Montaser Fathelrhman Hussen Mohammedala",
      "photoUrl": "",
      "userId": "10501124642310264932"
     },
     "user_tz": 360
    },
    "id": "wF5YXhxJo-Br",
    "outputId": "ac7a24ee-4251-43fb-b168-468118ce5ac4"
   },
   "outputs": [],
   "source": [
    "acc = 0\n",
    "y_pre = np.zeros(training_set.shape[0])\n",
    "y_i = np.argmax(training_set[:, 784:], axis=1)\n",
    "for b in range (0, training_set.shape[0], batch_size):\n",
    "    x_i = torch.tensor(training_set[b : b + batch_size:, :784], device=device, dtype=torch.float)\n",
    "    y_pre[b:batch_size+b] = np.argmax(model(x_i.view(x_i.shape[0],1, 28,28 )).detach().cpu().numpy(), axis=1)\n",
    "\n",
    "for i in range (training_set.shape[0]):\n",
    "    if y_pre[i]==y_i[i]:\n",
    "        acc += 1\n",
    "print(\"model accuracy: {}\".format(np.round(100*acc/training_set.shape[0], 3)))"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "risk_reward_ensamble_mnist_regression.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
