{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Learning to Ask for Help"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1- Loading datasets and converting it to a multi-armed bandit setting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "! pip install torchvision"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "import numpy as np\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torchvision import transforms as transforms\n",
    "from IPython import display"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load and transform data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "transform = transforms.ToTensor()\n",
    "# MNIST train\n",
    "mnist_train = torchvision.datasets.MNIST(root='datasets', train=True, download=True, transform=transform)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Convert it to a multi-armed bandit setting"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\\begin{cases}\n",
    "    1, & \\text{if arm = correct label}\\\\\n",
    "    0.25, & \\text{if arm = help arm}\\\\\n",
    "    0, & \\text{otherwise}.\n",
    "\\end{cases} "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def action_to_reward (labels, n_arms=11, last_action_value=0.25):\n",
    "    '''\n",
    "    convert each action into a reward vector has size equal to number of actions\n",
    "    '''\n",
    "    rewards  = last_action_value*torch.ones((1, n_arms))\n",
    "    rewards[0, :-1] = 1.0*(labels == torch.arange(n_arms-1)).float()\n",
    "    return rewards\n",
    "# convert MNIST training dataset to img+rewards    \n",
    "mnist_training_set = np.zeros((len(mnist_train) , 795))\n",
    "for i, data in enumerate (mnist_train):\n",
    "    img, label = data\n",
    "    mnist_training_set[i, : 784] = img.view(-1).numpy()\n",
    "    mnist_training_set[i ,784:] = action_to_reward(label).numpy()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2- Train Deep Ensemble for be a reward distribution to caputure the epistemic uncertainty"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Neural Network hyper-parameters\n",
    "device = \"cuda:0\" # \"cpu\"\n",
    "n_models = 100 # number of models in the deep ensemble\n",
    "n_epochs = 100 # number of epochs each model trained for\n",
    "batch_size = 512 # batch size\n",
    "learning_rate = 1.6e-3 # Learning rate\n",
    "l2 = 0.0 # L2 regularization\n",
    "loss_fun = nn.MSELoss() # loss function Mean Square Error\n",
    "training_loss = np.zeros((n_models, n_epochs)) # training loss for each model in the deep ensemble\n",
    "models = [] # list of all models's parameters in the deep ensemble\n",
    "models_saving_dir = \"models/Learning to Ask for Help\" # directory where each model's papamters will be saved"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train neural networks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for m in range (n_models):\n",
    "    model = nn.Sequential(\n",
    "        nn.Conv2d(1, 64, (4, 4)),\n",
    "        nn.MaxPool2d((2, 2)),\n",
    "        nn.ReLU(),\n",
    "\n",
    "        nn.Conv2d(64, 16, (4, 4)),\n",
    "        nn.MaxPool2d((2, 2)),\n",
    "        nn.ReLU(),\n",
    "\n",
    "        nn.Flatten(),\n",
    "\n",
    "        nn.Linear(256, 50),\n",
    "        nn.ReLU(),\n",
    "        \n",
    "        nn.Linear(50, 15),\n",
    "        nn.ReLU(),\n",
    "        \n",
    "        nn.Linear(15, 11),\n",
    "    ).to(device)\n",
    "    opt = torch.optim.Adam(params = model.parameters(), lr=learning_rate, weight_decay=l2)\n",
    "\n",
    "    for ep in range (n_epochs):\n",
    "        np.random.shuffle(mnist_training_set)\n",
    "        epoch_loss=0\n",
    "        n_batches=0\n",
    "        for batch in range (0, mnist_training_set.shape[0], batch_size):\n",
    "            n_batches+=1\n",
    "            x = torch.tensor(mnist_training_set[batch: batch + batch_size:, :784], device=device, dtype=torch.float)\n",
    "            x = x.view(x.shape[0], 1, 28, 28)\n",
    "            y = torch.tensor(mnist_training_set[batch: batch + batch_size, 784:], device=device, dtype=torch.float)\n",
    "            loss = loss_fun(model(x), y)\n",
    "\n",
    "            epoch_loss+=loss.item()\n",
    "\n",
    "            opt.zero_grad()\n",
    "            loss.backward()\n",
    "            opt.step()\n",
    "        print(\"[model]: %i , [EPOCH]: %i, [training LOSS]: %.5f\" % (m, ep+1, epoch_loss/n_batches))\n",
    "        display.clear_output(wait=True)\n",
    "        training_loss[m, ep] = epoch_loss/n_batches\n",
    "\n",
    "    torch.save(model, \"{}/ensemble_model_{}\".format(models_saving_dir, m))\n",
    "    models.append(model)\n",
    "np.save(\"{}/training_loss\".format(models_saving_dir), training_loss)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3- Construct robust policies with k-of-n game"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load k-of-n functions from k_of_n script\n",
    "from k_of_n import sample_rewards_from_ensemble, sort_and_k_least, run_k_of_n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "method = \"TEST\" # or \"test\" if you want to play k-of-n a sample by sample in dataset, \"TEST\" play k-of-n for full datatset\n",
    "ks = [10, 5, 1] # k values\n",
    "ns = [10, 10, 10] # n values\n",
    "n_itr = 10 # number of itration for k-of-n game\n",
    "n_runs = 1 # how many times you want to repeat each k-of-n policy\n",
    "batch_size = 1024 # batch size for k-of-n\n",
    "gpu = 0 # number of gpus -1 if cpu\n",
    "n_actions = 11 # number of actions or arms\n",
    "replacment = False # True if you want to sample reward functions from reward distribution with replacment\n",
    "output_policies_dir = \"models/actions/{}/Learning to Ask for Help\".format(method) # diroctory path where you want to save k-of-n policies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# check hayper-parameters\n",
    "if len (ks) != len (ns):\n",
    "    raise ValueError(\"ks's length is not equal ns's length\")\n",
    "    \n",
    "if not replacment:\n",
    "    if (max(ns)* n_itr) > n_models:\n",
    "         raise ValueError(\"without replacment requreies more models\")\n",
    "            \n",
    "for k in range (len(ks)):\n",
    "    if ns[k] < ks[k]:\n",
    "        raise ValueError(\"n value={} should be greater than or equal to k value={}\".format(ns[k],  ks[k]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Run k-of-n game"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# MNIST\n",
    "dataset = \"MNIST\" # or \"E-MNIST\", \"MNIST-Fashion\" \n",
    "run_k_of_n(ks, ns, n_runs, n_itr, method, n_models,batch_size, models_saving_dir, output_policies_dir, device, dataset, n_actions)\n",
    "\n",
    "# Fashion-MNIST\n",
    "dataset = \"MNIST-Fashion\"\n",
    "run_k_of_n(ks, ns, n_runs, n_itr, method, n_models,batch_size, models_saving_dir, output_policies_dir, device, dataset, n_actions)\n",
    "\n",
    "# E-MNIST\n",
    "dataset = \"E-MNIST\"\n",
    "run_k_of_n(ks, ns, n_runs, n_itr, method, n_models,batch_size, models_saving_dir, output_policies_dir, device, dataset, n_actions)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4- Show different robust policies's behavior"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Mean and std of probabilities of correct arms and help arm of k-of-n policies for MNIST, Fashion-MNIST and E-MNIST"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mnist_arg = np.zeros((n_runs,len(ks), 10000, 11))\n",
    "emnist_arg = np.zeros((n_runs,len(ks), 20800, 11))\n",
    "fashion_arg = np.zeros((n_runs,len(ks), 10000, 11))\n",
    "\n",
    "mnist_m = np.zeros(len(ks))\n",
    "emnist_m = np.zeros(len(ks))\n",
    "fashion_m = np.zeros(len(ks))\n",
    "mnist_acc_m = np.zeros(len(ks))\n",
    "\n",
    "for i in range (n_runs):\n",
    "    for j in range (len(ks)):\n",
    "        mnist_arg[i,j] = np.load(\"actions/{}/run_{}_mnist_actions_{}-of-{}_n_itr_{}.npy\".format(dir_path, i, ks[j], ns[j], n_itr))\n",
    "        emnist_arg[i,j] = np.load(\"actions/{}/run_{}_emnist_actions_{}-of-{}_n_itr_{}.npy\".format(dir_path, i, ks[j], ns[j], n_itr))\n",
    "        fashion_arg[i,j] = np.load(\"actions/{}/run_{}_fashion_actions_{}-of-{}_n_itr_{}.npy\".format(dir_path, i, ks[j], ns[j], n_itr))\n",
    "        \n",
    "for j in range (len(ks)):        \n",
    "    mnist_m [j]= 100 * np.round(np.mean(np.mean(mnist_arg[:, j, :, -1], 1)),4)\n",
    "    emnist_m [j]= 100 * np.round(np.mean(np.mean(emnist_arg[:, j, :, -1], 1)),4)\n",
    "    fashion_m [j]= 100 * np.round(np.mean(np.mean(fashion_arg[:, j, :, -1], 1)),4)  \n",
    "    acc = np.zeros(n_runs)\n",
    "    for i in range (n_runs):\n",
    "        acc [i]= np.sum(np.equal( np.argmax(mnist_arg[i, j], axis=1), testing_set_y.numpy()) + np.zeros((testing_set_y.shape[0]))) / testing_set_y.shape[0]\n",
    "    mnist_acc_m[j] = 100 * np.round(np.mean(acc, axis=0), 4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"MNIST accuracy\", 100 * mnist_acc_m)\n",
    "print(\"MNIST P(help)\", 100 * mnist_m)\n",
    "print(\"Fashion-MNIST P(help)\", 100 * fashion_m)\n",
    "print(\"E-MNIST P(help)\", 100 * emnist_m)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### heat-map for E-MNIST"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "emnist_test  = torchvision.datasets.EMNIST(root=\"datasets\", train=False, transform, target_transform=None, download=True, split=\"letters\")\n",
    "\n",
    "emnist_labels = torch.zeros(len(emnist_test), dtype=torch.int)\n",
    "\n",
    "emnist_labels_names = list(\"abcdefghijklmnopqrstuvwxyz\")\n",
    "\n",
    "for i in range (len(emnist_test)):\n",
    "    emnist_labels[i] = emnist_test[i][1]    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### load saved policies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "output_policies_dir = \"all-images regime/Learning to Ask for Help\"\n",
    "\n",
    "emnist_policies = np.zeros((n_runs,len(ks), len(emnist_test), n_actions))\n",
    "\n",
    "for i in range (n_runs):\n",
    "    for j in range (len(ks)):\n",
    "        emnist_policies[i,j] = np.load(\"actions/{}/run_{}_emnist_actions_{}-of-{}_n_itr_{}.npy\".format(dir_path, i, ks[j], ns[j], n_itr))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "map_emnist_m = np.zeros((len(ks), len(emnist_labels_names), n_actions))\n",
    "\n",
    "for k in range (len(ks)):\n",
    "    for i in range (len(emnist_labels_names)):\n",
    "        map_emnist_m[k, i] = np.mean(np.mean(emnist_policies[:, k, np.where(emnist_labels.numpy()==(i+1))[0]], 1),0)\n",
    "\n",
    "if n_actions==11:\n",
    "    labelss = [\"0\", \"1\", \"2\", \"3\", \"4\", \"5\", \"6\", \"7\", \"8\", \"9\", \"help\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### plot heat-map figures for each policy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range (len(ks)):\n",
    "    plt.figure(figsize=(10,10))\n",
    "    frame1 = plt.gca()\n",
    "    sns.heatmap(np.round(map_emnist_m[i], 2), annot=True,  linewidths=.5, cmap=\"YlGnBu\", vmin=0, vmax=1)\n",
    "    plt.ylabel(\"Letter\")\n",
    "    plt.yticks(np.arange(len(emnist_labels_names))+0.5, emnist_labels_names, rotation='horizontal')\n",
    "    plt.xlabel(\"Actions\")\n",
    "    plt.xticks(np.arange(n_actions)+0.5, labelss)  \n",
    "    plt.tight_layout()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
