{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Experiment **4.2**: Discovering Non-Obvious Cautious Actions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Import libs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import torchvision, torch\n",
    "from torchvision import transforms as transforms\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib import style\n",
    "import scipy\n",
    "import colorsys\n",
    "import seaborn as sns\n",
    "transform = transforms.ToTensor()\n",
    "## plot style and font\n",
    "style.use('seaborn-colorblind')\n",
    "transform = transforms.ToTensor()\n",
    "font_params = {}\n",
    "font_params['family'] = 'serif'\n",
    "plt.rc('font', **font_params)\n",
    "plt.rc('pdf', fonttype=42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def mean_confidence_interval(data, confidence=0.95):\n",
    "    n = len(data)\n",
    "    m, se = np.mean(data), scipy.stats.sem(data)\n",
    "    h = se * scipy.stats.t.ppf((1 + confidence) / 2., n-1)\n",
    "    return h\n",
    "\n",
    "dark_hues = [0, 0.1, 0.4, 0.55, 0.65, 0.75, 0.9]\n",
    "dark_lightness = [0.4, 0.2, 0.15, 0.4, 0.4, 0.4, 0.4]\n",
    "light_hues = [1, 0.1, 0.3, 0.5, 0.6, 0.75, 0.85]\n",
    "light_lightness = [0.75, 0.6, 0.45, 0.6, 0.65, 0.8, 0.55]\n",
    "dark_colors = [colorsys.hls_to_rgb(h, dark_lightness[i], 1) for i, h in enumerate(dark_hues)]\n",
    "sns.palplot(dark_colors)\n",
    "light_colors = [colorsys.hls_to_rgb(h, light_lightness[i], 1) for i, h in enumerate(light_hues)]\n",
    "sns.palplot(light_colors)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load datasets (MNIST, Fashion, and E-MNIST) and pre-processing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mnist_test = torchvision.datasets.MNIST(root='datasets', train=False, download=True, transform=transform)\n",
    "\n",
    "fashion_test = torchvision.datasets.FashionMNIST(root='datasets', train=False, download=True, transform=transform)\n",
    "\n",
    "emnist_test  = torchvision.datasets.EMNIST(root=\"datasets\", train=False, transform=transform, target_transform=None, download=True, split=\"letters\")\n",
    "\n",
    "mnist_labels = torch.zeros((len(mnist_test)))\n",
    "fashion_labels = torch.zeros((len(fashion_test)))\n",
    "emnist_labels = torch.zeros((len(emnist_test)))\n",
    "        \n",
    "for i in range (len(mnist_test)):\n",
    "    mnist_labels[i] = mnist_test[i][1]\n",
    "\n",
    "    fashion_labels[i] = fashion_test[i][1]\n",
    "\n",
    "    emnist_labels[i] = emnist_test[i][1]\n",
    "\n",
    "for j in range (i, len(emnist_test)):\n",
    "    emnist_labels[j] = emnist_test[j][1]\n",
    "\n",
    "fashion_labels_names = [\"T-shirt/top\",  \"Trouser\", \"Pullover\" , \"Dress\" , \"Coat\" , \"Sandal\" , \"Shirt\", \"Sneaker\" ,\"Bag\" ,\"Ankle boot\"]\n",
    "emnist_labels_names = [\"A/a\", \"B/b\", \"C/c\", \"D/d\", \"E/e\", \"F/f\", \"G/g\", \"H/h\", \"I/i\", \"J/j\", \"K/k\", \"L/l\", \"M/m\", \"N/n\", \"O/o\", \"P/p\", \"Q/q\", \"R/r\", \"S/s\", \"T/t\", \"U/u\", \"V/v\", \"W/w\", \"X/x\", \"Y/y\", \"Z/z\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Plot "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## load k-of-n actions "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "models_dir = \"models/Discovering Non-Obvious Cautious Actions\" # directory where each model's papamters will be saved\n",
    "models_dir_split = models_dir.split('/')\n",
    "fig_dir = 'fig/' + models_dir_split[1]\n",
    "if not os.path.exists(fig_dir):\n",
    "    os.makedirs(fig_dir)\n",
    "n_models = 2000\n",
    "n_actions = 10\n",
    "n_runs = 10\n",
    "n_itr = 100\n",
    "fontsize = 25\n",
    "    \n",
    "ks = [1, 1, 5, 10]\n",
    "ns = [20, 10, 10, 10]\n",
    "\n",
    "test_k_of_n_mnist = np.zeros((len(ks), n_runs, n_actions))\n",
    "test_k_of_n_fashion = np.zeros((len(ks), n_runs, n_actions))\n",
    "test_k_of_n_emnist = np.zeros((len(ks), n_runs,  n_actions))\n",
    "\n",
    "TEST_k_of_n_mnist = np.zeros((len(ks), n_runs, n_actions))\n",
    "TEST_k_of_n_fashion = np.zeros((len(ks), n_runs, n_actions))\n",
    "TEST_k_of_n_emnist = np.zeros((len(ks), n_runs,  n_actions))\n",
    "\n",
    "for run in range (n_runs):\n",
    "    for k in range (len(ks)):\n",
    "        TEST_k_of_n_mnist[k, run] = np.mean(np.load('actions/all-images regime/{}/run_{}_mnist_actions_{}-of-{}_n_itr_100.npy'.format(models_dir_split[1], run, ks[k], ns[k])), axis=0)\n",
    "        TEST_k_of_n_fashion[k, run] = np.mean(np.load('actions/all-images regime/{}/run_{}_fashion_actions_{}-of-{}_n_itr_100.npy'.format(models_dir_split[1], run, ks[k], ns[k])), axis=0)\n",
    "        TEST_k_of_n_emnist[k, run] = np.mean(np.load('actions/all-images regime/{}/run_{}_emnist_actions_{}-of-{}_n_itr_100.npy'.format(models_dir_split[1], run, ks[k], ns[k])), axis=0)\n",
    "        \n",
    "        test_k_of_n_mnist[k, run] = np.mean(np.load('actions/single-image regime/{}/run_{}_mnist_actions_{}-of-{}_n_itr_100.npy'.format(models_dir_split[1], run, ks[k], ns[k])), axis=0)\n",
    "        test_k_of_n_fashion[k, run] = np.mean(np.load('actions/single-image regime/{}/run_{}_fashion_actions_{}-of-{}_n_itr_100.npy'.format(models_dir_split[1], run, ks[k], ns[k])), axis=0)\n",
    "        test_k_of_n_emnist[k, run] = np.mean(np.load('actions/single-image regime/{}/run_{}_emnist_actions_{}-of-{}_n_itr_100.npy'.format(models_dir_split[1], run, ks[k], ns[k])), axis=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## load Standard RL actions "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "baseline_dir = 'actions/all-images regime/'+models_dir_split[1]\n",
    "greedy_mnist = np.mean(np.load('{}/hard_max_mnist_actions.npy'.format(baseline_dir)), axis=1) # take each model as a policy and average over polcies predictions\n",
    "greedy_fashion = np.mean(np.load('{}/hard_max_fashion_actions.npy'.format(baseline_dir)), axis=1) # take each model as a policy and average over polcies predictions\n",
    "greedy_emnist = np.mean(np.load('{}/hard_max_emnist_actions.npy'.format(baseline_dir)), axis=1) # take each model as a policy and average over polcies predictions\n",
    "\n",
    "error_greedy_mnist = mean_confidence_interval(np.sum(greedy_mnist*np.arange(n_actions), axis=1))\n",
    "error_greedy_fashion = mean_confidence_interval(np.sum(greedy_fashion*np.arange(n_actions), axis=1))\n",
    "error_greedy_emnist = mean_confidence_interval(np.sum(greedy_emnist*np.arange(n_actions), axis=1))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Fashion and E-MNIST"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "w = 0.05\n",
    "space_bars = 0.002\n",
    "space_between_datasets = 0.45\n",
    "scatter_x = np.arange(-5, n_runs-5) / 210\n",
    "legends = ['1-of-20', '1-of-10', '5-of-10', '10-of-10', 'greedy'+r\"$(\\hat{r})$\"]\n",
    "\n",
    "xx = np.arange(1)\n",
    "fig = plt.figure(figsize=(20,5))\n",
    "plt.subplot(1, 2, 1)\n",
    "space = 0\n",
    "for i in range (len(ns)):\n",
    "    plt.bar(xx + space, sum(np.mean(TEST_k_of_n_fashion[i], 0)*np.arange(n_actions)), width=w, alpha=0.8, label=\"{}-of-{}\".format(ks[i], ns[i]), color=light_colors[i])\n",
    "    plt.bar(xx + space + space_between_datasets, sum(np.mean(TEST_k_of_n_emnist[i], 0)*np.arange(n_actions)), width=w, alpha=0.8, color=light_colors[i])\n",
    "    points_fashion, points_emnist = [], []\n",
    "    for run in range (n_runs):\n",
    "        points_fashion.append(sum(TEST_k_of_n_fashion[i, run]*np.arange(n_actions)))\n",
    "        points_emnist.append(sum(TEST_k_of_n_emnist[i, run]*np.arange(n_actions)))\n",
    "    plt.scatter(xx + space + scatter_x, np.sort(points_fashion), s=55, marker='*', color=dark_colors[i])\n",
    "    plt.scatter(xx + space + scatter_x + space_between_datasets, np.sort(points_emnist), s=55, marker='*', color=dark_colors[i])\n",
    "\n",
    "    space = space + w + space_bars\n",
    "plt.bar(xx + space, sum(np.arange(n_actions)*np.mean(greedy_fashion, axis=0)), width=w, yerr=error_greedy_fashion, linewidth=10, capsize=20, alpha=0.8, label='greedy'+r\"$(\\hat{r})$\", color=light_colors[i+1])\n",
    "plt.bar(xx + space + space_between_datasets, sum(np.arange(n_actions)*np.mean(greedy_emnist, axis=0)), yerr=error_greedy_emnist, linewidth=10, capsize=20, width=w, alpha=0.8, color=light_colors[i+1])\n",
    " \n",
    "plt.text(0.095, -1.3, 'f', fontsize = fontsize+5)\n",
    "plt.text(0.54, -1.3, r'$\\alpha$', fontsize = fontsize+5)\n",
    "\n",
    "plt.xticks([], fontsize=0)\n",
    "plt.yticks(np.arange(0, 10, 3),fontsize = fontsize)\n",
    "plt.ylabel(\"action index\", fontsize = fontsize+2)\n",
    "plt.ylim(0, 9)\n",
    "plt.grid(axis=\"y\")  \n",
    "\n",
    "plt.subplot(1, 2, 2)\n",
    "space = 0\n",
    "bars = []\n",
    "for i in range (len(ns)):\n",
    "    bars.append(plt.bar(xx + space, sum(np.mean(test_k_of_n_fashion[i], 0)*np.arange(n_actions)), width=w, alpha=0.8, label=\"{}-of-{}\".format(ks[i], ns[i]), color=light_colors[i]))\n",
    "    plt.bar(xx + space + space_between_datasets, sum(np.mean(test_k_of_n_emnist[i], 0)*np.arange(n_actions)), width=w, alpha=0.8, color=light_colors[i])\n",
    "    points_fashion, points_emnist = [], []\n",
    "    for run in range (n_runs):\n",
    "        points_fashion.append(sum(test_k_of_n_fashion[i, run]*np.arange(n_actions)))\n",
    "        points_emnist.append(sum(test_k_of_n_emnist[i, run]*np.arange(n_actions)))\n",
    "    plt.scatter(xx + space + scatter_x, np.sort(points_fashion), s=55, marker='*', color=dark_colors[i])\n",
    "    plt.scatter(xx + space + scatter_x + space_between_datasets, np.sort(points_emnist), s=55, marker='*', color=dark_colors[i])\n",
    "\n",
    "    space = space + w + space_bars\n",
    "bars.append(plt.bar(xx + space, sum(np.arange(n_actions)*np.mean(greedy_fashion, axis=0)), yerr=error_greedy_fashion, linewidth=10, capsize=20, width=w, alpha=0.8, label='greedy'+r\"$(\\hat{r})$\", color=light_colors[i+1]))\n",
    "plt.bar(xx + space + space_between_datasets, sum(np.arange(n_actions)*np.mean(greedy_emnist, axis=0)), yerr=error_greedy_emnist, linewidth=10, capsize=20, width=w, alpha=0.8, color=light_colors[i+1])\n",
    "\n",
    "plt.xticks([], fontsize=0)\n",
    "plt.yticks(np.arange(0, 10, 3), fontsize=0)\n",
    "plt.text(0.095, -1.3, 'f', fontsize = fontsize+5)\n",
    "plt.text(0.54, -1.3, r'$\\alpha$', fontsize = fontsize+5)\n",
    "plt.ylim(0, 9)\n",
    "plt.grid(axis=\"y\")\n",
    "fig.legend(bars, legends, fontsize=fontsize+2, ncol=len(ns)+1, loc=(0.12, .81), )\n",
    "plt.subplots_adjust(bottom=0.5, top=0.6)\n",
    "plt.tight_layout(w_pad=1.0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Fashion and E-MNIST histograms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "w = 0.12\n",
    "space_between_bars = 0.00\n",
    "fontsize = 25\n",
    "hist_fig = plt.figure(figsize=(20,10))\n",
    "hist_grid = plt.GridSpec(2, 2, figure=hist_fig, width_ratios=[1, 1])\n",
    "xx = np.arange(n_actions)\n",
    "legends = ['1-of-20', '1-of-10', '5-of-10', '10-of-10', 'greedy'+r\"$(\\hat{r})$\"]\n",
    "labelss = [\"0\", \"1\", \"2\", \"3\", \"4\", \"5\", \"6\", \"7\", \"8\", \"9\"]\n",
    "\n",
    "space = 0\n",
    "plt.subplot(hist_grid[0, 0])\n",
    "for k in range (len(ks)):\n",
    "    plt.bar(xx+space, np.mean(TEST_k_of_n_fashion[k], axis=0), width=w, color=light_colors[k])\n",
    "    space = space + w + space_between_bars\n",
    "    plt.axvline(x=np.sum(np.mean(TEST_k_of_n_fashion[k], axis=0)*np.arange(n_actions)), lw=3, linestyle=\"-.\", color=light_colors[k])\n",
    "plt.bar(xx+space, np.mean(greedy_fashion, axis=0),width=w, color=light_colors[k+1])\n",
    "plt.axvline(x=np.sum(np.mean(greedy_fashion, axis=0)*np.arange(n_actions)), lw=3, linestyle=\"-.\", color=light_colors[k+1])\n",
    "plt.yticks(np.arange(0, 0.9, 0.2), fontsize = fontsize)\n",
    "plt.xticks([])\n",
    "plt.ylabel(\"frequency\", fontsize = fontsize+15)\n",
    "plt.grid(axis=\"y\")\n",
    "plt.ylim(0, 0.8)\n",
    "\n",
    "space = 0\n",
    "plt.subplot(hist_grid[0, 1])\n",
    "for k in range (len(ks)):\n",
    "    plt.bar(xx+space, np.mean(TEST_k_of_n_emnist[k], axis=0), width=w, color=light_colors[k])\n",
    "    space = space + w + space_between_bars\n",
    "    plt.axvline(x=np.sum(np.mean(TEST_k_of_n_emnist[k], axis=0)*np.arange(n_actions)), lw=3, linestyle=\"-.\", color=light_colors[k])\n",
    "plt.bar(xx+space, np.mean(greedy_emnist, axis=0), width=w, color=light_colors[k+1])\n",
    "plt.axvline(x=np.sum(np.mean(greedy_emnist, axis=0)*np.arange(n_actions)), lw=3, linestyle=\"-.\", color=light_colors[k+1])\n",
    "plt.xticks([])\n",
    "plt.yticks(np.arange(0, 0.9, 0.2), fontsize=0)\n",
    "plt.grid(axis=\"y\")\n",
    "plt.ylim(0, 0.8)\n",
    "\n",
    "space = 0       \n",
    "plt.subplot(hist_grid[1, 0])\n",
    "for k in range (len(ks)):\n",
    "    plt.bar(xx+space, np.mean(test_k_of_n_fashion[k], axis=0), width=w, color=light_colors[k])\n",
    "    space = space + w + space_between_bars\n",
    "    plt.axvline(x=np.sum(np.mean(test_k_of_n_fashion[k], axis=0)*np.arange(n_actions)), lw=3, linestyle=\"-.\", color=light_colors[k])\n",
    "plt.bar(xx+space, np.mean(greedy_fashion, axis=0), width=w, color=light_colors[k+1])\n",
    "plt.axvline(x=np.sum(np.mean(greedy_fashion, axis=0)*np.arange(n_actions)), lw=3, linestyle=\"-.\", color=light_colors[k+1])\n",
    "plt.xticks(0.2+xx, labelss, fontsize = fontsize+5)\n",
    "plt.yticks(np.arange(0, 0.9, 0.2), fontsize = fontsize)\n",
    "plt.ylabel(\"frequency\", fontsize = fontsize+15)\n",
    "plt.xlabel('action', fontsize=fontsize+15)\n",
    "plt.ylim(0, 0.8)\n",
    "plt.grid(axis=\"y\")\n",
    "  \n",
    "space = 0\n",
    "plt.subplot(hist_grid[1, 1])\n",
    "hist_bars = []\n",
    "for k in range (len(ks)):\n",
    "    hist_bars.append(plt.bar(xx+space, np.mean(test_k_of_n_emnist[k], axis=0), width=w, color=light_colors[k]))\n",
    "    space = space + w + space_between_bars\n",
    "    plt.axvline(x=np.sum(np.mean(test_k_of_n_emnist[k], axis=0)*np.arange(n_actions)), lw=3, linestyle=\"-.\", color=light_colors[k])\n",
    "hist_bars.append(plt.bar(xx+space, np.mean(greedy_emnist, axis=0), width=w, color=light_colors[k+1]))\n",
    "plt.axvline(x=np.sum(np.mean(greedy_emnist, axis=0)*np.arange(n_actions)), lw=3, linestyle=\"-.\", color=light_colors[k+1])\n",
    "\n",
    "plt.xticks(0.2+xx, labelss, fontsize=fontsize+5)\n",
    "plt.xlabel('action', fontsize=fontsize+15)\n",
    "plt.yticks(np.arange(0, 0.9, 0.2), fontsize=0)\n",
    "plt.ylim(0, 0.8)\n",
    "plt.grid(axis=\"y\")\n",
    "        \n",
    "hist_fig.legend(hist_bars, legends, fontsize=fontsize+5, ncol=len(ns)+1, loc=(0.09, .91))\n",
    "plt.subplots_adjust(bottom=0.5, top=0.7)\n",
    "plt.tight_layout(w_pad=3.0, h_pad=-9.0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## k-of-n convergence (Expected Value)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "exp = np.zeros((6, len(ks), n_runs, n_itr))\n",
    "exp_conf = np.zeros((6, len(ks), n_itr))\n",
    "\n",
    "for k in range (len(ks)):\n",
    "    for run in range (n_runs):\n",
    "        exp[0, k, run] = np.load('actions/all-images regime/{}/run_{}_expected_value_mnist_{}-of-{}_n_itr_{}.npy'.format(models_dir_split[1], run, ks[k], ns[k], n_itr)) / (n_actions * len(mnist_labels))\n",
    "        exp[1, k, run] = np.load('actions/all-images regime/{}/run_{}_expected_value_fashion_{}-of-{}_n_itr_{}.npy'.format(models_dir_split[1], run, ks[k], ns[k], n_itr)) / (n_actions *len(fashion_labels))\n",
    "        exp[2, k, run] = np.load('actions/all-images regime/{}/run_{}_expected_value_emnist_{}-of-{}_n_itr_{}.npy'.format(models_dir_split[1], run, ks[k], ns[k], n_itr)) / (n_actions *len(emnist_labels))\n",
    "        \n",
    "        exp[3, k, run] = np.load('actions/single-image regime/{}/run_{}_expected_value_mnist_{}-of-{}_n_itr_{}.npy'.format(models_dir_split[1], run, ks[k], ns[k], n_itr)) / (n_actions *len(mnist_labels))\n",
    "        exp[4, k, run] = np.load('actions/single-image regime/{}/run_{}_expected_value_fashion_{}-of-{}_n_itr_{}.npy'.format(models_dir_split[1], run, ks[k], ns[k], n_itr)) / (n_actions *len(fashion_labels))\n",
    "        exp[5, k, run] = np.load('actions/single-image regime/{}/run_{}_expected_value_emnist_{}-of-{}_n_itr_{}.npy'.format(models_dir_split[1], run, ks[k], ns[k], n_itr)) / (n_actions *len(emnist_labels))\n",
    "\n",
    "        for i in range (6):\n",
    "            for itr in range (n_itr):\n",
    "                exp_conf[i, k, itr] = mean_confidence_interval(exp[i, k, :, itr])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "exp_legends = ['1-of-20', '1-of-10', '5-of-10', '10-of-10']\n",
    "exp_fig = plt.figure(figsize=(20,8))\n",
    "exp_bars = []\n",
    "for f in range (6):\n",
    "    plt.subplot(2, 3, f+1)\n",
    "    for k in range (len(ks)):\n",
    "        m = np.mean(exp[f, k], axis=0)\n",
    "        for run in range (n_runs):\n",
    "            plt.plot(exp[f, k, run], color=light_colors[k], lw=1, alpha=0.5)\n",
    "        a, = plt.plot(m, label=\"{}-of-{}\".format(ks[k], ns[k]), color=light_colors[k], lw=4)\n",
    "#         plt.fill_between(np.arange(n_itr), m - exp_conf[f, k], m + exp_conf[f, k], alpha=0.3, color=light_colors[k])\n",
    "        if f == len(ks)-1:\n",
    "            exp_bars.append(a)\n",
    "        if f%3 == 0:\n",
    "            plt.text(-32, 0.1, s=r'$u_{\\mathcal{M}}(\\pi^{t}; \\bar{r}^{t})$', fontsize=fontsize+5, rotation='vertical')\n",
    "#             plt.ylabel(ylabel=r'$u_{\\mathcal{M}}(\\pi^{t}; \\bar{r}^{t})$', fontsize=fontsize+5, position=(3*f, 0.5))\n",
    "        if f > 2:\n",
    "            plt.xlabel('iteration', fontsize=fontsize+10)\n",
    "            plt.xticks(fontsize=fontsize)\n",
    "        if f < 3:\n",
    "            plt.xticks(fontsize=0)\n",
    "            \n",
    "#     plt.ylim(0, )\n",
    "    plt.yticks(fontsize=fontsize)\n",
    "    plt.grid(axis=\"y\")\n",
    "exp_fig.legend(exp_bars, exp_legends, fontsize=fontsize+5, ncol=len(ks), loc=(0.18, .9), bbox_transform = plt.gcf().transFigure)\n",
    "plt.subplots_adjust(top=0.7, bottom=0.4)\n",
    "plt.tight_layout(w_pad=2.5, h_pad=-4.0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Heatmap E-MNIST test 1-of-20 TEST 1-0f-20 and baseline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "TEST_emnist_heatmap = np.zeros((n_runs, len(emnist_labels), n_actions))\n",
    "test_emnist_heatmap = np.zeros((n_runs, len(emnist_labels), n_actions))\n",
    "baseline_emnist_heatmap = np.load('{}/hard_max_emnist_actions.npy'.format(baseline_dir))\n",
    "for run in range (n_runs):\n",
    "    TEST_emnist_heatmap[run] = np.load('actions/all-images regime/{}/run_{}_emnist_actions_1-of-20_n_itr_100.npy'.format(models_dir_split[1], run))\n",
    "    test_emnist_heatmap[run] = np.load('actions/single-image regime/{}/run_{}_emnist_actions_1-of-20_n_itr_100.npy'.format(models_dir_split[1], run))\n",
    "\n",
    "TEST_emnist_m = np.zeros((len(emnist_labels_names), n_actions))  \n",
    "test_emnist_m = np.zeros((len(emnist_labels_names), n_actions))  \n",
    "baseline_emnist_m = np.zeros((len(emnist_labels_names), n_actions))  \n",
    "for i in range (len(emnist_labels_names)):\n",
    "    TEST_emnist_m[i] = np.round(np.mean(np.mean(TEST_emnist_heatmap[:, np.where(emnist_labels.numpy()==(i+1))[0],:], 1), 0), 1)\n",
    "    test_emnist_m[i] = np.round(np.mean(np.mean(test_emnist_heatmap[:, np.where(emnist_labels.numpy()==(i+1))[0],:], 1), 0), 1)    \n",
    "    baseline_emnist_m[i] = np.round(np.mean(np.mean(baseline_emnist_heatmap[:, np.where(emnist_labels.numpy()==(i+1))[0],:], 1), 0), 1)    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.set(font_scale=1.4)\n",
    "fig_heatmap = plt.figure(figsize=(20, 10))\n",
    "grid = plt.GridSpec(1, 3, figure=fig_heatmap, width_ratios=[1, 1, 1.3])\n",
    "\n",
    "plt.subplot(grid[0])\n",
    "ax = sns.heatmap(TEST_emnist_m, annot=True, linewidths=1,  cmap=\"Blues\" , vmin=0, vmax=1, cbar=False, linecolor='black')\n",
    "plt.yticks(np.arange(len(emnist_labels_names))+0.5, emnist_labels_names, rotation='horizontal', fontsize=fontsize-6)\n",
    "plt.xlabel(\"action\", fontsize=fontsize+10)\n",
    "plt.ylabel(\"letter\", fontsize=fontsize+10)\n",
    "plt.xticks(np.arange(n_actions)+0.5, np.arange(n_actions), fontsize=fontsize-3)\n",
    "ax.xaxis.tick_top()\n",
    "\n",
    "plt.subplot(grid[1])\n",
    "ax = sns.heatmap(test_emnist_m, annot=True, linewidths=1,  cmap=\"Blues\" , vmin=0, vmax=1, cbar=False, linecolor='black')\n",
    "plt.yticks(np.arange(len(emnist_labels_names))+0.5, emnist_labels_names, rotation='horizontal', fontsize=fontsize-6)\n",
    "plt.xlabel(\"action\", fontsize=fontsize+10)\n",
    "plt.xticks(np.arange(n_actions)+0.5, np.arange(n_actions), fontsize=fontsize-3)\n",
    "ax.xaxis.tick_top()\n",
    "\n",
    "plt.subplot(grid[2])\n",
    "ax = sns.heatmap(baseline_emnist_m, annot=True, linewidths=1,  cmap=\"Blues\" , vmin=0, vmax=1,linecolor='black')\n",
    "plt.yticks(np.arange(len(emnist_labels_names))+0.5, emnist_labels_names, rotation='horizontal', fontsize=fontsize-6)\n",
    "plt.xlabel(\"action\", fontsize=fontsize+10)\n",
    "plt.xticks(np.arange(n_actions)+0.5, np.arange(n_actions), fontsize=fontsize-3)\n",
    "ax.xaxis.tick_top()\n",
    "\n",
    "plt.tight_layout(w_pad=1.0)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
