{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports and general settings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Imports\n",
    "import os\n",
    "import sys\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import itertools\n",
    "from analysis_helpers import load_data, prep_condition, order_econs\n",
    "from analysis_helpers import main_plot, three_plots, heat_plot, plot_decisions, plot_decisions_change, main_plot_both\n",
    "from analysis_helpers import calc_econ\n",
    "from data import init_dataset\n",
    "import seaborn as sns\n",
    "\n",
    "# Keep local packages updated\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set base network and conditions\n",
    "base_network = \"Res18\"\n",
    "folder_name = \"./results/{}/\".format(base_network)\n",
    "figure_path = \"./figures/{}/\".format(base_network)\n",
    "\n",
    "conditions = [f\"{base_network}_Base_condition\",\n",
    "              f\"{base_network}_Plus_1ep\",\n",
    "              f\"{base_network}_Plus_10ep\",\n",
    "              f\"{base_network}_Different_optimizer\",\n",
    "              f\"{base_network}_Different_batchsize\",\n",
    "              f\"{base_network}_Different_initialisation\",\n",
    "              f\"{base_network}_Different_LR\",\n",
    "              f\"{base_network}_CUDA_nondeterministic\",\n",
    "              f\"{base_network}_Different_dataorder\",\n",
    "              f\"{base_network}_Different_architecture\",\n",
    "              f\"{base_network}_Different_data\",\n",
    "              f\"{base_network}_Half_data\",\n",
    "              f\"{base_network}_Combined_condition\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get base array and allocate list for main plot\n",
    "base, _, _, _, num_base_epochs = prep_condition(f\"{base_network}_Base_condition\", folder_name, base_network)\n",
    "\n",
    "# Set flag whether to correct for label errors\n",
    "correct = False\n",
    "\n",
    "# Set flag whether to remove impossibles and trivials for heatmap analysis\n",
    "rm = False\n",
    "\n",
    "# Load label error indices if they should be corrected for\n",
    "label_error = np.load(\"imagenet_val_ident.npy\", allow_pickle=True)[:, 2].astype(np.int)\n",
    "error_inds = np.where(label_error == 1)\n",
    "right_inds = np.where(label_error == 0)\n",
    "\n",
    "# Pre-allocate arrays\n",
    "main_array = np.zeros((num_base_epochs, len(conditions), 2), dtype=object)\n",
    "\n",
    "if correct:\n",
    "    decisions = np.zeros((len(conditions), len(base[0][0][1])-len(error_inds[0]), num_base_epochs))\n",
    "    decisions_correct = np.zeros((len(conditions), len(base[0][0][1])-len(error_inds[0]), num_base_epochs))\n",
    "    decisions_change = np.zeros((len(conditions), num_base_epochs - 1))\n",
    "else:\n",
    "    decisions = np.zeros((len(conditions), len(base[0][0][1]), num_base_epochs))\n",
    "    decisions_correct = np.zeros((len(conditions), len(base[0][0][1]), num_base_epochs))\n",
    "    decisions_change = np.zeros((len(conditions), num_base_epochs - 1))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Analysis of class accuracies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get number of classes\n",
    "num_classes = len(np.unique(base[0][0][2]))\n",
    "\n",
    "# Get mean accuracies for each class\n",
    "class_accuracies = np.zeros(num_classes)\n",
    "equal_answers = np.equal(base[0][num_base_epochs-1][1], base[0][num_base_epochs-1][2]).numpy()\n",
    "for ind in range(num_classes):\n",
    "    class_accuracies[ind] = np.mean(equal_answers[np.where(base[0][num_base_epochs-1][2] == ind)[0]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if base_network == \"Res18\":\n",
    " \n",
    "    num_ident_classes = 10\n",
    "\n",
    "    # Load item idents\n",
    "    idents = np.load('imagenet_val_ident.npy', allow_pickle=True)\n",
    "\n",
    "    print(\"Bad classes\")\n",
    "    for i, ind in enumerate(np.argsort(class_accuracies)[:num_ident_classes]):\n",
    "        print(idents[ind*50])\n",
    "    \n",
    "    print(\"Good classes\")\n",
    "    for i, ind in enumerate(np.flip(np.argsort(class_accuracies))[:num_ident_classes]):\n",
    "        print(idents[ind*50])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10,10))\n",
    "plt.plot(np.flip(np.sort(class_accuracies)), linewidth=4);\n",
    "\n",
    "\n",
    "# Plot settings\n",
    "sns.set_context(\"paper\", font_scale=4.0)\n",
    "sns.set_style(\"white\")\n",
    "sns.color_palette(\"viridis\", as_cmap=True)\n",
    "sns.despine(right=True, top=True, offset=10, trim=False)\n",
    "plt.xlabel(\"Class number\", fontsize=40, labelpad=20)\n",
    "plt.ylabel(\"Accuracy\", fontsize=40, labelpad=20)\n",
    "plt.ylim(0,1.05)\n",
    "\n",
    "plt.savefig(figure_path + f'{base_network}_class_accs.png', dpi=300, bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Analysis of accuracy in last epoch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "accs_lastep = np.zeros(len(conditions))\n",
    "\n",
    "# Loop through conditions to make array containing accuracy in last epoch\n",
    "for ind, condition in enumerate(conditions):\n",
    "\n",
    "    # Load data, conditions with more epochs than base are stored in base file\n",
    "    results, val_acc, num_epochs, num_models, num_base_epochs = prep_condition(condition, folder_name, base_network)\n",
    "    \n",
    "    if val_acc.shape[0] == 1:\n",
    "        accs_lastep[ind] = val_acc[0, num_base_epochs-1]\n",
    "    else:\n",
    "        accs_lastep[ind] = np.mean(val_acc[:, num_base_epochs-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(np.mean(accs_lastep), np.min(accs_lastep), np.max(accs_lastep), np.std(accs_lastep))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loop through conditions to build arrays"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "rm = False\n",
    "rm_inds = 0\n",
    "\n",
    "# Only enable to make superimp plot without combined condition\n",
    "#conditions.pop(np.where([cond.endswith(\"Combined_condition\") for cond in conditions])[0][0])\n",
    "\n",
    "# Loop through conditions to make error consistency and krippendorf arrays\n",
    "for ind, condition in enumerate(conditions):\n",
    "\n",
    "    # Load data, conditions with more epochs than base are stored in base file\n",
    "    results, val_acc, num_epochs, num_models, num_base_epochs = prep_condition(condition, folder_name, base_network)\n",
    "    \n",
    "    # Init econ array\n",
    "    econ_tobase = np.zeros((num_models, num_epochs))\n",
    "    econ_subseq = np.zeros((num_models, num_epochs-1))\n",
    "\n",
    "    # Loop through models and epochs to get error consistency for all epochs of all models of this condition\n",
    "    for model in range(num_models):\n",
    "            \n",
    "        # Loop through epochs\n",
    "        for epoch in range(num_epochs):\n",
    "\n",
    "            # Print where we are\n",
    "            print(f\"Condition: {condition}, Model: {model}, Epoch: {epoch}\")\n",
    "            \n",
    "            # For different data compared, the two networks are only compared to each other\n",
    "            if condition.endswith(\"Different_data\"):\n",
    "                econ_tobase[model, epoch] = calc_econ(results, results, \n",
    "                                                      0+model, 1-model, \n",
    "                                                      epoch, epoch,\n",
    "                                                      error_inds, correct,\n",
    "                                                      rm_inds, rm)\n",
    "            else:\n",
    "                if epoch < num_base_epochs-1:\n",
    "                    if condition.endswith(\"Plus_1ep\") and epoch > 0:\n",
    "                        econ_tobase[model, epoch] = calc_econ(results, base, \n",
    "                                                              model, 0, \n",
    "                                                              epoch, epoch-1,\n",
    "                                                              error_inds, correct,\n",
    "                                                              rm_inds, rm)\n",
    "                    elif condition.endswith(\"Plus_10ep\") and epoch > 9:\n",
    "                        econ_tobase[model, epoch] = calc_econ(results, base, \n",
    "                                                              model, 0, \n",
    "                                                              epoch, epoch-10,\n",
    "                                                              error_inds, correct,\n",
    "                                                              rm_inds, rm)\n",
    "                    else:\n",
    "                        econ_tobase[model, epoch] = calc_econ(results, base, \n",
    "                                                              model, 0, \n",
    "                                                              epoch, epoch,\n",
    "                                                              error_inds, correct,\n",
    "                                                              rm_inds, rm)\n",
    "                else:\n",
    "                    econ_tobase[model, epoch] = calc_econ(results, base, \n",
    "                                                          model, 0, \n",
    "                                                          epoch, num_base_epochs-1,\n",
    "                                                          error_inds, correct,\n",
    "                                                          rm_inds, rm)\n",
    "\n",
    "            # Calculate error consistency compared to next epoch\n",
    "            if epoch < (num_epochs-1):\n",
    "                econ_subseq[model, epoch] = calc_econ(results, results, \n",
    "                                                      model, model, \n",
    "                                                      epoch, epoch+1,\n",
    "                                                      error_inds, correct,\n",
    "                                                      rm_inds, rm)\n",
    "\n",
    "            # Add to main_array only when econ_array is in last model index\n",
    "            if model == num_models-1:\n",
    "                if condition.endswith(\"Plus_1ep\") and epoch > 0:\n",
    "                    main_array[epoch-1, ind, 0] = condition.replace(base_network, \"\").replace(\"_\", \"\\n\")\n",
    "                    main_array[epoch-1, ind, 1] = econ_tobase[:, epoch]\n",
    "                elif condition.endswith(\"Plus_10ep\") and epoch > 9:\n",
    "                    main_array[epoch-10, ind, 0] = condition.replace(base_network, \"\").replace(\"_\", \"\\n\")\n",
    "                    main_array[epoch-10, ind, 1] = econ_tobase[:, epoch]\n",
    "                else:\n",
    "                    main_array[epoch, ind, 0] = condition.replace(base_network, \"\").replace(\"_\", \"\\n\")\n",
    "                    main_array[epoch, ind, 1] = econ_tobase[:, epoch]\n",
    "                \n",
    "                # Add decisions and whether they were correct to their respective arrays\n",
    "                if epoch < (num_base_epochs):\n",
    "                    if correct:\n",
    "                        decisions[ind, :, epoch] = np.delete(np.array(results[model][epoch][1]),\n",
    "                                                             error_inds)\n",
    "                        decisions_correct[ind, :, epoch] = np.delete(np.equal(np.array(results[model][epoch][1]),\n",
    "                                                                              np.array(results[model][epoch][2])),\n",
    "                                                                     error_inds)\n",
    "                    else:\n",
    "                        decisions[ind, :, epoch] = np.array(results[model][epoch][1])\n",
    "                        decisions_correct[ind, :, epoch] = np.equal(np.array(results[model][epoch][1]),\n",
    "                                                                    np.array(results[model][epoch][2]))\n",
    "\n",
    "                if epoch < (num_base_epochs - 1):\n",
    "                    if correct:\n",
    "                        decisions_change[ind, epoch] = np.sum(np.delete(np.equal(np.array(results[model][epoch][1]),\n",
    "                                                                                 np.array(results[model][epoch+1][1])),\n",
    "                                                                        error_inds))\n",
    "                                                              \n",
    "                    else:\n",
    "                        decisions_change[ind, epoch] = np.sum(np.equal(np.array(results[model][epoch][1]),\n",
    "                                                                   np.array(results[model][epoch+1][1])))\n",
    "\n",
    "\n",
    "    # Make the three basic plots\n",
    "    three_plots(val_acc, econ_subseq, econ_tobase, condition, figure_path, base_network)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prepare for and make main error consistency plot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get order of decisions plot for base condition\n",
    "base_index = conditions.index(f\"{base_network}_Base_condition\")\n",
    "base_order = np.flip(np.mean(decisions_correct[base_index, :, :], axis=1).argsort())\n",
    "\n",
    "# Make decisions superimposed array\n",
    "mean_decisions_correct = np.mean(decisions_correct, axis=0)\n",
    "\n",
    "# Ordering for mean decisions correct array\n",
    "mean_order = np.flip((mean_decisions_correct[:,-1]).argsort())\n",
    "\n",
    "# Get indices for impossibles and trivials together\n",
    "entries = np.unique(mean_decisions_correct[:, num_base_epochs-1])\n",
    "\n",
    "if rm == True:\n",
    "    rm_inds = np.concatenate((np.where(mean_decisions_correct[:, num_base_epochs-1] == entries[-1])[0], \n",
    "                              np.where(mean_decisions_correct[:, num_base_epochs-1] == entries[-2])[0], \n",
    "                              np.where(mean_decisions_correct[:, num_base_epochs-1] == entries[0])[0],\n",
    "                              np.where(mean_decisions_correct[:, num_base_epochs-1] == entries[1])[0]))\n",
    "\n",
    "# Get y position for hlines in decision plot\n",
    "upper_line = len(np.where(mean_decisions_correct[:, num_base_epochs-1] == entries[-1])[0])\n",
    "lower_line = mean_decisions_correct.shape[0] - len(np.where(mean_decisions_correct[:, num_base_epochs-1] == entries[0])[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if base_network == \"Res18CIFAR\":\n",
    "    np.save(\"cifar_main_array\", main_array)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if base_network == \"Res18fc100\":\n",
    "    \n",
    "    # Make main plot\n",
    "    cifar_main = np.load(\"cifar_main_array.npy\", allow_pickle=True)\n",
    "    order = np.load(\"res18_main_order.npy\", allow_pickle=True)\n",
    "    main_plot_both(main_array, cifar_main, num_base_epochs-1, figure_path, base_network, order)\n",
    "    \n",
    "elif base_network == \"Res18\":\n",
    "    \n",
    "    # Make main plot\n",
    "    res18_order = main_plot(main_array, num_base_epochs-1, figure_path, base_network)\n",
    "    np.save(\"res18_main_order\", res18_order)\n",
    "    \n",
    "elif base_network == \"VGG11\" or base_network == \"Dense121\":\n",
    "    \n",
    "    # Make main plot\n",
    "    order = np.load(\"res18_main_order.npy\")\n",
    "    main_plot(main_array, num_base_epochs-1, figure_path, base_network, order)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plots for appendix: main plot with accuracy and with RSA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if base_network == \"Res18\":\n",
    "    \n",
    "    # Remove first \\n from labels\n",
    "    labels = main_array[0, res18_order, 0]\n",
    "    labels_correct = np.empty(len(labels), dtype=object)\n",
    "    for ind, label in enumerate(labels):\n",
    "        labels_correct[ind] = label[1:]\n",
    "\n",
    "    from analysis_helpers import main_plot_acc\n",
    "    main_plot_acc(accs_lastep, 0, figure_path, base_network, labels_correct, res18_order)\n",
    "    plt.savefig(figure_path + f'{base_network}_main_plot_acc.png', bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if base_network == \"Res18\":\n",
    "    rsa = np.load(f'{base_network}_rsa_corrs.npy')\n",
    "    main_plot_acc(rsa, 0, figure_path, base_network, labels_correct, res18_order)\n",
    "    plt.savefig(figure_path + f'{base_network}_main_plot_rsa.png', bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Make superimposed decisions plot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Only if errors exist\n",
    "if correct == False:\n",
    "\n",
    "    # Make puffer and line arrays\n",
    "    puffer = np.zeros((decisions.shape[1],1))\n",
    "    puffer[:] = np.nan\n",
    "    lines = np.copy(puffer)\n",
    "    lines[upper_line-50:upper_line+50, :] = 1\n",
    "    lines[lower_line-50:lower_line+50, :] = 1\n",
    "    \n",
    "    # Restrict decisions array and add label errors\n",
    "    decisions_with_errs = np.concatenate((mean_decisions_correct[mean_order, :], \n",
    "                                          puffer,\n",
    "                                          puffer,\n",
    "                                          lines,\n",
    "                                          lines), \n",
    "                                          axis=1)\n",
    "\n",
    "    # Make overlay \n",
    "    overlay = np.zeros(mean_decisions_correct.shape) \n",
    "    overlay[:] = np.nan\n",
    "    overlay_array = np.concatenate((overlay,     \n",
    "                                    puffer,\n",
    "                                    puffer,\n",
    "                                    lines,\n",
    "                                    lines),\n",
    "                                    axis=1)\n",
    "\n",
    "else:\n",
    "    \n",
    "    # Restrict decisions array and add label errors\n",
    "    decisions_with_errs = mean_decisions_correct[mean_order, :]\n",
    "    overlay_array = np.zeros(decisions_with_errs.shape) \n",
    "    overlay_array[:] = np.nan   \n",
    "\n",
    "# Plot decisions array\n",
    "plot_decisions(decisions_with_errs, overlay_array, num_base_epochs, \n",
    "               figure_path, f\"{base_network}_superimp\", correct)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Histogram analysis with binomial observer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set epoch for which to make histogram\n",
    "ep = num_base_epochs-1\n",
    "num_models = len(conditions)\n",
    "\n",
    "# Make mean decisions correct with binomial observer figure\n",
    "from numpy.random import default_rng\n",
    "np.random.seed(1312)\n",
    "rng = default_rng()\n",
    "s = rng.binomial(num_models, np.mean(mean_decisions_correct[:, ep]), mean_decisions_correct.shape[0])\n",
    "\n",
    "# Weights for normalisation of heatmap\n",
    "weights = np.ones_like(mean_decisions_correct[:, ep]) / float(len(mean_decisions_correct[:, ep]))\n",
    "\n",
    "# Get counts\n",
    "counts = np.histogram(mean_decisions_correct[:, ep] * num_models, bins=num_models+1)\n",
    "\n",
    "# Init histogram figure\n",
    "plt.figure(figsize=(10,10))\n",
    "sns.histplot(data=mean_decisions_correct[:, ep] * num_models, bins=num_models+1, \n",
    "             color= 'blue', label='Images from validation set', stat=\"probability\")\n",
    "\n",
    "if base_network == \"Res18\" and correct == False:\n",
    "    \n",
    "    # Get wronlgy labelled images\n",
    "    mean_decisions_correct_errors = mean_decisions_correct[error_inds, :].squeeze() * num_models\n",
    "    error_weights = np.ones_like(mean_decisions_correct_errors[:, ep]) / float(len(mean_decisions_correct_errors[:, ep]))\n",
    "    sns.histplot(data=np.concatenate((np.negative(np.ones(45000) * (num_models+1)),mean_decisions_correct_errors[:, ep])), \n",
    "                 bins=num_models*2+1, color='red', label='Label errors', stat=\"probability\")\n",
    "\n",
    "# Plot histogram\n",
    "sns.histplot(data=s, alpha=0.5, bins=num_models, \n",
    "             color='green', label='Binomial distribution', stat=\"probability\")\n",
    "\n",
    "# Plot settings\n",
    "sns.set_context(\"paper\", font_scale=3.0)\n",
    "sns.set_style(\"white\")\n",
    "sns.color_palette(\"viridis\", as_cmap=True)\n",
    "sns.despine(right=True, top=True, offset=20, trim=False)\n",
    "plt.xlabel(\"Number of models that classify image correctly\", fontsize=30, labelpad=10)\n",
    "plt.ylabel(\"Fraction of images\", fontsize=30, labelpad=10)\n",
    "plt.xticks(np.arange(0, num_models+1, step=2), [\"None\", \"2\", \"4\", \"6\", \"8\", \"10\", \"   All\"])\n",
    "plt.ylim(0,0.7)\n",
    "plt.xlim(0,num_models+1)\n",
    "\n",
    "if base_network == \"Res18\":\n",
    "    plt.legend(loc=\"upper left\", fontsize=30, frameon=False)\n",
    "plt.savefig(figure_path + f'{base_network}_hist.png', dpi=300, bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "counts_percent = (counts[0] / decisions.shape[1]) * 100\n",
    "counts_impossible = counts_percent[0]\n",
    "counts_trivial = counts_percent[-1]\n",
    "counts_fluctuating = np.sum(counts_percent[1:-1])\n",
    "print(f\"Impossible: {counts_impossible:.4}, Trivial: {counts_trivial:.4}, Fluctuation: {counts_fluctuating:.4}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Calculate what error consistency can be expected based on binomial observer model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sklearn\n",
    "\n",
    "# Get indices of items that are always done wrong and right and combine the two\n",
    "all_right = np.where((mean_decisions_correct[:, ep] == 1))[0]\n",
    "all_wrong = np.where((mean_decisions_correct[:, ep] == 0))[0]\n",
    "all_same = np.hstack((all_right, all_wrong))\n",
    "\n",
    "# Remove images which are always answered the same and calculate accuracy for the remaining over all conditions\n",
    "cleaned_acc = np.mean(np.delete(mean_decisions_correct[:, ep], all_same))\n",
    "    \n",
    "# Randomly sample network making right or wrong decisions on remaining samples\n",
    "s1 = rng.binomial(1, cleaned_acc, mean_decisions_correct.shape[0] - all_same.shape[0])\n",
    "s2 = rng.binomial(1, cleaned_acc, mean_decisions_correct.shape[0] - all_same.shape[0])\n",
    "\n",
    "# Add all same images \n",
    "s1_all = np.hstack((np.ones(all_right.shape[0]), np.zeros(all_wrong.shape[0]), s1))\n",
    "s2_all = np.hstack((np.ones(all_right.shape[0]), np.zeros(all_wrong.shape[0]), s2))\n",
    "\n",
    "# Observed same\n",
    "sklearn.metrics.cohen_kappa_score(s1_all, s2_all)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prepare export of example ImageNet images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchvision.transforms as transforms\n",
    "import torchvision.datasets as datasets\n",
    "\n",
    "class ImageFolderWithPaths(datasets.ImageFolder):\n",
    "    \"\"\"Custom dataset that includes image paths. Extends\n",
    "    torchvision.datasets.ImageFolder\n",
    "    \"\"\"\n",
    "    # override the __getitem__ method that dataloader calls\n",
    "    def __getitem__(self, index):\n",
    "        # this is what ImageFolder normally returns \n",
    "        original_tuple = super(ImageFolderWithPaths, self).__getitem__(index)\n",
    "        # the image file path\n",
    "        path = self.imgs[index][0]\n",
    "        # make a new tuple that includes original and the path\n",
    "        tuple_with_path = (original_tuple + (path,))\n",
    "        \n",
    "        return tuple_with_path\n",
    "    \n",
    "def init_dataset(DSET, NUM, DATA, MODEL, CONDITION, train=True):\n",
    "    \"\"\"\n",
    "    :param DSET: Which dataset to use\n",
    "    :param NUM: Model index\n",
    "    :param DATA: Whether to use same or different datasets\n",
    "    :param train: Whether to choose train or test dataset\n",
    "    :return: Initialized dataset\n",
    "    \"\"\"\n",
    "\n",
    "    # Check which dataset is supposed to be used\n",
    "    if DSET == \"ImageNet\":\n",
    "\n",
    "        # Dataset location\n",
    "        data_dir = '/scratch_local/datasets/ImageNet2012/'\n",
    "\n",
    "        # Check whether train or validation dataset is needed\n",
    "        if train is True:\n",
    "            path = os.path.join(data_dir, 'train')\n",
    "\n",
    "            # Initialize dataset and apply transforms\n",
    "            dataset = ImageFolderWithPaths(\n",
    "                path,\n",
    "                transforms.Compose([\n",
    "                    transforms.RandomResizedCrop(224),\n",
    "                    transforms.RandomHorizontalFlip(),\n",
    "                    transforms.ToTensor(),\n",
    "                    #normalize,\n",
    "                ]))\n",
    "\n",
    "        elif train is False:\n",
    "            path = os.path.join(data_dir, 'val')\n",
    "\n",
    "            # Initialize dataset and apply transforms\n",
    "            dataset = ImageFolderWithPaths(\n",
    "                path,\n",
    "                transforms.Compose([\n",
    "                    transforms.Resize(256),\n",
    "                    transforms.CenterCrop(224),\n",
    "                    transforms.ToTensor(),\n",
    "                    #normalize,\n",
    "                ]))\n",
    "            \n",
    "    \n",
    "    if DSET == \"CIFAR100\":\n",
    "\n",
    "        # Dataset location\n",
    "        data_dir = '/home/wichmann/lschulzebuschoff43/cifar100/'\n",
    "\n",
    "        # Set normalization parameters\n",
    "        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n",
    "     \n",
    "        # Check whether train or validation dataset is needed\n",
    "        if train is True:\n",
    "            path = os.path.join(data_dir, 'train')\n",
    "\n",
    "            # Initialize dataset and apply transforms\n",
    "            dataset = datasets.ImageFolder(\n",
    "                path,\n",
    "                transforms.Compose([\n",
    "                    transforms.RandomHorizontalFlip(),\n",
    "                    transforms.RandomCrop(32, 4),\n",
    "                    transforms.ToTensor(),\n",
    "                    normalize,\n",
    "                ]))\n",
    "\n",
    "        elif train is False:\n",
    "            path = os.path.join(data_dir, 'val')\n",
    "\n",
    "            # Initialize dataset and apply transforms\n",
    "            dataset = datasets.ImageFolder(\n",
    "                path,\n",
    "                transforms.Compose([\n",
    "                    transforms.ToTensor(),\n",
    "                    #normalize,\n",
    "                ]))\n",
    "\n",
    "    return dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Show wrong images for ResNet-18 on ImageNet\n",
    "if base_network == \"Res18\":\n",
    "\n",
    "    val_set = init_dataset(\"ImageNet\", _, _, _, _, train=False)\n",
    "    \n",
    "# Show wrong images for ResNet-18 on CIFAR-100\n",
    "if base_network == \"Res18CIFAR\":\n",
    "\n",
    "    val_set = init_dataset(\"CIFAR100\", _, _, _, _, train=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# If 5000 label error images are supposed to be removed, remove entries from order\n",
    "if correct:\n",
    "    clean_order = mean_order[np.where(label_error[mean_order] == 0)]\n",
    "else:\n",
    "    clean_order = mean_order"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Imports and parameters for example image plots\n",
    "from analysis_helpers import plotSingle\n",
    "exp_loc = []\n",
    "num_samples = 50\n",
    "np.random.seed(1312)\n",
    "\n",
    "# Get indices of random bad and good images\n",
    "bad_inds = np.random.choice(np.flip(clean_order)[0:1000], size=num_samples, replace=False)\n",
    "good_inds = np.random.choice(clean_order[0:1000], size=num_samples, replace=False)\n",
    "\n",
    "# Get 10 most falsely decided on images\n",
    "for index in range(num_samples):\n",
    "\n",
    "    # Get respective image from validation set\n",
    "    image_good = np.transpose(val_set[good_inds[index]][0].cpu().detach().numpy(), (1, 2, 0))\n",
    "    image_bad = np.transpose(val_set[bad_inds[index]][0].cpu().detach().numpy(), (1, 2, 0))\n",
    "\n",
    "    plt.figure(figsize=(16,16))\n",
    "    plt.imshow(image_good)\n",
    "    plt.axis('off')\n",
    "    plt.savefig(f'./example_images/{base_network}/good_{index}.jpg', bbox_inches='tight')\n",
    "    plt.close()\n",
    "\n",
    "    plt.figure(figsize=(16,16))\n",
    "    plt.imshow(image_bad)\n",
    "    plt.axis('off')\n",
    "    plt.savefig(f'./example_images/{base_network}/bad_{index}.jpg', bbox_inches='tight')\n",
    "    plt.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Make decisions plots for the individual conditions and all combinations between them"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Order decisions dataframe by mean accuracy for each item\n",
    "for ind, condition in enumerate(conditions):\n",
    "       \n",
    "    # Only if errors exist\n",
    "    if correct:\n",
    "        \n",
    "        # Restrict decisions array and add label errors\n",
    "        decisions_with_errs = np.concatenate((decisions_correct[ind, base_order, :], \n",
    "                                          np.zeros((len(label_error),1)),\n",
    "                                          np.zeros((len(label_error),1)),\n",
    "                                          np.zeros((len(label_error),1))), \n",
    "                                          axis=1)\n",
    "    \n",
    "        # Make overlay \n",
    "        overlay = np.zeros(decisions_correct[ind, :, :].shape) \n",
    "        overlay[:] = np.nan\n",
    "        overlay_array = np.concatenate((overlay,     \n",
    "                                        np.zeros((len(label_error),1)),\n",
    "                                        np.zeros((len(label_error),1)),\n",
    "                                        label_error[base_order, np.newaxis]),\n",
    "                                        axis=1)\n",
    "    else:\n",
    "        \n",
    "        # Restrict decisions array and add label errors\n",
    "        decisions_with_errs = decisions_correct[ind, base_order, :]\n",
    "        overlay_array = np.zeros(decisions_with_errs.shape) \n",
    "        overlay_array[:] = np.nan\n",
    "\n",
    "        \n",
    "    # Plot decisions array\n",
    "    plot_decisions(decisions_with_errs, overlay_array, num_base_epochs, \n",
    "                   figure_path, condition, correct)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Make combinations for decisions arrays\n",
    "for cond1, cond2 in itertools.combinations(conditions, 2):\n",
    "\n",
    "    # Get indices of conditions\n",
    "    ind1 = conditions.index(cond1)\n",
    "    ind2 = conditions.index(cond2)\n",
    " \n",
    "    # Only if errors exist\n",
    "    if correct == True:\n",
    "        \n",
    "        # Combine decisions arrays\n",
    "        decisions_with_errs = np.concatenate((np.equal(decisions_correct[ind1, base_order, :], \n",
    "                                                       decisions_correct[ind2, base_order, :]),\n",
    "                                              np.zeros((len(label_error),1)),\n",
    "                                              np.zeros((len(label_error),1)),\n",
    "                                              np.zeros((len(label_error),1))), \n",
    "                                              axis=1)\n",
    "     \n",
    "        # Make overlay \n",
    "        overlay = np.zeros(decisions_correct[ind, :, :].shape) \n",
    "        overlay[:] = np.nan\n",
    "        overlay_array = np.concatenate((overlay,     \n",
    "                                        np.zeros((len(label_error),1)),\n",
    "                                        np.zeros((len(label_error),1)),\n",
    "                                        label_error[base_order, np.newaxis]),\n",
    "                                        axis=1)\n",
    "        \n",
    "    else:\n",
    "\n",
    "        # Combine decisions arrays\n",
    "        decisions_with_errs = np.equal(decisions_correct[ind1, base_order, :], \n",
    "                                       decisions_correct[ind2, base_order, :])\n",
    "        overlay_array = np.zeros(decisions_with_errs.shape) \n",
    "        overlay_array[:] = np.nan\n",
    " \n",
    " \n",
    "    # Plot decisions array\n",
    "    plot_decisions(decisions_with_errs, overlay_array, num_base_epochs, \n",
    "                   figure_path, f\"comb_{cond1}_{cond2}\", base_network, correct)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plot for appendix that shows how decisions change from epoch to epoch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Reformat decisions change array to be between 0 and 1 \n",
    "decisions_change_c = (decisions_correct.shape[1] - decisions_change) / decisions_correct.shape[1]\n",
    "\n",
    "# Plot changes in decisions for each condition\n",
    "plot_decisions_change(decisions_change_c, figure_path, base_network, conditions, decisions_correct.shape[1])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Remove different data condition for heat map analysis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Remove different data condition from conditions array\n",
    "conditions.pop(np.where([cond.endswith(\"Different_data\") for cond in conditions])[0][0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Make heat map plots for appendix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Pre-allocate array for heat map\n",
    "heat_array = np.empty((len(conditions), len(conditions)))\n",
    "heat_array[:] = np.nan\n",
    "heat_array_ep1 = np.empty((len(conditions), len(conditions)))\n",
    "heat_array_ep1[:] = np.nan\n",
    "\n",
    "# Loop through conditions to make heat map, except for different data compared condition\n",
    "for ind_1, condition_1 in enumerate(conditions):\n",
    "    for ind_2, condition_2 in enumerate(conditions):\n",
    "\n",
    "        # Load data, conditions with more epochs than base are stored in base file\n",
    "        results_1, val_acc_1, num_epochs_1, num_models_1, num_base_epochs_1 = prep_condition(condition_1,\n",
    "                                                                                             folder_name, base_network)\n",
    "        results_2, val_acc_2, num_epochs_2, num_models_2, num_base_epochs_2 = prep_condition(condition_2,\n",
    "                                                                                             folder_name, base_network)\n",
    "\n",
    "        # Pre-allocate heat array for this combination of conditions\n",
    "        comb_array = []\n",
    "        comb_array_ep1 = []\n",
    "\n",
    "        print(f\"{num_epochs_1}, {num_epochs_2}\")\n",
    "\n",
    "        # Calculate error consistency for all models of both conditions\n",
    "        for model_1 in range(num_models_1):\n",
    "            for model_2 in range(num_models_2):\n",
    "\n",
    "                print(f\"Condition 1: {condition_1}, Condition 2: {condition_2}, Model 1: {model_1}, Model 2: {model_2}\")\n",
    "                    \n",
    "                # Don't compare the same models for the same condition, as they will also have consistency of 1\n",
    "                if not (condition_1 == condition_2 and model_1 == model_2):\n",
    "                    comb_array.append(calc_econ(results_1, results_2, \n",
    "                                                model_1, model_2, \n",
    "                                                num_epochs_1-1, num_epochs_2-1,\n",
    "                                                error_inds, correct,\n",
    "                                                rm_inds, rm))\n",
    "                    if condition_1.endswith(\"Plus_1ep\") and not condition_2.endswith(\"Plus_1ep\"):\n",
    "                        comb_array_ep1.append(calc_econ(results_1, results_2, \n",
    "                                                        model_1, model_2, \n",
    "                                                        2, 1,\n",
    "                                                        error_inds, correct,\n",
    "                                                        rm_inds, rm))\n",
    "                    elif not condition_1.endswith(\"Plus_1ep\") and condition_2.endswith(\"Plus_1ep\"):\n",
    "                        comb_array_ep1.append(calc_econ(results_1, results_2, \n",
    "                                                        model_1, model_2, \n",
    "                                                        1, 2,\n",
    "                                                        error_inds, correct,\n",
    "                                                        rm_inds, rm))\n",
    "                    elif condition_1.endswith(\"Plus_10ep\") and not condition_2.endswith(\"Plus_10ep\"):\n",
    "                        comb_array_ep1.append(calc_econ(results_1, results_2, \n",
    "                                                        model_1, model_2, \n",
    "                                                        11, 1,\n",
    "                                                        error_inds, correct,\n",
    "                                                        rm_inds, rm))\n",
    "                    elif not condition_1.endswith(\"Plus_10ep\") and condition_2.endswith(\"Plus_10ep\"):\n",
    "                        comb_array_ep1.append(calc_econ(results_1, results_2, \n",
    "                                                        model_1, model_2, \n",
    "                                                        1, 11,\n",
    "                                                        error_inds, correct,\n",
    "                                                        rm_inds, rm))\n",
    "                    else:\n",
    "                        comb_array_ep1.append(calc_econ(results_1, results_2, \n",
    "                                                        model_1, model_2, \n",
    "                                                        1, 1,\n",
    "                                                        error_inds, correct,\n",
    "                                                        rm_inds, rm))\n",
    "\n",
    "                # If the same condition is compared and there is only one model for this condition, add econ anyways\n",
    "                if condition_1 == condition_2 and num_models_1 == 1 and num_models_2 == 1:\n",
    "                    comb_array.append(calc_econ(results_1, results_2, \n",
    "                                                model_1, model_2, \n",
    "                                                num_epochs_1-1, num_epochs_2-1,\n",
    "                                                error_inds, correct,\n",
    "                                                rm_inds, rm))\n",
    "                    if condition_1.endswith(\"Plus_1ep\") and condition_2.endswith(\"Plus_1ep\"):\n",
    "                        comb_array_ep1.append(calc_econ(results_1, results_2, \n",
    "                                                        model_1, model_2, \n",
    "                                                        2, 2,\n",
    "                                                        error_inds, correct,\n",
    "                                                        rm_inds, rm))\n",
    "                    elif condition_1.endswith(\"Plus_10ep\") and condition_2.endswith(\"Plus_10ep\"):\n",
    "                        comb_array_ep1.append(calc_econ(results_1, results_2, \n",
    "                                                        model_1, model_2, \n",
    "                                                        11, 11,\n",
    "                                                        error_inds, correct,\n",
    "                                                        rm_inds, rm))\n",
    "                    else:\n",
    "                        comb_array_ep1.append(calc_econ(results_1, results_2, \n",
    "                                                        model_1, model_2, \n",
    "                                                        1, 1,\n",
    "                                                        error_inds, correct,\n",
    "                                                        rm_inds, rm))\n",
    "\n",
    "        # Mean over error consistencies is entry for this combination\n",
    "        heat_array[ind_1, ind_2] = np.mean(comb_array)\n",
    "        heat_array_ep1[ind_1, ind_2] = np.mean(comb_array_ep1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Format conditions\n",
    "epoch = num_base_epochs - 1\n",
    "plt.figure(figsize=(12, 12))\n",
    "cond_clear = [cond.replace(base_network, \"\").replace(\"_\", \" \") for cond in conditions]\n",
    "\n",
    "# Make and save heatmap\n",
    "sns.heatmap(heat_array, annot=True, fmt = \".2f\", annot_kws={\"fontsize\":12}, square=True,\n",
    "            xticklabels=cond_clear, yticklabels=cond_clear, cmap=\"Blues\", cbar_kws={\"shrink\": .60},\n",
    "            vmin=0, vmax=1)\n",
    "#plt.title(label=f'Error consistency compared to base model for Epoch: {epoch}', fontsize=26, y=1.1)\n",
    "plt.tight_layout()\n",
    "plt.savefig(figure_path + f'{base_network}_heatmap_ep{epoch}.png')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Format conditions\n",
    "epoch = 1\n",
    "plt.figure(figsize=(12, 12))\n",
    "cond_clear = [cond.replace(base_network, \"\").replace(\"_\", \" \") for cond in conditions]\n",
    "\n",
    "# Make and save heatmap\n",
    "sns.heatmap(heat_array_ep1, annot=True, fmt = \".2f\", annot_kws={\"fontsize\":12}, square=True,\n",
    "            xticklabels=cond_clear, yticklabels=cond_clear, cmap=\"Blues\", cbar_kws={\"shrink\": .60},\n",
    "            vmin=0, vmax=1)\n",
    "#plt.title(label=f'Error consistency compared to base model for Epoch: {epoch}', fontsize=26, y=1.1)\n",
    "plt.tight_layout()\n",
    "plt.savefig(figure_path + f'{base_network}_heatmap_ep{epoch}.png')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
