{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports and general settings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Imports\n",
    "import os\n",
    "import sys\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import itertools\n",
    "from analysis_helpers import load_data, prep_condition, order_econs\n",
    "from analysis_helpers import main_plot, three_plots, heat_plot, plot_decisions, plot_decisions_change, main_plot_both\n",
    "from analysis_helpers import calc_econ\n",
    "from data import init_dataset\n",
    "import seaborn as sns\n",
    "from sklearn.metrics import cohen_kappa_score\n",
    "\n",
    "# Keep local packages updated\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set base network and conditions\n",
    "folder_name = \"./results/Sota/\"\n",
    "figure_path = \"./figures/\"\n",
    "conditions = os.listdir(folder_name)\n",
    "print(f\"{len(conditions)} models found: {conditions}.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set flag whether to correct for label errors\n",
    "correct = False\n",
    "\n",
    "# Set flag whether to remove impossibles and trivials for heatmap analysis\n",
    "rm = False\n",
    "\n",
    "# Load label error indices if they should be corrected for\n",
    "label_error = np.load(\"imagenet_val_ident.npy\", allow_pickle=True)[:, 2].astype(np.int)\n",
    "error_inds = np.where(label_error == 1)\n",
    "right_inds = np.where(label_error == 0)\n",
    "\n",
    "# Pre-allocate decisions_correct array\n",
    "decisions_correct = np.zeros((len(conditions), 50000, 1))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loop through models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(40,40))\n",
    "fig, ax = plt.subplots(nrows=3, ncols=4, sharex=True, sharey=True, figsize=(40, 40))\n",
    "fig.text(0.5, 0.04, 'Class', ha='center')\n",
    "fig.text(0.04, 0.5, 'Accuracy', va='center', rotation='vertical')\n",
    "fig.delaxes(ax[2][3])\n",
    "\n",
    "for cond_ind, condition in enumerate(conditions):\n",
    "    \n",
    "    # Special case for clip\n",
    "    if condition == \"CLIP\":\n",
    "        \n",
    "        # Load data\n",
    "        results = np.genfromtxt(folder_name + condition + \"/NUM1/RESULTS_EP0.csv\", delimiter=',', dtype=\"|S10\")\n",
    "        results = [[[results[:,0], results[:,1]]]]\n",
    "        equal_answers = (results[0][0][0] == results[0][0][1])\n",
    "        \n",
    "        # Print model name and accuracy\n",
    "        print(f'{condition}, Accuracy: {np.mean(equal_answers):.2f}')\n",
    "        \n",
    "        # Get class accuracies\n",
    "        num_classes = len(np.unique(results[0][0][1]))\n",
    "        class_accuracies = np.zeros(num_classes)\n",
    "        for ind, category in enumerate(np.unique(results[0][0][1])):\n",
    "            class_accuracies[ind] = np.mean(equal_answers[np.where(results[0][0][1] == category)[0]])\n",
    "            \n",
    "        # Prepare data for histogram\n",
    "        decisions_correct[cond_ind, :, 0] = np.array(results[0][0][0] == results[0][0][1])\n",
    "        \n",
    "    else:\n",
    "        \n",
    "        # Load data\n",
    "        results, val_acc = load_data(folder_name + condition + \"/\")\n",
    "\n",
    "        # Print model name and accuracy\n",
    "        print(f'{condition}, Accuracy: {np.equal(results[0][0][0], results[0][0][1]).float().mean():.2f}')\n",
    "\n",
    "        # Get class accuracies\n",
    "        num_classes = len(np.unique(results[0][0][1]))\n",
    "        class_accuracies = np.zeros(num_classes)\n",
    "        equal_answers = np.equal(results[0][0][0], results[0][0][1]).numpy()\n",
    "        for ind in range(num_classes):\n",
    "            class_accuracies[ind] = np.mean(equal_answers[np.where(results[0][0][1] == ind)[0]])\n",
    "        \n",
    "        # Prepare data for histogram\n",
    "        decisions_correct[cond_ind, :, 0] = np.equal(np.array(results[0][0][0]), np.array(results[0][0][1]))\n",
    "    \n",
    "#     # Make class accuracy plot\n",
    "#     fig.add_subplot(3, 4, cond_ind+1)\n",
    "#     plt.plot(np.flip(np.sort(class_accuracies)), linewidth=4);\n",
    "#     sns.set_context(\"paper\", font_scale=4.0)\n",
    "#     sns.set_style(\"white\")\n",
    "#     sns.color_palette(\"viridis\", as_cmap=True)\n",
    "#     sns.despine(right=True, top=True, offset=10, trim=False)\n",
    "#     plt.title(f\"{condition}\")    \n",
    "#     plt.ylim(0,1.00)\n",
    "#     plt.xlim(0,1000)\n",
    "\n",
    "\n",
    "# # Save combined figure\n",
    "# plt.tick_params(labelcolor='none', which='both', top=False, bottom=False, left=False, right=False)\n",
    "# plt.savefig(figure_path + 'all_class_accs.png', dpi=300, bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Make Histogram plot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set epoch for which to make histogram\n",
    "ep = 0\n",
    "num_models = len(conditions)\n",
    "mean_decisions_correct = np.mean(decisions_correct, axis=0)\n",
    "\n",
    "# Make mean decisions correct with binomial observer figure\n",
    "from numpy.random import default_rng\n",
    "np.random.seed(1312)\n",
    "rng = default_rng()\n",
    "s = rng.binomial(num_models, np.mean(mean_decisions_correct[:, ep]), mean_decisions_correct.shape[0])\n",
    "\n",
    "# Weights for normalisation of heatmap\n",
    "weights = np.ones_like(mean_decisions_correct[:, ep]) / float(len(mean_decisions_correct[:, ep]))\n",
    "\n",
    "# Get counts\n",
    "counts_data = np.histogram(mean_decisions_correct[:, ep] * num_models, bins=np.arange(0, num_models+2))\n",
    "counts_model = np.histogram(s, bins=np.arange(0, num_models+2))\n",
    "\n",
    "# Init histogram figure\n",
    "plt.figure(figsize=(10,10))\n",
    "sns.histplot(data=mean_decisions_correct[:, ep] * num_models, bins=num_models+1, \n",
    "             color= 'blue', label='Images from validation set', stat=\"probability\")\n",
    "\n",
    "# Plot histogram\n",
    "sns.histplot(data=s, alpha=0.5, bins=num_models+1, \n",
    "             color='green', label='Binomial distribution', stat=\"probability\")\n",
    "\n",
    "# Plot settings\n",
    "sns.set_context(\"paper\", font_scale=3.0)\n",
    "sns.set_style(\"white\")\n",
    "sns.color_palette(\"viridis\", as_cmap=True)\n",
    "sns.despine(right=True, top=True, offset=20, trim=False)\n",
    "plt.xlabel(\"Number of models that classify image correctly\", fontsize=30, labelpad=10)\n",
    "plt.ylabel(\"Fraction of images\", fontsize=30, labelpad=10)\n",
    "xlabels = np.concatenate((np.array([\"None\"]), np.arange(1, num_models).astype(str), np.array([\"All\"])))\n",
    "plt.xticks(np.arange(0, num_models+1, step=1), xlabels)\n",
    "plt.ylim(0,0.7)\n",
    "plt.xlim(0,num_models+1)\n",
    "\n",
    "plt.savefig(figure_path + 'SOTA_hist.png', dpi=300, bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"Expected counts: {counts_model[0]}\")\n",
    "print(f\"Observed counts: {counts_data[0]}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Calculate what error consistency can be expected based on binomial observer model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sklearn\n",
    "\n",
    "# Get indices of items that are always done wrong and right and combine the two\n",
    "all_right = np.where((mean_decisions_correct[:, ep] == 1))[0]\n",
    "all_wrong = np.where((mean_decisions_correct[:, ep] == 0))[0]\n",
    "all_same = np.hstack((all_right, all_wrong))\n",
    "\n",
    "# Remove images which are always answered the same and calculate accuracy for the remaining over all conditions\n",
    "cleaned_acc = np.mean(np.delete(mean_decisions_correct[:, ep], all_same))\n",
    "    \n",
    "# Randomly sample network making right or wrong decisions on remaining samples\n",
    "s1 = rng.binomial(1, cleaned_acc, mean_decisions_correct.shape[0] - all_same.shape[0])\n",
    "s2 = rng.binomial(1, cleaned_acc, mean_decisions_correct.shape[0] - all_same.shape[0])\n",
    "\n",
    "# Add all same images \n",
    "s1_all = np.hstack((np.ones(all_right.shape[0]), np.zeros(all_wrong.shape[0]), s1))\n",
    "s2_all = np.hstack((np.ones(all_right.shape[0]), np.zeros(all_wrong.shape[0]), s2))\n",
    "\n",
    "# Observed same\n",
    "sklearn.metrics.cohen_kappa_score(s1_all, s2_all)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Remove impossibles and trivials for heatmap"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set flag to true to remove impossibles and trivials\n",
    "rm = True # or False\n",
    "    \n",
    "num_base_epochs = 1\n",
    "\n",
    "# Make decisions superimposed array\n",
    "mean_decisions_correct = np.mean(decisions_correct, axis=0)\n",
    "\n",
    "# Ordering for mean decisions correct array\n",
    "mean_order = np.flip((mean_decisions_correct[:,-1]).argsort())\n",
    "\n",
    "# Get indices for impossibles and trivials together\n",
    "entries = np.unique(mean_decisions_correct[:, num_base_epochs-1])\n",
    "\n",
    "# Build indices which should be removed\n",
    "if rm == True:\n",
    "        rm_inds = np.concatenate((np.where(mean_decisions_correct[:, num_base_epochs-1] == entries[-1])[0], \n",
    "                                  np.where(mean_decisions_correct[:, num_base_epochs-1] == entries[-2])[0], \n",
    "                                  np.where(mean_decisions_correct[:, num_base_epochs-1] == entries[0])[0],\n",
    "                                  np.where(mean_decisions_correct[:, num_base_epochs-1] == entries[1])[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Pre-allocate array for heat map\n",
    "heat_array = np.empty((len(conditions), len(conditions)))\n",
    "heat_array[:] = np.nan\n",
    "heat_array_ep1 = np.empty((len(conditions), len(conditions)))\n",
    "heat_array_ep1[:] = np.nan\n",
    "\n",
    "# Loop through conditions to make heat map, except for different data compared condition\n",
    "for ind_1, condition_1 in enumerate(conditions):\n",
    "    for ind_2, condition_2 in enumerate(conditions):\n",
    "\n",
    "        # Load data, conditions with more epochs than base are stored in base file\n",
    "        num_epochs_1 = 1\n",
    "        num_epochs_2 = 1\n",
    "        num_models_1 = 0\n",
    "        num_models_2 = 0\n",
    "        \n",
    "        if condition_1 == \"CLIP\":\n",
    "            results_1 = np.genfromtxt(folder_name + condition_1 + \"/NUM1/RESULTS_EP0.csv\", delimiter=',', dtype=\"|S10\")\n",
    "            results_1 = [[[results_1[:,0], results_1[:,1]]]]\n",
    "        else:\n",
    "            results_1, _ = load_data(folder_name + condition_1 + \"/\")\n",
    "        \n",
    "        if condition_2 == \"CLIP\":\n",
    "            results_2 = np.genfromtxt(folder_name + condition_2 + \"/NUM1/RESULTS_EP0.csv\", delimiter=',', dtype=\"|S10\")\n",
    "            results_2 = [[[results_2[:,0], results_2[:,1]]]]\n",
    "        else:\n",
    "            results_2, _ = load_data(folder_name + condition_2 + \"/\")\n",
    "\n",
    "        # Pre-allocate heat array for this combination of conditions\n",
    "        comb_array = []\n",
    "        comb_array_ep1 = []\n",
    "\n",
    "        # Calculate error consistency for all models of both conditions\n",
    "        print(f\"Condition 1: {condition_1}, Condition 2: {condition_2}\")\n",
    "\n",
    "        # Calculate error consistency to base model\n",
    "        ep_1 = np.array(results_1[0][0][0]) == np.array(results_1[0][0][1])\n",
    "        ep_2 = np.array(results_2[0][0][0]) == np.array(results_2[0][0][1])\n",
    "\n",
    "        if rm == True:\n",
    "            print(\"Removing impossibles and trivials items\")\n",
    "            ep_1 = np.delete(ep_1, rm_inds)\n",
    "            ep_2 = np.delete(ep_2, rm_inds)\n",
    "\n",
    "        # Calculate kappa\n",
    "        heat_array[ind_1, ind_2] = cohen_kappa_score(ep_1, ep_2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Make Heatmap"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Format conditions\n",
    "epoch = num_base_epochs - 1\n",
    "plt.figure(figsize=(12, 12))\n",
    "cond_clear = [cond.replace(base_network, \"\").replace(\"_\", \" \") for cond in conditions]\n",
    "\n",
    "# Make and save heatmap\n",
    "sns.set(font_scale=2.0)\n",
    "sns.heatmap(heat_array, annot=True, fmt = \".2f\", annot_kws={\"fontsize\":12}, square=True,\n",
    "            xticklabels=cond_clear, yticklabels=cond_clear, cmap=\"Blues\", cbar_kws={\"shrink\": .60},\n",
    "            vmin=0, vmax=1)\n",
    "plt.tight_layout()\n",
    "plt.savefig(figure_path + f'Heatmap_removed{rm_num}.png')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
