{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Experiment **4.1 and 4.4**: \n",
    "\n",
    "- **4.1 Learning to Ask for Help**\n",
    "\n",
    "- **4.4 How Caution Depends on the Extent of Training Data**"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Import libs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import torchvision, torch\n",
    "from IPython import display\n",
    "from torchvision import transforms as transforms\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib import style\n",
    "import scipy\n",
    "import colorsys\n",
    "import seaborn as sns\n",
    "transform = transforms.ToTensor()\n",
    "## plot style and font\n",
    "style.use('seaborn-colorblind')\n",
    "transform = transforms.ToTensor()\n",
    "font_params = {}\n",
    "font_params['family'] = 'serif'\n",
    "plt.rc('font', **font_params)\n",
    "plt.rc('pdf', fonttype=42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def mean_confidence_interval(data, confidence=0.95):\n",
    "    n = len(data)\n",
    "    m, se = np.mean(data), scipy.stats.sem(data)\n",
    "    h = se * scipy.stats.t.ppf((1 + confidence) / 2., n-1)\n",
    "    return h\n",
    "\n",
    "dark_hues = [0, 0.1, 0.4, 0.55, 0.65, 0.75, 0.9]\n",
    "dark_lightness = [0.4, 0.2, 0.15, 0.4, 0.4, 0.4, 0.4]\n",
    "light_hues = [1, 0.1, 0.3, 0.5, 0.6, 0.75, 0.85]\n",
    "light_lightness = [0.75, 0.6, 0.45, 0.6, 0.65, 0.8, 0.55]\n",
    "dark_colors = [colorsys.hls_to_rgb(h, dark_lightness[i], 1) for i, h in enumerate(dark_hues)]\n",
    "sns.palplot(dark_colors)\n",
    "light_colors = [colorsys.hls_to_rgb(h, light_lightness[i], 1) for i, h in enumerate(light_hues)]\n",
    "sns.palplot(light_colors)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load datasets (MNIST, Fashion, and E-MNIST) and pre-processing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mnist_test = torchvision.datasets.MNIST(root='datasets', train=False, download=True, transform=transform)\n",
    "\n",
    "fashion_test = torchvision.datasets.FashionMNIST(root='datasets', train=False, download=True, transform=transform)\n",
    "\n",
    "emnist_test  = torchvision.datasets.EMNIST(root=\"datasets\", train=False, transform=transform, target_transform=None, download=True, split=\"letters\")\n",
    "\n",
    "mnist_labels = torch.zeros((len(mnist_test)))\n",
    "fashion_labels = torch.zeros((len(fashion_test)))\n",
    "emnist_labels = torch.zeros((len(emnist_test)))\n",
    "        \n",
    "for i in range (len(mnist_test)):\n",
    "    mnist_labels[i] = mnist_test[i][1]\n",
    "    \n",
    "    fashion_labels[i] = fashion_test[i][1]\n",
    "    \n",
    "    emnist_labels[i] = emnist_test[i][1]\n",
    "    \n",
    "for j in range (i, len(emnist_test)):\n",
    "    emnist_labels[j] = emnist_test[j][1]\n",
    "    \n",
    "fashion_labels_names = [\"T-shirt/top\",  \"Trouser\", \"Pullover\" , \"Dress\" , \"Coat\" , \"Sandal\" , \"Shirt\", \"Sneaker\" ,\"Bag\" ,\"Ankle boot\"]\n",
    "emnist_labels_names = [\"A/a\", \"B/b\", \"C/c\", \"D/d\", \"E/e\", \"F/f\", \"G/g\", \"H/h\", \"I/i\", \"J/j\", \"K/k\", \"L/l\", \"M/m\", \"N/n\", \"O/o\", \"P/p\", \"Q/q\", \"R/r\", \"S/s\", \"T/t\", \"U/u\", \"V/v\", \"W/w\", \"X/x\", \"Y/y\", \"Z/z\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Plot "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## load k-of-n actions "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "models_dir = \"models/last_action_3_fc\" # directory where each model's papamters will be saved\n",
    "models_dir_split = models_dir.split('/')\n",
    "fig_dir = 'fig/' + models_dir_split[1]\n",
    "if not os.path.exists(fig_dir):\n",
    "    os.makedirs(fig_dir)\n",
    "n_models = 2000\n",
    "n_itr = 100\n",
    "n_actions = 11\n",
    "n_runs = 10\n",
    "\n",
    "ks = [1, 1, 5, 10]\n",
    "ns = [20, 10, 10, 10]\n",
    "\n",
    "test_k_of_n_mnist = np.zeros((len(ks), n_runs, n_actions))\n",
    "test_k_of_n_fashion = np.zeros((len(ks), n_runs, n_actions))\n",
    "test_k_of_n_emnist = np.zeros((len(ks), n_runs,  n_actions))\n",
    "\n",
    "TEST_k_of_n_mnist = np.zeros((len(ks), n_runs, n_actions))\n",
    "TEST_k_of_n_fashion = np.zeros((len(ks), n_runs, n_actions))\n",
    "TEST_k_of_n_emnist = np.zeros((len(ks), n_runs,  n_actions))\n",
    "\n",
    "for run in range (n_runs):\n",
    "    for k in range (len(ks)):\n",
    "        TEST_k_of_n_mnist[k, run] = np.mean(np.load('actions/all-images regime/{}/run_{}_mnist_actions_{}-of-{}_n_itr_100.npy'.format(models_dir_split[1], run, ks[k], ns[k])), axis=0)\n",
    "        TEST_k_of_n_fashion[k, run] = np.mean(np.load('actions/all-images regime/{}/run_{}_fashion_actions_{}-of-{}_n_itr_100.npy'.format(models_dir_split[1], run, ks[k], ns[k])), axis=0)\n",
    "        TEST_k_of_n_emnist[k, run] = np.mean(np.load('actions/all-images regime/{}/run_{}_emnist_actions_{}-of-{}_n_itr_100.npy'.format(models_dir_split[1], run, ks[k], ns[k])), axis=0)\n",
    "        \n",
    "        test_k_of_n_mnist[k, run] = np.mean(np.load('actions/single-image regime/{}/run_{}_mnist_actions_{}-of-{}_n_itr_100.npy'.format(models_dir_split[1], run, ks[k], ns[k])), axis=0)\n",
    "        test_k_of_n_fashion[k, run] = np.mean(np.load('actions/single-image regime/{}/run_{}_fashion_actions_{}-of-{}_n_itr_100.npy'.format(models_dir_split[1], run, ks[k], ns[k])), axis=0)\n",
    "        test_k_of_n_emnist[k, run] = np.mean(np.load('actions/single-image regime/{}/run_{}_emnist_actions_{}-of-{}_n_itr_100.npy'.format(models_dir_split[1], run, ks[k], ns[k])), axis=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## load Standard RL actions "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "baseline_dir = 'actions/all-images regime/'+models_dir_split[1]\n",
    "greedy_mnist = np.mean(np.load('{}/hard_max_mnist_actions.npy'.format(baseline_dir)), axis=1) # take each model as a policy and average over polcies predictions\n",
    "greedy_fashion = np.mean(np.load('{}/hard_max_fashion_actions.npy'.format(baseline_dir)), axis=1) # take each model as a policy and average over polcies predictions\n",
    "greedy_emnist = np.mean(np.load('{}/hard_max_emnist_actions.npy'.format(baseline_dir)), axis=1) # take each model as a policy and average over polcies predictions\n",
    "\n",
    "error_greedy_fashion = mean_confidence_interval(greedy_fashion[:, -1])\n",
    "error_greedy_emnist = mean_confidence_interval(greedy_emnist[:, -1])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Fashion and E-MNIST"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "w = 0.055\n",
    "space_bars = 0.002\n",
    "space_between_datasets = 0.4\n",
    "fontsize = 25\n",
    "scatter_x = np.arange(-5, n_runs-5)/200\n",
    "labelss = [\"0\", \"1\", \"2\", \"3\", \"4\", \"5\", \"6\", \"7\", \"8\", \"9\", \"help\"]\n",
    "legends = ['1-of-20', '1-of-10', '5-of-10', '10-of-10', 'greedy'+r\"$(\\hat{r})$\"]\n",
    "labelss = [\"help\"]\n",
    "\n",
    "xx = np.arange(len(labelss))\n",
    "fig = plt.figure(figsize=(20,5))\n",
    "plt.subplot(1, 2, 1)\n",
    "space = 0\n",
    "for i in range (len(ns)):\n",
    "    plt.scatter(xx + space + scatter_x, np.sort(TEST_k_of_n_fashion[i, :, -1]), s=55, marker='*', color=dark_colors[i])\n",
    "    plt.bar(xx + space , np.mean(TEST_k_of_n_fashion[i], axis=0)[-1], width=w,  alpha=0.8, label=\"{}-of-{}\".format(ks[i], ns[i]), color=light_colors[i])\n",
    "    plt.bar(xx + space + space_between_datasets , np.mean(TEST_k_of_n_emnist[i], axis=0)[-1], width=w, alpha=0.8, color=light_colors[i])\n",
    "    plt.scatter(xx + space + scatter_x + space_between_datasets, np.sort(TEST_k_of_n_emnist[i, :, -1]), s=55, marker='*', color=dark_colors[i])\n",
    "\n",
    "    space = space + w + space_bars\n",
    "\n",
    "plt.bar(xx + space, np.mean(greedy_fashion, axis=0)[-1], width=w, yerr=error_greedy_fashion, linewidth=10, capsize=10, alpha=0.8, label='greedy'+r\"$(\\hat{r})$\", color=light_colors[i+1])\n",
    "plt.bar(xx + space + space_between_datasets, np.mean(greedy_emnist, axis=0)[-1], width=w, yerr=error_greedy_emnist, linewidth=10, capsize=10, alpha=0.8, color=light_colors[i+1])\n",
    "\n",
    "plt.text(0.1, -0.18, 'f', fontsize = fontsize+10)\n",
    "plt.text(0.5, -0.18, r'$\\alpha$', fontsize = fontsize+10)\n",
    "plt.xticks([])\n",
    "plt.yticks(fontsize = fontsize)\n",
    "plt.ylabel(\"frequency\", fontsize = fontsize+15)\n",
    "plt.ylim(-0.001, 1)\n",
    "plt.grid(axis=\"y\")  \n",
    "\n",
    "plt.subplot(1, 2, 2)\n",
    "space = 0\n",
    "bars = []\n",
    "for i in range (len(ns)):\n",
    "    plt.scatter(xx + space + scatter_x, np.sort(test_k_of_n_fashion[i, :, -1]), s=55, marker='*', color=dark_colors[i])\n",
    "    bars.append(plt.bar(xx + space , np.mean(test_k_of_n_fashion[i], axis=0)[-1], width=w,  alpha=0.8, label=\"{}-of-{}\".format(ks[i], ns[i]), color=light_colors[i]))\n",
    "    plt.bar(xx + space + space_between_datasets , np.mean(test_k_of_n_emnist[i], axis=0)[-1], width=w, alpha=0.8, color=light_colors[i])\n",
    "    plt.scatter(xx + space + scatter_x + space_between_datasets, np.sort(test_k_of_n_emnist[i, :, -1]), s=55, marker='*', color=dark_colors[i])\n",
    "\n",
    "    space = space + w + space_bars\n",
    "\n",
    "bars.append(plt.bar(xx + space, np.mean(greedy_fashion, axis=0)[-1], width=w, yerr=error_greedy_fashion, capsize=10, linewidth=10, alpha=0.8, label='greedy'+r\"$(\\hat{r})$\", color=light_colors[i+1]))\n",
    "plt.bar(xx + space + space_between_datasets, np.mean(greedy_emnist, axis=0)[-1], width=w, yerr=error_greedy_emnist, linewidth=10, capsize=10, alpha=0.8, color=light_colors[i+1])\n",
    "\n",
    "plt.grid(axis=\"y\")\n",
    "plt.xticks([])\n",
    "plt.yticks(fontsize=0)\n",
    "plt.text(0.1, -0.18, 'f', fontsize = fontsize+10)\n",
    "plt.text(0.5, -0.18, r'$\\alpha$', fontsize = fontsize+10)\n",
    "plt.ylim(-0.001, 1)\n",
    "fig.legend(bars, legends, fontsize=fontsize+5, ncol=len(ns)+1, loc=(0.08, .8), bbox_transform = plt.gcf().transFigure)\n",
    "plt.subplots_adjust(top=0.9, bottom=0.7)\n",
    "plt.tight_layout(w_pad=1.0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Fahion and E-MNIST histogram"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "w = 0.1\n",
    "space_between_bars = 0.00\n",
    "fontsize = 25\n",
    "hist_fig = plt.figure(figsize=(20,10))\n",
    "hist_grid = plt.GridSpec(2, 2, figure=hist_fig, width_ratios=[1, 1])\n",
    "xx = np.arange(n_actions)\n",
    "legends = ['1-of-20', '1-of-10', '5-of-10', '10-of-10', 'greedy'+r\"$(\\hat{r})$\"]\n",
    "labelss = [\"0\", \"1\", \"2\", \"3\", \"4\", \"5\", \"6\", \"7\", \"8\", \"9\", \"help\"]\n",
    "\n",
    "space = 0\n",
    "plt.subplot(hist_grid[0, 0])\n",
    "for k in range (len(ks)):\n",
    "    plt.bar(xx+space, np.mean(TEST_k_of_n_fashion[k], axis=0), width=w, color=light_colors[k])\n",
    "    space = space + w + space_between_bars\n",
    "plt.bar(xx+space, np.mean(greedy_fashion, axis=0),width=w, color=light_colors[k+1])\n",
    "plt.yticks(np.arange(0, 1.1, 0.2), fontsize = fontsize)\n",
    "plt.xticks([])\n",
    "plt.ylabel(\"frequency\", fontsize = fontsize+15)\n",
    "plt.grid(axis=\"y\")\n",
    "plt.ylim(0, 1)\n",
    "\n",
    "space = 0\n",
    "plt.subplot(hist_grid[0, 1])\n",
    "for k in range (len(ks)):\n",
    "    plt.bar(xx+space, np.mean(TEST_k_of_n_emnist[k], axis=0), width=w, color=light_colors[k])\n",
    "    space = space + w + space_between_bars\n",
    "plt.bar(xx+space, np.mean(greedy_emnist, axis=0), width=w, color=light_colors[k+1])\n",
    "plt.xticks([])\n",
    "plt.yticks(np.arange(0, 1.1, 0.2), fontsize=0)\n",
    "plt.grid(axis=\"y\")\n",
    "plt.ylim(0, 1)\n",
    "\n",
    "space = 0       \n",
    "plt.subplot(hist_grid[1, 0])\n",
    "for k in range (len(ks)):\n",
    "    plt.bar(xx+space, np.mean(test_k_of_n_fashion[k], axis=0), width=w, color=light_colors[k])\n",
    "    space = space + w + space_between_bars\n",
    "plt.bar(xx+space, np.mean(greedy_fashion, axis=0), width=w, color=light_colors[k+1])\n",
    "plt.xticks(0.2+xx, labelss, fontsize = fontsize+5)\n",
    "plt.yticks(np.arange(0, 1.1, 0.2), fontsize = fontsize)\n",
    "plt.ylabel(\"frequency\", fontsize = fontsize+15)\n",
    "plt.xlabel('action', fontsize=fontsize+15)\n",
    "plt.ylim(0, 1)\n",
    "plt.grid(axis=\"y\")\n",
    "  \n",
    "space = 0\n",
    "plt.subplot(hist_grid[1, 1])\n",
    "hist_bars = []\n",
    "for k in range (len(ks)):\n",
    "    hist_bars.append(plt.bar(xx+space, np.mean(test_k_of_n_emnist[k], axis=0), width=w, color=light_colors[k]))\n",
    "    space = space + w + space_between_bars\n",
    "hist_bars.append(plt.bar(xx+space, np.mean(greedy_emnist, axis=0), width=w, color=light_colors[k+1]))\n",
    "plt.xticks(0.2+xx, labelss, fontsize=fontsize+5)\n",
    "plt.xlabel('action', fontsize=fontsize+15)\n",
    "plt.yticks(np.arange(0, 1.1, 0.2), fontsize=0)\n",
    "plt.ylim(0, 1)\n",
    "plt.grid(axis=\"y\")\n",
    "        \n",
    "hist_fig.legend(hist_bars, legends, fontsize=fontsize+5, ncol=len(ns)+1, loc=(0.09, .91))\n",
    "plt.subplots_adjust(bottom=0.5, top=0.7)\n",
    "plt.tight_layout(w_pad=3.0, h_pad=-9.0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## k-of-n convergence (Expected Value)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "exp = np.zeros((6, len(ks), n_runs, n_itr))\n",
    "exp_conf = np.zeros((6, len(ks), n_itr))\n",
    "\n",
    "for k in range (len(ks)):\n",
    "    for run in range (n_runs):\n",
    "        exp[0, k, run] = np.load('actions/all-images regime/{}/run_{}_expected_value_mnist_{}-of-{}_n_itr_{}.npy'.format(models_dir_split[1], run, ks[k], ns[k], n_itr)) / (n_actions * len(mnist_labels))\n",
    "        exp[1, k, run] = np.load('actions/all-images regime/{}/run_{}_expected_value_fashion_{}-of-{}_n_itr_{}.npy'.format(models_dir_split[1], run, ks[k], ns[k], n_itr)) / (n_actions *len(fashion_labels))\n",
    "        exp[2, k, run] = np.load('actions/all-images regime/{}/run_{}_expected_value_emnist_{}-of-{}_n_itr_{}.npy'.format(models_dir_split[1], run, ks[k], ns[k], n_itr)) / (n_actions *len(emnist_labels))\n",
    "        \n",
    "        exp[3, k, run] = np.load('actions/single-image regime/{}/run_{}_expected_value_mnist_{}-of-{}_n_itr_{}.npy'.format(models_dir_split[1], run, ks[k], ns[k], n_itr)) / (n_actions *len(mnist_labels))\n",
    "        exp[4, k, run] = np.load('actions/single-image regime/{}/run_{}_expected_value_fashion_{}-of-{}_n_itr_{}.npy'.format(models_dir_split[1], run, ks[k], ns[k], n_itr)) / (n_actions *len(fashion_labels))\n",
    "        exp[5, k, run] = np.load('actions/single-image regime/{}/run_{}_expected_value_emnist_{}-of-{}_n_itr_{}.npy'.format(models_dir_split[1], run, ks[k], ns[k], n_itr)) / (n_actions *len(emnist_labels))\n",
    "\n",
    "        for i in range (6):\n",
    "            for itr in range (n_itr):\n",
    "                exp_conf[i, k, itr] = mean_confidence_interval(exp[i, k, :, itr])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fontsize = 25\n",
    "labelss = [\"0\", \"1\", \"2\", \"3\", \"4\", \"5\", \"6\", \"7\", \"8\", \"9\", \"help\"]\n",
    "exp_legends = ['1-of-20', '1-of-10', '5-of-10', '10-of-10']\n",
    "exp_fig = plt.figure(figsize=(20,8))\n",
    "exp_bars = []\n",
    "for f in range (6):\n",
    "    plt.subplot(2, 3, f+1)\n",
    "    for k in range (len(ks)):\n",
    "        m = np.mean(exp[f, k], axis=0)\n",
    "        for run in range (n_runs):\n",
    "            plt.plot(exp[f, k, run], color=light_colors[k], lw=1, alpha=0.5)\n",
    "        a, = plt.plot(m, label=\"{}-of-{}\".format(ks[k], ns[k]), color=light_colors[k], lw=4)\n",
    "        if f == len(ks)-1:\n",
    "            exp_bars.append(a)\n",
    "        if f%3 == 0:\n",
    "            plt.text(-32, 0.1, s=r'$u_{\\mathcal{M}}(\\pi^{t}; \\bar{r}^{t})$', fontsize=fontsize+5, rotation='vertical')\n",
    "        if f > 2:\n",
    "            plt.xlabel('iteration', fontsize=fontsize+10)\n",
    "            plt.xticks(fontsize=fontsize)\n",
    "        if f < 3:\n",
    "            plt.xticks(fontsize=0)\n",
    "            \n",
    "    plt.yticks(fontsize=fontsize)\n",
    "    plt.grid(axis=\"y\")\n",
    "exp_fig.legend(exp_bars, exp_legends, fontsize=fontsize+5, ncol=len(ks), loc=(0.18, .9), bbox_transform = plt.gcf().transFigure)\n",
    "plt.subplots_adjust(top=0.7, bottom=0.4)\n",
    "plt.tight_layout(w_pad=2.5, h_pad=-4.0)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Heatmap E-MNIST test 1-of-20 TEST 1-0f-20 and baseline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "TEST_emnist_heatmap = np.zeros((n_runs, len(emnist_labels), n_actions))\n",
    "test_emnist_heatmap = np.zeros((n_runs, len(emnist_labels), n_actions))\n",
    "baseline_emnist_heatmap = np.load('{}/hard_max_emnist_actions.npy'.format(baseline_dir))\n",
    "\n",
    "for run in range (n_runs):\n",
    "    TEST_emnist_heatmap[run] = np.load('actions/all-images regime/{}/run_{}_emnist_actions_1-of-20_n_itr_100.npy'.format(models_dir_split[1], run))\n",
    "    test_emnist_heatmap[run] = np.load('actions/single-image regime/{}/run_{}_emnist_actions_1-of-20_n_itr_100.npy'.format(models_dir_split[1], run))\n",
    "\n",
    "TEST_emnist_m = np.zeros((len(emnist_labels_names), n_actions))  \n",
    "test_emnist_m = np.zeros((len(emnist_labels_names), n_actions))  \n",
    "baseline_emnist_m = np.zeros((len(emnist_labels_names), n_actions))  \n",
    "\n",
    "for i in range (len(emnist_labels_names)):\n",
    "    TEST_emnist_m[i] = np.round(np.mean(np.mean(TEST_emnist_heatmap[:, np.where(emnist_labels.numpy()==(i+1))[0],:], 1), 0), 1)\n",
    "    test_emnist_m[i] = np.round(np.mean(np.mean(test_emnist_heatmap[:, np.where(emnist_labels.numpy()==(i+1))[0],:], 1), 0), 1)    \n",
    "    baseline_emnist_m[i] = np.round(np.mean(np.mean(baseline_emnist_heatmap[:, np.where(emnist_labels.numpy()==(i+1))[0],:], 1), 0), 1)    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fontsize = 25\n",
    "labelss = [\"0\", \"1\", \"2\", \"3\", \"4\", \"5\", \"6\", \"7\", \"8\", \"9\", \"help\"]\n",
    "\n",
    "sns.set(font_scale=1.4)\n",
    "fig_heatmap = plt.figure(figsize=(20, 10))\n",
    "grid = plt.GridSpec(1, 3, figure=fig_heatmap, width_ratios=[1, 1, 1.3])\n",
    "\n",
    "ax = plt.subplot(grid[0])\n",
    "sns.heatmap(TEST_emnist_m, annot=True, linewidths=1,  cmap=\"Blues\" , vmin=0, vmax=1, cbar=False, linecolor='black')\n",
    "plt.yticks(np.arange(len(emnist_labels_names))+0.5, emnist_labels_names, rotation='horizontal', fontsize=fontsize-6)\n",
    "plt.xlabel(\"action\", fontsize=fontsize+10)\n",
    "plt.ylabel(\"letter\", fontsize=fontsize+10)\n",
    "plt.xticks(np.arange(n_actions)+0.4, labelss, fontsize=fontsize-3) \n",
    "ax.xaxis.tick_top()\n",
    "\n",
    "ax = plt.subplot(grid[1])\n",
    "sns.heatmap(test_emnist_m, annot=True, linewidths=1,  cmap=\"Blues\" , vmin=0, vmax=1, cbar=False, linecolor='black')\n",
    "\n",
    "plt.yticks(np.arange(len(emnist_labels_names))+0.5, emnist_labels_names, rotation='horizontal', fontsize=fontsize-6)\n",
    "plt.xlabel(\"action\", fontsize=fontsize+10)\n",
    "plt.xticks(np.arange(n_actions)+0.4, labelss, fontsize=fontsize-3) \n",
    "ax.xaxis.tick_top()\n",
    "\n",
    "ax = plt.subplot(grid[2])\n",
    "sns.heatmap(baseline_emnist_m, annot=True, linewidths=1,  cmap=\"Blues\" , vmin=0, vmax=1, linecolor='black')\n",
    "plt.yticks(np.arange(len(emnist_labels_names))+0.5, emnist_labels_names, rotation='horizontal', fontsize=fontsize-6)\n",
    "plt.xlabel(\"action\", fontsize=fontsize+10)\n",
    "plt.xticks(np.arange(n_actions)+0.4, labelss, fontsize=fontsize-3)\n",
    "ax.xaxis.tick_top()\n",
    "\n",
    "plt.tight_layout(w_pad=1.0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# How Caution Depends on the Extent of Training Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "std_models_dir_100 = \"actions/How Caution Depends on the Extent of Training Data/100%\"\n",
    "std_models_dir_10 = \"actions/How Caution Depends on the Extent of Training Data/10%\"\n",
    "std_models_dir_1 = \"actions/How Caution Depends on the Extent of Training Data/1%\" \n",
    "\n",
    "std_models_dir = [std_models_dir_1.split('/')[1], std_models_dir_10.split('/')[1], std_models_dir_100.split('/')[1]]\n",
    "\n",
    "std_models_dir_split = std_models_dir_100.split('/')[1]\n",
    "\n",
    "std_fig_dir = 'fig/' + std_models_dir_split\n",
    "\n",
    "if not os.path.exists(std_fig_dir):\n",
    "    os.makedirs(std_fig_dir)\n",
    "    \n",
    "ks = [1, 1, 5, 10]\n",
    "ns = [20, 10, 10, 10]\n",
    "n_runs = 10\n",
    "n_actions = 11\n",
    "\n",
    "TEST_mnist_std = np.zeros((3, len(ks), n_runs, n_actions))\n",
    "TEST_fashion_std = np.zeros((3, len(ks), n_runs, n_actions))\n",
    "TEST_emnist_std = np.zeros((3, len(ks), n_runs, n_actions))\n",
    "\n",
    "test_mnist_std = np.zeros((3, len(ks), n_runs, n_actions))\n",
    "test_fashion_std = np.zeros((3, len(ks), n_runs, n_actions))\n",
    "test_emnist_std = np.zeros((3, len(ks), n_runs, n_actions))\n",
    "\n",
    "for i in range (3):\n",
    "    for j, k in enumerate (ks):\n",
    "        for run in range (n_runs):\n",
    "            TEST_mnist_std[i, j, run] = np.mean(np.load('actions/all-images regime/{}/run_{}_mnist_actions_{}-of-{}_n_itr_{}.npy'.format(std_models_dir[i], run, k, ns[j], 100)), axis=0)\n",
    "            TEST_fashion_std[i, j, run] = np.mean(np.load('actions/all-images regime/{}/run_{}_fashion_actions_{}-of-{}_n_itr_{}.npy'.format(std_models_dir[i], run, k, ns[j], 100)), axis=0)\n",
    "            TEST_emnist_std[i, j, run] = np.mean(np.load('actions/all-images regime/{}/run_{}_emnist_actions_{}-of-{}_n_itr_{}.npy'.format(std_models_dir[i], run, k, ns[j], 100)), axis=0)\n",
    "\n",
    "            test_mnist_std[i, j, run] = np.mean(np.load('actions/single-image regime/{}/run_{}_mnist_actions_{}-of-{}_n_itr_{}.npy'.format(std_models_dir[i], run, k, ns[j], 100)), axis=0)\n",
    "            test_fashion_std[i, j, run] = np.mean(np.load('actions/single-image regime/{}/run_{}_fashion_actions_{}-of-{}_n_itr_{}.npy'.format(std_models_dir[i], run, k, ns[j], 100)), axis=0)\n",
    "            test_emnist_std[i, j, run] = np.mean(np.load('actions/single-image regime/{}/run_{}_emnist_actions_{}-of-{}_n_itr_{}.npy'.format(std_models_dir[i], run, k, ns[j], 100)), axis=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Baselines"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "std_greedy_mnist = np.zeros((len(std_models_dir), n_models, n_actions))\n",
    "std_greedy_fashion = np.zeros((len(std_models_dir), n_models, n_actions))\n",
    "std_greedy_emnist = np.zeros((len(std_models_dir), n_models, n_actions))\n",
    "\n",
    "std_error_greedy_mnist = np.zeros(len(std_models_dir))\n",
    "std_error_greedy_fashion = np.zeros(len(std_models_dir))\n",
    "std_error_greedy_emnist = np.zeros(len(std_models_dir))\n",
    "\n",
    "for i in range (3):\n",
    "    std_greedy_mnist[i] = np.mean(np.load('actions/all-images regime/{}/hard_max_mnist_actions.npy'.format(std_models_dir[i])), axis=1)\n",
    "    std_greedy_fashion[i] = np.mean(np.load('actions/all-images regime/{}/hard_max_fashion_actions.npy'.format(std_models_dir[i])), axis=1)\n",
    "    std_greedy_emnist[i] = np.mean(np.load('actions/all-images regime/{}/hard_max_emnist_actions.npy'.format(std_models_dir[i])), axis=1)\n",
    "\n",
    "    std_error_greedy_mnist[i] = mean_confidence_interval(std_greedy_mnist[i,:, -1])\n",
    "    std_error_greedy_fashion[i] = mean_confidence_interval(std_greedy_fashion[i,:, -1])\n",
    "    std_error_greedy_emnist[i] = mean_confidence_interval(std_greedy_emnist[i,:, -1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "w = 0.006\n",
    "space_bars = 0.000\n",
    "space_between_datasets = 0.15\n",
    "space_percent = 0.015\n",
    "n_itr = 100\n",
    "n_runs = 10\n",
    "fontsize = 25\n",
    "scatter_x = np.arange(-5, n_runs-5)/1500\n",
    "ns = [20, 10, 10, 10]\n",
    "ks = [1, 1, 5, 10]\n",
    "x_labels_precent = ['1%', '10%', '100%']\n",
    "x_labels = ['f', 'greedy' + r\"$(\\hat{r})$\"]\n",
    "legends = ['1-of-20', '1-of-10', '5-of-10', '10-of-10', 'greedy' + r\"$(\\hat{r})$\"]\n",
    "labelss = [\"help\"]\n",
    "\n",
    "xx = np.arange(len(labelss))\n",
    "fig = plt.figure(figsize=(20, 5))\n",
    "plt.subplot(1, 2, 1)\n",
    "space = 0\n",
    "for i in range (len(x_labels_precent)):\n",
    "    for k in range (len(ns)):\n",
    "        plt.scatter(xx + space + scatter_x, np.sort(TEST_fashion_std[i,k, :, -1]), s=20, marker='*', color=dark_colors[k])\n",
    "        plt.bar(xx + space , np.mean(TEST_fashion_std[i,k], axis=0)[-1], width=w,  alpha=0.8, label=\"{}-of-{}\".format(ks[k], ns[k]), color=light_colors[k])\n",
    "        plt.bar(xx + space + space_between_datasets , np.mean(TEST_emnist_std[i, k], axis=0)[-1], width=w, alpha=0.8, color=light_colors[k])\n",
    "        plt.scatter(xx + space + scatter_x + space_between_datasets, np.sort(TEST_emnist_std[i,k, :, -1]), s=20, marker='*', color=dark_colors[k])\n",
    "        space = space + w + space_bars \n",
    "\n",
    "    plt.bar(xx + space, np.mean(std_greedy_fashion[i], axis=0)[-1], width=w, yerr=std_error_greedy_fashion[i], linewidth=10, capsize=5, alpha=0.8, label='greedy'+r\"$(\\hat{r})$\", color=light_colors[k+1])\n",
    "    plt.bar(xx + space + space_between_datasets, np.mean(std_greedy_emnist[i], axis=0)[-1], width=w, yerr=std_error_greedy_emnist[i],linewidth=10, capsize=5, alpha=0.8, color=light_colors[k+1])\n",
    "\n",
    "    plt.text(-0.003+i*0.038, -0.1, x_labels_precent[i], fontsize=fontsize)\n",
    "    plt.text(-0.003+i*0.038 + space_between_datasets, -0.1, x_labels_precent[i], fontsize=fontsize)\n",
    "\n",
    "    space += space_percent\n",
    "\n",
    "plt.text(0.04, -0.25, 'f', fontsize = fontsize+10)\n",
    "plt.text(0.18, -0.25, r'$\\alpha$', fontsize = fontsize+10)\n",
    "plt.xticks([])\n",
    "plt.yticks(fontsize = fontsize)\n",
    "plt.ylabel(\"frequency\", fontsize = fontsize+10)\n",
    "plt.ylim(-0.001, 1)\n",
    "plt.grid(axis=\"y\")\n",
    "\n",
    "\n",
    "plt.subplot(1, 2, 2)\n",
    "space = 0\n",
    "for i in range (len(x_labels_precent)):\n",
    "    bars = []\n",
    "    for k in range (len(ns)):\n",
    "        plt.scatter(xx + space + scatter_x, np.sort(test_fashion_std[i,k, :, -1]), s=20, marker='*', color=dark_colors[k])\n",
    "        bars.append(plt.bar(xx + space , np.mean(test_fashion_std[i,k], axis=0)[-1], width=w,  alpha=0.8, label=\"{}-of-{}\".format(ks[k], ns[k]), color=light_colors[k]))\n",
    "        plt.bar(xx + space + space_between_datasets , np.mean(test_emnist_std[i, k], axis=0)[-1], width=w, alpha=0.8, color=light_colors[k])\n",
    "        plt.scatter(xx + space + scatter_x + space_between_datasets, np.sort(test_emnist_std[i,k, :, -1]), s=20, marker='*', color=dark_colors[k])\n",
    "        space = space + w + space_bars\n",
    "\n",
    "    bars.append(plt.bar(xx + space, np.mean(std_greedy_fashion[i], axis=0)[-1], width=w, yerr=std_error_greedy_fashion[i], linewidth=10, capsize=5, alpha=0.8, label='greedy'+r\"$(\\hat{r})$\", color=light_colors[k+1]))\n",
    "    plt.bar(xx + space + space_between_datasets, np.mean(std_greedy_emnist[i], axis=0)[-1], width=w, yerr=std_error_greedy_emnist[i],linewidth=10, capsize=8, alpha=0.8, color=light_colors[k+1])\n",
    "\n",
    "    plt.text(-0.003+i*0.038, -0.1, x_labels_precent[i], fontsize=fontsize)\n",
    "    plt.text(-0.003+i*0.038 + space_between_datasets, -0.1, x_labels_precent[i], fontsize=fontsize)\n",
    "\n",
    "    space += space_percent\n",
    "\n",
    "plt.text(0.04, -0.25, 'f', fontsize = fontsize+10)\n",
    "plt.text(0.18, -0.25, r'$\\alpha$', fontsize = fontsize+10)\n",
    "plt.ylim(-0.001, 1)\n",
    "plt.grid(axis=\"y\")\n",
    "plt.xticks([])\n",
    "plt.yticks(fontsize=0)\n",
    "fig.legend(bars, legends, fontsize=fontsize+5, ncol=len(ns)+1, loc=(0.09, .825), bbox_transform = plt.gcf().transFigure)\n",
    "plt.subplots_adjust(top=0.9, bottom=0.7)\n",
    "fig.tight_layout(w_pad=1.0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## k-of-n convergence (Expected Value)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "exp = np.zeros((6, len(ks), n_runs, n_itr))\n",
    "exp_conf = np.zeros((6, len(ks), n_itr))\n",
    "\n",
    "for k in range (len(ks)):\n",
    "    for run in range (n_runs):\n",
    "        exp[0, k, run] = np.load('actions/all-images regime/{}/run_{}_expected_value_mnist_{}-of-{}_n_itr_{}.npy'.format('last_action_std_0.1', run, ks[k], ns[k], n_itr)) / (n_actions * len(mnist_labels))\n",
    "        exp[1, k, run] = np.load('actions/all-images regime/{}/run_{}_expected_value_fashion_{}-of-{}_n_itr_{}.npy'.format('last_action_std_0.1', run, ks[k], ns[k], n_itr)) / (n_actions *len(fashion_labels))\n",
    "        exp[2, k, run] = np.load('actions/all-images regime/{}/run_{}_expected_value_emnist_{}-of-{}_n_itr_{}.npy'.format('last_action_std_0.1', run, ks[k], ns[k], n_itr)) / (n_actions *len(emnist_labels))\n",
    "        \n",
    "        exp[3, k, run] = np.load('actions/single-image regime/{}/run_{}_expected_value_mnist_{}-of-{}_n_itr_{}.npy'.format('last_action_std_0.1', run, ks[k], ns[k], n_itr)) / (n_actions *len(mnist_labels))\n",
    "        exp[4, k, run] = np.load('actions/single-image regime/{}/run_{}_expected_value_fashion_{}-of-{}_n_itr_{}.npy'.format('last_action_std_0.1', run, ks[k], ns[k], n_itr)) / (n_actions *len(fashion_labels))\n",
    "        exp[5, k, run] = np.load('actions/single-image regime/{}/run_{}_expected_value_emnist_{}-of-{}_n_itr_{}.npy'.format('last_action_std_0.1', run, ks[k], ns[k], n_itr)) / (n_actions *len(emnist_labels))\n",
    "\n",
    "        for i in range (6):\n",
    "            for itr in range (n_itr):\n",
    "                exp_conf[i, k, itr] = mean_confidence_interval(exp[i, k, :, itr])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "exp_legends = ['1-of-20', '1-of-10', '5-of-10', '10-of-10']\n",
    "exp_fig = plt.figure(figsize=(20,8))\n",
    "exp_bars = []\n",
    "for f in range (6):\n",
    "    plt.subplot(2, 3, f+1)\n",
    "    for k in range (len(ks)):\n",
    "        m = np.mean(exp[f, k], axis=0)\n",
    "        for run in range (n_runs):\n",
    "            plt.plot(exp[f, k, run], color=light_colors[k], lw=1, alpha=0.5)\n",
    "        a, = plt.plot(m, label=\"{}-of-{}\".format(ks[k], ns[k]), color=light_colors[k], lw=4)\n",
    "        if f == len(ks)-1:\n",
    "            exp_bars.append(a)\n",
    "        if f%3 == 0:\n",
    "            plt.text(-32, 0.1, s=r'$u_{\\mathcal{M}}(\\pi^{t}; \\bar{r}^{t})$', fontsize=fontsize+5, rotation='vertical')\n",
    "        if f > 2:\n",
    "            plt.xlabel('iteration', fontsize=fontsize+10)\n",
    "            plt.xticks(fontsize=fontsize)\n",
    "        if f < 3:\n",
    "            plt.xticks(fontsize=0)\n",
    "            \n",
    "    plt.yticks(fontsize=fontsize)\n",
    "    plt.grid(axis=\"y\")\n",
    "exp_fig.legend(exp_bars, exp_legends, fontsize=fontsize+5, ncol=len(ks), loc=(0.18, .9), bbox_transform = plt.gcf().transFigure)\n",
    "plt.subplots_adjust(top=0.7, bottom=0.4)\n",
    "plt.tight_layout(w_pad=2.5, h_pad=-4.0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "TEST_fashion_heatmap_std = np.zeros((n_runs, len(fashion_labels), n_actions))\n",
    "TEST_emnist_heatmap_std = np.zeros((n_runs, len(emnist_labels), n_actions))\n",
    "\n",
    "test_fashion_heatmap_std = np.zeros((n_runs, len(fashion_labels), n_actions))\n",
    "test_emnist_heatmap_std = np.zeros((n_runs, len(emnist_labels), n_actions))\n",
    "\n",
    "baseline_emnist_heatmap_std = np.load('actions/all-images regime/{}/hard_max_emnist_actions.npy'.format('last_action_std_0.1_600_samples'))\n",
    "baseline_fashion_heatmap_std = np.load('actions/all-images regime/{}/hard_max_fashion_actions.npy'.format('last_action_std_0.1_600_samples'))\n",
    "\n",
    "for run in range (n_runs):\n",
    "    TEST_fashion_heatmap_std[run] = np.load('actions/all-images regime/{}/run_{}_fashion_actions_1-of-20_n_itr_100.npy'.format('last_action_std_0.1_600_samples', run))\n",
    "    test_fashion_heatmap_std[run] = np.load('actions/single-image regime/{}/run_{}_fashion_actions_1-of-20_n_itr_100.npy'.format('last_action_std_0.1_600_samples', run))\n",
    "    \n",
    "    TEST_emnist_heatmap_std[run] = np.load('actions/all-images regime/{}/run_{}_emnist_actions_1-of-20_n_itr_100.npy'.format('last_action_std_0.1_600_samples', run))\n",
    "    test_emnist_heatmap_std[run] = np.load('actions/single-image regime/{}/run_{}_emnist_actions_1-of-20_n_itr_100.npy'.format('last_action_std_0.1_600_samples', run))\n",
    "    \n",
    "    \n",
    "TEST_emnist_m_std = np.zeros((len(emnist_labels_names), n_actions))  \n",
    "test_emnist_m_std = np.zeros((len(emnist_labels_names), n_actions))  \n",
    "baseline_emnist_m_std = np.zeros((len(emnist_labels_names), n_actions))  \n",
    "for i in range (len(emnist_labels_names)):\n",
    "    TEST_emnist_m_std[i] = np.round(np.mean(np.mean(TEST_emnist_heatmap_std[:, np.where(emnist_labels.numpy()==(i+1))[0],:], 1), 0), 1)\n",
    "    test_emnist_m_std[i] = np.round(np.mean(np.mean(test_emnist_heatmap_std[:, np.where(emnist_labels.numpy()==(i+1))[0],:], 1), 0), 1)\n",
    "    baseline_emnist_m_std[i] = np.round(np.mean(np.mean(baseline_emnist_heatmap_std[:, np.where(emnist_labels.numpy()==(i+1))[0],:], 1), 0), 1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Fashion and E-MNIST histogram"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "w = 0.1\n",
    "space_between_bars = 0.00\n",
    "fontsize = 25\n",
    "hist_fig = plt.figure(figsize=(20,10))\n",
    "hist_grid = plt.GridSpec(2, 2, figure=hist_fig, width_ratios=[1, 1])\n",
    "xx = np.arange(n_actions)\n",
    "legends = ['1-of-20', '1-of-10', '5-of-10', '10-of-10', 'greedy'+r\"$(\\hat{r})$\"]\n",
    "labelss = [\"0\", \"1\", \"2\", \"3\", \"4\", \"5\", \"6\", \"7\", \"8\", \"9\", \"help\"]\n",
    "\n",
    "space = 0\n",
    "plt.subplot(hist_grid[0, 0])\n",
    "for k in range (len(ks)):\n",
    "    plt.bar(xx+space, np.mean(TEST_fashion_heatmap_std[k], axis=0), width=w, color=light_colors[k])\n",
    "    space = space + w + space_between_bars\n",
    "plt.bar(xx+space, np.mean(np.mean(baseline_fashion_heatmap_std, axis=1), axis=0),width=w, color=light_colors[k+1])\n",
    "plt.yticks(np.arange(0, 1.1, 0.2), fontsize = fontsize)\n",
    "plt.xticks([])\n",
    "plt.ylabel(\"frequency\", fontsize = fontsize+15)\n",
    "plt.grid(axis=\"y\")\n",
    "plt.ylim(0, 1)\n",
    "\n",
    "space = 0\n",
    "plt.subplot(hist_grid[0, 1])\n",
    "for k in range (len(ks)):\n",
    "    plt.bar(xx+space, np.mean(TEST_emnist_heatmap_std[k], axis=0), width=w, color=light_colors[k])\n",
    "    space = space + w + space_between_bars\n",
    "plt.bar(xx+space, np.mean(np.mean(baseline_emnist_heatmap_std, axis=1), axis=0), width=w, color=light_colors[k+1])\n",
    "plt.xticks([])\n",
    "plt.yticks(np.arange(0, 1.1, 0.2), fontsize=0)\n",
    "plt.grid(axis=\"y\")\n",
    "plt.ylim(0, 1)\n",
    "\n",
    "space = 0       \n",
    "plt.subplot(hist_grid[1, 0])\n",
    "for k in range (len(ks)):\n",
    "    plt.bar(xx+space, np.mean(test_fashion_heatmap_std[k], axis=0), width=w, color=light_colors[k])\n",
    "    space = space + w + space_between_bars\n",
    "plt.bar(xx+space, np.mean(np.mean(baseline_fashion_heatmap_std, axis=1), axis=0), width=w, color=light_colors[k+1])\n",
    "plt.xticks(0.2+xx, labelss, fontsize = fontsize+5)\n",
    "plt.yticks(np.arange(0, 1.1, 0.2), fontsize = fontsize)\n",
    "plt.ylabel(\"frequency\", fontsize = fontsize+15)\n",
    "plt.xlabel('action', fontsize=fontsize+15)\n",
    "plt.ylim(0, 1)\n",
    "plt.grid(axis=\"y\")\n",
    "  \n",
    "space = 0\n",
    "plt.subplot(hist_grid[1, 1])\n",
    "hist_bars = []\n",
    "for k in range (len(ks)):\n",
    "    hist_bars.append(plt.bar(xx+space, np.mean(test_emnist_heatmap_std[k], axis=0), width=w, color=light_colors[k]))\n",
    "    space = space + w + space_between_bars\n",
    "hist_bars.append(plt.bar(xx+space, np.mean(np.mean(baseline_emnist_heatmap_std, axis=1), axis=0), width=w, color=light_colors[k+1]))\n",
    "plt.xticks(0.2+xx, labelss, fontsize=fontsize+5)\n",
    "plt.xlabel('action', fontsize=fontsize+15)\n",
    "plt.yticks(np.arange(0, 1.1, 0.2), fontsize=0)\n",
    "plt.ylim(0, 1)\n",
    "plt.grid(axis=\"y\")\n",
    "        \n",
    "hist_fig.legend(hist_bars, legends, fontsize=fontsize+5, ncol=len(ns)+1, loc=(0.09, .91))\n",
    "plt.subplots_adjust(bottom=0.5, top=0.7)\n",
    "plt.tight_layout(w_pad=3.0, h_pad=-9.0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fontsize = 25\n",
    "labelss = [\"0\", \"1\", \"2\", \"3\", \"4\", \"5\", \"6\", \"7\", \"8\", \"9\", \"help\"]\n",
    "\n",
    "sns.set(font_scale=1.4)\n",
    "fig_heatmap = plt.figure(figsize=(20, 10))\n",
    "grid = plt.GridSpec(1, 3, figure=fig_heatmap, width_ratios=[1, 1, 1.3])\n",
    "\n",
    "ax = plt.subplot(grid[0])\n",
    "sns.heatmap(TEST_emnist_m_std, annot=True, linewidths=1,  cmap=\"Blues\" , vmin=0, vmax=1, cbar=False, linecolor='black')\n",
    "plt.yticks(np.arange(len(emnist_labels_names))+0.5, emnist_labels_names, rotation='horizontal', fontsize=fontsize-6)\n",
    "plt.xlabel(\"action\", fontsize=fontsize+10)\n",
    "plt.ylabel(\"letter\", fontsize=fontsize+10)\n",
    "plt.xticks(np.arange(n_actions)+0.4, labelss, fontsize=fontsize-3) \n",
    "ax.xaxis.tick_top()\n",
    "\n",
    "ax = plt.subplot(grid[1])\n",
    "sns.heatmap(test_emnist_m_std, annot=True, linewidths=1,  cmap=\"Blues\" , vmin=0, vmax=1, cbar=False, linecolor='black')\n",
    "plt.yticks(np.arange(len(emnist_labels_names))+0.5, emnist_labels_names, rotation='horizontal', fontsize=fontsize-6)\n",
    "plt.xlabel(\"action\", fontsize=fontsize+10)\n",
    "plt.xticks(np.arange(n_actions)+0.4, labelss, fontsize=fontsize-3) \n",
    "ax.xaxis.tick_top()\n",
    "\n",
    "ax = plt.subplot(grid[2])\n",
    "sns.heatmap(baseline_emnist_m_std, annot=True, linewidths=1,  cmap=\"Blues\" , vmin=0, vmax=1, linecolor='black')\n",
    "plt.yticks(np.arange(len(emnist_labels_names))+0.5, emnist_labels_names, rotation='horizontal', fontsize=fontsize-6)\n",
    "plt.xlabel(\"action\", fontsize=fontsize+10)\n",
    "plt.xticks(np.arange(n_actions)+0.4, labelss, fontsize=fontsize-3)\n",
    "ax.xaxis.tick_top()\n",
    "\n",
    "plt.tight_layout(w_pad=1.0)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
