{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "0ea56929",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import os\n",
    "import pickle\n",
    "import math\n",
    "import random\n",
    "from sklearn.covariance import ledoit_wolf\n",
    "from collections import OrderedDict\n",
    "\n",
    "from EDGE_4_3_1 import EDGE\n",
    "from npeet.entropy_estimators import midd\n",
    "from pomegranate import *\n",
    "from tqdm import *\n",
    "\n",
    "import matplotlib\n",
    "matplotlib.use('TkAgg')\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.gridspec as gridspec\n",
    "import seaborn as sns\n",
    "sns.set_style('darkgrid')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "8ef5a4f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "ip_data = \"IP/resnet18_gaussianDrp_0.4_128\"\n",
    "repr_data = \"representations/resnet18_gaussianDrp_0.4_128\"\n",
    "\n",
    "p = 0.4\n",
    "drp_noise = p/(1-p)\n",
    "\n",
    "GMM_MEANS_NUM = 1000"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4c33fe42",
   "metadata": {},
   "source": [
    "### Draw for information dropout"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc1ce2da",
   "metadata": {},
   "outputs": [],
   "source": [
    "def drawIP(mi_xz, mi_zy, title):\n",
    "    gs = gridspec.GridSpec(4,2)\n",
    "\n",
    "    COLORBAR_MAX_EPOCHS=200\n",
    "    sm = plt.cm.ScalarMappable(cmap='gnuplot', norm=plt.Normalize(vmin=0, vmax=COLORBAR_MAX_EPOCHS))\n",
    "\n",
    "    n_epoch = len(list(mi_xz.keys()))\n",
    "    PLOT_LAYERS = [0]\n",
    "    for epoch in range(n_epoch):\n",
    "        c = sm.to_rgba(epoch)\n",
    "        # we saved optimized value (with information dropout), need to add -0.5*log(2*pi*e)-log(c),\n",
    "        # where c is defining the log-uniform distribution of the ReLU prior\n",
    "        # for Softplus no additional values are needed\n",
    "        xmvals = mi_xz[epoch] #- 0.5*np.log(2*math.pi*math.e)\n",
    "        # we saved crossentropy value - lower bound on MI is -crossentropy + H(Y)\n",
    "        ymvals = -mi_zy[epoch] + np.log(10)\n",
    "        #plt.plot(xmvals, ymvals, c=c, alpha=0.5, zorder=1)\n",
    "        plt.scatter(xmvals, ymvals, s=20, facecolors=[c for _ in PLOT_LAYERS], edgecolor='none', zorder=2) \n",
    "    #plt.ylim([1, 3.5])\n",
    "    #plt.xlim([4, 14])\n",
    "    plt.xlabel('I(X;Z)')\n",
    "    plt.ylabel('I(Y;Z)')\n",
    "    plt.title(title)\n",
    "    plt.colorbar(sm, label='Epoch')\n",
    "    plt.tight_layout()\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "5adbdd08",
   "metadata": {},
   "outputs": [],
   "source": [
    "val_mi_xz = pickle.load(open(os.path.join(ip_data, \"val_mi_xz\"), \"rb\"))\n",
    "val_mi_zy = pickle.load(open(os.path.join(ip_data, \"val_mi_zy\"), \"rb\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "501a56ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "drawIP(val_mi_xz, val_mi_zy, 'Information dropout, fullCNN (validation)')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "d7dbc5bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_mi_xz = pickle.load(open(os.path.join(ip_data, \"train_mi_xz\"), \"rb\"))\n",
    "train_mi_zy = pickle.load(open(os.path.join(ip_data, \"train_mi_zy\"), \"rb\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7cdd8597",
   "metadata": {},
   "outputs": [],
   "source": [
    "drawIP(train_mi_xz, train_mi_zy, 'Information dropout, fullCNN (training)')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a362203f",
   "metadata": {},
   "source": [
    "### Draw the estimations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "ef9aceba",
   "metadata": {},
   "outputs": [],
   "source": [
    "def drawIP(mi_xz, mi_zy, title, crossentropy_zy=True):\n",
    "    gs = gridspec.GridSpec(4,2)\n",
    "\n",
    "    COLORBAR_MAX_EPOCHS=list(mi_xz.keys())[-1]\n",
    "    sm = plt.cm.ScalarMappable(cmap='gnuplot', norm=plt.Normalize(vmin=0, vmax=COLORBAR_MAX_EPOCHS))\n",
    "\n",
    "    PLOT_LAYERS = [0]\n",
    "    for epoch in list(mi_xz.keys()):\n",
    "        c = sm.to_rgba(epoch)\n",
    "        xmvals = mi_xz[epoch]\n",
    "        if crossentropy_zy:\n",
    "            # we saved crossentropy value - lower bound on MI is -crossentropy + H(Y)\n",
    "            ymvals = -mi_zy[epoch] + np.log(10)\n",
    "        else:\n",
    "            ymvals = mi_zy[epoch]\n",
    "        #plt.plot(xmvals, ymvals, c=c, alpha=0.5, zorder=1)\n",
    "        plt.scatter(xmvals, ymvals, s=20, facecolors=[c for _ in PLOT_LAYERS], edgecolor='none', zorder=2) \n",
    "    #plt.ylim([1, 3.5])\n",
    "    #plt.xlim([4, 14])\n",
    "    ax = plt.gca()\n",
    "    ax.get_yaxis().get_major_formatter().set_useOffset(False)\n",
    "    plt.xlabel('I(X;Z)')\n",
    "    plt.ylabel('I(Y;Z)')\n",
    "    plt.title(title)\n",
    "    plt.colorbar(sm, label='Epoch')\n",
    "    plt.tight_layout()\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "77cd7203",
   "metadata": {},
   "outputs": [],
   "source": [
    "def repr_entropy(nonoise_reprs, reprs, ratio_points=GMM_MEANS_NUM):\n",
    "    dists = []\n",
    "    used_points = nonoise_reprs[:ratio_points]\n",
    "    for i in tqdm(range(len(used_points))):\n",
    "        dists.append(MultivariateGaussianDistribution(used_points[i], drp_noise*np.diag(abs(used_points[i])+1e-6)))\n",
    "    gmm_Z = GeneralMixtureModel(dists, weights=np.full(len(dists), (1.0 / len(dists))))\n",
    "    log_probs = []\n",
    "    for i in tqdm(range(int(len(reprs)/10000))):\n",
    "        log_probs += gmm_Z.log_probability(reprs[i*10000:(i+1)*10000], n_jobs=10).tolist()\n",
    "    return (-1.0/len(log_probs))*np.array(log_probs).sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "05d7ea1a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def gaussian_noise_mi(reprs, nonoise_reprs, noise_variance):\n",
    "    f_dim = reprs.shape[1]\n",
    "    h_z = repr_entropy(nonoise_reprs, reprs)\n",
    "    # our innovative way to compute conditional entropy\n",
    "    h_zGivx = 0\n",
    "    h_z_part = math.sqrt(2*math.pi*math.e)*noise_variance\n",
    "    for i in range(f_dim):\n",
    "        # we have simple multiplication of 1-dim gaussian (noise) by a constant (current value of activations)\n",
    "        h_zGivx += (1.0/len(nonoise_reprs)) * np.sum(np.log(h_z_part * np.fabs(nonoise_reprs[:,i] + 1e-6)))\n",
    "    return h_z - h_zGivx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "b3f3ce25",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "no_noise_test_representations_0.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:03<00:00, 313.70it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [03:15<00:00, 39.09s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "145.78097473680145 0.5349238187767641\n",
      "no_noise_test_representations_10.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:02<00:00, 414.19it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [03:14<00:00, 38.85s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "150.49760122390464 1.034895448319866\n",
      "no_noise_test_representations_100.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:02<00:00, 414.92it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [03:29<00:00, 41.89s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "122.78102312958187 1.2000258775348513\n",
      "no_noise_test_representations_110.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:04<00:00, 228.11it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [03:29<00:00, 41.88s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "122.20478322098502 1.2273809184959514\n",
      "no_noise_test_representations_120.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:02<00:00, 338.56it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [03:28<00:00, 41.69s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "115.88270685702405 1.320031601958499\n",
      "no_noise_test_representations_130.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:02<00:00, 365.30it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [03:15<00:00, 39.10s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "115.44234737416024 1.4086633684200713\n",
      "no_noise_test_representations_140.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:04<00:00, 248.70it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [03:17<00:00, 39.56s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "117.6808072613065 1.4312062266523915\n",
      "no_noise_test_representations_150.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:02<00:00, 388.03it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [03:18<00:00, 39.65s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "115.5671212202713 1.548509968452891\n",
      "no_noise_test_representations_160.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:02<00:00, 397.37it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [03:22<00:00, 40.42s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "109.04248186205102 1.6219957139399255\n",
      "no_noise_test_representations_170.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:03<00:00, 315.24it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [03:15<00:00, 39.07s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "105.62074564672717 1.7679464647929057\n",
      "no_noise_test_representations_180.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:18<00:00, 53.02it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [03:36<00:00, 43.37s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "103.33960122863873 1.8300629964069943\n",
      "no_noise_test_representations_190.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:30<00:00, 33.10it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [03:48<00:00, 45.61s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "102.97705990849948 1.7994550458499048\n",
      "no_noise_test_representations_20.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:23<00:00, 42.02it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [03:41<00:00, 44.30s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "141.74487713747942 0.9477216529042782\n",
      "no_noise_test_representations_30.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:21<00:00, 47.05it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [03:29<00:00, 41.83s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "132.6842352985086 0.9388893294678111\n",
      "no_noise_test_representations_40.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:24<00:00, 41.55it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [03:38<00:00, 43.64s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "130.39855134927578 1.0999507450589512\n",
      "no_noise_test_representations_50.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:21<00:00, 45.54it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [03:28<00:00, 41.65s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "128.51534357833913 1.0430509807647688\n",
      "no_noise_test_representations_60.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:23<00:00, 42.87it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [05:10<00:00, 62.03s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "128.43510411639636 0.8849044501927816\n",
      "no_noise_test_representations_70.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:04<00:00, 223.47it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [04:51<00:00, 58.38s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "133.0610803923917 1.0848827936822376\n",
      "no_noise_test_representations_80.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:05<00:00, 180.77it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [04:55<00:00, 59.05s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "125.94017737009148 1.0122584040896112\n",
      "no_noise_test_representations_90.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:03<00:00, 278.07it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [04:48<00:00, 57.74s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "131.04493166685873 1.1012039710996526\n",
      "no_noise_train_representations_0.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:03<00:00, 254.17it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [04:44<00:00, 56.90s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "98.24783263699624 0.5047185548102955\n",
      "no_noise_train_representations_10.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:03<00:00, 256.65it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [05:23<00:00, 64.70s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "87.63718226927847 1.0102431081391083\n",
      "no_noise_train_representations_100.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:25<00:00, 38.81it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [05:38<00:00, 67.78s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "67.56664334970095 1.3688819450470187\n",
      "no_noise_train_representations_110.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:28<00:00, 35.67it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [05:41<00:00, 68.31s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "67.12136088360703 1.431032091486069\n",
      "no_noise_train_representations_120.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:23<00:00, 42.67it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [11:07:33<00:00, 8010.78s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "63.40405359074117 1.504722834678933\n",
      "no_noise_train_representations_130.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:19<00:00, 50.19it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [03:22<00:00, 40.54s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "62.62778490709811 1.5772015460670799\n",
      "no_noise_train_representations_140.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:40<00:00, 24.46it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [03:35<00:00, 43.05s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "58.34350602710032 1.6288770441639908\n",
      "no_noise_train_representations_150.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:19<00:00, 52.21it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [03:22<00:00, 40.42s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "54.04893780320366 1.761413352733196\n",
      "no_noise_train_representations_160.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:19<00:00, 51.29it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [03:18<00:00, 39.69s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "48.11701669436282 1.9023068168611095\n",
      "no_noise_train_representations_170.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:09<00:00, 108.17it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [03:05<00:00, 37.19s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "42.043736651175735 1.9711575133081054\n",
      "no_noise_train_representations_180.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:02<00:00, 438.75it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [03:05<00:00, 37.08s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "38.077366070027914 2.0344430282928805\n",
      "no_noise_train_representations_190.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:02<00:00, 416.73it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [03:05<00:00, 37.13s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "37.39177260567688 2.044754266297073\n",
      "no_noise_train_representations_20.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:02<00:00, 435.37it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [03:07<00:00, 37.49s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "84.2160140759103 0.9034409790761091\n",
      "no_noise_train_representations_30.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:02<00:00, 441.60it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [03:06<00:00, 37.21s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "82.47849346139193 0.9640060298051374\n",
      "no_noise_train_representations_40.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:02<00:00, 346.36it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [03:06<00:00, 37.22s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "78.72475015158213 1.0750756928314484\n",
      "no_noise_train_representations_50.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:02<00:00, 437.90it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [04:57<00:00, 59.48s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "77.51924125602947 1.1739109501927332\n",
      "no_noise_train_representations_60.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:05<00:00, 167.03it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [05:31<00:00, 66.34s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "77.31438615663474 1.1303660642905267\n",
      "no_noise_train_representations_70.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:05<00:00, 194.20it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [05:32<00:00, 66.44s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "73.64878980216301 0.9864260626029985\n",
      "no_noise_train_representations_80.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:05<00:00, 171.51it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [05:28<00:00, 65.79s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "74.23106935503856 1.0653742458740594\n",
      "no_noise_train_representations_90.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:04<00:00, 217.56it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [05:25<00:00, 65.13s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "72.78494622940937 1.1681694033777041\n"
     ]
    }
   ],
   "source": [
    "if os.path.exists(os.path.join(repr_data, \"test_comp_mi_xz\")):\n",
    "    test_comp_mi_xz = pickle.load(open(os.path.join(repr_data, \"test_comp_mi_xz\"), \"rb\"))\n",
    "    test_comp_mi_zy = pickle.load(open(os.path.join(repr_data, \"test_comp_mi_zy\"), \"rb\"))\n",
    "    train_comp_mi_xz = pickle.load(open(os.path.join(repr_data, \"train_comp_mi_xz\"), \"rb\"))\n",
    "    train_comp_mi_zy = pickle.load(open(os.path.join(repr_data, \"train_comp_mi_zy\"), \"rb\"))\n",
    "else:    \n",
    "    test_comp_mi_xz = {}\n",
    "    test_comp_mi_zy = {}\n",
    "    train_comp_mi_xz = {}\n",
    "    train_comp_mi_zy = {}\n",
    "\n",
    "    test_labels = np.load(os.path.join(repr_data, \"test_labels.npy\"))\n",
    "    train_labels = np.load(os.path.join(repr_data, \"train_labels.npy\"))\n",
    "    \n",
    "    train_repeat = 1\n",
    "    test_repeat = 5\n",
    "\n",
    "    for f in os.listdir(repr_data):\n",
    "        if \"test_representations\" in f:\n",
    "            print(f)\n",
    "            epoch = int(f.split(\".\")[0].split(\"_\")[-1])\n",
    "            nonoise_reprs = np.load(os.path.join(repr_data, f), allow_pickle=True)\n",
    "            nonoise_reprs, ind = np.unique(nonoise_reprs, axis=0, return_index=True)\n",
    "            reprs = []\n",
    "            for nr in nonoise_reprs:\n",
    "                for i in range(test_repeat):\n",
    "                    epsilon = np.random.randn(nonoise_reprs.shape[1]) * drp_noise + 1\n",
    "                    reprs.append(nr*epsilon)\n",
    "            reprs = np.array(reprs)\n",
    "            test_comp_mi_xz[epoch] = gaussian_noise_mi(reprs, nonoise_reprs, drp_noise)\n",
    "            test_comp_mi_zy[epoch] = EDGE(reprs, np.repeat(np.array(test_labels[ind]), test_repeat))\n",
    "            print(test_comp_mi_xz[epoch], test_comp_mi_zy[epoch])\n",
    "\n",
    "        if \"train_representations\" in f:\n",
    "            print(f)\n",
    "            epoch = int(f.split(\".\")[0].split(\"_\")[-1])\n",
    "            nonoise_reprs = np.load(os.path.join(repr_data, f), allow_pickle=True)\n",
    "            reprs = []\n",
    "            for nr in nonoise_reprs:\n",
    "                for i in range(train_repeat):\n",
    "                    epsilon = np.random.randn(nonoise_reprs.shape[1]) * drp_noise + 1\n",
    "                    reprs.append(nr*epsilon)\n",
    "            reprs = np.array(reprs)\n",
    "            train_comp_mi_xz[epoch] = gaussian_noise_mi(reprs, nonoise_reprs, drp_noise)\n",
    "            train_comp_mi_zy[epoch] = EDGE(reprs, np.repeat(train_labels, train_repeat))\n",
    "            print(train_comp_mi_xz[epoch], train_comp_mi_zy[epoch])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "444ca7d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "if not os.path.exists(os.path.join(repr_data, \"test_comp_mi_xz\")):\n",
    "    pickle.dump(test_comp_mi_xz, open(os.path.join(repr_data, \"test_comp_mi_xz\"), \"wb\"))\n",
    "    pickle.dump(test_comp_mi_zy, open(os.path.join(repr_data, \"test_comp_mi_zy\"), \"wb\"))\n",
    "    pickle.dump(train_comp_mi_xz, open(os.path.join(repr_data, \"train_comp_mi_xz\"), \"wb\"))\n",
    "    pickle.dump(train_comp_mi_zy, open(os.path.join(repr_data, \"train_comp_mi_zy\"), \"wb\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "c5313140",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\adylo\\AppData\\Local\\Temp\\ipykernel_9868\\1700152655.py:25: MatplotlibDeprecationWarning: Auto-removal of grids by pcolor() and pcolormesh() is deprecated since 3.5 and will be removed two minor releases later; please call grid(False) first.\n",
      "  plt.colorbar(sm, label='Epoch')\n"
     ]
    }
   ],
   "source": [
    "od_test_comp_mi_xz = OrderedDict(sorted(test_comp_mi_xz.items()))\n",
    "od_test_comp_mi_zy = OrderedDict(sorted(test_comp_mi_zy.items()))\n",
    "drawIP(od_test_comp_mi_xz, od_test_comp_mi_zy, 'Gaussian dropout, ResNet18 (validation)', crossentropy_zy=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "8fb88dd9",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\adylo\\AppData\\Local\\Temp\\ipykernel_9868\\1700152655.py:25: MatplotlibDeprecationWarning: Auto-removal of grids by pcolor() and pcolormesh() is deprecated since 3.5 and will be removed two minor releases later; please call grid(False) first.\n",
      "  plt.colorbar(sm, label='Epoch')\n"
     ]
    }
   ],
   "source": [
    "drawIP(od_test_comp_mi_xz, val_mi_zy, 'Gaussian dropout, ResNet18 (validation)', crossentropy_zy=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "04c92799",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\adylo\\AppData\\Local\\Temp\\ipykernel_9868\\1700152655.py:25: MatplotlibDeprecationWarning: Auto-removal of grids by pcolor() and pcolormesh() is deprecated since 3.5 and will be removed two minor releases later; please call grid(False) first.\n",
      "  plt.colorbar(sm, label='Epoch')\n"
     ]
    }
   ],
   "source": [
    "od_train_comp_mi_xz = OrderedDict(sorted(train_comp_mi_xz.items()))\n",
    "od_train_comp_mi_zy = OrderedDict(sorted(train_comp_mi_zy.items()))\n",
    "drawIP(od_train_comp_mi_xz, od_train_comp_mi_zy, 'Gaussian dropout, ResNet18 (training)', crossentropy_zy=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "6a559016",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\adylo\\AppData\\Local\\Temp\\ipykernel_9868\\1700152655.py:25: MatplotlibDeprecationWarning: Auto-removal of grids by pcolor() and pcolormesh() is deprecated since 3.5 and will be removed two minor releases later; please call grid(False) first.\n",
      "  plt.colorbar(sm, label='Epoch')\n"
     ]
    }
   ],
   "source": [
    "drawIP(od_train_comp_mi_xz, train_mi_zy, 'Gaussian dropout, ResNet18 (training)', crossentropy_zy=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eef4fa06",
   "metadata": {},
   "source": [
    "### Binning IP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "cf226e69",
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_bins(min_bound, max_bound, num_of_bins=None, bin_size=None):\n",
    "    if bin_size is not None:\n",
    "        bins = np.arange(min_bound, max_bound, bin_size, dtype='float32')\n",
    "    elif num_of_bins is not None:\n",
    "        bins = np.linspace(min_bound, max_bound, num_of_bins, dtype='float32')\n",
    "    else:\n",
    "        print(\"Computation error; set either bin size or number of bins to a value\")\n",
    "        return None\n",
    "    return bins"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "2d312a50",
   "metadata": {},
   "outputs": [],
   "source": [
    "def double_bin_calc_information(inputdata, layerdata, num_of_bins=None, bin_size=None):\n",
    "    bins_inp = create_bins(inputdata.min(), inputdata.max(), num_of_bins=num_of_bins, bin_size=bin_size)\n",
    "    digitized_inp = bins_inp[np.digitize(np.squeeze(inputdata.reshape(1, -1)), bins_inp) - 1].reshape(len(inputdata), -1)\n",
    "\n",
    "    bins_rep = create_bins(layerdata.min(), layerdata.max(), num_of_bins=num_of_bins, bin_size=bin_size)\n",
    "    digitized_rep = bins_rep[np.digitize(np.squeeze(layerdata.reshape(1, -1)), bins_rep) - 1].reshape(len(layerdata), -1)\n",
    "\n",
    "    # measure information in nats\n",
    "    return midd(digitized_inp, digitized_rep, base=np.exp(1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "6e6bd1dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# do I need noise for binning IP?\n",
    "if os.path.exists(os.path.join(repr_data, \"test_bin_mi_xz\")):\n",
    "    test_bin_mi_xz = pickle.load(open(os.path.join(repr_data, \"test_bin_mi_xz\"), \"rb\"))\n",
    "    test_bin_mi_zy = pickle.load(open(os.path.join(repr_data, \"test_bin_mi_zy\"), \"rb\"))\n",
    "    train_bin_mi_xz = pickle.load(open(os.path.join(repr_data, \"train_bin_mi_xz\"), \"rb\"))\n",
    "    train_bin_mi_zy = pickle.load(open(os.path.join(repr_data, \"train_bin_mi_zy\"), \"rb\"))\n",
    "else:\n",
    "    test_bin_mi_xz = {}\n",
    "    test_bin_mi_zy = {}\n",
    "    train_bin_mi_xz = {}\n",
    "    train_bin_mi_zy = {}\n",
    "\n",
    "    test_inputs = np.load(os.path.join(repr_data, \"test_inputs.npy\"))\n",
    "    test_inputs = test_inputs.reshape(test_inputs.shape[0], -1)\n",
    "    test_labels = np.load(os.path.join(repr_data, \"test_labels.npy\"))\n",
    "    train_inputs = np.load(os.path.join(repr_data, \"train_inputs.npy\"))\n",
    "    train_inputs = train_inputs.reshape(train_inputs.shape[0], -1)\n",
    "    train_labels = np.load(os.path.join(repr_data, \"train_labels.npy\"))\n",
    "\n",
    "    train_repeat = 1\n",
    "    test_repeat = 5\n",
    "    \n",
    "    for f in os.listdir(repr_data):\n",
    "        if \"test_representations\" in f:\n",
    "            print(f)\n",
    "            epoch = int(f.split(\".\")[0].split(\"_\")[-1])            \n",
    "            nonoise_reprs = np.load(os.path.join(repr_data, f))\n",
    "            nonoise_reprs, ind = np.unique(nonoise_reprs, axis=0, return_index=True)\n",
    "            reprs = []\n",
    "            for nr in nonoise_reprs:\n",
    "                for i in range(test_repeat):\n",
    "                    epsilon = np.random.randn(nonoise_reprs.shape[1]) * drp_noise + 1\n",
    "                    reprs.append(nr*epsilon)\n",
    "            reprs = np.array(reprs)\n",
    "            test_bin_mi_xz[epoch] = double_bin_calc_information(np.repeat(np.array(test_inputs[ind]), test_repeat, axis=0), \n",
    "                                                                reprs, num_of_bins=10)\n",
    "            test_bin_mi_zy[epoch] = double_bin_calc_information(np.repeat(np.array(test_labels[ind]), test_repeat), reprs, \n",
    "                                                                num_of_bins=10)\n",
    "            print(test_bin_mi_xz[epoch], test_bin_mi_zy[epoch])\n",
    "\n",
    "        if \"train_representations\" in f:\n",
    "            print(f)\n",
    "            epoch = int(f.split(\".\")[0].split(\"_\")[-1])\n",
    "            nonoise_reprs = np.load(os.path.join(repr_data, f))\n",
    "            reprs = []\n",
    "            for nr in nonoise_reprs:\n",
    "                for i in range(train_repeat):\n",
    "                    epsilon = np.random.randn(nonoise_reprs.shape[1]) * drp_noise + 1\n",
    "                    reprs.append(nr*epsilon)\n",
    "            reprs = np.array(reprs)\n",
    "            train_bin_mi_xz[epoch] = double_bin_calc_information(np.repeat(np.array(train_inputs), train_repeat, axis=0), \n",
    "                                                                 reprs, num_of_bins=10)\n",
    "            train_bin_mi_zy[epoch] = double_bin_calc_information(np.repeat(np.array(train_labels), train_repeat), \n",
    "                                                                 reprs, num_of_bins=10)\n",
    "            print(train_bin_mi_xz[epoch], train_bin_mi_zy[epoch])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "10f7e5bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "if not os.path.exists(os.path.join(repr_data, \"test_bin_mi_xz\")):\n",
    "    pickle.dump(test_bin_mi_xz, open(os.path.join(repr_data, \"test_bin_mi_xz\"), \"wb\"))\n",
    "    pickle.dump(test_bin_mi_zy, open(os.path.join(repr_data, \"test_bin_mi_zy\"), \"wb\"))\n",
    "    pickle.dump(train_bin_mi_xz, open(os.path.join(repr_data, \"train_bin_mi_xz\"), \"wb\"))\n",
    "    pickle.dump(train_bin_mi_zy, open(os.path.join(repr_data, \"train_bin_mi_zy\"), \"wb\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "id": "ffbaa049",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\adylo\\AppData\\Local\\Temp\\ipykernel_17232\\1700152655.py:25: MatplotlibDeprecationWarning: Auto-removal of grids by pcolor() and pcolormesh() is deprecated since 3.5 and will be removed two minor releases later; please call grid(False) first.\n",
      "  plt.colorbar(sm, label='Epoch')\n"
     ]
    }
   ],
   "source": [
    "od_test_bin_mi_xz = OrderedDict(sorted(test_bin_mi_xz.items()))\n",
    "od_test_bin_mi_zy = OrderedDict(sorted(test_bin_mi_zy.items()))\n",
    "drawIP(od_test_bin_mi_xz, od_test_bin_mi_zy, 'Gaussian dropout + binning, ResNet18 (validation)', crossentropy_zy=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "id": "5bad92ab",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\adylo\\AppData\\Local\\Temp\\ipykernel_17232\\1700152655.py:25: MatplotlibDeprecationWarning: Auto-removal of grids by pcolor() and pcolormesh() is deprecated since 3.5 and will be removed two minor releases later; please call grid(False) first.\n",
      "  plt.colorbar(sm, label='Epoch')\n"
     ]
    }
   ],
   "source": [
    "od_train_bin_mi_xz = OrderedDict(sorted(train_bin_mi_xz.items()))\n",
    "od_train_bin_mi_zy = OrderedDict(sorted(train_bin_mi_zy.items()))\n",
    "drawIP(od_train_bin_mi_xz, od_train_bin_mi_zy, 'Gaussian dropout + binning, ResNet18 (training)', crossentropy_zy=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e886b8d8",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
