{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b385c7e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, sys\n",
    "dir2 = os.path.abspath('')\n",
    "dir1 = os.path.dirname(dir2)\n",
    "if not dir1 in sys.path:\n",
    "    sys.path.append(dir1)\n",
    "import json\n",
    "import numpy as np\n",
    "from models.model_utils import get_model\n",
    "from data.data_utils import get_dataloader\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import time\n",
    "from collections import Counter\n",
    "from neuron_affinity import *\n",
    "from sklearn.decomposition import PCA\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn.pipeline import make_pipeline\n",
    "from sklearn.linear_model import LinearRegression\n",
    "import seaborn as sns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f5bf3c12",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f'Cuda {torch.cuda.is_available()}')\n",
    "root = None # the root of all experiments\n",
    "# exp_name = 'Resnet18_CIFAR10_10k_100_COPIES'\n",
    "# group = '10'\n",
    "exp_name = 'Resnet18_CIFAR10_45k_30X_ALL_DIFF'\n",
    "group = '45'\n",
    "exp_dir = root.format(exp_name)\n",
    "print(exp_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6e79a3b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_1_path = os.path.join(exp_dir, 'model_1')\n",
    "config_path = os.path.join(model_1_path, 'config')\n",
    "print('Loading config...')\n",
    "with open(config_path, 'r') as f:\n",
    "    config = json.load(f)\n",
    "all_dataloaders = get_dataloader(config)\n",
    "_, test_loader, _ = all_dataloaders  # random subset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b609aeec",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_data = []\n",
    "all_label = []\n",
    "for x, y in test_loader:\n",
    "    all_data.append(x.data.numpy())\n",
    "    all_label.append(y.data.numpy())\n",
    "all_data = np.concatenate(all_data, axis=0).transpose((0, 2, 3, 1))\n",
    "all_label = np.concatenate(all_label, axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "758b6a66",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(all_data.shape)\n",
    "print(all_label.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c2b0bfc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# def project_repr(model_repr, n_component=50):\n",
    "#     model_repr_flat = model_repr.reshape((len(model_repr), -1))\n",
    "#     pca = PCA(n_components=n_component)\n",
    "#     pca.fit(model_repr_flat)\n",
    "#     return pca.transform(model_repr_flat)\n",
    "\n",
    "def project_repr(model_repr, n_component=50):\n",
    "    model_repr_flat = model_repr.reshape((len(model_repr), -1))\n",
    "    scaler = StandardScaler()\n",
    "    pca = PCA(n_components=n_component)\n",
    "    pipeline = make_pipeline(scaler, pca)\n",
    "    pipeline.fit(model_repr_flat)\n",
    "    return pipeline.transform(model_repr_flat)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "854f4581",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_model = 20\n",
    "n_component = 50\n",
    "\n",
    "tic = time.time()\n",
    "per_model_repr = []\n",
    "for i in range(n_model):\n",
    "    model_path = os.path.join(exp_dir, f'model_{i}')\n",
    "    weight_path = os.path.join(model_path, 'weight.pt')\n",
    "    model = get_model(config).cuda()\n",
    "    state_dict = torch.load(weight_path, map_location=torch.device('cpu'))\n",
    "    model.load_state_dict(state_dict)\n",
    "    model_repr = get_all_repr(model, test_loader, 'layer4')\n",
    "    per_model_repr.append(project_repr(model_repr, n_component=n_component))\n",
    "    print('='*60)\n",
    "    print(f'{i} {(time.time()-tic)/(i+1)}  s / model')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe4db419",
   "metadata": {},
   "outputs": [],
   "source": [
    "max_per_example = np.max(np.abs(per_model_repr), -1)[:,:,None]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8606364e",
   "metadata": {},
   "outputs": [],
   "source": [
    "normalized_per_model_repr = per_model_repr / max_per_example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14bdc420",
   "metadata": {},
   "outputs": [],
   "source": [
    "n = 10\n",
    "m = 20\n",
    "fig, axs = plt.subplots(m, n+1, figsize=(18, 18))\n",
    "for i, img_idx in enumerate(np.random.choice(np.arange(10000), m)):\n",
    "    axs[i, 0].imshow(all_data[img_idx])\n",
    "    for j in range(n):\n",
    "        axs[i, j + 1].set_xlim([-0.1, 1.])\n",
    "        axs[i, j + 1].set_ylim([0, 30])\n",
    "        _ = axs[i, j + 1].hist(normalized_per_model_repr[j][i], bins=20)\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a46a4183",
   "metadata": {},
   "outputs": [],
   "source": [
    "global_percentile = 90\n",
    "global_threshold = np.percentile(normalized_per_model_repr, [global_percentile])[0]\n",
    "print(global_threshold)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e3f7ab6",
   "metadata": {},
   "outputs": [],
   "source": [
    "activated_neuron_idx = np.where(normalized_per_model_repr > global_threshold)\n",
    "activated_neuron_idx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12bd01b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "activatd_neuron_count = activated_neuron_idx[0].shape[0]\n",
    "all_neuron_n = np.prod(normalized_per_model_repr.shape)\n",
    "percent_activated = activatd_neuron_count/all_neuron_n\n",
    "print(activatd_neuron_count, all_neuron_n, percent_activated)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "57107143",
   "metadata": {},
   "outputs": [],
   "source": [
    "neuron_feature_pairing_path = os.path.join(exp_dir, 'neuron_correlation_pca', 'neuron_feature_pairing_p80') + '.npy'\n",
    "neuron_feature_pairing = np.load(neuron_feature_pairing_path)\n",
    "print(neuron_feature_pairing.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d32fc70",
   "metadata": {},
   "outputs": [],
   "source": [
    "image_feature = [[] for i in range(10000)]\n",
    "for i, j, k in zip(activated_neuron_idx[1], activated_neuron_idx[0], activated_neuron_idx[2]):\n",
    "    if neuron_feature_pairing[j, k] > 0:\n",
    "        image_feature[i].append(neuron_feature_pairing[j, k])\n",
    "image_feature = [set(f) for f in image_feature]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c8abf00",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_image_feature_count = [len(f) for f in image_feature]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "31bc82b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "_ = plt.hist(all_image_feature_count, bins=20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0da8fabf",
   "metadata": {},
   "outputs": [],
   "source": [
    "feature_count_percentile = np.percentile(all_image_feature_count, [5, 95])\n",
    "feature_count_percentile"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "09cab6ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "small_fc_idx = np.where(all_image_feature_count < feature_count_percentile[0])[0]\n",
    "large_fc_idx = np.where(all_image_feature_count > feature_count_percentile[1])[0]\n",
    "\n",
    "large_fc_idx = sorted(large_fc_idx, key=lambda x: -all_image_feature_count[x])\n",
    "small_fc_idx = sorted(small_fc_idx, key=lambda x: all_image_feature_count[x])\n",
    "# print(small_fc_idx)\n",
    "# print(large_fc_idx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f888ed6",
   "metadata": {},
   "outputs": [],
   "source": [
    "n = 10\n",
    "m = 10\n",
    "offset = 0\n",
    "fig, axs = plt.subplots(m, n, figsize=(18, 18))\n",
    "for i in range(m):\n",
    "    for j in range(n):\n",
    "        axs[i, j].imshow(all_data[small_fc_idx[offset + i * m + j]])\n",
    "#         axs[i, j].set_title(f\"i {small_fc_idx[i * m + j]}  nf {all_image_feature_count[small_fc_idx[i * m + j]]}\")\n",
    "        axs[i, j].set_title(f\"nf {all_image_feature_count[small_fc_idx[offset + i * m + j]]}\", fontsize=15,  y=-0.25)\n",
    "        axs[i, j].axis('off')\n",
    "fig.suptitle('Example with fewest features', fontsize=20, y=1.0)\n",
    "fig.tight_layout()\n",
    "fig.show()\n",
    "\n",
    "# plt.imshow(all_data[5721])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ff4b296",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(np.argmax(all_image_feature_count), np.max(all_image_feature_count))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a4035052",
   "metadata": {},
   "outputs": [],
   "source": [
    "n = 10\n",
    "m = 10\n",
    "fig, axs = plt.subplots(m, n, figsize=(18, 18))\n",
    "for i in range(m):\n",
    "    for j in range(n):\n",
    "        axs[i, j].imshow(all_data[large_fc_idx[i * m + j]])\n",
    "#         axs[i, j].set_title(f\"i {large_fc_idx[i * m + j]}  nf {all_image_feature_count[large_fc_idx[i * m + j]]}\")\n",
    "        axs[i, j].set_title(f\"nf {all_image_feature_count[large_fc_idx[i * m + j]]}\", fontsize=15,  y=-0.25)\n",
    "        axs[i, j].axis('off')\n",
    "fig.suptitle('Example with most features', fontsize=20, y=1.0)\n",
    "fig.tight_layout()\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8bb5f53b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# os.listdir(os.path.join(exp_dir, 'neuron_correlation_pca'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b23a6bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "# gt_confidence = np.load(os.path.join(exp_dir, 'neuron_correlation', 'gt_confidence.npy'))\n",
    "gt_confidence = np.load(os.path.join(exp_dir, 'neuron_correlation_pca', 'gt_confidence.npy'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d529959c",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(gt_confidence, all_image_feature_count, s=0.05)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75985b99",
   "metadata": {},
   "outputs": [],
   "source": [
    "# sns.set_style(\"darkgrid\")\n",
    "root = './plots'\n",
    "sns.set_style(\"whitegrid\")\n",
    "sns.set_palette(\"tab10\")\n",
    "\n",
    "plt.figure(figsize=(8, 6))\n",
    "sns.kdeplot(gt_confidence, all_image_feature_count, fill=True, cbar=False, color='aliceblue',)\n",
    "# sns.histplot(\n",
    "#     None, x=gt_confidence, y=all_image_feature_count,\n",
    "#     bins=100, discrete=(False, True), log_scale=(False, False),\n",
    "#     pthresh=.01, pmax=.7,\n",
    "# )\n",
    "\n",
    "plt.title(f'conf vs. feature count ({group}k)', fontsize=26)\n",
    "plt.xlabel('confidence', fontsize=24)\n",
    "plt.ylabel('number of features', fontsize=24)\n",
    "_ = plt.xticks(fontsize=22)\n",
    "_ = plt.yticks(fontsize=22)\n",
    "# plt.ticklabel_format(axis=\"x\", style=\"sci\", scilimits=(0,0))\n",
    "# plt.ticklabel_format(axis=\"y\", style=\"sci\", scilimits=(0,0))\n",
    "plt.savefig(os.path.join(root, f'{group}_conf_feat_count.pdf'), bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e70d12c9",
   "metadata": {},
   "source": [
    "## Class conditioned analysis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c8ec58d",
   "metadata": {},
   "outputs": [],
   "source": [
    "class_idx = 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "41034883",
   "metadata": {},
   "outputs": [],
   "source": [
    "image_idx = np.where(all_label == class_idx)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c9a09639",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.imshow(all_data[image_idx[1]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d6a4e30",
   "metadata": {},
   "outputs": [],
   "source": [
    "filtered_image_features = np.array(image_feature)[image_idx]\n",
    "filtered_image_feature_count = [len(f) for f in filtered_image_features]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6ec36c08",
   "metadata": {},
   "outputs": [],
   "source": [
    "_ = plt.hist(filtered_image_feature_count, bins=20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e6ea51c",
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.kdeplot(gt_confidence[image_idx], filtered_image_feature_count, fill=True, cbar=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b77b5035",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(10, 3, figsize=(10, 18))\n",
    "for i in range(10):\n",
    "    image_idx = np.where(all_label == i)[0]\n",
    "    filtered_image_features = np.array(image_feature)[image_idx]\n",
    "    filtered_image_feature_count = [len(f) for f in filtered_image_features]\n",
    "    axs[i, 0].imshow(all_data[image_idx[1]])\n",
    "    axs[i, 1].hist(filtered_image_feature_count, bins=20)\n",
    "    sns.kdeplot(gt_confidence[image_idx], filtered_image_feature_count, fill=True, cbar=True, ax=axs[i, 2])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d32f993b",
   "metadata": {},
   "outputs": [],
   "source": [
    "image_feature = [[] for i in range(10000)]\n",
    "for i, j, k in zip(activated_neuron_idx[1], activated_neuron_idx[0], activated_neuron_idx[2]):\n",
    "    if neuron_feature_pairing[j, k] > 0:\n",
    "        image_feature[i].append(neuron_feature_pairing[j, k])\n",
    "image_feature = [set(f) for f in image_feature]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "97bd0a6a",
   "metadata": {},
   "outputs": [],
   "source": [
    "feature_to_class = [[] for _ in range(np.max(neuron_feature_pairing)+1)]\n",
    "feature_to_example = [[] for _ in range(np.max(neuron_feature_pairing)+1)]\n",
    "for i in range(10000):\n",
    "    for f in list(image_feature[i]):\n",
    "        feature_to_class[f].append(all_label[i])\n",
    "        feature_to_example[f].append(i)\n",
    "\n",
    "feature_to_class = sorted(feature_to_class, key=lambda x: -len(x))\n",
    "feature_to_example = sorted(feature_to_example, key=lambda x: -len(x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cfdad364",
   "metadata": {},
   "outputs": [],
   "source": [
    "feature_to_class = np.array(feature_to_class)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "847198b4",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "from collections import Counter\n",
    "class_label = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']\n",
    "# empty = [' '] * 10\n",
    "# class_label = zip(class_label, empty)\n",
    "n = 5\n",
    "m = 6\n",
    "fig, axs = plt.subplots(m, n, figsize=(20, 2.4 * m))\n",
    "cmap = plt.cm.get_cmap('tab20')\n",
    "for i in range(m):\n",
    "    for j in range(n):\n",
    "        fc = feature_to_class[i * n + j + 1]\n",
    "        counter = Counter(fc)\n",
    "#         print(counter)\n",
    "#         print(1/0)\n",
    "        fc = [counter[k] for k in range(10)]\n",
    "#         axs[i, j].hist(fc, bins=20)\n",
    "        axs[i, j].bar(class_label, fc, color=cmap(list(range(10))))\n",
    "        axs[i, j].set_ylim([0, 1100])\n",
    "#         axs[i, j].axis('off')\n",
    "#         axs[i, j].set_title(f'f {i * n + j + 1} nd {len(fc)}')\n",
    "        axs[i, j].set_title(f'feat {i * n + j + 1},   n {len(feature_to_class[i * n + j + 1])}', fontsize=15)\n",
    "#         axs[i, j].tick_params(axis='both', which='both', bottom=True, top=False, left=False, right=False)\n",
    "#         axs[i, j].set_title(f'f {i * n + j + 1}')\n",
    "#         axs[i, j].set_yscale('log')\n",
    "        if i == m-1:\n",
    "            axs[i, j].set_xticklabels(class_label, rotation=50, fontsize=10)\n",
    "        else:\n",
    "            axs[i, j].xaxis.set_visible(False)\n",
    "# plt.title('Class distribution per feature')\n",
    "fig.suptitle('Class distribution per feature', fontsize=20)\n",
    "fig.tight_layout()\n",
    "# for i, fc in enumerate(feature_to_class[1:1+n]):\n",
    "#     axs[i].set_ylim([0, 1000])\n",
    "#     axs[i].hist(fc, bins=20)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f9fa4d79",
   "metadata": {},
   "source": [
    "## Examples in each feature"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7bc6b6df",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(feature_to_example)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e5742404",
   "metadata": {},
   "outputs": [],
   "source": [
    "n = 10\n",
    "m = 10\n",
    "fig, axs = plt.subplots(m, n, figsize=(18, 18))\n",
    "for i in range(m):\n",
    "    example = np.random.choice(feature_to_example[offset + i], size=n, replace=False)\n",
    "    for j in range(n):\n",
    "        axs[i, j].imshow(all_data[example[j]])\n",
    "        axs[i, j].set_title(f'feat {i}', fontsize=15,  y=-0.25)\n",
    "        axs[i, j].axis('off')\n",
    "# plt.title('Class distribution per feature')\n",
    "fig.suptitle('Sample images for top features', fontsize=20, y=1.0)\n",
    "fig.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "27772347",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "n = 10\n",
    "m = 10\n",
    "offset = 590\n",
    "fig, axs = plt.subplots(m, n, figsize=(18, 18))\n",
    "for i in range(m):\n",
    "    example = np.random.choice(feature_to_example[offset + i], size=n, replace=False)\n",
    "    for j in range(n):\n",
    "        axs[i, j].imshow(all_data[example[j]])\n",
    "        axs[i, j].set_title(f'feat {offset + i}', fontsize=15,  y=-0.25)\n",
    "        axs[i, j].axis('off')\n",
    "# plt.title('Class distribution per feature')\n",
    "fig.suptitle('Sample images for rare features', fontsize=20, y=1.0)\n",
    "fig.tight_layout()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eb4a2662",
   "metadata": {},
   "source": [
    "## Avg number of feature per data vs number of data point with that feature"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4236704c",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_features = max([max(list(f)) for f in image_feature])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1daf56e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "990797f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_data_for_feature = [1e-6 for _ in range(n_features+1)]\n",
    "avg_feature_per_data = [0 for _ in range(n_features+1)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5448fb66",
   "metadata": {},
   "outputs": [],
   "source": [
    "# list(image_feature[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f98ad8fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "for i_f in image_feature:\n",
    "    i_f = list(i_f)\n",
    "    for f in i_f:\n",
    "        n_data_for_feature[f] += 1\n",
    "        avg_feature_per_data[f] += len(i_f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0400d0ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(1, len(avg_feature_per_data)):\n",
    "#     print(i, avg_feature_per_data[i], n_data_for_feature[i])\n",
    "    avg_feature_per_data[i] = avg_feature_per_data[i] / n_data_for_feature[i]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0c3a7233",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ecd3026",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(n_data_for_feature, avg_feature_per_data, s=0.5)\n",
    "# sns.kdeplot(n_data_for_feature, avg_feature_per_data, fill=True, cbar=True)\n",
    "plt.ylabel('avg feat per data')\n",
    "plt.xlabel('n data for feat')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07f9962b",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(np.arange(len(n_data_for_feature)), n_data_for_feature)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4da00ad3",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(10, 2, figsize=(10, 18))\n",
    "for i in range(10):\n",
    "    image_idx = np.where(all_label == i)[0]\n",
    "    filtered_image_features = np.array(image_feature)[image_idx]\n",
    "    n_data_for_feature = [1e-6 for _ in range(n_features+1)]\n",
    "    avg_feature_per_data = [0 for _ in range(n_features+1)]\n",
    "    for i_f in filtered_image_features:\n",
    "        i_f = list(i_f)\n",
    "        for f in i_f:\n",
    "            n_data_for_feature[f] += 1\n",
    "            avg_feature_per_data[f] += len(i_f)\n",
    "    for j in range(1, len(avg_feature_per_data)):\n",
    "        avg_feature_per_data[j] = avg_feature_per_data[j] / n_data_for_feature[j]\n",
    "    axs[i, 0].imshow(all_data[image_idx[1]])\n",
    "    axs[i, 1].scatter(n_data_for_feature, avg_feature_per_data, s=0.5)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ad9a4571",
   "metadata": {},
   "source": [
    "## Model vs Feature distribution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e8ba181d",
   "metadata": {},
   "outputs": [],
   "source": [
    "neuron_feature_pairing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e6cb4d42",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_feature = np.max(neuron_feature_pairing) + 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17ed14c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "feature_to_model = [[0]*20 for _ in range(n_feature)]\n",
    "for i in range(len(neuron_feature_pairing)):\n",
    "    for j in neuron_feature_pairing[i]:\n",
    "        feature_to_model[j][i] = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf23d0b2",
   "metadata": {},
   "outputs": [],
   "source": [
    "feature_to_model = sorted(feature_to_model, key=lambda a: -sum(a))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d93c9a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "feature_to_model = np.array(feature_to_model)\n",
    "print(feature_to_model)\n",
    "print(feature_to_model.shape)\n",
    "\n",
    "# shuffle\n",
    "np.random.shuffle(feature_to_model)\n",
    "feature_to_model = np.array(sorted(feature_to_model, key=lambda a: -sum(a)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0cc3e55a",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(1, 1, figsize=(20, 18))\n",
    "plt.imshow(feature_to_model.T[:,:400])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "674daa63",
   "metadata": {},
   "outputs": [],
   "source": [
    "# fig, axs = plt.subplots(1, 1, figsize=(20, 18))\n",
    "# shuffle_model_idx = np.arange(20)\n",
    "# plt.imshow(feature_to_model.T[shuffle_model_idx,:400])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "892f0cd1",
   "metadata": {},
   "source": [
    "## Using feature-class signature to do prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1edb4954",
   "metadata": {},
   "outputs": [],
   "source": [
    "def convert_feature_to_signature(feature, n_class=10):\n",
    "    count = Counter(feature)\n",
    "    signature = [count[i] for i in range(n_class)]\n",
    "    return np.array(signature)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1573f8f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "feature_class_signature = [convert_feature_to_signature(f) for f in feature_to_class]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1588634",
   "metadata": {},
   "outputs": [],
   "source": [
    "total_data_per_feature = [sum(s) for s in feature_class_signature]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec5195ae",
   "metadata": {},
   "outputs": [],
   "source": [
    "# sns.set_style(\"darkgrid\")\n",
    "# root = './plots'\n",
    "# sns.set_style(\"whitegrid\")\n",
    "# sns.set_palette(\"tab10\")\n",
    "\n",
    "# plt.figure(figsize=(8, 6))\n",
    "# sns.kdeplot(gt_confidence, all_image_feature_count, fill=True, cbar=False, color='aliceblue',)\n",
    "# # sns.histplot(\n",
    "# #     None, x=gt_confidence, y=all_image_feature_count,\n",
    "# #     bins=100, discrete=(False, True), log_scale=(False, False),\n",
    "# #     pthresh=.01, pmax=.7,\n",
    "# # )\n",
    "\n",
    "# plt.title(f'conf vs. feature count ({group}k)', fontsize=26)\n",
    "# plt.xlabel('confidence', fontsize=24)\n",
    "# plt.ylabel('number of features', fontsize=24)\n",
    "# _ = plt.xticks(fontsize=22)\n",
    "# _ = plt.yticks(fontsize=22)\n",
    "# # plt.ticklabel_format(axis=\"x\", style=\"sci\", scilimits=(0,0))\n",
    "# # plt.ticklabel_format(axis=\"y\", style=\"sci\", scilimits=(0,0))\n",
    "# plt.savefig(os.path.join(root, f'{group}_conf_feat_count.pdf'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f7adc9f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# # plt.ticklabel_format(axis='both', style='sci')\n",
    "# plt.ticklabel_format(style='sci', axis='x')\n",
    "# plt.gca().set_yticklabels(['{:.0f}'.format(x) for x in sorted(total_data_per_feature)])\n",
    "plt.figure(figsize=(8, 6))\n",
    "plt.plot(np.arange(len(total_data_per_feature))[1:], sorted(total_data_per_feature)[1:], linewidth=4.,)\n",
    "plt.title(f'feature frequency ({group}k)', fontsize=26)\n",
    "plt.xlabel('feature (1e2)', fontsize=24)\n",
    "plt.ylabel('occurrences (1e3)', fontsize=24)\n",
    "_ = plt.xticks(fontsize=22)\n",
    "_ = plt.yticks(fontsize=22)\n",
    "plt.ticklabel_format(axis=\"x\", style=\"sci\", scilimits=(0,0))\n",
    "plt.ticklabel_format(axis=\"y\", style=\"sci\", scilimits=(0,0))\n",
    "plt.savefig(os.path.join(root, f'{group}_feat_freq.pdf'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "141b8b0d",
   "metadata": {},
   "outputs": [],
   "source": [
    "feature_class_signature = np.array(feature_class_signature) + 1e-6\n",
    "normalized_feature_class_signature = feature_class_signature / np.sum(feature_class_signature, axis=1)[:, None]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "77fa9660",
   "metadata": {},
   "outputs": [],
   "source": [
    "def predict_v1(data_features, all_feature_signatures):\n",
    "    energy = 0\n",
    "    for f in data_features:\n",
    "        energy += all_feature_signatures[f]\n",
    "    return np.argmax(energy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "00bee5a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "idx = 100\n",
    "test_image = image_feature[idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "97fb3beb",
   "metadata": {},
   "outputs": [],
   "source": [
    "predict_v1(test_image, feature_class_signature)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ff60a471",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_label[idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2a05b55a",
   "metadata": {},
   "outputs": [],
   "source": [
    "classes = [4, 3]\n",
    "correct = 0\n",
    "n_data = 0\n",
    "for i in range(len(image_feature)):\n",
    "    if all_label[i] in classes:\n",
    "        prediction = predict_v1(image_feature[i], feature_class_signature)\n",
    "        correct += int(prediction == all_label[i])\n",
    "        n_data += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "86124b58",
   "metadata": {},
   "outputs": [],
   "source": [
    "correct / n_data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "24beb1c5",
   "metadata": {},
   "source": [
    "## Testing sampling models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dac7ebab",
   "metadata": {},
   "outputs": [],
   "source": [
    "def sample_model(feature_energy, n_feature, t=1.0):\n",
    "    p = feature_energy / t\n",
    "    p -= np.max(p)\n",
    "    p = np.exp(p)\n",
    "    p /= np.sum(p)\n",
    "    return np.random.choice(np.arange(len(feature_energy)), size=n_feature, replace=False, p=p)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bda16543",
   "metadata": {},
   "outputs": [],
   "source": [
    "energy = np.array(total_data_per_feature)\n",
    "model = sample_model(energy, 50, t=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fcfac6a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e2387f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "energy[model]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "884d0139",
   "metadata": {},
   "outputs": [],
   "source": [
    "models = [sample_model(energy, 50, t=1700) for _ in range(20)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d7e0a08",
   "metadata": {},
   "outputs": [],
   "source": [
    "feature_to_model = [[0]*20 for _ in range(n_feature)]\n",
    "for i, m in enumerate(models):\n",
    "    for j in m:\n",
    "        feature_to_model[j][i] = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ade39196",
   "metadata": {},
   "outputs": [],
   "source": [
    "feature_to_model = sorted(feature_to_model, key=lambda a: -sum(a))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "65b18440",
   "metadata": {},
   "outputs": [],
   "source": [
    "feature_to_model = np.array(feature_to_model)\n",
    "feature_to_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2a92ea6b",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(1, 1, figsize=(20, 18))\n",
    "plt.imshow(feature_to_model.T[:,:400])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9b3179ff",
   "metadata": {},
   "source": [
    "## Distribution of features in high confidence points"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de083be8",
   "metadata": {},
   "outputs": [],
   "source": [
    "high_confidence_data = np.where(gt_confidence > 0.9)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a770bc5a",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(high_confidence_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "60cc0581",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_feature = np.max(neuron_feature_pairing) + 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f85a23d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "high_feature_count = [0 for _ in range(n_feature)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c65b29b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in high_confidence_data:\n",
    "    for f in image_feature[i]:\n",
    "        high_feature_count[f] += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a2581d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(np.arange(len(high_feature_count)), high_feature_count)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2f972e77",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(np.log(np.arange(len(high_feature_count))), np.log(np.array(sorted(high_feature_count))/sum(high_feature_count)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "410653ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "# plt.plot(np.arange(len(feature_count)), np.log(sorted(feature_count)))\n",
    "_ = plt.hist(gt_confidence, bins=30)\n",
    "np.percentile(gt_confidence, [30])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9bc8a49a",
   "metadata": {},
   "outputs": [],
   "source": [
    "low_confidence_data = np.where(gt_confidence < 0.9)[0]\n",
    "print(len(low_confidence_data))\n",
    "low_feature_count = [0 for _ in range(n_feature)]\n",
    "for i in low_confidence_data:\n",
    "    for f in image_feature[i]:\n",
    "        low_feature_count[f] += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a7bffb4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# sns.set_style(\"darkgrid\")\n",
    "# root = './plots'\n",
    "# sns.set_style(\"whitegrid\")\n",
    "# sns.set_palette(\"tab10\")\n",
    "\n",
    "# plt.figure(figsize=(8, 6))\n",
    "# sns.kdeplot(gt_confidence, all_image_feature_count, fill=True, cbar=False, color='aliceblue',)\n",
    "# # sns.histplot(\n",
    "# #     None, x=gt_confidence, y=all_image_feature_count,\n",
    "# #     bins=100, discrete=(False, True), log_scale=(False, False),\n",
    "# #     pthresh=.01, pmax=.7,\n",
    "# # )\n",
    "\n",
    "# plt.title(f'conf vs. feature count ({group}k)', fontsize=26)\n",
    "# plt.xlabel('confidence', fontsize=24)\n",
    "# plt.ylabel('number of features', fontsize=24)\n",
    "# _ = plt.xticks(fontsize=22)\n",
    "# _ = plt.yticks(fontsize=22)\n",
    "# # plt.ticklabel_format(axis=\"x\", style=\"sci\", scilimits=(0,0))\n",
    "# # plt.ticklabel_format(axis=\"y\", style=\"sci\", scilimits=(0,0))\n",
    "# plt.savefig(os.path.join(root, f'{group}_conf_feat_count.pdf'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d5e34a54",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(9, 6))\n",
    "high_prob = np.array(high_feature_count)/sum(high_feature_count)\n",
    "low_prob = np.array(low_feature_count)/sum(low_feature_count)\n",
    "sorted_idx = np.argsort(high_prob)\n",
    "sorted_idx = np.argsort(total_data_per_feature)\n",
    "high_prob = high_prob[sorted_idx]\n",
    "low_prob = low_prob[sorted_idx]\n",
    "plt.plot(np.arange(len(high_prob))[::2], np.log10(high_prob)[::2], label='high confidence', linewidth=3)\n",
    "plt.plot(np.arange(len(low_prob))[::2], np.log10(low_prob)[::2], label='low confidence', linewidth=3)\n",
    "# plt.plot(np.arange(len(high_prob)), high_prob, label='high')\n",
    "# plt.plot(np.arange(len(low_prob)), low_prob, label='low')\n",
    "plt.title(f'log feature density by confidence ({group}k)', fontsize=26)\n",
    "plt.legend(fontsize=22)\n",
    "plt.xlabel('feature (1e2)', fontsize=24)\n",
    "plt.ylabel('log prob (base 10)', fontsize=24)\n",
    "_ = plt.xticks(fontsize=22)\n",
    "_ = plt.yticks(fontsize=22)\n",
    "plt.ticklabel_format(axis=\"x\", style=\"sci\", scilimits=(0,0))\n",
    "plt.savefig(os.path.join(root, f'{group}_feat_freq_conf.pdf'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a4e1cc11",
   "metadata": {},
   "outputs": [],
   "source": [
    "kl = 0\n",
    "for p, q in zip(high_prob, low_prob):\n",
    "    if p == 0:\n",
    "        print(p, q)\n",
    "    else:\n",
    "        kl += p * np.log(p/q)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e9033ca2",
   "metadata": {},
   "outputs": [],
   "source": [
    "kl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa5ed2d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(np.arange(len(low_prob)), (low_prob+1e-6)/(high_prob+1e-6))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8640e24e",
   "metadata": {},
   "source": [
    "## Number of data points with certain features vs. the number of models that learned that feature"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50836202",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_data_w_feature = [0 for _ in range(n_feature)]\n",
    "n_model_w_feature = [0 for _ in range(n_feature)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "352e6bc9",
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(len(gt_confidence)):\n",
    "    for f in image_feature[i]:\n",
    "        n_data_w_feature[f] += 1\n",
    "\n",
    "for i in range(len(neuron_feature_pairing)):\n",
    "    for j in neuron_feature_pairing[i]:\n",
    "        n_model_w_feature[j] += 1\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "34208b8c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# sns.set_style(\"darkgrid\")\n",
    "# root = './plots'\n",
    "# sns.set_style(\"whitegrid\")\n",
    "# sns.set_palette(\"tab10\")\n",
    "\n",
    "# plt.figure(figsize=(8, 6))\n",
    "# sns.kdeplot(gt_confidence, all_image_feature_count, fill=True, cbar=False, color='aliceblue',)\n",
    "# # sns.histplot(\n",
    "# #     None, x=gt_confidence, y=all_image_feature_count,\n",
    "# #     bins=100, discrete=(False, True), log_scale=(False, False),\n",
    "# #     pthresh=.01, pmax=.7,\n",
    "# # )\n",
    "\n",
    "# plt.title(f'conf vs. feature count ({group}k)', fontsize=26)\n",
    "# plt.xlabel('confidence', fontsize=24)\n",
    "# plt.ylabel('number of features', fontsize=24)\n",
    "# _ = plt.xticks(fontsize=22)\n",
    "# _ = plt.yticks(fontsize=22)\n",
    "# # plt.ticklabel_format(axis=\"x\", style=\"sci\", scilimits=(0,0))\n",
    "# # plt.ticklabel_format(axis=\"y\", style=\"sci\", scilimits=(0,0))\n",
    "# plt.savefig(os.path.join(root, f'{group}_conf_feat_count.pdf'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "77c939f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "for i, n in enumerate(n_data_w_feature):\n",
    "    if n == 0:\n",
    "        n_model_w_feature[i] = 0\n",
    "\n",
    "plt.figure(figsize=(8, 6))\n",
    "plt.scatter(n_data_w_feature, n_model_w_feature, alpha=0.9, s=40)\n",
    "plt.xlabel('n_data_w_feature')\n",
    "plt.ylabel('n_model_w_feature')\n",
    "\n",
    "# plt.legend(fontsize=15)\n",
    "plt.title(f'n data vs. n models w/ feat ({group}k)', fontsize=26)\n",
    "plt.xlabel('# of data with a feature (1e3)', fontsize=24)\n",
    "plt.ylabel('# of models with a feature', fontsize=24)\n",
    "_ = plt.xticks(fontsize=22)\n",
    "_ = plt.yticks(fontsize=22)\n",
    "plt.ticklabel_format(axis=\"x\", style=\"sci\", scilimits=(0,0))\n",
    "fig.tight_layout()\n",
    "# plt.savefig(os.path.join(root, f'{group}_n_data_n_model_feat.pdf'), bbox_inches=\"tight\")\n",
    "plt.savefig(os.path.join('./plots', f'{group}_n_data_n_model_feat.pdf'), bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7a36fa4",
   "metadata": {},
   "outputs": [],
   "source": [
    "os.path.join(root, f'{group}_n_data_n_model_feat.pdf')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b1b0c4da",
   "metadata": {},
   "source": [
    "## How models make mistakes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dba2bec6",
   "metadata": {},
   "outputs": [],
   "source": [
    "affinity_dir = os.path.join(exp_dir, 'neuron_correlation_pca')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "201454a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_predictions(model, data_loader, onehot=True):\n",
    "    all_predictions = []\n",
    "    for x, _ in data_loader:\n",
    "        x = x.cuda()\n",
    "        prediction = model(x)\n",
    "        prediction = prediction.detach().cpu().numpy()\n",
    "        if onehot:\n",
    "            new_prediction = np.zeros(prediction.shape)\n",
    "            prediction = np.argmax(prediction, axis=-1)\n",
    "            new_prediction[np.arange(len(prediction)), prediction] = 1\n",
    "            prediction = new_prediction\n",
    "        all_predictions.append(prediction)\n",
    "    return np.concatenate(all_predictions, axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dcce0819",
   "metadata": {},
   "outputs": [],
   "source": [
    "# n_model = 30\n",
    "# tic = time.time()\n",
    "# all_prediction = []\n",
    "# for i in range(n_model):\n",
    "#     if i > 0:\n",
    "#         print('='*60)\n",
    "#         print(f'{i} {(time.time()-tic)/i}  s / model')\n",
    "#     model_path = os.path.join(exp_dir, f'model_{i}')\n",
    "#     weight_path = os.path.join(model_path, 'weight.pt')\n",
    "#     model = get_model(config).cuda()\n",
    "#     state_dict = torch.load(weight_path, map_location=torch.device('cpu'))\n",
    "#     model.load_state_dict(state_dict)\n",
    "#     predictions = get_predictions(model, test_loader)\n",
    "#     all_prediction.append(predictions)\n",
    "#     np.save(os.path.join(exp_dir, 'neuron_correlation_pca', f'prediction_{i}'), np.uint8(predictions))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f12c939",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_model = 30\n",
    "tic = time.time()\n",
    "all_prediction = []\n",
    "for i in range(n_model):\n",
    "    if i > 0:\n",
    "        print('='*60)\n",
    "        print(f'{i} {(time.time()-tic)/i}  s / model')\n",
    "    predictions = np.load(os.path.join(exp_dir, 'neuron_correlation_pca', f'prediction_{i}.npy'))\n",
    "    all_prediction.append(predictions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8870d7c",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_prediction = np.array(all_prediction)\n",
    "all_prediction_argmax = np.argmax(all_prediction, axis=-1)\n",
    "all_prediction.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23d36aca",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_label = []\n",
    "for _, y in test_loader:\n",
    "    all_label.append(y.detach().cpu().numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "464d8ee7",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_label = np.concatenate(all_label)\n",
    "all_label.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12249fdb",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_error = []\n",
    "for i in range(n_model):\n",
    "    error = np.float32(all_label == all_prediction_argmax[i])\n",
    "    all_error.append(error)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "228814c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_error_dict = []\n",
    "for i, err in enumerate(all_error):\n",
    "    err_dict = {}\n",
    "    for j, correct in enumerate(err):\n",
    "        if correct == 0:\n",
    "            err_dict[j] = all_prediction_argmax[i][j]\n",
    "    all_error_dict.append(err_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "91b8d2c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "def same_error_count(model_1_error_dict, model_2_error_dict):\n",
    "    count = 0\n",
    "    total = 0\n",
    "    for k in model_1_error_dict:\n",
    "        if k in model_2_error_dict and model_1_error_dict[k] == model_2_error_dict[k]:\n",
    "            count += 1\n",
    "    total = (len(model_1_error_dict) + len(model_2_error_dict)) / 2\n",
    "    return count / total"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa9e9359",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_pair_error_correlation = [[-1 for _ in range(n_model)] for _ in range(n_model)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "31f35149",
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(n_model):\n",
    "    for j in range(n_model):\n",
    "        all_pair_error_correlation[i][j] = same_error_count(all_error_dict[i], all_error_dict[j])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c6dad45a",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_pair_error_correlation = np.array(all_pair_error_correlation)\n",
    "all_pair_error_correlation.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "972537f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.imshow(all_pair_error_correlation)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7ce562af",
   "metadata": {},
   "source": [
    "### similarity measure"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f4bb4f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "manhattan_cost = np.zeros((50, 50))\n",
    "for i in range(50):\n",
    "    for j in range(50):\n",
    "        manhattan_cost[i, j] = abs(i-j)\n",
    "plt.imshow(manhattan_cost)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f5877c78",
   "metadata": {},
   "outputs": [],
   "source": [
    "corr_mat = np.load(os.path.join(affinity_dir, 'corr_18_10.npy'))\n",
    "corr_mat = np.abs(corr_mat)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ef0007b",
   "metadata": {},
   "outputs": [],
   "source": [
    "(corr_mat/corr_mat.sum() * manhattan_cost).sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dcc9a556",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_corr_mat = [[None for _ in range(20)] for _ in range(20)]\n",
    "all_cost = [[None for _ in range(20)] for _ in range(20)]\n",
    "for i in range(20):\n",
    "    for j in range(20):\n",
    "        corr_mat = np.load(os.path.join(affinity_dir, f'corr_{i}_{j}.npy'))\n",
    "        corr_mat = np.abs(corr_mat)\n",
    "#         all_cost[i][j] = (corr_mat/corr_mat.sum() * manhattan_cost).sum()\n",
    "        all_cost[i][j] = np.trace(corr_mat)\n",
    "#         all_cost[i][j] = np.sum(corr_mat)\n",
    "        all_corr_mat[i][j] = corr_mat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5ccda196",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_cost = np.array(all_cost)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0de254ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.imshow(all_cost)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6a57512a",
   "metadata": {},
   "outputs": [],
   "source": [
    "_ = plt.hist(all_cost.reshape(-1), bins=50)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "66cc6aae",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(all_cost.max())\n",
    "print(all_cost.min())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d490d4f9",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_sum = [[None for _ in range(20)] for _ in range(20)]\n",
    "for i in range(20):\n",
    "    for j in range(20):\n",
    "        corr_mat = np.load(os.path.join(affinity_dir, f'corr_{i}_{j}.npy'))\n",
    "        corr_mat = np.abs(corr_mat)\n",
    "        all_sum[i][j] = corr_mat.sum()\n",
    "all_sum = np.array(all_sum)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f7e08f38",
   "metadata": {},
   "outputs": [],
   "source": [
    "_ = plt.hist(all_sum.reshape(-1), bins=50)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "16172ad9",
   "metadata": {},
   "outputs": [],
   "source": [
    "similarity = []\n",
    "error_corr = []\n",
    "for i in range(20):\n",
    "    for j in range(20):\n",
    "        if i == j: continue\n",
    "        error_corr.append(all_pair_error_correlation[i][j])\n",
    "        similarity.append(all_cost[i][j])\n",
    "similarity = np.array(similarity)\n",
    "error_corr = np.array(error_corr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "95516b21",
   "metadata": {},
   "outputs": [],
   "source": [
    "reg = LinearRegression().fit(similarity[:, None], error_corr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6cd3c6fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "a = int(np.min(similarity))\n",
    "b = int(np.max(similarity))+2\n",
    "plt.plot(np.arange(a, b), reg.predict(np.arange(a, b)[:, None]), linewidth=2.5, c='C0', linestyle='dashed')\n",
    "plt.scatter(similarity, error_corr, s=15, alpha=0.3)\n",
    "print(a,b)\n",
    "\n",
    "# plt.title('Feature frequency', fontsize=17)\n",
    "plt.xlabel('feature similarity', fontsize=16)\n",
    "plt.ylabel('shared error', fontsize=16)\n",
    "_ = plt.xticks(fontsize=15)\n",
    "_ = plt.yticks(fontsize=15)\n",
    "# plt.ticklabel_format(axis=\"x\", style=\"sci\", scilimits=(0,0))\n",
    "# plt.ticklabel_format(axis=\"y\", style=\"sci\", scilimits=(0,0))\n",
    "reg.score(similarity[:, None], error_corr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76f0de22",
   "metadata": {},
   "outputs": [],
   "source": [
    "neuron_feature_pairing.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de815679",
   "metadata": {},
   "outputs": [],
   "source": [
    "overlap = []\n",
    "error_corr = []\n",
    "for i in range(20):\n",
    "    for j in range(20):\n",
    "        if i == j: continue\n",
    "        l = len(set(neuron_feature_pairing[i]).intersection(neuron_feature_pairing[j]))\n",
    "        error_corr.append(all_pair_error_correlation[i][j])\n",
    "        overlap.append(l)\n",
    "overlap = np.array(overlap)\n",
    "error_corr = np.array(error_corr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d3f90fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# sns.set_style(\"darkgrid\")\n",
    "# root = './plots'\n",
    "# sns.set_style(\"whitegrid\")\n",
    "# sns.set_palette(\"tab10\")\n",
    "\n",
    "# plt.figure(figsize=(8, 6))\n",
    "# sns.kdeplot(gt_confidence, all_image_feature_count, fill=True, cbar=False, color='aliceblue',)\n",
    "# # sns.histplot(\n",
    "# #     None, x=gt_confidence, y=all_image_feature_count,\n",
    "# #     bins=100, discrete=(False, True), log_scale=(False, False),\n",
    "# #     pthresh=.01, pmax=.7,\n",
    "# # )\n",
    "\n",
    "# plt.title(f'conf vs. feature count ({group}k)', fontsize=26)\n",
    "# plt.xlabel('confidence', fontsize=24)\n",
    "# plt.ylabel('number of features', fontsize=24)\n",
    "# _ = plt.xticks(fontsize=22)\n",
    "# _ = plt.yticks(fontsize=22)\n",
    "# # plt.ticklabel_format(axis=\"x\", style=\"sci\", scilimits=(0,0))\n",
    "# # plt.ticklabel_format(axis=\"y\", style=\"sci\", scilimits=(0,0))\n",
    "# plt.savefig(os.path.join(root, f'{group}_conf_feat_count.pdf'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d1bf5304",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "root = './plots'\n",
    "reg = LinearRegression().fit(overlap[:, None], error_corr)\n",
    "a = int(np.min(overlap))\n",
    "b = int(np.max(overlap))+2\n",
    "plt.figure(figsize=(8, 6))\n",
    "# plt.plot(np.arange(a, b), reg.predict(np.arange(a, b)[:, None]), linewidth=2.5, c='C0', linestyle='dashed')\n",
    "# sns.kdeplot(overlap, fill=True, alpha=.35, linewidth=0.0, color='gold')\n",
    "plt.scatter(overlap, error_corr, s=40, alpha=0.6)\n",
    "# ax = sns.boxplot(x=overlap, y=error_corr, data=None)\n",
    "print(a,b)\n",
    "\n",
    "plt.title(f'shared feature vs. shared error ({group}k)', fontsize=26)\n",
    "plt.xlabel('shared features', fontsize=24)\n",
    "plt.ylabel('shared error', fontsize=24)\n",
    "_ = plt.xticks(fontsize=22)\n",
    "_ = plt.yticks(fontsize=22)\n",
    "# plt.ticklabel_format(axis=\"x\", style=\"sci\", scilimits=(0,0))\n",
    "# plt.ticklabel_format(axis=\"y\", style=\"sci\", scilimits=(0,0))\n",
    "reg.score(overlap[:, None], error_corr)\n",
    "plt.savefig(os.path.join(root, f'{group}_shared_feat_shared_err.pdf'), bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f97aec13",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.diag(all_corr_mat[0][0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a5d0b70",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.diag(all_corr_mat[0][1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "31ad00a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "! pwd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cde8f90e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e4c38528",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_files = sorted(os.listdir(base_dir))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "05ea043c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# all_files"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db83462d",
   "metadata": {},
   "outputs": [],
   "source": [
    "labels = np.load(os.path.join(base_dir, 'Y.npy'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b6afb73",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_representations = []\n",
    "all_logits = []\n",
    "all_name = []\n",
    "for f in all_files:\n",
    "    if f == 'Y.npy':\n",
    "        continue\n",
    "    if 'F' in f:\n",
    "        index = f.split('.')[0][1:]\n",
    "        feature_path = os.path.join(base_dir, f)\n",
    "        logits_path = os.path.join(base_dir, f'L{index}.npy')\n",
    "        feat = np.load(feature_path)\n",
    "        if np.isnan(feat.sum()) or np.abs(feat).mean() < 1e-6:\n",
    "            print(f'skipping {f}')\n",
    "            continue\n",
    "        all_representations.append(np.load(feature_path))\n",
    "        all_logits.append(np.load(logits_path))\n",
    "        all_name.append(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ff83df2",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_pca_repr = []\n",
    "for name, logits, r in zip(all_name, all_logits, all_representations):\n",
    "    print(f'======={name}=======')\n",
    "    all_pca_repr.append(project_repr(r))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bf3f0893",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1a5c87a9",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12 (main, Mar 26 2022, 15:51:13) \n[Clang 12.0.0 (clang-1200.0.32.29)]"
  },
  "vscode": {
   "interpreter": {
    "hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
