{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b385c7e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, sys\n",
    "dir2 = os.path.abspath('')\n",
    "dir1 = os.path.dirname(dir2)\n",
    "if not dir1 in sys.path:\n",
    "    sys.path.append(dir1)\n",
    "import json\n",
    "import numpy as np\n",
    "from models.model_utils import get_model\n",
    "from data.data_utils import get_dataloader\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import time\n",
    "from collections import Counter\n",
    "from neuron_affinity import *\n",
    "from sklearn.decomposition import PCA\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn.pipeline import make_pipeline\n",
    "from sklearn.linear_model import LinearRegression\n",
    "import seaborn as sns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f5bf3c12",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f'Cuda {torch.cuda.is_available()}')\n",
    "exp_dir = None  # path to the checkpoints\n",
    "print(exp_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ffea073e",
   "metadata": {},
   "outputs": [],
   "source": [
    "os.listdir(exp_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d084360",
   "metadata": {},
   "outputs": [],
   "source": [
    "os.listdir(os.path.join(exp_dir, 'model_15'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e9bca4f",
   "metadata": {},
   "outputs": [],
   "source": [
    "os.listdir(os.path.join(exp_dir, 'neuron_correlation_pca'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b716be9b",
   "metadata": {},
   "outputs": [],
   "source": [
    "os.listdir(os.path.join(exp_dir, 'neuron_correlation_pca', 'ep_1_'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8defc7ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "max_ckpt_n = 0\n",
    "min_ckpt_n = 1000\n",
    "for m in os.listdir(exp_dir):\n",
    "    if 'config' in m:\n",
    "        continue\n",
    "    curr_max = 0\n",
    "    for n in os.listdir(os.path.join(exp_dir, m)):\n",
    "        if 'ep' in n:\n",
    "            curr_max +=1 \n",
    "    max_ckpt_n = max(max_ckpt_n, curr_max)\n",
    "    min_ckpt_n = min(min_ckpt_n, curr_max)\n",
    "\n",
    "ckpt_name = [n for n in os.listdir(os.path.join(exp_dir, 'model_15')) if 'ep' in n]\n",
    "ckpt_name = sorted(ckpt_name, key=lambda n: [int(s) for s in n.split('_') if s.isdigit()][0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1470563",
   "metadata": {},
   "outputs": [],
   "source": [
    "n = 'ep_30_weight.pt'\n",
    "[int(s) for s in n.split('_') if s.isdigit()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ed30a053",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(max_ckpt_n, min_ckpt_n)\n",
    "print(sorted(ckpt_name, key=lambda n: [int(s) for s in n.split('_') if s.isdigit()][0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6e79a3b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_1_path = os.path.join(exp_dir, 'model_1')\n",
    "config_path = os.path.join(model_1_path, 'config')\n",
    "print('Loading config...')\n",
    "with open(config_path, 'r') as f:\n",
    "    config = json.load(f)\n",
    "all_dataloaders = get_dataloader(config)\n",
    "_, test_loader, _ = all_dataloaders  # random subset\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b609aeec",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_data = []\n",
    "all_label = []\n",
    "for x, y in test_loader:\n",
    "    all_data.append(x.data.numpy())\n",
    "    all_label.append(y.data.numpy())\n",
    "all_data = np.concatenate(all_data, axis=0).transpose((0, 2, 3, 1))\n",
    "all_label = np.concatenate(all_label, axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "758b6a66",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(all_data.shape)\n",
    "print(all_label.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c2b0bfc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# def project_repr(model_repr, n_component=50):\n",
    "#     model_repr_flat = model_repr.reshape((len(model_repr), -1))\n",
    "#     pca = PCA(n_components=n_component)\n",
    "#     pca.fit(model_repr_flat)\n",
    "#     return pca.transform(model_repr_flat)\n",
    "\n",
    "def project_repr(model_repr, n_component=50):\n",
    "    model_repr_flat = model_repr.reshape((len(model_repr), -1))\n",
    "    scaler = StandardScaler()\n",
    "    pca = PCA(n_components=n_component)\n",
    "    pipeline = make_pipeline(scaler, pca)\n",
    "    pipeline.fit(model_repr_flat)\n",
    "    return pipeline.transform(model_repr_flat)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b9f2c5e",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_repr_path = os.path.join(exp_dir, 'data_repr_pca')\n",
    "try:\n",
    "    os.mkdir(data_repr_path)\n",
    "except:\n",
    "    print(f'{data_repr_path} exists')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "854f4581",
   "metadata": {},
   "outputs": [],
   "source": [
    "# n_model = 20\n",
    "# n_component = 50\n",
    "\n",
    "# tic = time.time()\n",
    "# per_ep_all_repr = []\n",
    "# n = 1\n",
    "# for ep_n in ckpt_name:\n",
    "#     per_model_repr = []\n",
    "#     save_name = os.path.join(data_repr_path, f'{ep_n}_per_model_repr.npy')\n",
    "#     if os.path.exists(save_name):\n",
    "#         print(f'{save_name} exists')\n",
    "#         continue\n",
    "#     for i in range(n_model):\n",
    "#         model_path = os.path.join(exp_dir, f'model_{i}')\n",
    "#         weight_path = os.path.join(model_path, ep_n)\n",
    "#         model = get_model(config).cuda()\n",
    "#         state_dict = torch.load(weight_path, map_location=torch.device('cpu'))\n",
    "#         model.load_state_dict(state_dict)\n",
    "#         model_repr = get_all_repr(model, test_loader, 'layer4')\n",
    "#         per_model_repr.append(project_repr(model_repr, n_component=n_component))\n",
    "#         print('='*60)\n",
    "#         print(f'{ep_n} {i} {(time.time()-tic)/(n)}  s / model')\n",
    "#         n += 1\n",
    "#     np.save(save_name, np.array(per_model_repr))\n",
    "#     per_ep_all_repr.append(per_model_repr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "847f2e18",
   "metadata": {},
   "outputs": [],
   "source": [
    "os.listdir(data_repr_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0505b1e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "per_ep_all_repr = {}\n",
    "for n in sorted(os.listdir(data_repr_path), key=lambda n: int(n.split('_')[1])):\n",
    "    print(n)\n",
    "    r = np.load(os.path.join(data_repr_path, n))\n",
    "    per_ep_all_repr[int(n.split('_')[1])] = r"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "baa09ba0",
   "metadata": {},
   "outputs": [],
   "source": [
    "ep = 29"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe4db419",
   "metadata": {},
   "outputs": [],
   "source": [
    "max_per_example = np.max(np.abs(per_ep_all_repr[ep]), -1)[:,:,None]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8606364e",
   "metadata": {},
   "outputs": [],
   "source": [
    "normalized_per_model_repr = per_ep_all_repr[ep] / max_per_example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14bdc420",
   "metadata": {},
   "outputs": [],
   "source": [
    "# how much each image (y-axis) activates the features of each model (x-axis)\n",
    "n = 10\n",
    "m = 20\n",
    "fig, axs = plt.subplots(m, n+1, figsize=(18, 18))\n",
    "for i, img_idx in enumerate(np.random.choice(np.arange(10000), m)):\n",
    "    axs[i, 0].imshow(all_data[img_idx])\n",
    "    for j in range(n):\n",
    "        axs[i, j + 1].set_xlim([-0.1, 1.])\n",
    "        axs[i, j + 1].set_ylim([0, 30])\n",
    "        _ = axs[i, j + 1].hist(normalized_per_model_repr[j][i], bins=20)\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a46a4183",
   "metadata": {},
   "outputs": [],
   "source": [
    "global_percentile = 90\n",
    "global_threshold = np.percentile(normalized_per_model_repr, [global_percentile])[0]\n",
    "print(global_threshold)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f87228f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "normalized_per_model_repr.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e3f7ab6",
   "metadata": {},
   "outputs": [],
   "source": [
    "activated_neuron_idx = np.where(normalized_per_model_repr > global_threshold)\n",
    "activated_neuron_idx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12bd01b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "activatd_neuron_count = activated_neuron_idx[0].shape[0]\n",
    "all_neuron_n = np.prod(normalized_per_model_repr.shape)\n",
    "percent_activated = activatd_neuron_count/all_neuron_n\n",
    "print(activatd_neuron_count, all_neuron_n, percent_activated)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "57107143",
   "metadata": {},
   "outputs": [],
   "source": [
    "neuron_feature_pairing_path = os.path.join(exp_dir, 'neuron_correlation_pca', f'ep_{ep}_', 'neuron_feature_pairing_p80') + '.npy'\n",
    "neuron_feature_pairing = np.load(neuron_feature_pairing_path)\n",
    "print(neuron_feature_pairing.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "04fcd3c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.max(activated_neuron_idx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d32fc70",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_image = np.max(activated_neuron_idx) + 1\n",
    "# print(n_image)\n",
    "image_feature = [[] for i in range(n_image)]\n",
    "for i, j, k in zip(activated_neuron_idx[1], activated_neuron_idx[0], activated_neuron_idx[2]):\n",
    "    if neuron_feature_pairing[j, k] > 0:\n",
    "#         print(i, j, k)\n",
    "        image_feature[i].append(neuron_feature_pairing[j, k])\n",
    "image_feature = [set(f) for f in image_feature]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c8abf00",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_image_feature_count = [len(f) for f in image_feature]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "31bc82b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "_ = plt.hist(all_image_feature_count, bins=20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0da8fabf",
   "metadata": {},
   "outputs": [],
   "source": [
    "feature_count_percentile = np.percentile(all_image_feature_count, [5, 95])\n",
    "feature_count_percentile"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "09cab6ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "small_fc_idx = np.where(all_image_feature_count < feature_count_percentile[0])[0]\n",
    "large_fc_idx = np.where(all_image_feature_count > feature_count_percentile[1])[0]\n",
    "# print(small_fc_idx)\n",
    "# print(large_fc_idx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee36ed84",
   "metadata": {},
   "outputs": [],
   "source": [
    "small_fc_idx = sorted(small_fc_idx, key=lambda k: all_image_feature_count[k])\n",
    "large_fc_idx = sorted(large_fc_idx, key=lambda k: all_image_feature_count[k])[::-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f888ed6",
   "metadata": {},
   "outputs": [],
   "source": [
    "n = 10\n",
    "m = 10\n",
    "fig, axs = plt.subplots(m, n, figsize=(18, 18))\n",
    "for i in range(m):\n",
    "    for j in range(n):\n",
    "        axs[i, j].imshow(all_data[small_fc_idx[i * m + j]])\n",
    "#         axs[i, j].set_title(f\"i {small_fc_idx[i * m + j]}  nf {all_image_feature_count[small_fc_idx[i * m + j]]}\")\n",
    "        axs[i, j].set_title(f\"nf {all_image_feature_count[small_fc_idx[i * m + j]]}\")\n",
    "        axs[i, j].axis('off')\n",
    "fig.show()\n",
    "\n",
    "# plt.imshow(all_data[5721])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5be36bc5",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ff4b296",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(np.argmax(all_image_feature_count), np.max(all_image_feature_count))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a4035052",
   "metadata": {},
   "outputs": [],
   "source": [
    "n = 10\n",
    "m = 10\n",
    "fig, axs = plt.subplots(m, n, figsize=(18, 18))\n",
    "for i in range(m):\n",
    "    for j in range(n):\n",
    "        axs[i, j].imshow(all_data[large_fc_idx[i * m + j]])\n",
    "        axs[i, j].set_title(f\"nf {all_image_feature_count[large_fc_idx[i * m + j]]}\")\n",
    "#         axs[i, j].set_title(f\"i {large_fc_idx[i * m + j]}  nf {all_image_feature_count[large_fc_idx[i * m + j]]}\")\n",
    "        axs[i, j].axis('off')\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a3bcfc44",
   "metadata": {},
   "outputs": [],
   "source": [
    "# len(class_idx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e6f7a171",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_data_per_class = 10\n",
    "fig, axs = plt.subplots(n_data_per_class, 10, figsize=(18, 18))\n",
    "for k in range(10):\n",
    "    class_idx = np.where(all_label == k)[0]\n",
    "#     smallest_idx_class_k = np.intersect1d(class_idx, small_fc_idx)\n",
    "    smallest_idx_class_k = sorted(class_idx, key=lambda k: all_image_feature_count[k])\n",
    "    for i in range(n_data_per_class):\n",
    "        axs[k, i].imshow(all_data[smallest_idx_class_k[i]])\n",
    "        axs[k, i].set_title(f\"nf {all_image_feature_count[smallest_idx_class_k[i]]}\")\n",
    "        axs[k, i].axis('off')\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64adbeae",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_per_row = 15\n",
    "class_images = []\n",
    "for k in range(10):\n",
    "    class_idx = np.where(all_label == k)[0]\n",
    "    smallest_idx_class_k = sorted(class_idx, key=lambda k: all_image_feature_count[k])\n",
    "    class_images.append(all_data[smallest_idx_class_k[:n_per_row]])\n",
    "class_images = np.stack(class_images)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "77ae329b",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(class_images.shape)\n",
    "fig = plt.figure(figsize=(20, 40))\n",
    "all_class_image = []\n",
    "selected_classes = [0, 1, 3, 5, 7]\n",
    "# selected_classes = range(10)\n",
    "# for k in range(10):\n",
    "for k in selected_classes:\n",
    "    all_image_k = class_images[k][:8]\n",
    "    all_image_k = np.concatenate(all_image_k, axis=1)\n",
    "    all_class_image.append(all_image_k)\n",
    "    all_class_image.append(np.ones((2, all_image_k.shape[1], all_image_k.shape[2])))\n",
    "all_class_image = np.concatenate(all_class_image, axis=0)\n",
    "plt.axis('off')\n",
    "plt.imshow(all_class_image)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dcf0edb7",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_data_per_class = 10\n",
    "fig, axs = plt.subplots(n_data_per_class, 10, figsize=(18, 18))\n",
    "for k in range(10):\n",
    "    class_idx = np.where(all_label == k)[0]\n",
    "#     smallest_idx_class_k = np.intersect1d(class_idx, small_fc_idx)\n",
    "    largest_idx_class_k = sorted(class_idx, key=lambda k: all_image_feature_count[k])[::-1]\n",
    "    for i in range(n_data_per_class):\n",
    "        axs[k, i].imshow(all_data[largest_idx_class_k[i]])\n",
    "        axs[k, i].set_title(f\"nf {all_image_feature_count[largest_idx_class_k[i]]}\")\n",
    "        axs[k, i].axis('off')\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a71a123e",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_per_row = 15\n",
    "class_images = []\n",
    "for k in range(10):\n",
    "    class_idx = np.where(all_label == k)[0]\n",
    "    largest_idx_class_k = sorted(class_idx, key=lambda k: all_image_feature_count[k])[::-1]\n",
    "    class_images.append(all_data[largest_idx_class_k[:n_per_row]])\n",
    "class_images = np.stack(class_images)\n",
    "\n",
    "print(class_images.shape)\n",
    "fig = plt.figure(figsize=(20, 40))\n",
    "all_class_image = []\n",
    "# selected_classes = range(10)\n",
    "# for k in range(10):\n",
    "for k in selected_classes:\n",
    "    all_image_k = class_images[k][:8]\n",
    "    all_image_k = np.concatenate(all_image_k, axis=1)\n",
    "    all_class_image.append(all_image_k)\n",
    "    all_class_image.append(np.ones((2, all_image_k.shape[1], all_image_k.shape[2])))\n",
    "all_class_image = np.concatenate(all_class_image, axis=0)\n",
    "plt.axis('off')\n",
    "plt.imshow(all_class_image)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "155b0fbc",
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import namedtuple\n",
    "Image = namedtuple('Image', ['dataset_idx', 'label', 'feature', 'feature_count'])\n",
    "all_image_tuple = []\n",
    "for i in range(len(image_feature)):\n",
    "    image_tuple = Image(dataset_idx=i, label=all_label[i], feature=tuple(image_feature[i]), feature_count=len(image_feature[i]))\n",
    "    all_image_tuple.append(image_tuple)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9dcb052a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# len(all_image_tuple)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f7d576e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def feature_similarity(feature_1, feature_2):\n",
    "    inter = set(feature_1).intersection(set(feature_2))\n",
    "    avg = (len(feature_1) + len(feature_2))/2\n",
    "    return len(inter) / avg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f854c66",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_similarity = []\n",
    "for i in np.random.randint(10000, size=100):\n",
    "    t_1 = all_image_tuple[i]\n",
    "    single_similarity = {}\n",
    "    for j, t_2 in enumerate(all_image_tuple):\n",
    "        single_similarity[t_2] = feature_similarity(t_1.feature, t_2.feature)\n",
    "    all_similarity.append(single_similarity)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "68774772",
   "metadata": {},
   "outputs": [],
   "source": [
    "n = 10\n",
    "fig, axs = plt.subplots(n, 10, figsize=(18, 18 * n / 10))\n",
    "\n",
    "for i in range(n):\n",
    "    all_neighbors_dict = all_similarity[i+50]\n",
    "    all_neighbors = all_neighbors_dict.keys()\n",
    "    all_neighbors = sorted(all_neighbors, key=lambda k: all_neighbors_dict[k])[::-1]\n",
    "    for j, t in enumerate(all_neighbors[:10]):\n",
    "        axs[i, j].imshow(all_data[t.dataset_idx])\n",
    "        bold = \"bold\" if j==0 else None\n",
    "        axs[i, j].set_title(\"s = {:.4f}\".format(all_neighbors_dict[t]), fontsize=13,  fontweight=bold)\n",
    "        axs[i, j].axis('off')\n",
    "        \n",
    "fig.suptitle(\"Image Similarity vis Feature\", fontsize=15)\n",
    "fig.subplots_adjust(top=0.95)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8bb5f53b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# os.listdir(os.path.join(exp_dir, 'neuron_correlation_pca'))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6e5eaa6f",
   "metadata": {},
   "source": [
    "## Feature density vs confidence"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b23a6bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "# gt_confidence = np.load(os.path.join(exp_dir, 'neuron_correlation', 'gt_confidence.npy'))\n",
    "gt_confidence = np.load(os.path.join(exp_dir, 'neuron_correlation_pca', f'ep_{ep}_', 'gt_confidence.npy'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d529959c",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(gt_confidence, all_image_feature_count, s=0.05)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75985b99",
   "metadata": {},
   "outputs": [],
   "source": [
    "# sns.set_style(\"darkgrid\")\n",
    "root = './plots'\n",
    "sns.set_style(\"whitegrid\")\n",
    "sns.set_palette(\"tab10\")\n",
    "\n",
    "plt.figure(figsize=(8, 6))\n",
    "sns.kdeplot(gt_confidence, all_image_feature_count, fill=True, cbar=False, color='aliceblue',)\n",
    "# sns.histplot(\n",
    "#     None, x=gt_confidence, y=all_image_feature_count,\n",
    "#     bins=100, discrete=(False, True), log_scale=(False, False),\n",
    "#     pthresh=.01, pmax=.7,\n",
    "# )\n",
    "\n",
    "# plt.title(f'conf vs feature num (epoch {ep})', fontsize=26)\n",
    "plt.title(f'conf vs feature num', fontsize=26)\n",
    "plt.xlabel('confidence', fontsize=24)\n",
    "plt.ylabel('features num', fontsize=24)\n",
    "_ = plt.xticks(fontsize=22)\n",
    "_ = plt.yticks(fontsize=22)\n",
    "# plt.ticklabel_format(axis=\"x\", style=\"sci\", scilimits=(0,0))\n",
    "# plt.ticklabel_format(axis=\"y\", style=\"sci\", scilimits=(0,0))\n",
    "# plt.savefig(os.path.join(root, f'{group}_conf_feat_count.pdf'))\n",
    "plt.savefig(os.path.join(root, f'svhn_conf_feat_count.pdf'), bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "938aeb5b",
   "metadata": {},
   "source": [
    "# Whole trajectory confidence"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0613e3a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_single_ep_feature_count(ep, global_percentile=90, image_n=10000):\n",
    "    max_per_example = np.max(np.abs(per_ep_all_repr[ep]), -1)[:,:,None]\n",
    "    normalized_per_model_repr = per_ep_all_repr[ep] / max_per_example\n",
    "    global_threshold = np.percentile(normalized_per_model_repr, [global_percentile])[0]\n",
    "    activated_neuron_idx = np.where(normalized_per_model_repr > global_threshold)\n",
    "    neuron_feature_pairing_path = os.path.join(exp_dir, 'neuron_correlation_pca', f'ep_{ep}_', 'neuron_feature_pairing_p80') + '.npy'\n",
    "    neuron_feature_pairing = np.load(neuron_feature_pairing_path)\n",
    "    image_feature = [[] for i in range(image_n)]\n",
    "    for i, j, k in zip(activated_neuron_idx[1], activated_neuron_idx[0], activated_neuron_idx[2]):\n",
    "        if neuron_feature_pairing[j, k] > 0:\n",
    "            image_feature[i].append(neuron_feature_pairing[j, k])\n",
    "    image_feature = [set(f) for f in image_feature]\n",
    "    all_image_feature_count = [len(f) for f in image_feature]\n",
    "    gt_confidence = np.load(os.path.join(exp_dir, 'neuron_correlation_pca', f'ep_{ep}_', 'gt_confidence.npy'))\n",
    "    return gt_confidence, all_image_feature_count"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f66dceb7",
   "metadata": {},
   "outputs": [],
   "source": [
    "ep_confidence, ep_feature_count = {}, {}\n",
    "for ep in np.arange(30):\n",
    "    conf, f_count = get_single_ep_feature_count(ep, image_n=n_image)\n",
    "    ep_confidence[ep] = conf\n",
    "    ep_feature_count[ep] = f_count"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f654ed26",
   "metadata": {},
   "outputs": [],
   "source": [
    "n = 6\n",
    "m = 5\n",
    "fig, axs = plt.subplots(m, n, figsize=(18, 18))\n",
    "sns.set_style(\"whitegrid\")\n",
    "sns.set_palette(\"tab10\")\n",
    "for i in range(m):\n",
    "    for j in range(n):\n",
    "        idx = i * m + j\n",
    "        conf, f_count = ep_confidence[idx], ep_feature_count[idx]\n",
    "        sns.kdeplot(conf, f_count, fill=True, cbar=False, color='aliceblue', ax=axs[i, j])\n",
    "        axs[i, j].set_title(f\"{idx}\")\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e70d12c9",
   "metadata": {},
   "source": [
    "\n",
    "# Class conditioned analysis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c8ec58d",
   "metadata": {},
   "outputs": [],
   "source": [
    "class_idx = 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "41034883",
   "metadata": {},
   "outputs": [],
   "source": [
    "image_idx = np.where(all_label == class_idx)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c9a09639",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.imshow(all_data[image_idx[1]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d6a4e30",
   "metadata": {},
   "outputs": [],
   "source": [
    "filtered_image_features = np.array(image_feature)[image_idx]\n",
    "filtered_image_feature_count = [len(f) for f in filtered_image_features]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6ec36c08",
   "metadata": {},
   "outputs": [],
   "source": [
    "_ = plt.hist(filtered_image_feature_count, bins=20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e6ea51c",
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.kdeplot(gt_confidence[image_idx], filtered_image_feature_count, fill=True, cbar=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b77b5035",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(10, 3, figsize=(10, 18))\n",
    "for i in range(10):\n",
    "    image_idx = np.where(all_label == i)[0]\n",
    "    filtered_image_features = np.array(image_feature)[image_idx]\n",
    "    filtered_image_feature_count = [len(f) for f in filtered_image_features]\n",
    "    axs[i, 0].imshow(all_data[image_idx[1]])\n",
    "    axs[i, 1].hist(filtered_image_feature_count, bins=20)\n",
    "    sns.kdeplot(gt_confidence[image_idx], filtered_image_feature_count, fill=True, cbar=True, ax=axs[i, 2])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d32f993b",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_image = np.max(activated_neuron_idx) + 1\n",
    "image_feature = [[] for i in range(n_image)]\n",
    "for i, j, k in zip(activated_neuron_idx[1], activated_neuron_idx[0], activated_neuron_idx[2]):\n",
    "    if neuron_feature_pairing[j, k] > 0:\n",
    "        image_feature[i].append(neuron_feature_pairing[j, k])\n",
    "image_feature = [set(f) for f in image_feature]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "97bd0a6a",
   "metadata": {},
   "outputs": [],
   "source": [
    "feature_to_class = [[] for _ in range(np.max(neuron_feature_pairing)+1)]\n",
    "for i in range(n_image):\n",
    "    for f in list(image_feature[i]):\n",
    "        feature_to_class[f].append(all_label[i])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cfdad364",
   "metadata": {},
   "outputs": [],
   "source": [
    "feature_to_class = np.array(feature_to_class)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "847198b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "n = 8\n",
    "m = 20\n",
    "fig, axs = plt.subplots(m, n, figsize=(18, 2 * m))\n",
    "for i in range(m):\n",
    "    for j in range(n):\n",
    "        fc = feature_to_class[i * n + j + 1]\n",
    "        axs[i, j].hist(fc, bins=20)\n",
    "        axs[i, j].set_ylim([0, 1000])\n",
    "        axs[i, j].set_title(f'f {i * n + j + 1} nd {len(fc)}')\n",
    "#         axs[i, j].set_yscale('log')\n",
    "\n",
    "# for i, fc in enumerate(feature_to_class[1:1+n]):\n",
    "#     axs[i].set_ylim([0, 1000])\n",
    "#     axs[i].hist(fc, bins=20)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eb4a2662",
   "metadata": {},
   "source": [
    "## Avg number of feature per data vs number of data point with that feature"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4236704c",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_features = max([max(list(f)) for f in image_feature])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1daf56e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "990797f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_data_for_feature = [1e-6 for _ in range(n_features+1)]\n",
    "avg_feature_per_data = [0 for _ in range(n_features+1)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5448fb66",
   "metadata": {},
   "outputs": [],
   "source": [
    "# list(image_feature[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f98ad8fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "for i_f in image_feature:\n",
    "    i_f = list(i_f)\n",
    "    for f in i_f:\n",
    "        n_data_for_feature[f] += 1\n",
    "        avg_feature_per_data[f] += len(i_f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0400d0ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(1, len(avg_feature_per_data)):\n",
    "#     print(i, avg_feature_per_data[i], n_data_for_feature[i])\n",
    "    avg_feature_per_data[i] = avg_feature_per_data[i] / n_data_for_feature[i]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0c3a7233",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ecd3026",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(n_data_for_feature, avg_feature_per_data, s=0.5)\n",
    "# sns.kdeplot(n_data_for_feature, avg_feature_per_data, fill=True, cbar=True)\n",
    "plt.ylabel('avg feat per data')\n",
    "plt.xlabel('n data for feat')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07f9962b",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(np.arange(len(n_data_for_feature)), n_data_for_feature)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4da00ad3",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(10, 2, figsize=(10, 18))\n",
    "for i in range(10):\n",
    "    image_idx = np.where(all_label == i)[0]\n",
    "    filtered_image_features = np.array(image_feature)[image_idx]\n",
    "    n_data_for_feature = [1e-6 for _ in range(n_features+1)]\n",
    "    avg_feature_per_data = [0 for _ in range(n_features+1)]\n",
    "    for i_f in filtered_image_features:\n",
    "        i_f = list(i_f)\n",
    "        for f in i_f:\n",
    "            n_data_for_feature[f] += 1\n",
    "            avg_feature_per_data[f] += len(i_f)\n",
    "    for j in range(1, len(avg_feature_per_data)):\n",
    "        avg_feature_per_data[j] = avg_feature_per_data[j] / n_data_for_feature[j]\n",
    "    axs[i, 0].imshow(all_data[image_idx[1]])\n",
    "    axs[i, 1].scatter(n_data_for_feature, avg_feature_per_data, s=0.5)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ad9a4571",
   "metadata": {},
   "source": [
    "## Model vs Feature distribution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e8ba181d",
   "metadata": {},
   "outputs": [],
   "source": [
    "neuron_feature_pairing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e6cb4d42",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_feature = np.max(neuron_feature_pairing) + 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17ed14c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "feature_to_model = [[0]*20 for _ in range(n_feature)]\n",
    "for i in range(len(neuron_feature_pairing)):\n",
    "    for j in neuron_feature_pairing[i]:\n",
    "        feature_to_model[j][i] = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf23d0b2",
   "metadata": {},
   "outputs": [],
   "source": [
    "feature_to_model = sorted(feature_to_model, key=lambda a: -sum(a))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d93c9a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "feature_to_model = np.array(feature_to_model)\n",
    "print(feature_to_model)\n",
    "print(feature_to_model.shape)\n",
    "\n",
    "# shuffle\n",
    "np.random.shuffle(feature_to_model)\n",
    "feature_to_model = np.array(sorted(feature_to_model, key=lambda a: -sum(a)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0cc3e55a",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(1, 1, figsize=(20, 18))\n",
    "plt.imshow(feature_to_model.T[:,:400])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "674daa63",
   "metadata": {},
   "outputs": [],
   "source": [
    "# fig, axs = plt.subplots(1, 1, figsize=(20, 18))\n",
    "# shuffle_model_idx = np.arange(20)\n",
    "# plt.imshow(feature_to_model.T[shuffle_model_idx,:400])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "892f0cd1",
   "metadata": {},
   "source": [
    "## Using feature-class signature to do prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1edb4954",
   "metadata": {},
   "outputs": [],
   "source": [
    "def convert_feature_to_signature(feature, n_class=10):\n",
    "    count = Counter(feature)\n",
    "    signature = [count[i] for i in range(n_class)]\n",
    "    return np.array(signature)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1573f8f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "feature_class_signature = [convert_feature_to_signature(f) for f in feature_to_class]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1588634",
   "metadata": {},
   "outputs": [],
   "source": [
    "total_data_per_feature = [sum(s) for s in feature_class_signature]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec5195ae",
   "metadata": {},
   "outputs": [],
   "source": [
    "# sns.set_style(\"darkgrid\")\n",
    "# root = './plots'\n",
    "# sns.set_style(\"whitegrid\")\n",
    "# sns.set_palette(\"tab10\")\n",
    "\n",
    "# plt.figure(figsize=(8, 6))\n",
    "# sns.kdeplot(gt_confidence, all_image_feature_count, fill=True, cbar=False, color='aliceblue',)\n",
    "# # sns.histplot(\n",
    "# #     None, x=gt_confidence, y=all_image_feature_count,\n",
    "# #     bins=100, discrete=(False, True), log_scale=(False, False),\n",
    "# #     pthresh=.01, pmax=.7,\n",
    "# # )\n",
    "\n",
    "# plt.title(f'conf vs. feature count ({group}k)', fontsize=26)\n",
    "# plt.xlabel('confidence', fontsize=24)\n",
    "# plt.ylabel('number of features', fontsize=24)\n",
    "# _ = plt.xticks(fontsize=22)\n",
    "# _ = plt.yticks(fontsize=22)\n",
    "# # plt.ticklabel_format(axis=\"x\", style=\"sci\", scilimits=(0,0))\n",
    "# # plt.ticklabel_format(axis=\"y\", style=\"sci\", scilimits=(0,0))\n",
    "# plt.savefig(os.path.join(root, f'{group}_conf_feat_count.pdf'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f7adc9f8",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# # plt.ticklabel_format(axis='both', style='sci')\n",
    "# plt.ticklabel_format(style='sci', axis='x')\n",
    "# plt.gca().set_yticklabels(['{:.0f}'.format(x) for x in sorted(total_data_per_feature)])\n",
    "plt.figure(figsize=(8, 6))\n",
    "plt.plot(np.arange(len(total_data_per_feature))[1:], sorted(total_data_per_feature)[1:], linewidth=4.,)\n",
    "# plt.title(f'feature frequency epoch {ep}', fontsize=26)\n",
    "plt.title(f'feature frequency', fontsize=26)\n",
    "plt.xlabel('feature (1e2)', fontsize=24)\n",
    "plt.ylabel('occurrences (1e3)', fontsize=24)\n",
    "_ = plt.xticks(fontsize=22)\n",
    "_ = plt.yticks(fontsize=22)\n",
    "plt.ticklabel_format(axis=\"x\", style=\"sci\", scilimits=(0,0))\n",
    "plt.ticklabel_format(axis=\"y\", style=\"sci\", scilimits=(0,0))\n",
    "# plt.savefig(os.path.join(root, f'{group}_feat_freq.pdf'))\n",
    "plt.savefig(os.path.join(root, f'svhn_feat_freq.pdf'), bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "141b8b0d",
   "metadata": {},
   "outputs": [],
   "source": [
    "feature_class_signature = np.array(feature_class_signature) + 1e-6\n",
    "normalized_feature_class_signature = feature_class_signature / np.sum(feature_class_signature, axis=1)[:, None]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "77fa9660",
   "metadata": {},
   "outputs": [],
   "source": [
    "def predict_v1(data_features, all_feature_signatures):\n",
    "    energy = 0\n",
    "    for f in data_features:\n",
    "        energy += all_feature_signatures[f]\n",
    "    return np.argmax(energy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "00bee5a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "idx = 100\n",
    "test_image = image_feature[idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "97fb3beb",
   "metadata": {},
   "outputs": [],
   "source": [
    "predict_v1(test_image, feature_class_signature)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ff60a471",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_label[idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2a05b55a",
   "metadata": {},
   "outputs": [],
   "source": [
    "classes = [4, 3]\n",
    "correct = 0\n",
    "n_data = 0\n",
    "for i in range(len(image_feature)):\n",
    "    if all_label[i] in classes:\n",
    "        prediction = predict_v1(image_feature[i], feature_class_signature)\n",
    "        correct += int(prediction == all_label[i])\n",
    "        n_data += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "86124b58",
   "metadata": {},
   "outputs": [],
   "source": [
    "correct / n_data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d00da314",
   "metadata": {},
   "source": [
    "## Get feature frequency for whole epoch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bfb786f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def convert_feature_to_signature(feature, n_class=10):\n",
    "    count = Counter(feature)\n",
    "    signature = [count[i] for i in range(n_class)]\n",
    "    return np.array(signature)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "09003fa7",
   "metadata": {},
   "outputs": [],
   "source": [
    "def feature_freq_ep(ep, global_percentile=90, image_n=10000):\n",
    "    max_per_example = np.max(np.abs(per_ep_all_repr[ep]), -1)[:,:,None]\n",
    "    normalized_per_model_repr = per_ep_all_repr[ep] / max_per_example\n",
    "    global_threshold = np.percentile(normalized_per_model_repr, [global_percentile])[0]\n",
    "    activated_neuron_idx = np.where(normalized_per_model_repr > global_threshold)\n",
    "    #######\n",
    "    neuron_feature_pairing_path = os.path.join(exp_dir, 'neuron_correlation_pca', f'ep_{ep}_', 'neuron_feature_pairing_p80') + '.npy'\n",
    "    neuron_feature_pairing = np.load(neuron_feature_pairing_path)\n",
    "    #######\n",
    "    image_feature = [[] for i in range(image_n)]\n",
    "    for i, j, k in zip(activated_neuron_idx[1], activated_neuron_idx[0], activated_neuron_idx[2]):\n",
    "        if neuron_feature_pairing[j, k] > 0:\n",
    "            image_feature[i].append(neuron_feature_pairing[j, k])\n",
    "    image_feature = [set(f) for f in image_feature]\n",
    "    #######\n",
    "    feature_to_class = [[] for _ in range(np.max(neuron_feature_pairing)+1)]\n",
    "    for i in range(image_n):\n",
    "        for f in list(image_feature[i]):\n",
    "            feature_to_class[f].append(all_label[i])\n",
    "    feature_class_signature = [convert_feature_to_signature(f) for f in feature_to_class]\n",
    "    total_data_per_feature = [sum(s) for s in feature_class_signature]\n",
    "    return feature_class_signature, total_data_per_feature\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "90bd07e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_1_path = os.path.join(exp_dir, 'model_1')\n",
    "config_path = os.path.join(model_1_path, 'config')\n",
    "print('Loading config...')\n",
    "with open(config_path, 'r') as f:\n",
    "    config = json.load(f)\n",
    "all_dataloaders = get_dataloader(config)\n",
    "_, test_loader, _ = all_dataloaders  # random subset\n",
    "all_data = []\n",
    "all_label = []\n",
    "for x, y in test_loader:\n",
    "    all_data.append(x.data.numpy())\n",
    "    all_label.append(y.data.numpy())\n",
    "all_data = np.concatenate(all_data, axis=0).transpose((0, 2, 3, 1))\n",
    "all_label = np.concatenate(all_label, axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f381cbc",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f0433d58",
   "metadata": {},
   "outputs": [],
   "source": [
    "ep_feature_class_signature, ep_total_data_per_feature = {}, {}\n",
    "tic = time.time()\n",
    "for ep in np.arange(30):\n",
    "    feature_class_signature, total_data_per_feature = feature_freq_ep(ep, image_n=n_image)\n",
    "    ep_feature_class_signature[ep] = feature_class_signature\n",
    "    ep_total_data_per_feature[ep] = total_data_per_feature\n",
    "    print('ep {}  time {}'.format(ep, (time.time() - tic)/(ep+1)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5263d7f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# sns.set_style(\"darkgrid\")\n",
    "sns.set_style(\"whitegrid\")\n",
    "\n",
    "cut_off = 30\n",
    "plt.figure(figsize=(8, 6))\n",
    "eps = np.arange(1, 30, 4)\n",
    "colors = plt.cm.viridis(np.linspace(0, 1, len(eps)))\n",
    "\n",
    "feature_class_signature, total_data_per_feature = ep_feature_class_signature[0], ep_total_data_per_feature[0]\n",
    "plt.plot(np.arange(len(total_data_per_feature))[cut_off:], np.log10(np.array(sorted(total_data_per_feature)[cut_off:])+1), linewidth=4., label='Init', c='r')\n",
    "\n",
    "for i, ep in enumerate(eps):\n",
    "    feature_class_signature, total_data_per_feature = ep_feature_class_signature[ep], ep_total_data_per_feature[ep]\n",
    "    plt.plot(np.arange(len(total_data_per_feature))[cut_off:], np.log10(np.array(sorted(total_data_per_feature)[cut_off:])+1), linewidth=4, label=f'ep {ep}', c=colors[i])\n",
    "# plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))\n",
    "plt.legend(fontsize=15)\n",
    "plt.title(f'feature frequency over training', fontsize=26)\n",
    "plt.xlabel('feature (1e2)', fontsize=24)\n",
    "plt.ylabel('occurrences (base 10)', fontsize=24)\n",
    "_ = plt.xticks(fontsize=22)\n",
    "_ = plt.yticks(fontsize=22)\n",
    "plt.ticklabel_format(axis=\"x\", style=\"sci\", scilimits=(0,0))\n",
    "plt.ticklabel_format(axis=\"y\", style=\"sci\", scilimits=(0,0))\n",
    "\n",
    "root = './plots'\n",
    "# plt.savefig(os.path.join(root, '45k_feat_freq_over_training.pdf'), bbox_inches='tight')\n",
    "plt.savefig(os.path.join(root, 'svhn_feat_freq_over_training.pdf'), bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f103600d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "24beb1c5",
   "metadata": {},
   "source": [
    "## Testing sampling models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dac7ebab",
   "metadata": {},
   "outputs": [],
   "source": [
    "def sample_model(feature_energy, n_feature, t=1.0):\n",
    "    p = feature_energy / t\n",
    "    p -= np.max(p)\n",
    "    p = np.exp(p)\n",
    "    p /= np.sum(p)\n",
    "    return np.random.choice(np.arange(len(feature_energy)), size=n_feature, replace=False, p=p)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bda16543",
   "metadata": {},
   "outputs": [],
   "source": [
    "energy = np.array(total_data_per_feature)\n",
    "model = sample_model(energy, 50, t=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fcfac6a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e2387f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "energy[model]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "884d0139",
   "metadata": {},
   "outputs": [],
   "source": [
    "models = [sample_model(energy, 50, t=1700) for _ in range(20)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d7e0a08",
   "metadata": {},
   "outputs": [],
   "source": [
    "feature_to_model = [[0]*20 for _ in range(n_feature)]\n",
    "for i, m in enumerate(models):\n",
    "    for j in m:\n",
    "        feature_to_model[j][i] = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ade39196",
   "metadata": {},
   "outputs": [],
   "source": [
    "feature_to_model = sorted(feature_to_model, key=lambda a: -sum(a))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "65b18440",
   "metadata": {},
   "outputs": [],
   "source": [
    "feature_to_model = np.array(feature_to_model)\n",
    "feature_to_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2a92ea6b",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(1, 1, figsize=(20, 18))\n",
    "plt.imshow(feature_to_model.T[:,:400])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9b3179ff",
   "metadata": {},
   "source": [
    "## Distribution of features in high confidence points"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de083be8",
   "metadata": {},
   "outputs": [],
   "source": [
    "high_confidence_data = np.where(gt_confidence > 0.9)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a770bc5a",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(high_confidence_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "60cc0581",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_feature = np.max(neuron_feature_pairing) + 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f85a23d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "high_feature_count = [0 for _ in range(n_feature)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c65b29b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in high_confidence_data:\n",
    "    for f in image_feature[i]:\n",
    "        high_feature_count[f] += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a2581d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(np.arange(len(high_feature_count)), high_feature_count)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2f972e77",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(np.log(np.arange(len(high_feature_count))), np.log(np.array(sorted(high_feature_count))/sum(high_feature_count)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "410653ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "# plt.plot(np.arange(len(feature_count)), np.log(sorted(feature_count)))\n",
    "_ = plt.hist(gt_confidence, bins=30)\n",
    "np.percentile(gt_confidence, [30])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9bc8a49a",
   "metadata": {},
   "outputs": [],
   "source": [
    "low_confidence_data = np.where(gt_confidence < 0.9)[0]\n",
    "print(len(low_confidence_data))\n",
    "low_feature_count = [0 for _ in range(n_feature)]\n",
    "for i in low_confidence_data:\n",
    "    for f in image_feature[i]:\n",
    "        low_feature_count[f] += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a7bffb4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# sns.set_style(\"darkgrid\")\n",
    "# root = './plots'\n",
    "# sns.set_style(\"whitegrid\")\n",
    "# sns.set_palette(\"tab10\")\n",
    "\n",
    "# plt.figure(figsize=(8, 6))\n",
    "# sns.kdeplot(gt_confidence, all_image_feature_count, fill=True, cbar=False, color='aliceblue',)\n",
    "# # sns.histplot(\n",
    "# #     None, x=gt_confidence, y=all_image_feature_count,\n",
    "# #     bins=100, discrete=(False, True), log_scale=(False, False),\n",
    "# #     pthresh=.01, pmax=.7,\n",
    "# # )\n",
    "\n",
    "# plt.title(f'conf vs. feature count ({group}k)', fontsize=26)\n",
    "# plt.xlabel('confidence', fontsize=24)\n",
    "# plt.ylabel('number of features', fontsize=24)\n",
    "# _ = plt.xticks(fontsize=22)\n",
    "# _ = plt.yticks(fontsize=22)\n",
    "# # plt.ticklabel_format(axis=\"x\", style=\"sci\", scilimits=(0,0))\n",
    "# # plt.ticklabel_format(axis=\"y\", style=\"sci\", scilimits=(0,0))\n",
    "# plt.savefig(os.path.join(root, f'{group}_conf_feat_count.pdf'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d5e34a54",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(9, 6))\n",
    "high_prob = np.array(high_feature_count)/sum(high_feature_count)\n",
    "low_prob = np.array(low_feature_count)/sum(low_feature_count)\n",
    "sorted_idx = np.argsort(high_prob)\n",
    "sorted_idx = np.argsort(total_data_per_feature)\n",
    "high_prob = high_prob[sorted_idx]\n",
    "low_prob = low_prob[sorted_idx]\n",
    "plt.plot(np.arange(len(high_prob))[::2], np.log10(high_prob)[::2], label='high confidence', linewidth=3)\n",
    "plt.plot(np.arange(len(low_prob))[::2], np.log10(low_prob)[::2], label='low confidence', linewidth=3)\n",
    "# plt.plot(np.arange(len(high_prob)), high_prob, label='high')\n",
    "# plt.plot(np.arange(len(low_prob)), low_prob, label='low')\n",
    "# plt.title(f'log feature density by confidence ({group}k)', fontsize=26)\n",
    "plt.title(f'log feature density by confidence', fontsize=26)\n",
    "plt.legend(fontsize=22)\n",
    "plt.xlabel('feature (1e2)', fontsize=24)\n",
    "plt.ylabel('log prob (base 10)', fontsize=24)\n",
    "_ = plt.xticks(fontsize=22)\n",
    "_ = plt.yticks(fontsize=22)\n",
    "plt.ticklabel_format(axis=\"x\", style=\"sci\", scilimits=(0,0))\n",
    "# plt.savefig(os.path.join(root, f'{group}_feat_freq_conf.pdf'))\n",
    "plt.savefig(os.path.join(root, f'svhn_feat_freq_conf.pdf'), bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a4e1cc11",
   "metadata": {},
   "outputs": [],
   "source": [
    "kl = 0\n",
    "for p, q in zip(high_prob, low_prob):\n",
    "    if p == 0:\n",
    "        print(p, q)\n",
    "    else:\n",
    "        kl += p * np.log(p/q)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e9033ca2",
   "metadata": {},
   "outputs": [],
   "source": [
    "kl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa5ed2d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(np.arange(len(low_prob)), (low_prob+1e-6)/(high_prob+1e-6))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8640e24e",
   "metadata": {},
   "source": [
    "## Number of data points with certain features vs. the number of models that learned that feature"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50836202",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_data_w_feature = [0 for _ in range(n_feature)]\n",
    "n_model_w_feature = [0 for _ in range(n_feature)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "352e6bc9",
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(len(gt_confidence)):\n",
    "    for f in image_feature[i]:\n",
    "        n_data_w_feature[f] += 1\n",
    "\n",
    "for i in range(len(neuron_feature_pairing)):\n",
    "    for j in neuron_feature_pairing[i]:\n",
    "        n_model_w_feature[j] += 1\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "34208b8c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# sns.set_style(\"darkgrid\")\n",
    "# root = './plots'\n",
    "# sns.set_style(\"whitegrid\")\n",
    "# sns.set_palette(\"tab10\")\n",
    "\n",
    "# plt.figure(figsize=(8, 6))\n",
    "# sns.kdeplot(gt_confidence, all_image_feature_count, fill=True, cbar=False, color='aliceblue',)\n",
    "# # sns.histplot(\n",
    "# #     None, x=gt_confidence, y=all_image_feature_count,\n",
    "# #     bins=100, discrete=(False, True), log_scale=(False, False),\n",
    "# #     pthresh=.01, pmax=.7,\n",
    "# # )\n",
    "\n",
    "# plt.title(f'conf vs. feature count ({group}k)', fontsize=26)\n",
    "# plt.xlabel('confidence', fontsize=24)\n",
    "# plt.ylabel('number of features', fontsize=24)\n",
    "# _ = plt.xticks(fontsize=22)\n",
    "# _ = plt.yticks(fontsize=22)\n",
    "# # plt.ticklabel_format(axis=\"x\", style=\"sci\", scilimits=(0,0))\n",
    "# # plt.ticklabel_format(axis=\"y\", style=\"sci\", scilimits=(0,0))\n",
    "# plt.savefig(os.path.join(root, f'{group}_conf_feat_count.pdf'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "77c939f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "for i, n in enumerate(n_data_w_feature):\n",
    "    if n == 0:\n",
    "        n_model_w_feature[i] = 0\n",
    "\n",
    "plt.figure(figsize=(8, 6))\n",
    "plt.scatter(n_data_w_feature, n_model_w_feature, alpha=0.9, s=40)\n",
    "plt.xlabel('n_data_w_feature')\n",
    "plt.ylabel('n_model_w_feature')\n",
    "\n",
    "# plt.legend(fontsize=15)\n",
    "# plt.title(f'n data vs. n models w/ feat ({group}k)', fontsize=26)\n",
    "plt.title(f'n data vs. n models w/ feat', fontsize=26)\n",
    "plt.xlabel('# of data with a feature (1e3)', fontsize=24)\n",
    "plt.ylabel('# of models with a feature', fontsize=24)\n",
    "_ = plt.xticks(fontsize=22)\n",
    "_ = plt.yticks(fontsize=22)\n",
    "plt.ticklabel_format(axis=\"x\", style=\"sci\", scilimits=(0,0))\n",
    "# plt.savefig(os.path.join(root, f'{group}_n_data_n_model_feat.pdf'))\n",
    "plt.savefig(os.path.join(root, 'svhn_n_data_n_model_feat.pdf'), bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7a36fa4",
   "metadata": {},
   "outputs": [],
   "source": [
    "! pwd"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b1b0c4da",
   "metadata": {},
   "source": [
    "## How models make mistakes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dba2bec6",
   "metadata": {},
   "outputs": [],
   "source": [
    "affinity_dir = os.path.join(exp_dir, 'neuron_correlation_pca')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "201454a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_predictions(model, data_loader, onehot=True):\n",
    "    all_predictions = []\n",
    "    for x, _ in data_loader:\n",
    "        x = x.cuda()\n",
    "        prediction = model(x)\n",
    "        prediction = prediction.detach().cpu().numpy()\n",
    "        if onehot:\n",
    "            new_prediction = np.zeros(prediction.shape)\n",
    "            prediction = np.argmax(prediction, axis=-1)\n",
    "            new_prediction[np.arange(len(prediction)), prediction] = 1\n",
    "            prediction = new_prediction\n",
    "        all_predictions.append(prediction)\n",
    "    return np.concatenate(all_predictions, axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dcce0819",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_model = 30\n",
    "tic = time.time()\n",
    "all_prediction = []\n",
    "for i in range(n_model):\n",
    "    if i > 0:\n",
    "        print('='*60)\n",
    "        print(f'{i} {(time.time()-tic)/i}  s / model')\n",
    "    model_path = os.path.join(exp_dir, f'model_{i}')\n",
    "#     weight_path = os.path.join(model_path, 'weight.pt')\n",
    "    weight_path = os.path.join(model_path, 'ep_29_weight.pt')\n",
    "    model = get_model(config).cuda()\n",
    "    state_dict = torch.load(weight_path, map_location=torch.device('cpu'))\n",
    "    model.load_state_dict(state_dict)\n",
    "    predictions = get_predictions(model, test_loader)\n",
    "    all_prediction.append(predictions)\n",
    "    np.save(os.path.join(exp_dir, 'neuron_correlation_pca', f'prediction_{i}'), np.uint8(predictions))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8870d7c",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_prediction = np.array(all_prediction)\n",
    "all_prediction_argmax = np.argmax(all_prediction, axis=-1)\n",
    "all_prediction.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23d36aca",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_label = []\n",
    "for _, y in test_loader:\n",
    "    all_label.append(y.detach().cpu().numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "464d8ee7",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_label = np.concatenate(all_label)\n",
    "all_label.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12249fdb",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_error = []\n",
    "for i in range(n_model):\n",
    "    error = np.float32(all_label == all_prediction_argmax[i])\n",
    "    all_error.append(error)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "228814c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_error_dict = []\n",
    "for i, err in enumerate(all_error):\n",
    "    err_dict = {}\n",
    "    for j, correct in enumerate(err):\n",
    "        if correct == 0:\n",
    "            err_dict[j] = all_prediction_argmax[i][j]\n",
    "    all_error_dict.append(err_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "91b8d2c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "def same_error_count(model_1_error_dict, model_2_error_dict):\n",
    "    count = 0\n",
    "    total = 0\n",
    "    for k in model_1_error_dict:\n",
    "        if k in model_2_error_dict and model_1_error_dict[k] == model_2_error_dict[k]:\n",
    "            count += 1\n",
    "    total = (len(model_1_error_dict) + len(model_2_error_dict)) / 2\n",
    "    return count / total"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa9e9359",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_pair_error_correlation = [[-1 for _ in range(n_model)] for _ in range(n_model)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "31f35149",
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(n_model):\n",
    "    for j in range(n_model):\n",
    "        all_pair_error_correlation[i][j] = same_error_count(all_error_dict[i], all_error_dict[j])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c6dad45a",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_pair_error_correlation = np.array(all_pair_error_correlation)\n",
    "all_pair_error_correlation.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "972537f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.imshow(all_pair_error_correlation)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7ce562af",
   "metadata": {},
   "source": [
    "### similarity measure"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f4bb4f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "manhattan_cost = np.zeros((50, 50))\n",
    "for i in range(50):\n",
    "    for j in range(50):\n",
    "        manhattan_cost[i, j] = abs(i-j)\n",
    "plt.imshow(manhattan_cost)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f5877c78",
   "metadata": {},
   "outputs": [],
   "source": [
    "corr_mat = np.load(os.path.join(affinity_dir, 'corr_18_10.npy'))\n",
    "corr_mat = np.abs(corr_mat)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ef0007b",
   "metadata": {},
   "outputs": [],
   "source": [
    "(corr_mat/corr_mat.sum() * manhattan_cost).sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dcc9a556",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_corr_mat = [[None for _ in range(20)] for _ in range(20)]\n",
    "all_cost = [[None for _ in range(20)] for _ in range(20)]\n",
    "for i in range(20):\n",
    "    for j in range(20):\n",
    "        corr_mat = np.load(os.path.join(affinity_dir, f'corr_{i}_{j}.npy'))\n",
    "        corr_mat = np.abs(corr_mat)\n",
    "#         all_cost[i][j] = (corr_mat/corr_mat.sum() * manhattan_cost).sum()\n",
    "        all_cost[i][j] = np.trace(corr_mat)\n",
    "#         all_cost[i][j] = np.sum(corr_mat)\n",
    "        all_corr_mat[i][j] = corr_mat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5ccda196",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_cost = np.array(all_cost)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0de254ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.imshow(all_cost)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6a57512a",
   "metadata": {},
   "outputs": [],
   "source": [
    "_ = plt.hist(all_cost.reshape(-1), bins=50)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "66cc6aae",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(all_cost.max())\n",
    "print(all_cost.min())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d490d4f9",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_sum = [[None for _ in range(20)] for _ in range(20)]\n",
    "for i in range(20):\n",
    "    for j in range(20):\n",
    "        corr_mat = np.load(os.path.join(affinity_dir, f'corr_{i}_{j}.npy'))\n",
    "        corr_mat = np.abs(corr_mat)\n",
    "        all_sum[i][j] = corr_mat.sum()\n",
    "all_sum = np.array(all_sum)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f7e08f38",
   "metadata": {},
   "outputs": [],
   "source": [
    "_ = plt.hist(all_sum.reshape(-1), bins=50)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "16172ad9",
   "metadata": {},
   "outputs": [],
   "source": [
    "similarity = []\n",
    "error_corr = []\n",
    "for i in range(20):\n",
    "    for j in range(20):\n",
    "        if i == j: continue\n",
    "        error_corr.append(all_pair_error_correlation[i][j])\n",
    "        similarity.append(all_cost[i][j])\n",
    "similarity = np.array(similarity)\n",
    "error_corr = np.array(error_corr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "95516b21",
   "metadata": {},
   "outputs": [],
   "source": [
    "reg = LinearRegression().fit(similarity[:, None], error_corr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6cd3c6fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "a = int(np.min(similarity))\n",
    "b = int(np.max(similarity))+2\n",
    "plt.plot(np.arange(a, b), reg.predict(np.arange(a, b)[:, None]), linewidth=2.5, c='C0', linestyle='dashed')\n",
    "plt.scatter(similarity, error_corr, s=15, alpha=0.3)\n",
    "print(a,b)\n",
    "\n",
    "# plt.title('Feature frequency', fontsize=17)\n",
    "plt.xlabel('feature similarity', fontsize=16)\n",
    "plt.ylabel('shared error', fontsize=16)\n",
    "_ = plt.xticks(fontsize=15)\n",
    "_ = plt.yticks(fontsize=15)\n",
    "# plt.ticklabel_format(axis=\"x\", style=\"sci\", scilimits=(0,0))\n",
    "# plt.ticklabel_format(axis=\"y\", style=\"sci\", scilimits=(0,0))\n",
    "reg.score(similarity[:, None], error_corr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76f0de22",
   "metadata": {},
   "outputs": [],
   "source": [
    "neuron_feature_pairing.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de815679",
   "metadata": {},
   "outputs": [],
   "source": [
    "overlap = []\n",
    "error_corr = []\n",
    "for i in range(20):\n",
    "    for j in range(20):\n",
    "        if i == j: continue\n",
    "        l = len(set(neuron_feature_pairing[i]).intersection(neuron_feature_pairing[j]))\n",
    "        error_corr.append(all_pair_error_correlation[i][j])\n",
    "        overlap.append(l)\n",
    "overlap = np.array(overlap)\n",
    "error_corr = np.array(error_corr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d3f90fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# sns.set_style(\"darkgrid\")\n",
    "# root = './plots'\n",
    "# sns.set_style(\"whitegrid\")\n",
    "# sns.set_palette(\"tab10\")\n",
    "\n",
    "# plt.figure(figsize=(8, 6))\n",
    "# sns.kdeplot(gt_confidence, all_image_feature_count, fill=True, cbar=False, color='aliceblue',)\n",
    "# # sns.histplot(\n",
    "# #     None, x=gt_confidence, y=all_image_feature_count,\n",
    "# #     bins=100, discrete=(False, True), log_scale=(False, False),\n",
    "# #     pthresh=.01, pmax=.7,\n",
    "# # )\n",
    "\n",
    "# plt.title(f'conf vs. feature count ({group}k)', fontsize=26)\n",
    "# plt.xlabel('confidence', fontsize=24)\n",
    "# plt.ylabel('number of features', fontsize=24)\n",
    "# _ = plt.xticks(fontsize=22)\n",
    "# _ = plt.yticks(fontsize=22)\n",
    "# # plt.ticklabel_format(axis=\"x\", style=\"sci\", scilimits=(0,0))\n",
    "# # plt.ticklabel_format(axis=\"y\", style=\"sci\", scilimits=(0,0))\n",
    "# plt.savefig(os.path.join(root, f'{group}_conf_feat_count.pdf'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d1bf5304",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "reg = LinearRegression().fit(overlap[:, None], error_corr)\n",
    "a = int(np.min(overlap))\n",
    "b = int(np.max(overlap))+2\n",
    "plt.figure(figsize=(8, 6))\n",
    "# plt.plot(np.arange(a, b), reg.predict(np.arange(a, b)[:, None]), linewidth=2.5, c='C0', linestyle='dashed')\n",
    "# sns.kdeplot(overlap, fill=True, alpha=.35, linewidth=0.0, color='gold')\n",
    "plt.scatter(overlap, error_corr, s=40, alpha=0.6)\n",
    "# ax = sns.boxplot(x=overlap, y=error_corr, data=None)\n",
    "print(a,b)\n",
    "\n",
    "# plt.title(f'shared feature vs. shared error ({group}k)', fontsize=26)\n",
    "plt.title(f'shared feature vs. shared error', fontsize=26)\n",
    "plt.xlabel('shared features', fontsize=24)\n",
    "plt.ylabel('shared error', fontsize=24)\n",
    "_ = plt.xticks(fontsize=22)\n",
    "_ = plt.yticks(fontsize=22)\n",
    "# plt.ticklabel_format(axis=\"x\", style=\"sci\", scilimits=(0,0))\n",
    "# plt.ticklabel_format(axis=\"y\", style=\"sci\", scilimits=(0,0))\n",
    "reg.score(overlap[:, None], error_corr)\n",
    "# plt.savefig(os.path.join(root, f'{group}_shared_feat_shared_err.pdf'))\n",
    "plt.savefig(os.path.join(root, f'svhn_shared_feat_shared_err.pdf'), bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f97aec13",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.diag(all_corr_mat[0][0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a5d0b70",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.diag(all_corr_mat[0][1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "31ad00a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "! pwd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cde8f90e",
   "metadata": {},
   "outputs": [],
   "source": [
    "base_dir = None # path to features of different models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e4c38528",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_files = sorted(os.listdir(base_dir))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "05ea043c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# all_files"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db83462d",
   "metadata": {},
   "outputs": [],
   "source": [
    "labels = np.load(os.path.join(base_dir, 'Y.npy'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b6afb73",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_representations = []\n",
    "all_logits = []\n",
    "all_name = []\n",
    "for f in all_files:\n",
    "    if f == 'Y.npy':\n",
    "        continue\n",
    "    if 'F' in f:\n",
    "        index = f.split('.')[0][1:]\n",
    "        feature_path = os.path.join(base_dir, f)\n",
    "        logits_path = os.path.join(base_dir, f'L{index}.npy')\n",
    "        feat = np.load(feature_path)\n",
    "        if np.isnan(feat.sum()) or np.abs(feat).mean() < 1e-6:\n",
    "            print(f'skipping {f}')\n",
    "            continue\n",
    "        all_representations.append(np.load(feature_path))\n",
    "        all_logits.append(np.load(logits_path))\n",
    "        all_name.append(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ff83df2",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_pca_repr = []\n",
    "for name, logits, r in zip(all_name, all_logits, all_representations):\n",
    "    print(f'======={name}=======')\n",
    "    all_pca_repr.append(project_repr(r))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bf3f0893",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1a5c87a9",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
