{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Experiment **4.3**: Ask for Help Only When it is Available"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Import libs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import torchvision, torch\n",
    "from torchvision import transforms as transforms\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib import style\n",
    "import colorsys\n",
    "import scipy\n",
    "import seaborn as sns\n",
    "\n",
    "transform = transforms.ToTensor()\n",
    "## plot style and font\n",
    "style.use('seaborn-colorblind')\n",
    "transform = transforms.ToTensor()\n",
    "font_params = {}\n",
    "font_params['family'] = 'serif'\n",
    "plt.rc('font', **font_params)\n",
    "plt.rc('pdf', fonttype=42)\n",
    "\n",
    "dark_hues = [0, 0.1, 0.4, 0.55, 0.65, 0.75, 0.9]\n",
    "dark_lightness = [0.4, 0.2, 0.15, 0.4, 0.4, 0.4, 0.4]\n",
    "light_hues = [1, 0.1, 0.3, 0.5, 0.6, 0.75, 0.85]\n",
    "light_lightness = [0.75, 0.6, 0.45, 0.6, 0.65, 0.8, 0.55]\n",
    "dark_colors = [colorsys.hls_to_rgb(h, dark_lightness[i], 1) for i, h in enumerate(dark_hues)]\n",
    "sns.palplot(dark_colors)\n",
    "light_colors = [colorsys.hls_to_rgb(h, light_lightness[i], 1) for i, h in enumerate(light_hues)]\n",
    "sns.palplot(light_colors)\n",
    "\n",
    "def mean_confidence_interval(data, confidence=0.95):\n",
    "    n = len(data)\n",
    "    m, se = np.mean(data), scipy.stats.sem(data)\n",
    "    h = se * scipy.stats.t.ppf((1 + confidence) / 2., n-1)\n",
    "    return h\n",
    "def to_x_axis(x):\n",
    "    return 0.015+(x*0.6 - 0.05)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load datasets (MNIST, Fashion, and E-MNIST) and pre-processing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mnist_test = torchvision.datasets.MNIST(root='datasets', train=False, download=True, transform=transform)\n",
    "\n",
    "fashion_test = torchvision.datasets.FashionMNIST(root='datasets', train=False, download=True, transform=transform)\n",
    "\n",
    "emnist_test  = torchvision.datasets.EMNIST(root=\"datasets\", train=False, transform=transform, target_transform=None, download=True, split=\"letters\")\n",
    "\n",
    "pos = np.arange(0, 2*len(mnist_test), 2)\n",
    "neg = np.arange(1, 2*len(mnist_test), 2)\n",
    "\n",
    "pos_emnist = np.arange(0,2*len(emnist_test) , 2)    \n",
    "neg_emnist = np.arange(1,2*len(emnist_test) , 2) \n",
    "\n",
    "mnist_labels = torch.zeros((2*len(mnist_test)), dtype=torch.int)\n",
    "\n",
    "fashion_labels = torch.zeros((2*len(fashion_test)))\n",
    "\n",
    "emnist_labels = torch.zeros((2*len(emnist_test)))\n",
    "\n",
    "l=0 \n",
    "for i in range (len(mnist_test)):\n",
    "    mnist_labels[l] = mnist_test[i][1]\n",
    "    mnist_labels[l+1] = mnist_test[i][1]\n",
    "\n",
    "    fashion_labels[l] = fashion_test[i][1]\n",
    "    fashion_labels[l+1] = fashion_test[i][1]\n",
    "\n",
    "    emnist_labels[l] = emnist_test[i][1]\n",
    "    emnist_labels[l+1] = emnist_test[i][1]\n",
    "    l+=2\n",
    "\n",
    "\n",
    "for j in range (i+1, len(emnist_test)): \n",
    "    emnist_labels[l] = emnist_test[j][1]\n",
    "    emnist_labels[l+1] = emnist_test[j][1]\n",
    "    l+=2\n",
    "emnist_labels_names = [\"A/a\", \"B/b\", \"C/c\", \"D/d\", \"E/e\", \"F/f\", \"G/g\", \"H/h\", \"I/i\", \"J/j\", \"K/k\", \"L/l\", \"M/m\", \"N/n\", \"O/o\", \"P/p\", \"Q/q\", \"R/r\", \"S/s\", \"T/t\", \"U/u\", \"V/v\", \"W/w\", \"X/x\", \"Y/y\", \"Z/z\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Plot"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## load k-of-n actions "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "models_dir = \"models/Ask for Help Only When it is Available\" # directory where each model's papamters will be saved\n",
    "models_dir_split = models_dir.split('/')\n",
    "fig_dir = 'fig/' + models_dir_split[1]\n",
    "if not os.path.exists(fig_dir):\n",
    "    os.makedirs(fig_dir)\n",
    "n_models = 2000\n",
    "n_actions = 11\n",
    "n_runs = 10\n",
    "n_itr= 100\n",
    "fontsize = 25\n",
    "\n",
    "ks = [1, 1, 5, 10]\n",
    "ns = [20, 10, 10, 10]\n",
    "\n",
    "test_k_of_n_mnist_pos = np.zeros((len(ks), n_runs, n_actions))\n",
    "test_k_of_n_mnist_neg = np.zeros((len(ks), n_runs, n_actions))\n",
    "test_k_of_n_fashion_pos = np.zeros((len(ks), n_runs, n_actions))\n",
    "test_k_of_n_fashion_neg = np.zeros((len(ks), n_runs, n_actions))\n",
    "test_k_of_n_emnist_pos = np.zeros((len(ks), n_runs,  n_actions))\n",
    "test_k_of_n_emnist_neg = np.zeros((len(ks), n_runs,  n_actions))\n",
    "\n",
    "TEST_k_of_n_mnist_pos = np.zeros((len(ks), n_runs, n_actions))\n",
    "TEST_k_of_n_mnist_neg = np.zeros((len(ks), n_runs, n_actions))\n",
    "TEST_k_of_n_fashion_pos = np.zeros((len(ks), n_runs, n_actions))\n",
    "TEST_k_of_n_fashion_neg = np.zeros((len(ks), n_runs, n_actions))\n",
    "TEST_k_of_n_emnist_pos = np.zeros((len(ks), n_runs,  n_actions))\n",
    "TEST_k_of_n_emnist_neg = np.zeros((len(ks), n_runs,  n_actions))\n",
    "\n",
    "for run in range (n_runs):\n",
    "    for k in range (len(ks)):\n",
    "        TEST_k_of_n_mnist_pos[k, run] = np.mean(np.load('actions/all-images regime/{}/run_{}_mnist_actions_{}-of-{}_n_itr_100.npy'.format(models_dir_split[1], run, ks[k], ns[k]))[pos], axis=0)\n",
    "        TEST_k_of_n_mnist_neg[k, run] = np.mean(np.load('actions/all-images regime/{}/run_{}_mnist_actions_{}-of-{}_n_itr_100.npy'.format(models_dir_split[1], run, ks[k], ns[k]))[neg], axis=0)\n",
    "\n",
    "        TEST_k_of_n_fashion_pos[k, run] = np.mean(np.load('actions/all-images regime/{}/run_{}_fashion_actions_{}-of-{}_n_itr_100.npy'.format(models_dir_split[1], run, ks[k], ns[k]))[pos], axis=0)\n",
    "        TEST_k_of_n_fashion_neg[k, run] = np.mean(np.load('actions/all-images regime/{}/run_{}_fashion_actions_{}-of-{}_n_itr_100.npy'.format(models_dir_split[1], run, ks[k], ns[k]))[neg], axis=0)\n",
    "\n",
    "        TEST_k_of_n_emnist_pos[k, run] = np.mean(np.load('actions/all-images regime/{}/run_{}_emnist_actions_{}-of-{}_n_itr_100.npy'.format(models_dir_split[1], run, ks[k], ns[k]))[pos_emnist], axis=0)\n",
    "        TEST_k_of_n_emnist_neg[k, run] = np.mean(np.load('actions/all-images regime/{}/run_{}_emnist_actions_{}-of-{}_n_itr_100.npy'.format(models_dir_split[1], run, ks[k], ns[k]))[neg_emnist], axis=0)\n",
    "\n",
    "        test_k_of_n_mnist_pos[k, run] = np.mean(np.load('actions/single-image regime/{}/run_{}_mnist_actions_{}-of-{}_n_itr_100.npy'.format(models_dir_split[1], run, ks[k], ns[k]))[pos], axis=0)\n",
    "        test_k_of_n_mnist_neg[k, run] = np.mean(np.load('actions/single-image regime/{}/run_{}_mnist_actions_{}-of-{}_n_itr_100.npy'.format(models_dir_split[1], run, ks[k], ns[k]))[neg], axis=0)\n",
    "\n",
    "        test_k_of_n_fashion_pos[k, run] = np.mean(np.load('actions/single-image regime/{}/run_{}_fashion_actions_{}-of-{}_n_itr_100.npy'.format(models_dir_split[1], run, ks[k], ns[k]))[pos], axis=0)\n",
    "        test_k_of_n_fashion_neg[k, run] = np.mean(np.load('actions/single-image regime/{}/run_{}_fashion_actions_{}-of-{}_n_itr_100.npy'.format(models_dir_split[1], run, ks[k], ns[k]))[neg], axis=0)\n",
    "\n",
    "        test_k_of_n_emnist_pos[k, run] = np.mean(np.load('actions/single-image regime/{}/run_{}_emnist_actions_{}-of-{}_n_itr_100.npy'.format(models_dir_split[1], run, ks[k], ns[k]))[pos_emnist], axis=0)\n",
    "        test_k_of_n_emnist_neg[k, run] = np.mean(np.load('actions/single-image regime/{}/run_{}_emnist_actions_{}-of-{}_n_itr_100.npy'.format(models_dir_split[1], run, ks[k], ns[k]))[neg_emnist], axis=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## load Standard RL actions "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "baseline_dir = 'actions/all-images regime/'+models_dir_split[1]\n",
    "greedy_mnist_pos = np.mean(np.load('{}/hard_max_mnist_actions.npy'.format(baseline_dir))[:, pos], axis=1) # take each model as a policy and average over polcies predictions\n",
    "greedy_mnist_neg = np.mean(np.load('{}/hard_max_mnist_actions.npy'.format(baseline_dir))[:, neg], axis=1) # take each model as a policy and average over polcies predictions\n",
    "\n",
    "greedy_fashion_pos = np.mean(np.load('{}/hard_max_fashion_actions.npy'.format(baseline_dir))[:, pos], axis=1) # take each model as a policy and average over polcies predictions\n",
    "greedy_fashion_neg = np.mean(np.load('{}/hard_max_fashion_actions.npy'.format(baseline_dir))[:, neg], axis=1) # take each model as a policy and average over polcies predictions\n",
    "\n",
    "greedy_emnist_pos = np.mean(np.load('{}/hard_max_emnist_actions.npy'.format(baseline_dir))[:, pos_emnist], axis=1) # take each model as a policy and average over polcies predictions\n",
    "greedy_emnist_neg = np.mean(np.load('{}/hard_max_emnist_actions.npy'.format(baseline_dir))[:, neg_emnist], axis=1) # take each model as a policy and average over polcies predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "error_greedy_mnist_pos = mean_confidence_interval(np.sum(greedy_mnist_pos[:, :-1]*np.arange(10), axis=1))\n",
    "error_greedy_mnist_neg = mean_confidence_interval(np.sum(greedy_mnist_neg[:, :-1]*np.arange(10), axis=1))\n",
    "\n",
    "error_greedy_fashion_pos = mean_confidence_interval(np.sum(greedy_fashion_pos[:, :-1]*np.arange(10), axis=1))\n",
    "error_greedy_fashion_neg = mean_confidence_interval(np.sum(greedy_fashion_neg[:, :-1]*np.arange(10), axis=1))\n",
    "\n",
    "error_greedy_emnist_pos = mean_confidence_interval(np.sum(greedy_emnist_pos[:, :-1]*np.arange(10), axis=1))\n",
    "error_greedy_emnist_neg = mean_confidence_interval(np.sum(greedy_emnist_neg[:, :-1]*np.arange(10), axis=1))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# fashion and E-MNIST"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "w = 0.04\n",
    "space_bars = 0.002\n",
    "space_between_datasets = 0.35\n",
    "y_space = 0.48\n",
    "y_space_between_datasets = 3.38\n",
    "scatter_x = np.arange(-5, n_runs - 5) / 300\n",
    "legends = ['1-of-20', '1-of-10', '5-of-10', '10-of-10', 'greedy' + r\"$(\\hat{r})$\"]\n",
    "labelss = [\"0\", \"1\", \"2\", \"3\", \"4\", \"5\", \"6\", \"7\", \"8\", \"9\"]\n",
    "\n",
    "xx = np.arange(1)\n",
    "fig = plt.figure(figsize=(20,10))\n",
    "grid = plt.GridSpec(2, 2,figure=fig, height_ratios=[1.6, 1])\n",
    "plt.subplot(grid[0, 0])\n",
    "space = 0\n",
    "start_y = 11\n",
    "for i in range (len(ns)):\n",
    "    plt.bar(xx + space, sum(np.mean(TEST_k_of_n_fashion_pos[i], 0)[:-1]*np.arange(10)), width=w, alpha=0.8, label=\"{}-of-{}\".format(ks[i], ns[i]), color=light_colors[i])\n",
    "    plt.bar(xx + space + space_between_datasets, sum(np.mean(TEST_k_of_n_emnist_pos[i], 0)[:-1]*np.arange(10)), width=w, alpha=0.8, color=light_colors[i])\n",
    "    \n",
    "    plt.bar(xx + space, 6*np.mean(TEST_k_of_n_fashion_pos[i, :, -1]), bottom=start_y,  width=w, alpha=0.8, color=light_colors[i])\n",
    "    plt.bar(xx + space + space_between_datasets, 6*np.mean(TEST_k_of_n_emnist_pos[i, :, -1]), bottom=start_y, width=w, alpha=0.8, color=light_colors[i])\n",
    "\n",
    "    points_fashion, points_emnist = [], []\n",
    "    for run in range (n_runs):\n",
    "        points_fashion.append(sum(TEST_k_of_n_fashion_pos[i, run, :-1]*np.arange(10)))\n",
    "        points_emnist.append(sum(TEST_k_of_n_emnist_pos[i, run, :-1]*np.arange(10)))\n",
    "    plt.scatter(xx + space + scatter_x, np.sort(points_fashion), s=45, marker='*', color=dark_colors[i])\n",
    "    plt.scatter(xx + space + scatter_x + space_between_datasets, np.sort(points_emnist), s=45, marker='*', color=dark_colors[i])\n",
    "    \n",
    "    plt.scatter(xx + space + scatter_x, start_y + 6*np.sort(TEST_k_of_n_fashion_pos[i, :, -1]), s=45, marker='*', color=dark_colors[i])\n",
    "    plt.scatter(xx + space + scatter_x + space_between_datasets, start_y + 6*np.sort(TEST_k_of_n_emnist_pos[i, :, -1]), s=45, marker='*', color=dark_colors[i])\n",
    "    \n",
    "    space = space + w + space_bars\n",
    "\n",
    "plt.bar(xx + space, sum(np.arange(10)*np.mean(greedy_fashion_pos, axis=0)[:-1]), width=w, yerr=error_greedy_fashion_pos, linewidth=10, capsize=20, alpha=0.8, label='greedy'+r\"$(\\hat{r})$\", color=light_colors[i+1])\n",
    "plt.bar(xx + space + space_between_datasets, sum(np.arange(10)*np.mean(greedy_emnist_pos, axis=0)[:-1]), yerr=error_greedy_emnist_neg, linewidth=10, capsize=20, width=w, alpha=0.8, color=light_colors[i+1])\n",
    "\n",
    "plt.bar(xx + space, 6*np.mean(greedy_fashion_pos[:, -1]),  bottom=start_y, yerr=mean_confidence_interval(greedy_fashion_pos[:, -1]), capsize=20, width=w, alpha=0.8, color=light_colors[i+1])\n",
    "plt.bar(xx + space + space_between_datasets, 6*np.mean(greedy_emnist_pos[:, -1]), bottom=start_y, yerr=mean_confidence_interval(greedy_emnist_pos[:, -1]), capsize=20, width=w, alpha=0.8, color=light_colors[i+1])\n",
    "plt.text(-0.12, 12.5, 'help', fontsize=fontsize+5, rotation='vertical')\n",
    "plt.xticks([], fontsize=0)\n",
    "plt.yticks(np.arange(0, len(labelss), 3), fontsize = fontsize)\n",
    "plt.ylabel(\"action index\", fontsize = fontsize+4, position=(-9, .26))\n",
    "plt.ylim(0, 17)\n",
    "plt.grid(axis=\"y\")  \n",
    "\n",
    "help_h = np.arange(0, 11, 5)/10\n",
    "start_h = 11\n",
    "for h in range (len(help_h)):\n",
    "    plt.axhline(y=start_h, xmin =-2, xmax=1.1, color='gray', lw=0.6)\n",
    "    plt.text(-0.085, start_h - 0.25, str(help_h[h]), fontsize=fontsize)\n",
    "    start_h += 3\n",
    "    \n",
    "plt.subplot(grid[0, 1])\n",
    "space = 0\n",
    "bars = []\n",
    "start_y = 11\n",
    "for i in range (len(ns)):\n",
    "    bars.append(plt.bar(xx + space, sum(np.mean(test_k_of_n_fashion_pos[i], 0)[:-1]*np.arange(10)), width=w, alpha=0.8, label=\"{}-of-{}\".format(ks[i], ns[i]), color=light_colors[i]))\n",
    "    plt.bar(xx + space + space_between_datasets, sum(np.mean(test_k_of_n_emnist_pos[i], 0)[:-1]*np.arange(10)), width=w, alpha=0.8, color=light_colors[i])\n",
    "    points_fashion, points_emnist = [], []\n",
    "    for run in range (n_runs):\n",
    "        points_fashion.append(sum(test_k_of_n_fashion_pos[i, run, :-1]*np.arange(10)))\n",
    "        points_emnist.append(sum(test_k_of_n_emnist_pos[i, run, :-1]*np.arange(10)))\n",
    "    plt.scatter(xx + space + scatter_x, np.sort(points_fashion), s=45, marker='*', color=dark_colors[i])\n",
    "    plt.scatter(xx + space + scatter_x + space_between_datasets, np.sort(points_emnist), s=45, marker='*', color=dark_colors[i])\n",
    "    \n",
    "    plt.scatter(xx + space + scatter_x, start_y + 6*np.sort(test_k_of_n_fashion_pos[i, :, -1]), s=45, marker='*', color=dark_colors[i])\n",
    "    plt.scatter(xx + space + scatter_x + space_between_datasets, start_y + 6*np.sort(test_k_of_n_emnist_pos[i, :, -1]), s=45, marker='*', color=dark_colors[i])\n",
    "    \n",
    "    plt.bar(xx + space, 6*np.mean(test_k_of_n_fashion_pos[i, :, -1]), bottom=start_y,  width=w, alpha=0.8, color=light_colors[i])\n",
    "    plt.bar(xx + space + space_between_datasets, 6*np.mean(test_k_of_n_emnist_pos[i, :, -1]), bottom=start_y, width=w, alpha=0.8, color=light_colors[i])\n",
    "\n",
    "    space = space + w + space_bars\n",
    "bars.append(plt.bar(xx + space, sum(np.arange(10)*np.mean(greedy_fashion_pos, axis=0)[:-1]), yerr=error_greedy_fashion_pos, linewidth=10, capsize=20, width=w, alpha=0.8, label='greedy'+r\"$(\\hat{r})$\", color=light_colors[i+1]))\n",
    "plt.bar(xx + space + space_between_datasets, sum(np.arange(10)*np.mean(greedy_emnist_pos, axis=0)[:-1]), yerr=error_greedy_emnist_pos, linewidth=10, capsize=20, width=w, alpha=0.8, color=light_colors[i+1])\n",
    "\n",
    "plt.bar(xx + space, sum(np.arange(10)*np.mean(greedy_fashion_pos, axis=0)[:-1]), width=w, yerr=error_greedy_fashion_pos, linewidth=10, capsize=20, alpha=0.8, label='greedy'+r\"$(\\hat{r})$\", color=light_colors[i+1])\n",
    "plt.bar(xx + space + space_between_datasets, sum(np.arange(10)*np.mean(greedy_emnist_pos, axis=0)[:-1]), yerr=error_greedy_emnist_neg, linewidth=10, capsize=20, width=w, alpha=0.8, color=light_colors[i+1])\n",
    "\n",
    "plt.bar(xx + space, 6*np.mean(greedy_fashion_pos[:, -1]),  bottom=start_y, yerr=mean_confidence_interval(greedy_fashion_pos[:, -1]), capsize=20, width=w, alpha=0.8, color=light_colors[i+1])\n",
    "plt.bar(xx + space + space_between_datasets, 6*np.mean(greedy_emnist_pos[:, -1]), bottom=start_y, yerr=mean_confidence_interval(greedy_emnist_pos[:, -1]), capsize=20, width=w, alpha=0.8, color=light_colors[i+1])\n",
    "\n",
    "\n",
    "plt.xticks([], fontsize=0)\n",
    "plt.yticks(np.arange(0, len(labelss), 3), fontsize=0)\n",
    "plt.ylim(0, 17)\n",
    "plt.grid(axis=\"y\")\n",
    "help_h = np.arange(0, 11, 2)/10\n",
    "start_h = 11\n",
    "for h in range (len(help_h)):\n",
    "    plt.axhline(y=start_h, xmax=1.1, color='gray', lw=0.6)\n",
    "    start_h += 3\n",
    "    \n",
    "plt.subplot(grid[1, 0])\n",
    "space = 0\n",
    "start_y = 11\n",
    "for i in range (len(ns)):\n",
    "    plt.bar(xx + space, sum(np.mean(TEST_k_of_n_fashion_neg[i], 0)[:-1]*np.arange(10)), width=w, alpha=0.8, label=\"{}-of-{}\".format(ks[i], ns[i]), color=light_colors[i])\n",
    "    plt.bar(xx + space + space_between_datasets, sum(np.mean(TEST_k_of_n_emnist_neg[i], 0)[:-1]*np.arange(10)), width=w, alpha=0.8, color=light_colors[i])\n",
    "\n",
    "    points_fashion, points_emnist = [], []\n",
    "    for run in range (n_runs):\n",
    "        points_fashion.append(sum(TEST_k_of_n_fashion_neg[i, run, :-1]*np.arange(10)))\n",
    "        points_emnist.append(sum(TEST_k_of_n_emnist_neg[i, run, :-1]*np.arange(10)))\n",
    "    plt.scatter(xx + space + scatter_x, np.sort(points_fashion), s=45, marker='*', color=dark_colors[i])\n",
    "    plt.scatter(xx + space + scatter_x + space_between_datasets, np.sort(points_emnist), s=45, marker='*', color=dark_colors[i])\n",
    "    \n",
    "    space = space + w + space_bars\n",
    "    start_y += y_space\n",
    "plt.bar(xx + space, sum(np.arange(10)*np.mean(greedy_fashion_neg, axis=0)[:-1]), width=w, yerr=error_greedy_fashion_neg, linewidth=10, capsize=20, alpha=0.8, label='greedy'+r\"$(\\hat{r})$\", color=light_colors[i+1])\n",
    "plt.bar(xx + space + space_between_datasets, sum(np.arange(10)*np.mean(greedy_emnist_neg, axis=0)[:-1]), yerr=error_greedy_emnist_neg, linewidth=10, capsize=20, width=w, alpha=0.8, color=light_colors[i+1])\n",
    "plt.text(0.075, -1.5, 'f', fontsize = fontsize+5)\n",
    "plt.text(0.425, -1.5, r'$\\alpha$', fontsize = fontsize+5)\n",
    "plt.xticks([], fontsize=0)\n",
    "plt.yticks(np.arange(0, len(labelss), 3), fontsize = fontsize)\n",
    "plt.ylabel(\"action index\", fontsize = fontsize+4, position=(9, 0.5))\n",
    "plt.ylim(0, 9)\n",
    "plt.grid(axis=\"y\")  \n",
    "\n",
    "plt.subplot(grid[1, 1])\n",
    "space = 0\n",
    "bars = []\n",
    "start_y = 11\n",
    "for i in range (len(ns)):\n",
    "    bars.append(plt.bar(xx + space, sum(np.mean(test_k_of_n_fashion_neg[i], 0)[:-1]*np.arange(10)), width=w, alpha=0.8, label=\"{}-of-{}\".format(ks[i], ns[i]), color=light_colors[i]))\n",
    "    plt.bar(xx + space + space_between_datasets, sum(np.mean(test_k_of_n_emnist_neg[i], 0)[:-1]*np.arange(10)), width=w, alpha=0.8, color=light_colors[i])\n",
    "    points_fashion, points_emnist = [], []\n",
    "    for run in range (n_runs):\n",
    "        points_fashion.append(sum(test_k_of_n_fashion_neg[i, run, :-1]*np.arange(10)))\n",
    "        points_emnist.append(sum(test_k_of_n_emnist_neg[i, run, :-1]*np.arange(10)))\n",
    "    plt.scatter(xx + space + scatter_x, np.sort(points_fashion), s=45, marker='*', color=dark_colors[i])\n",
    "    plt.scatter(xx + space + scatter_x + space_between_datasets, np.sort(points_emnist), s=45, marker='*', color=dark_colors[i])\n",
    "\n",
    "    start_y += y_space\n",
    "    space = space + w + space_bars\n",
    "bars.append(plt.bar(xx + space, sum(np.arange(10)*np.mean(greedy_fashion_neg, axis=0)[:-1]), width=w, yerr=error_greedy_fashion_neg, linewidth=10, capsize=20, alpha=0.8, label='greedy'+r\"$(\\hat{r})$\", color=light_colors[i+1]))\n",
    "plt.bar(xx + space + space_between_datasets, sum(np.arange(10)*np.mean(greedy_emnist_neg, axis=0)[:-1]), yerr=error_greedy_emnist_neg, linewidth=10, capsize=20, width=w, alpha=0.8, color=light_colors[i+1])\n",
    "\n",
    "plt.xticks([], fontsize=0)\n",
    "plt.yticks(np.arange(0, len(labelss), 3), fontsize=0)\n",
    "plt.text(0.075, -1.5, 'f', fontsize = fontsize+5)\n",
    "plt.text(0.425, -1.5, r'$\\alpha$', fontsize = fontsize+5)\n",
    "\n",
    "plt.ylim(0, 9)\n",
    "plt.grid(axis=\"y\")\n",
    "fig.legend(bars, legends, fontsize=fontsize+4, ncol=len(ns)+1, loc=(0.1, .9), )\n",
    "plt.subplots_adjust(bottom=0.6,top=0.66)\n",
    "plt.tight_layout(w_pad=1.5, h_pad=-12.0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Histogram of Fashion and E-MNIST"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "w = 0.1\n",
    "space_between_bars = 0.00\n",
    "fontsize = 25\n",
    "hist_fig = plt.figure(figsize=(20,10))\n",
    "hist_grid = plt.GridSpec(2, 2, figure=hist_fig, width_ratios=[1, 1])\n",
    "xx = np.arange(n_actions)\n",
    "legends = ['1-of-20', '1-of-10', '5-of-10', '10-of-10', 'greedy'+r\"$(\\hat{r})$\"]\n",
    "labelss = [\"0\", \"1\", \"2\", \"3\", \"4\", \"5\", \"6\", \"7\", \"8\", \"9\", \"help\"]\n",
    "\n",
    "space = 0\n",
    "plt.subplot(hist_grid[0, 0])\n",
    "for k in range (len(ks)):\n",
    "    plt.bar(xx+space, np.mean(TEST_k_of_n_fashion_pos[k], axis=0), width=w, color=light_colors[k])\n",
    "    plt.axvline(x=np.sum(np.mean(TEST_k_of_n_fashion_pos[k], axis=0)[:-1]*np.arange(10)), lw=3, linestyle=\"-.\", color=light_colors[k])\n",
    "    space = space + w + space_between_bars\n",
    "plt.bar(xx+space, np.mean(greedy_fashion_pos, axis=0),width=w, color=light_colors[k+1])\n",
    "plt.axvline(x=np.sum(np.mean(greedy_fashion_pos, axis=0)[:-1]*np.arange(10)), lw=3, linestyle=\"-.\", color=light_colors[k+1])\n",
    "plt.yticks(np.arange(0, 1.1, 0.2), fontsize = fontsize)\n",
    "plt.xticks([])\n",
    "plt.ylabel(\"frequency\", fontsize = fontsize+15)\n",
    "plt.grid(axis=\"y\")\n",
    "plt.ylim(0, 1)\n",
    "\n",
    "space = 0\n",
    "plt.subplot(hist_grid[0, 1])\n",
    "for k in range (len(ks)):\n",
    "    plt.bar(xx+space, np.mean(TEST_k_of_n_emnist_pos[k], axis=0), width=w, color=light_colors[k])\n",
    "    plt.axvline(x=np.sum(np.mean(TEST_k_of_n_emnist_pos[k], axis=0)[:-1]*np.arange(10)), lw=3, linestyle=\"-.\", color=light_colors[k])\n",
    "    space = space + w + space_between_bars\n",
    "plt.bar(xx+space, np.mean(greedy_emnist_pos, axis=0), width=w, color=light_colors[k+1])\n",
    "plt.axvline(x=np.sum(np.mean(greedy_emnist_pos, axis=0)[:-1]*np.arange(10)), lw=3, linestyle=\"-.\", color=light_colors[k+1])\n",
    "plt.xticks([])\n",
    "plt.yticks(np.arange(0, 1.1, 0.2), fontsize=0)\n",
    "plt.grid(axis=\"y\")\n",
    "plt.ylim(0, 1)\n",
    "\n",
    "space = 0       \n",
    "plt.subplot(hist_grid[1, 0])\n",
    "for k in range (len(ks)):\n",
    "    plt.bar(xx+space, np.mean(test_k_of_n_fashion_pos[k], axis=0), width=w, color=light_colors[k])\n",
    "    plt.axvline(x=np.sum(np.mean(test_k_of_n_fashion_pos[k], axis=0)[:-1]*np.arange(10)), lw=3, linestyle=\"-.\", color=light_colors[k])\n",
    "    space = space + w + space_between_bars\n",
    "plt.bar(xx+space, np.mean(greedy_fashion_pos, axis=0), width=w, color=light_colors[k+1])\n",
    "plt.axvline(x=np.sum(np.mean(greedy_fashion_pos, axis=0)[:-1]*np.arange(10)), lw=3, linestyle=\"-.\", color=light_colors[k+1])\n",
    "plt.xticks(0.2+xx, labelss, fontsize = fontsize+5)\n",
    "plt.yticks(np.arange(0, 1.1, 0.2), fontsize = fontsize)\n",
    "plt.ylabel(\"frequency\", fontsize = fontsize+15)\n",
    "plt.xlabel('action', fontsize=fontsize+15)\n",
    "plt.ylim(0, 1)\n",
    "plt.grid(axis=\"y\")\n",
    "  \n",
    "space = 0\n",
    "plt.subplot(hist_grid[1, 1])\n",
    "hist_bars = []\n",
    "for k in range (len(ks)):\n",
    "    hist_bars.append(plt.bar(xx+space, np.mean(test_k_of_n_emnist_pos[k], axis=0), width=w, color=light_colors[k]))\n",
    "    plt.axvline(x=np.sum(np.mean(test_k_of_n_emnist_pos[k], axis=0)[:-1]*np.arange(10)), lw=3, linestyle=\"-.\", color=light_colors[k])\n",
    "    space = space + w + space_between_bars\n",
    "hist_bars.append(plt.bar(xx+space, np.mean(greedy_emnist_pos, axis=0), width=w, color=light_colors[k+1]))\n",
    "plt.axvline(x=np.sum(np.mean(greedy_emnist_pos, axis=0)[:-1]*np.arange(10)), lw=3, linestyle=\"-.\", color=light_colors[k+1])\n",
    "plt.xticks(0.2+xx, labelss, fontsize=fontsize+5)\n",
    "plt.xlabel('action', fontsize=fontsize+15)\n",
    "plt.yticks(np.arange(0, 1.1, 0.2), fontsize=0)\n",
    "plt.ylim(0, 1)\n",
    "plt.grid(axis=\"y\")\n",
    "        \n",
    "hist_fig.legend(hist_bars, legends, fontsize=fontsize+5, ncol=len(ns)+1, loc=(0.09, .91))\n",
    "plt.subplots_adjust(bottom=0.5, top=0.7)\n",
    "plt.tight_layout(w_pad=3.0, h_pad=-9.0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "w = 0.1\n",
    "space_between_bars = 0.00\n",
    "fontsize = 25\n",
    "hist_fig = plt.figure(figsize=(20,10))\n",
    "hist_grid = plt.GridSpec(2, 2, figure=hist_fig, width_ratios=[1, 1])\n",
    "xx = np.arange(n_actions)\n",
    "legends = ['1-of-20', '1-of-10', '5-of-10', '10-of-10', 'greedy'+r\"$(\\hat{r})$\"]\n",
    "labelss = [\"0\", \"1\", \"2\", \"3\", \"4\", \"5\", \"6\", \"7\", \"8\", \"9\", \"help\"]\n",
    "\n",
    "space = 0\n",
    "plt.subplot(hist_grid[0, 0])\n",
    "for k in range (len(ks)):\n",
    "    plt.bar(xx+space, np.mean(TEST_k_of_n_fashion_neg[k], axis=0), width=w, color=light_colors[k])\n",
    "    plt.axvline(x=np.sum(np.mean(TEST_k_of_n_fashion_neg[k], axis=0)[:-1]*np.arange(10)), lw=3, linestyle=\"-.\", color=light_colors[k])\n",
    "    space = space + w + space_between_bars\n",
    "plt.bar(xx+space, np.mean(greedy_fashion_neg, axis=0),width=w, color=light_colors[k+1])\n",
    "plt.axvline(x=np.sum(np.mean(greedy_fashion_neg, axis=0)[:-1]*np.arange(10)), lw=3, linestyle=\"-.\", color=light_colors[k+1])\n",
    "plt.yticks(np.arange(0, 0.9, 0.2), fontsize = fontsize)\n",
    "plt.xticks([])\n",
    "plt.ylabel(\"frequency\", fontsize = fontsize+15)\n",
    "plt.grid(axis=\"y\")\n",
    "plt.ylim(0, 0.8)\n",
    "\n",
    "space = 0\n",
    "plt.subplot(hist_grid[0, 1])\n",
    "for k in range (len(ks)):\n",
    "    plt.bar(xx+space, np.mean(TEST_k_of_n_emnist_neg[k], axis=0), width=w, color=light_colors[k])\n",
    "    plt.axvline(x=np.sum(np.mean(TEST_k_of_n_emnist_neg[k], axis=0)[:-1]*np.arange(10)), lw=3, linestyle=\"-.\", color=light_colors[k])\n",
    "    space = space + w + space_between_bars\n",
    "plt.bar(xx+space, np.mean(greedy_emnist_neg, axis=0), width=w, color=light_colors[k+1])\n",
    "plt.axvline(x=np.sum(np.mean(greedy_emnist_neg, axis=0)[:-1]*np.arange(10)), lw=3, linestyle=\"-.\", color=light_colors[k+1])\n",
    "plt.xticks([])\n",
    "plt.yticks(np.arange(0, 0.9, 0.2), fontsize=0)\n",
    "plt.grid(axis=\"y\")\n",
    "plt.ylim(0, 0.8)\n",
    "\n",
    "space = 0       \n",
    "plt.subplot(hist_grid[1, 0])\n",
    "for k in range (len(ks)):\n",
    "    plt.bar(xx+space, np.mean(test_k_of_n_fashion_neg[k], axis=0), width=w, color=light_colors[k])\n",
    "    plt.axvline(x=np.sum(np.mean(test_k_of_n_fashion_neg[k], axis=0)[:-1]*np.arange(10)), lw=3, linestyle=\"-.\", color=light_colors[k])\n",
    "    space = space + w + space_between_bars\n",
    "plt.bar(xx+space, np.mean(greedy_fashion_neg, axis=0), width=w, color=light_colors[k+1])\n",
    "plt.axvline(x=np.sum(np.mean(greedy_fashion_neg, axis=0)[:-1]*np.arange(10)), lw=3, linestyle=\"-.\", color=light_colors[k+1])\n",
    "plt.xticks(0.2+xx, labelss, fontsize = fontsize+5)\n",
    "plt.yticks(np.arange(0, 0.9, 0.2), fontsize = fontsize)\n",
    "plt.ylabel(\"frequency\", fontsize = fontsize+15)\n",
    "plt.xlabel('action', fontsize=fontsize+15)\n",
    "plt.ylim(0, 0.8)\n",
    "plt.grid(axis=\"y\")\n",
    "  \n",
    "space = 0\n",
    "plt.subplot(hist_grid[1, 1])\n",
    "hist_bars = []\n",
    "for k in range (len(ks)):\n",
    "    hist_bars.append(plt.bar(xx+space, np.mean(test_k_of_n_emnist_neg[k], axis=0), width=w, color=light_colors[k]))\n",
    "    plt.axvline(x=np.sum(np.mean(test_k_of_n_emnist_neg[k], axis=0)[:-1]*np.arange(10)), lw=3, linestyle=\"-.\", color=light_colors[k])\n",
    "    space = space + w + space_between_bars\n",
    "hist_bars.append(plt.bar(xx+space, np.mean(greedy_emnist_neg, axis=0), width=w, color=light_colors[k+1]))\n",
    "plt.axvline(x=np.sum(np.mean(greedy_emnist_neg, axis=0)[:-1]*np.arange(10)), lw=3, linestyle=\"-.\", color=light_colors[k+1])\n",
    "plt.xticks(0.2+xx, labelss, fontsize=fontsize+5)\n",
    "plt.xlabel('action', fontsize=fontsize+15)\n",
    "plt.yticks(np.arange(0, 0.9, 0.2), fontsize=0)\n",
    "plt.ylim(0, 0.8)\n",
    "plt.grid(axis=\"y\")\n",
    "        \n",
    "hist_fig.legend(hist_bars, legends, fontsize=fontsize+5, ncol=len(ns)+1, loc=(0.09, .91))\n",
    "plt.subplots_adjust(bottom=0.5, top=0.7)\n",
    "plt.tight_layout(w_pad=3.0, h_pad=-9.0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## k-of-n convergence (Expected Value)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "exp = np.zeros((6, len(ks), n_runs, n_itr))\n",
    "exp_conf = np.zeros((6, len(ks), n_itr))\n",
    "\n",
    "for k in range (len(ks)):\n",
    "    for run in range (n_runs):\n",
    "        exp[0, k, run] = np.load('actions/all-images regime/{}/run_{}_expected_value_mnist_{}-of-{}_n_itr_{}.npy'.format(models_dir_split[1], run, ks[k], ns[k], n_itr)) / (n_actions * len(mnist_labels))\n",
    "        exp[1, k, run] = np.load('actions/all-images regime/{}/run_{}_expected_value_fashion_{}-of-{}_n_itr_{}.npy'.format(models_dir_split[1], run, ks[k], ns[k], n_itr)) / (n_actions *len(fashion_labels))\n",
    "        exp[2, k, run] = np.load('actions/all-images regime/{}/run_{}_expected_value_emnist_{}-of-{}_n_itr_{}.npy'.format(models_dir_split[1], run, ks[k], ns[k], n_itr)) / (n_actions *len(emnist_labels))\n",
    "        \n",
    "        exp[3, k, run] = np.load('actions/single-image regime/{}/run_{}_expected_value_mnist_{}-of-{}_n_itr_{}.npy'.format(models_dir_split[1], run, ks[k], ns[k], n_itr)) / (n_actions *len(mnist_labels))\n",
    "        exp[4, k, run] = np.load('actions/single-image regime/{}/run_{}_expected_value_fashion_{}-of-{}_n_itr_{}.npy'.format(models_dir_split[1], run, ks[k], ns[k], n_itr)) / (n_actions *len(fashion_labels))\n",
    "        exp[5, k, run] = np.load('actions/single-image regime/{}/run_{}_expected_value_emnist_{}-of-{}_n_itr_{}.npy'.format(models_dir_split[1], run, ks[k], ns[k], n_itr)) / (n_actions *len(emnist_labels))\n",
    "\n",
    "        for i in range (6):\n",
    "            for itr in range (n_itr):\n",
    "                exp_conf[i, k, itr] = mean_confidence_interval(exp[i, k, :, itr])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "exp_legends = ['1-of-20', '1-of-10', '5-of-10', '10-of-10']\n",
    "exp_fig = plt.figure(figsize=(20,8))\n",
    "exp_bars = []\n",
    "for f in range (6):\n",
    "    plt.subplot(2, 3, f+1)\n",
    "    for k in range (len(ks)):\n",
    "        m = np.mean(exp[f, k], axis=0)\n",
    "        for run in range (n_runs):\n",
    "            plt.plot(exp[f, k, run], color=light_colors[k], lw=1, alpha=0.5)\n",
    "        a, = plt.plot(m, label=\"{}-of-{}\".format(ks[k], ns[k]), color=light_colors[k], lw=4)\n",
    "        plt.fill_between(np.arange(n_itr), m - exp_conf[f, k], m + exp_conf[f, k], alpha=0.3, color=light_colors[k])\n",
    "        if f == len(ks)-1:\n",
    "            exp_bars.append(a)\n",
    "        if f%3 == 0:\n",
    "            plt.text(-32, 0.08, s=r'$u_{\\mathcal{M}}(\\pi^{t}; \\bar{r}^{t})$', fontsize=fontsize+5, rotation='vertical')\n",
    "        if f > 2:\n",
    "            plt.xlabel('iteration', fontsize=fontsize+10)\n",
    "            plt.xticks(fontsize=fontsize)\n",
    "        if f < 3:\n",
    "            plt.xticks(fontsize=0)\n",
    "            \n",
    "    plt.yticks(fontsize=fontsize)\n",
    "    plt.grid(axis=\"y\")\n",
    "exp_fig.legend(exp_bars, exp_legends, fontsize=fontsize+5, ncol=len(ks), loc=(0.18, .9), bbox_transform = plt.gcf().transFigure)\n",
    "plt.subplots_adjust(top=0.7, bottom=0.4)\n",
    "plt.tight_layout(w_pad=2.5, h_pad=-4.0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Heatmap E-MNIST 1-of-20 TEST, 1-of-20 test and baseline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "TEST_emnist_heatmap_pos = np.zeros((n_runs, len(emnist_labels[pos_emnist]), n_actions))\n",
    "TEST_emnist_heatmap_neg = np.zeros((n_runs, len(emnist_labels[neg_emnist]), n_actions))\n",
    "\n",
    "test_emnist_heatmap_pos = np.zeros((n_runs, len(emnist_labels[pos_emnist]), n_actions))\n",
    "test_emnist_heatmap_neg = np.zeros((n_runs, len(emnist_labels[neg_emnist]), n_actions))\n",
    "\n",
    "baseline_emnist_heatmap_pos = np.load('{}/hard_max_emnist_actions.npy'.format(baseline_dir))[:, pos_emnist]\n",
    "baseline_emnist_heatmap_neg = np.load('{}/hard_max_emnist_actions.npy'.format(baseline_dir))[:, neg_emnist]\n",
    "\n",
    "for run in range (n_runs):\n",
    "    TEST_emnist_heatmap_pos[run] = np.load('actions/all-images regime/{}/run_{}_emnist_actions_1-of-20_n_itr_100.npy'.format(models_dir_split[1], run))[pos_emnist]\n",
    "    TEST_emnist_heatmap_neg[run] = np.load('actions/all-images regime/{}/run_{}_emnist_actions_1-of-20_n_itr_100.npy'.format(models_dir_split[1], run))[neg_emnist]\n",
    "    \n",
    "    test_emnist_heatmap_pos[run] = np.load('actions/single-image regime/{}/run_{}_emnist_actions_1-of-20_n_itr_100.npy'.format(models_dir_split[1], run))[pos_emnist]\n",
    "    test_emnist_heatmap_neg[run] = np.load('actions/single-image regime/{}/run_{}_emnist_actions_1-of-20_n_itr_100.npy'.format(models_dir_split[1], run))[neg_emnist]\n",
    "\n",
    "TEST_emnist_m_pos = np.zeros((len(emnist_labels_names), n_actions))\n",
    "TEST_emnist_m_neg = np.zeros((len(emnist_labels_names), n_actions))\n",
    "\n",
    "test_emnist_m_pos = np.zeros((len(emnist_labels_names), n_actions))\n",
    "test_emnist_m_neg = np.zeros((len(emnist_labels_names), n_actions))\n",
    "\n",
    "baseline_emnist_m_pos = np.zeros((len(emnist_labels_names), n_actions))\n",
    "baseline_emnist_m_neg = np.zeros((len(emnist_labels_names), n_actions))\n",
    "for i in range (len(emnist_labels_names)):\n",
    "    TEST_emnist_m_pos[i] = np.round(np.mean(np.mean(TEST_emnist_heatmap_pos[:, np.where(emnist_labels[pos_emnist].numpy()==(i+1))[0],:], 1), 0), 1)\n",
    "    TEST_emnist_m_neg[i] = np.round(np.mean(np.mean(TEST_emnist_heatmap_neg[:, np.where(emnist_labels[neg_emnist].numpy()==(i+1))[0],:], 1), 0), 1)\n",
    "    \n",
    "    test_emnist_m_pos[i] = np.round(np.mean(np.mean(test_emnist_heatmap_pos[:, np.where(emnist_labels[pos_emnist].numpy()==(i+1))[0],:], 1), 0), 1)    \n",
    "    test_emnist_m_neg[i] = np.round(np.mean(np.mean(test_emnist_heatmap_neg[:, np.where(emnist_labels[neg_emnist].numpy()==(i+1))[0],:], 1), 0), 1) \n",
    "    \n",
    "    baseline_emnist_m_pos[i] = np.round(np.mean(np.mean(baseline_emnist_heatmap_pos[:, np.where(emnist_labels[pos_emnist].numpy()==(i+1))[0],:], 1), 0), 1) \n",
    "    baseline_emnist_m_neg[i] = np.round(np.mean(np.mean(baseline_emnist_heatmap_neg[:, np.where(emnist_labels[neg_emnist].numpy()==(i+1))[0],:], 1), 0), 1)        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "labelss = [\"0\", \"1\", \"2\", \"3\", \"4\", \"5\", \"6\", \"7\", \"8\", \"9\", \"help\"]\n",
    "\n",
    "sns.set(font_scale=1.4)\n",
    "fig_heatmap = plt.figure(figsize=(20, 20))\n",
    "grid = plt.GridSpec(2, 3, figure=fig_heatmap, width_ratios=[1, 1, 1.3])\n",
    "\n",
    "plt.subplot(grid[0, 0])\n",
    "ax = sns.heatmap(TEST_emnist_m_pos, annot=True, linewidths=1,  cmap=\"Blues\" , vmin=0, vmax=1, cbar=False, linecolor='black')\n",
    "plt.yticks(np.arange(len(emnist_labels_names))+0.5, emnist_labels_names, rotation='horizontal', fontsize=fontsize-6)\n",
    "plt.ylabel(\"letter\", fontsize=fontsize+10)\n",
    "plt.xticks(np.arange(n_actions)+0.5, labelss, fontsize=fontsize-4) \n",
    "ax.xaxis.tick_top()\n",
    "\n",
    "plt.subplot(grid[0, 1])\n",
    "ax = sns.heatmap(test_emnist_m_pos, annot=True,linewidths=1,  cmap=\"Blues\" , vmin=0, vmax=1, cbar=False, linecolor='black')\n",
    "plt.yticks(np.arange(len(emnist_labels_names))+0.5, emnist_labels_names, rotation='horizontal', fontsize=fontsize-6)\n",
    "plt.xticks(np.arange(n_actions)+0.5, labelss, fontsize=fontsize-4) \n",
    "\n",
    "ax.xaxis.tick_top()\n",
    "\n",
    "\n",
    "plt.subplot(grid[0, 2])\n",
    "ax = sns.heatmap(baseline_emnist_m_pos, annot=True, linewidths=1,  cmap=\"Blues\" , vmin=0, vmax=1, cbar=True, linecolor='black')\n",
    "plt.yticks(np.arange(len(emnist_labels_names))+0.5, emnist_labels_names, rotation='horizontal', fontsize=fontsize-6)\n",
    "plt.xticks(np.arange(n_actions)+0.5, labelss, fontsize=fontsize-4) \n",
    "ax.xaxis.tick_top()\n",
    "\n",
    "plt.subplot(grid[1, 0])\n",
    "ax = sns.heatmap(TEST_emnist_m_neg, annot=True, linewidths=1,  cmap=\"Blues\" , vmin=0, vmax=1, cbar=False, linecolor='black')\n",
    "plt.yticks(np.arange(len(emnist_labels_names))+0.5, emnist_labels_names, rotation='horizontal', fontsize=fontsize-6)\n",
    "plt.xlabel(\"action\", fontsize=fontsize+10)\n",
    "plt.ylabel(\"letter\", fontsize=fontsize+10)\n",
    "plt.xticks([], fontsize=fontsize-4) \n",
    "ax.xaxis.tick_top()\n",
    "\n",
    "\n",
    "plt.subplot(grid[1, 1])\n",
    "ax = sns.heatmap(test_emnist_m_neg, annot=True, linewidths=1,  cmap=\"Blues\" , vmin=0, vmax=1, cbar=False, linecolor='black')\n",
    "plt.yticks(np.arange(len(emnist_labels_names))+0.5, emnist_labels_names, rotation='horizontal', fontsize=fontsize-6)\n",
    "plt.xlabel(\"action\", fontsize=fontsize+10)\n",
    "plt.xticks([], fontsize=fontsize-4) \n",
    "ax.xaxis.tick_top()\n",
    "\n",
    "\n",
    "plt.subplot(grid[1, 2])\n",
    "ax = sns.heatmap(baseline_emnist_m_neg, annot=True, linewidths=1,  cmap=\"Blues\" , vmin=0, vmax=1, linecolor='black')\n",
    "plt.yticks(np.arange(len(emnist_labels_names))+0.5, emnist_labels_names, rotation='horizontal', fontsize=fontsize-6)\n",
    "plt.xlabel(\"action\", fontsize=fontsize+10)\n",
    "plt.xticks([], fontsize=fontsize-4)\n",
    "ax.xaxis.tick_top()\n",
    "\n",
    "plt.tight_layout(w_pad=1.0, h_pad=1.0)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
