{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "0ea56929",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import os\n",
    "import pickle\n",
    "import math\n",
    "import random\n",
    "from collections import OrderedDict\n",
    "from sklearn.covariance import ledoit_wolf\n",
    "\n",
    "from EDGE_4_3_1 import EDGE\n",
    "from npeet.entropy_estimators import midd\n",
    "from pomegranate import *\n",
    "from tqdm import *\n",
    "\n",
    "import matplotlib\n",
    "matplotlib.use('TkAgg')\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.gridspec as gridspec\n",
    "import seaborn as sns\n",
    "sns.set_style('darkgrid')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "8ef5a4f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "ip_data = \"IP/mnist_gaus0.01_dropout\"\n",
    "repr_data = \"representations/mnist_gaus0.01_dropout\"\n",
    "\n",
    "netw = \"LeNet\"\n",
    "\n",
    "p = 0.01\n",
    "drp_noise = p/(1-p)\n",
    "\n",
    "GMM_MEANS_NUM = 1000"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "107698d9",
   "metadata": {},
   "source": [
    "### Draw for information dropout"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "8b09f09a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def drawIP(mi_xz, mi_zy, title):\n",
    "    gs = gridspec.GridSpec(4,2)\n",
    "\n",
    "    COLORBAR_MAX_EPOCHS=100\n",
    "    sm = plt.cm.ScalarMappable(cmap='gnuplot', norm=plt.Normalize(vmin=0, vmax=COLORBAR_MAX_EPOCHS))\n",
    "\n",
    "    n_epoch = len(list(mi_xz.keys()))\n",
    "    PLOT_LAYERS = [0]\n",
    "    for epoch in range(n_epoch):\n",
    "        c = sm.to_rgba(epoch)\n",
    "        # we saved optimized value (with information dropout), need to add -0.5*log(2*pi*e)-log(c),\n",
    "        # where c is defining the log-uniform distribution of the ReLU prior\n",
    "        xmvals = mi_xz[epoch] #- 0.5*np.log(2*math.pi*math.e)\n",
    "        # we saved crossentropy value - lower bound on MI is -crossentropy + H(Y)\n",
    "        ymvals = -mi_zy[epoch] + np.log(10)\n",
    "        #plt.plot(xmvals, ymvals, c=c, alpha=0.5, zorder=1)\n",
    "        plt.scatter(xmvals, ymvals, s=20, facecolors=[c for _ in PLOT_LAYERS], edgecolor='none', zorder=2) \n",
    "    #plt.ylim([1, 3.5])\n",
    "    #plt.xlim([4, 14])\n",
    "    plt.xlabel('I(X;Z)')\n",
    "    plt.ylabel('I(Y;Z)')\n",
    "    plt.title(title)\n",
    "    plt.colorbar(sm, label='Epoch')\n",
    "    plt.tight_layout()\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "5adbdd08",
   "metadata": {},
   "outputs": [],
   "source": [
    "val_mi_xz = pickle.load(open(os.path.join(ip_data, \"val_mi_xz\"), \"rb\"))\n",
    "val_mi_zy = pickle.load(open(os.path.join(ip_data, \"val_mi_zy\"), \"rb\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "501a56ee",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\adylo\\AppData\\Local\\Temp\\ipykernel_9356\\794901088.py:23: MatplotlibDeprecationWarning: Auto-removal of grids by pcolor() and pcolormesh() is deprecated since 3.5 and will be removed two minor releases later; please call grid(False) first.\n",
      "  plt.colorbar(sm, label='Epoch')\n"
     ]
    }
   ],
   "source": [
    "drawIP(val_mi_xz, val_mi_zy, 'Information dropout, '+netw+' (validation)')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "d7dbc5bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_mi_xz = pickle.load(open(os.path.join(ip_data, \"train_mi_xz\"), \"rb\"))\n",
    "train_mi_zy = pickle.load(open(os.path.join(ip_data, \"train_mi_zy\"), \"rb\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "7cdd8597",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\adylo\\AppData\\Local\\Temp\\ipykernel_9356\\794901088.py:23: MatplotlibDeprecationWarning: Auto-removal of grids by pcolor() and pcolormesh() is deprecated since 3.5 and will be removed two minor releases later; please call grid(False) first.\n",
      "  plt.colorbar(sm, label='Epoch')\n"
     ]
    }
   ],
   "source": [
    "drawIP(train_mi_xz, train_mi_zy, 'Information dropout, '+netw+' (training)')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e47848d6",
   "metadata": {},
   "source": [
    "### Draw the estimations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "0c19158a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def drawIP(mi_xz, mi_zy, title, crossentropy_zy=True):\n",
    "    gs = gridspec.GridSpec(4,2)\n",
    "\n",
    "    COLORBAR_MAX_EPOCHS=list(mi_xz.keys())[-1]\n",
    "    sm = plt.cm.ScalarMappable(cmap='gnuplot', norm=plt.Normalize(vmin=0, vmax=COLORBAR_MAX_EPOCHS))\n",
    "\n",
    "    PLOT_LAYERS = [0]\n",
    "    for epoch in list(mi_xz.keys()):\n",
    "        c = sm.to_rgba(epoch)\n",
    "        xmvals = mi_xz[epoch]\n",
    "        if crossentropy_zy:\n",
    "            # we saved crossentropy value - lower bound on MI is -crossentropy + H(Y)\n",
    "            ymvals = -mi_zy[epoch] + np.log(10)\n",
    "        else:\n",
    "            ymvals = mi_zy[epoch]\n",
    "        #plt.plot(xmvals, ymvals, c=c, alpha=0.5, zorder=1)\n",
    "        plt.scatter(xmvals, ymvals, s=20, facecolors=[c for _ in PLOT_LAYERS], edgecolor='none', zorder=2) \n",
    "    #plt.ylim([1, 3.5])\n",
    "    #plt.xlim([4, 14])\n",
    "    ax = plt.gca()\n",
    "    ax.get_yaxis().get_major_formatter().set_useOffset(False)\n",
    "    plt.xlabel('I(X;Z)')\n",
    "    plt.ylabel('I(Y;Z)')\n",
    "    plt.title(title)\n",
    "    plt.colorbar(sm, label='Epoch')\n",
    "    plt.tight_layout()\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "2dd69578",
   "metadata": {},
   "outputs": [],
   "source": [
    "def repr_entropy(nonoise_reprs, reprs, ratio_points=GMM_MEANS_NUM):\n",
    "    dists = []\n",
    "    used_points = nonoise_reprs[:ratio_points]\n",
    "    for i in tqdm(range(len(used_points))):\n",
    "        dists.append(MultivariateGaussianDistribution(used_points[i], drp_noise*np.diag(abs(used_points[i]))))\n",
    "    gmm_Z = GeneralMixtureModel(dists, weights=np.full(len(dists), (1.0 / len(dists))))\n",
    "    log_probs = []\n",
    "    for i in tqdm(range(int(len(reprs)/10000))):\n",
    "        log_probs += gmm_Z.log_probability(reprs[i*10000:(i+1)*10000], n_jobs=10).tolist()\n",
    "    return (-1.0/len(log_probs))*np.array(log_probs).sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "75bd211d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def gaussian_noise_mi(reprs, nonoise_reprs, noise_variance):\n",
    "    f_dim = reprs.shape[1]\n",
    "    h_z = repr_entropy(nonoise_reprs, reprs)\n",
    "    # our innovative way to compute conditional entropy\n",
    "    h_zGivx = 0\n",
    "    h_z_part = math.sqrt(2*math.pi*math.e)*noise_variance\n",
    "    for i in range(f_dim):\n",
    "        # we have simple multiplication of 1-dim gaussian (noise) by a constant (current value of activations)\n",
    "        h_zGivx += (1.0/len(nonoise_reprs)) * np.sum(np.log(h_z_part * np.fabs(nonoise_reprs[:,i] + 1e-6)))\n",
    "    return h_z - h_zGivx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "04b5d82a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "no_noise_test_representations_0.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2158.61it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:35<00:00,  5.85s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "6552.7256324171085 1.8349455338717808\n",
      "no_noise_test_representations_10.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2145.07it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:36<00:00,  6.15s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "9897.009905030416 2.1450754855149476\n",
      "no_noise_test_representations_12.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 1797.77it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:37<00:00,  6.23s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "9638.021771622296 2.16814411801173\n",
      "no_noise_test_representations_14.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2163.46it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:36<00:00,  6.14s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "12645.10562281068 2.169160628440161\n",
      "no_noise_test_representations_16.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2179.35it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:37<00:00,  6.24s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "8510.551695229862 2.1569782666964343\n",
      "no_noise_test_representations_18.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 1888.27it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:38<00:00,  6.42s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "9511.112347284174 2.148296844323239\n",
      "no_noise_test_representations_2.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2203.70it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:38<00:00,  6.46s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "9413.98508263498 2.0386409141403115\n",
      "no_noise_test_representations_20.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2114.06it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:37<00:00,  6.31s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "8530.485248268327 2.184753410001183\n",
      "no_noise_test_representations_22.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 1793.67it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:37<00:00,  6.21s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "9043.432107388218 2.15717987446149\n",
      "no_noise_test_representations_24.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 1746.82it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:39<00:00,  6.52s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "7826.270144763852 2.1661948166846834\n",
      "no_noise_test_representations_26.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2124.06it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:37<00:00,  6.27s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "7379.000885706164 2.1705070156048825\n",
      "no_noise_test_representations_28.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2124.46it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:39<00:00,  6.52s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "7618.429769351432 2.185627232355862\n",
      "no_noise_test_representations_30.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2152.37it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:36<00:00,  6.16s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "8654.362369512137 2.185907113671954\n",
      "no_noise_test_representations_32.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2079.85it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:41<00:00,  6.95s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "8217.480648950646 2.191298202677733\n",
      "no_noise_test_representations_34.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2071.21it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:37<00:00,  6.18s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "6879.732549734872 2.1903774482684426\n",
      "no_noise_test_representations_36.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 1990.11it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:38<00:00,  6.44s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "8174.965738071465 2.1933536933868876\n",
      "no_noise_test_representations_38.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2142.46it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:38<00:00,  6.46s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "8512.797569410539 2.1925592619683325\n",
      "no_noise_test_representations_4.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2171.07it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:39<00:00,  6.58s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "9768.07601757588 2.0951339239698363\n",
      "no_noise_test_representations_40.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2153.00it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:38<00:00,  6.47s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "6062.517840101901 2.1965425797282796\n",
      "no_noise_test_representations_42.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2165.59it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:39<00:00,  6.51s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "7318.860208443209 2.2040500256356284\n",
      "no_noise_test_representations_44.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 1985.85it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:42<00:00,  7.12s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "6860.779247979853 2.193020346185131\n",
      "no_noise_test_representations_46.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2159.55it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:39<00:00,  6.59s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "6206.538736436517 2.2030625514271396\n",
      "no_noise_test_representations_48.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2174.98it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:38<00:00,  6.38s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5599.302421966053 2.2229229454170776\n",
      "no_noise_test_representations_50.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2165.38it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:38<00:00,  6.37s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "6567.660512131639 2.1980080642669084\n",
      "no_noise_test_representations_52.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2156.68it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:38<00:00,  6.38s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "7543.477525836822 2.1949792710622287\n",
      "no_noise_test_representations_54.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2142.46it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:37<00:00,  6.20s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "6237.31268212148 2.210491158072064\n",
      "no_noise_test_representations_56.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2130.32it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:38<00:00,  6.49s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "6508.350330860947 2.2254862660573553\n",
      "no_noise_test_representations_58.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2148.58it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:38<00:00,  6.40s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5835.53626840891 2.2287490918225403\n",
      "no_noise_test_representations_6.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2050.45it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:39<00:00,  6.56s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "8325.832122325957 2.0937912705806356\n",
      "no_noise_test_representations_60.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2102.03it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:38<00:00,  6.42s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "6121.773001697249 2.21743447626044\n",
      "no_noise_test_representations_62.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2100.53it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:38<00:00,  6.44s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "7232.594886986705 2.219318454801334\n",
      "no_noise_test_representations_64.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 1836.41it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:37<00:00,  6.21s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "6223.264783431213 2.228174735451959\n",
      "no_noise_test_representations_66.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 1881.13it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:39<00:00,  6.55s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5952.674966050544 2.2274753413613544\n",
      "no_noise_test_representations_68.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2110.77it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:39<00:00,  6.52s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5626.239685980283 2.2210771218066885\n",
      "no_noise_test_representations_70.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2069.16it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:38<00:00,  6.34s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "6031.045245148782 2.2233776857193988\n",
      "no_noise_test_representations_72.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2070.39it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:39<00:00,  6.58s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5027.027919300329 2.224315905760861\n",
      "no_noise_test_representations_74.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2093.26it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:38<00:00,  6.49s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "6440.645186807353 2.232012098600411\n",
      "no_noise_test_representations_76.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2081.76it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:38<00:00,  6.39s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "6780.076489218579 2.2150559780984205\n",
      "no_noise_test_representations_78.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2096.14it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:37<00:00,  6.18s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5733.592319281267 2.2188648956091974\n",
      "no_noise_test_representations_8.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2049.07it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:38<00:00,  6.47s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "9952.112427727232 2.1247172484588117\n",
      "no_noise_test_representations_80.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 1846.07it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:38<00:00,  6.44s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5678.862017821036 2.207293747697578\n",
      "no_noise_test_representations_82.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2097.64it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:37<00:00,  6.20s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5569.653783780771 2.21045000780618\n",
      "no_noise_test_representations_84.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2089.75it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:38<00:00,  6.42s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "6024.521630025616 2.20827886953753\n",
      "no_noise_test_representations_86.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2080.20it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:37<00:00,  6.30s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5707.749597756782 2.196758872640734\n",
      "no_noise_test_representations_88.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2050.45it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:39<00:00,  6.62s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4766.34695179581 2.2103414374288164\n",
      "no_noise_test_representations_90.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2077.54it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:36<00:00,  6.12s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5761.293209649606 2.1960345531770633\n",
      "no_noise_test_representations_92.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2058.63it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:38<00:00,  6.35s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5429.164857318686 2.2210791238932073\n",
      "no_noise_test_representations_94.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2029.24it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:39<00:00,  6.61s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5118.811382299242 2.202248122245023\n",
      "no_noise_test_representations_96.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2067.36it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:38<00:00,  6.49s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5510.475972596236 2.2221096902742663\n",
      "no_noise_test_representations_98.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 1953.66it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:37<00:00,  6.21s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "6172.313670287455 2.1901725838503756\n",
      "no_noise_train_representations_0.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2045.16it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:37<00:00,  6.32s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "765.2429357790527 1.7550380143575146\n",
      "no_noise_train_representations_10.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2035.77it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:39<00:00,  6.53s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1302.219940251472 2.1115103079052924\n",
      "no_noise_train_representations_12.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2074.86it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:38<00:00,  6.34s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1323.02248125623 2.1230508898938485\n",
      "no_noise_train_representations_14.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2056.33it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:39<00:00,  6.51s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1387.319543501721 2.120746627180056\n",
      "no_noise_train_representations_16.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2054.82it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:37<00:00,  6.18s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1212.5572792428043 2.1192156526232724\n",
      "no_noise_train_representations_18.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 1840.71it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:39<00:00,  6.57s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1253.3322427032394 2.131261045569559\n",
      "no_noise_train_representations_2.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2034.71it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:37<00:00,  6.26s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1230.8148010924538 2.0137796002848534\n",
      "no_noise_train_representations_20.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2021.36it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:37<00:00,  6.18s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1188.7705802018052 2.139155399932017\n",
      "no_noise_train_representations_22.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2063.91it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:39<00:00,  6.53s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1207.9707651817619 2.129021013358318\n",
      "no_noise_train_representations_24.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2054.65it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:36<00:00,  6.16s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1147.9954763119956 2.156057658034439\n",
      "no_noise_train_representations_26.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2003.00it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:39<00:00,  6.54s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1079.206570822371 2.152809286736298\n",
      "no_noise_train_representations_28.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 1981.38it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:37<00:00,  6.32s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1100.8969003581624 2.1657275120502746\n",
      "no_noise_train_representations_30.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 1579.01it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:40<00:00,  6.68s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1061.313472725125 2.166080649863046\n",
      "no_noise_train_representations_32.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2029.70it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:39<00:00,  6.55s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1052.1720770142074 2.167136565402701\n",
      "no_noise_train_representations_34.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2051.34it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:37<00:00,  6.25s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1007.1519448137717 2.169047278966043\n",
      "no_noise_train_representations_36.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 1944.57it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:39<00:00,  6.55s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "954.0170382092496 2.167226983030087\n",
      "no_noise_train_representations_38.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2029.69it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:40<00:00,  6.74s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "936.1257573383166 2.1629016039541265\n",
      "no_noise_train_representations_4.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2016.37it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:40<00:00,  6.81s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1275.7009620478707 2.052155883798908\n",
      "no_noise_train_representations_40.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2037.22it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:38<00:00,  6.42s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "948.0626752884336 2.191921651463983\n",
      "no_noise_train_representations_42.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2007.08it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:37<00:00,  6.28s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "970.9002495377649 2.1830271127864385\n",
      "no_noise_train_representations_44.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2051.97it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:37<00:00,  6.24s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "947.478477290598 2.1941091994382096\n",
      "no_noise_train_representations_46.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 1973.65it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:39<00:00,  6.58s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "892.2021684574438 2.1922248891368614\n",
      "no_noise_train_representations_48.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2045.59it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:36<00:00,  6.15s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "891.5626155047307 2.1976885243987443\n",
      "no_noise_train_representations_50.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2042.10it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:37<00:00,  6.27s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "871.9058796970835 2.189227286019139\n",
      "no_noise_train_representations_52.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2028.50it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:39<00:00,  6.50s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "908.3925051581372 2.18925293793273\n",
      "no_noise_train_representations_54.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 1950.72it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:37<00:00,  6.26s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "884.5214931018435 2.196576628484499\n",
      "no_noise_train_representations_56.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 1888.14it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:39<00:00,  6.63s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "864.389622055704 2.2015345499997374\n",
      "no_noise_train_representations_58.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2049.61it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:37<00:00,  6.18s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "855.6515503037344 2.2142779031751383\n",
      "no_noise_train_representations_6.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2051.17it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:37<00:00,  6.32s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1354.879773086965 2.0754239928804195\n",
      "no_noise_train_representations_60.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2076.03it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:38<00:00,  6.34s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "845.2095087481057 2.203043192261557\n",
      "no_noise_train_representations_62.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2041.94it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:39<00:00,  6.64s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "830.7770401300015 2.200743516594155\n",
      "no_noise_train_representations_64.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 1867.09it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:37<00:00,  6.32s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "855.447728016167 2.2146868336637993\n",
      "no_noise_train_representations_66.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2015.31it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:40<00:00,  6.67s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "828.0048102933472 2.2112257981633285\n",
      "no_noise_train_representations_68.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2033.81it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:40<00:00,  6.67s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "797.5256748777275 2.2141674409409298\n",
      "no_noise_train_representations_70.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2021.36it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:38<00:00,  6.43s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "805.1338629066474 2.2107933161821944\n",
      "no_noise_train_representations_72.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 1991.91it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:39<00:00,  6.63s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "775.0990639540632 2.201197569025572\n",
      "no_noise_train_representations_74.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2008.76it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:37<00:00,  6.17s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "779.2774829511388 2.208912755870914\n",
      "no_noise_train_representations_76.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 1972.47it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:40<00:00,  6.71s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "780.819348820806 2.198990241138283\n",
      "no_noise_train_representations_78.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 1802.83it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:37<00:00,  6.32s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "768.7934209808061 2.209910600633745\n",
      "no_noise_train_representations_8.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 1969.35it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:40<00:00,  6.67s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1337.6835634697286 2.095985239997198\n",
      "no_noise_train_representations_80.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2001.79it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:38<00:00,  6.35s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "746.9662139280085 2.198313036960756\n",
      "no_noise_train_representations_82.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 1874.14it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:40<00:00,  6.68s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "770.8340350678709 2.2020564896922004\n",
      "no_noise_train_representations_84.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 1833.03it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:39<00:00,  6.64s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "800.3951301944428 2.198312948769389\n",
      "no_noise_train_representations_86.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 1985.48it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:37<00:00,  6.32s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "746.9101594799245 2.192546495296737\n",
      "no_noise_train_representations_88.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 1958.24it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:40<00:00,  6.69s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "739.0691061393946 2.1997486881764448\n",
      "no_noise_train_representations_90.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 1869.18it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:38<00:00,  6.49s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "756.3214494966719 2.1945588418119666\n",
      "no_noise_train_representations_92.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2003.76it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:37<00:00,  6.32s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "745.729678569222 2.205826565260473\n",
      "no_noise_train_representations_94.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2010.96it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:39<00:00,  6.55s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "744.150127799356 2.190154074791944\n",
      "no_noise_train_representations_96.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2003.26it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:37<00:00,  6.33s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "734.1606463497576 2.207857556091989\n",
      "no_noise_train_representations_98.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2009.36it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:39<00:00,  6.64s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "738.4712043075414 2.18191994935051\n"
     ]
    }
   ],
   "source": [
    "if os.path.exists(os.path.join(repr_data, \"test_comp_mi_xz\")):\n",
    "    test_comp_mi_xz = pickle.load(open(os.path.join(repr_data, \"test_comp_mi_xz\"), \"rb\"))\n",
    "    test_comp_mi_zy = pickle.load(open(os.path.join(repr_data, \"test_comp_mi_zy\"), \"rb\"))\n",
    "    train_comp_mi_xz = pickle.load(open(os.path.join(repr_data, \"train_comp_mi_xz\"), \"rb\"))\n",
    "    train_comp_mi_zy = pickle.load(open(os.path.join(repr_data, \"train_comp_mi_zy\"), \"rb\"))\n",
    "else:    \n",
    "    test_comp_mi_xz = {}\n",
    "    test_comp_mi_zy = {}\n",
    "    train_comp_mi_xz = {}\n",
    "    train_comp_mi_zy = {}\n",
    "\n",
    "    test_labels = np.load(os.path.join(repr_data, \"test_labels.npy\"))\n",
    "    train_labels = np.load(os.path.join(repr_data, \"train_labels.npy\"))\n",
    "    \n",
    "    train_repeat = 1\n",
    "    test_repeat = 6\n",
    "\n",
    "    for f in os.listdir(repr_data):\n",
    "        if \"test_representations\" in f:\n",
    "            print(f)\n",
    "            epoch = int(f.split(\".\")[0].split(\"_\")[-1])\n",
    "            nonoise_reprs = np.load(os.path.join(repr_data, f), allow_pickle=True)\n",
    "            nonoise_reprs, ind = np.unique(nonoise_reprs, axis=0, return_index=True)\n",
    "            reprs = []\n",
    "            for nr in nonoise_reprs:\n",
    "                for i in range(test_repeat):\n",
    "                    epsilon = np.random.randn(nonoise_reprs.shape[1]) * drp_noise + 1\n",
    "                    reprs.append(nr*epsilon)\n",
    "            reprs = np.array(reprs)\n",
    "            test_comp_mi_xz[epoch] = gaussian_noise_mi(reprs, nonoise_reprs, drp_noise)\n",
    "            test_comp_mi_zy[epoch] = EDGE(reprs, np.repeat(np.array(test_labels[ind]), test_repeat))\n",
    "            print(test_comp_mi_xz[epoch], test_comp_mi_zy[epoch])\n",
    "\n",
    "        if \"train_representations\" in f:\n",
    "            print(f)\n",
    "            epoch = int(f.split(\".\")[0].split(\"_\")[-1])\n",
    "            nonoise_reprs = np.load(os.path.join(repr_data, f), allow_pickle=True)\n",
    "            reprs = []\n",
    "            for nr in nonoise_reprs:\n",
    "                for i in range(train_repeat):\n",
    "                    epsilon = np.random.randn(nonoise_reprs.shape[1]) * drp_noise + 1\n",
    "                    reprs.append(nr*epsilon)\n",
    "            reprs = np.array(reprs)\n",
    "            train_comp_mi_xz[epoch] = gaussian_noise_mi(reprs, nonoise_reprs, drp_noise)\n",
    "            train_comp_mi_zy[epoch] = EDGE(reprs, np.repeat(train_labels, train_repeat))\n",
    "            print(train_comp_mi_xz[epoch], train_comp_mi_zy[epoch])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "365dfdf1",
   "metadata": {},
   "outputs": [],
   "source": [
    "if not os.path.exists(os.path.join(repr_data, \"test_comp_mi_xz\")):\n",
    "    pickle.dump(test_comp_mi_xz, open(os.path.join(repr_data, \"test_comp_mi_xz\"), \"wb\"))\n",
    "    pickle.dump(test_comp_mi_zy, open(os.path.join(repr_data, \"test_comp_mi_zy\"), \"wb\"))\n",
    "    pickle.dump(train_comp_mi_xz, open(os.path.join(repr_data, \"train_comp_mi_xz\"), \"wb\"))\n",
    "    pickle.dump(train_comp_mi_zy, open(os.path.join(repr_data, \"train_comp_mi_zy\"), \"wb\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "c6a9f42b",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\adylo\\AppData\\Local\\Temp\\ipykernel_19284\\1700152655.py:25: MatplotlibDeprecationWarning: Auto-removal of grids by pcolor() and pcolormesh() is deprecated since 3.5 and will be removed two minor releases later; please call grid(False) first.\n",
      "  plt.colorbar(sm, label='Epoch')\n"
     ]
    }
   ],
   "source": [
    "od_test_comp_mi_xz = OrderedDict(sorted(test_comp_mi_xz.items()))\n",
    "od_test_comp_mi_zy = OrderedDict(sorted(test_comp_mi_zy.items()))\n",
    "drawIP(od_test_comp_mi_xz, od_test_comp_mi_zy, 'Gaussian dropout, '+netw+' (validation)', crossentropy_zy=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "d70fa71a",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\adylo\\AppData\\Local\\Temp\\ipykernel_19284\\1700152655.py:25: MatplotlibDeprecationWarning: Auto-removal of grids by pcolor() and pcolormesh() is deprecated since 3.5 and will be removed two minor releases later; please call grid(False) first.\n",
      "  plt.colorbar(sm, label='Epoch')\n"
     ]
    }
   ],
   "source": [
    "drawIP(od_test_comp_mi_xz, val_mi_zy, 'Gaussian dropout, '+netw+' (validation)', crossentropy_zy=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "e3b5975b",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\adylo\\AppData\\Local\\Temp\\ipykernel_19284\\1700152655.py:25: MatplotlibDeprecationWarning: Auto-removal of grids by pcolor() and pcolormesh() is deprecated since 3.5 and will be removed two minor releases later; please call grid(False) first.\n",
      "  plt.colorbar(sm, label='Epoch')\n"
     ]
    }
   ],
   "source": [
    "od_train_comp_mi_xz = OrderedDict(sorted(train_comp_mi_xz.items()))\n",
    "od_train_comp_mi_zy = OrderedDict(sorted(train_comp_mi_zy.items()))\n",
    "drawIP(od_train_comp_mi_xz, od_train_comp_mi_zy, 'Gaussian dropout, '+netw+' (training)', crossentropy_zy=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "205cbd8e",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\adylo\\AppData\\Local\\Temp\\ipykernel_19284\\1700152655.py:25: MatplotlibDeprecationWarning: Auto-removal of grids by pcolor() and pcolormesh() is deprecated since 3.5 and will be removed two minor releases later; please call grid(False) first.\n",
      "  plt.colorbar(sm, label='Epoch')\n"
     ]
    }
   ],
   "source": [
    "drawIP(od_train_comp_mi_xz, train_mi_zy, 'Gaussian dropout, '+netw+' (training)', crossentropy_zy=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c4689f1a",
   "metadata": {},
   "source": [
    "### Binning IP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "d6d40e35",
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_bins(min_bound, max_bound, num_of_bins=None, bin_size=None):\n",
    "    if bin_size is not None:\n",
    "        bins = np.arange(min_bound, max_bound, bin_size, dtype='float32')\n",
    "    elif num_of_bins is not None:\n",
    "        bins = np.linspace(min_bound, max_bound, num_of_bins, dtype='float32')\n",
    "    else:\n",
    "        print(\"Computation error; set either bin size or number of bins to a value\")\n",
    "        return None\n",
    "    return bins"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "fd9b9b3c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def double_bin_calc_information(inputdata, layerdata, num_of_bins=None, bin_size=None):\n",
    "    bins_inp = create_bins(inputdata.min(), inputdata.max(), num_of_bins=num_of_bins, bin_size=bin_size)\n",
    "    digitized_inp = bins_inp[np.digitize(np.squeeze(inputdata.reshape(1, -1)), bins_inp) - 1].reshape(len(inputdata), -1)\n",
    "\n",
    "    bins_rep = create_bins(layerdata.min(), layerdata.max(), num_of_bins=num_of_bins, bin_size=bin_size)\n",
    "    digitized_rep = bins_rep[np.digitize(np.squeeze(layerdata.reshape(1, -1)), bins_rep) - 1].reshape(len(layerdata), -1)\n",
    "\n",
    "    return midd(digitized_inp, digitized_rep)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "847efa11",
   "metadata": {},
   "outputs": [],
   "source": [
    "if os.path.exists(os.path.join(repr_data, \"test_bin_mi_xz\")):\n",
    "    test_bin_mi_xz = pickle.load(open(os.path.join(repr_data, \"test_bin_mi_xz\"), \"rb\"))\n",
    "    test_bin_mi_zy = pickle.load(open(os.path.join(repr_data, \"test_bin_mi_zy\"), \"rb\"))\n",
    "    train_bin_mi_xz = pickle.load(open(os.path.join(repr_data, \"train_bin_mi_xz\"), \"rb\"))\n",
    "    train_bin_mi_zy = pickle.load(open(os.path.join(repr_data, \"train_bin_mi_zy\"), \"rb\"))\n",
    "else:\n",
    "    test_bin_mi_xz = {}\n",
    "    test_bin_mi_zy = {}\n",
    "    train_bin_mi_xz = {}\n",
    "    train_bin_mi_zy = {}\n",
    "\n",
    "    test_inputs = np.load(os.path.join(repr_data, \"test_inputs.npy\"))\n",
    "    test_inputs = test_inputs.reshape(test_inputs.shape[0], -1)\n",
    "    test_labels = np.load(os.path.join(repr_data, \"test_labels.npy\"))\n",
    "    train_inputs = np.load(os.path.join(repr_data, \"train_inputs.npy\"))\n",
    "    train_inputs = train_inputs.reshape(train_inputs.shape[0], -1)\n",
    "    train_labels = np.load(os.path.join(repr_data, \"train_labels.npy\"))\n",
    "    \n",
    "    train_repeat = 1\n",
    "    test_repeat = 6\n",
    "\n",
    "    for f in os.listdir(repr_data):\n",
    "        if \"test_representations\" in f:\n",
    "            print(f)\n",
    "            epoch = int(f.split(\".\")[0].split(\"_\")[-1])            \n",
    "            nonoise_reprs = np.load(os.path.join(repr_data, f))\n",
    "            nonoise_reprs, ind = np.unique(nonoise_reprs, axis=0, return_index=True)\n",
    "            reprs = []\n",
    "            for nr in nonoise_reprs:\n",
    "                for i in range(test_repeat):\n",
    "                    epsilon = np.random.randn(nonoise_reprs.shape[1]) * drp_noise + 1\n",
    "                    reprs.append(nr*epsilon)\n",
    "            reprs = np.array(reprs)\n",
    "            test_bin_mi_xz[epoch] = double_bin_calc_information(np.repeat(np.array(test_inputs[ind]), test_repeat, axis=0), \n",
    "                                                                reprs, num_of_bins=3)\n",
    "            test_bin_mi_zy[epoch] = double_bin_calc_information(np.repeat(np.array(test_labels[ind]), test_repeat), \n",
    "                                                                reprs, num_of_bins=3)\n",
    "            print(test_bin_mi_xz[epoch], test_bin_mi_zy[epoch])\n",
    "\n",
    "        if \"train_representations\" in f:\n",
    "            print(f)\n",
    "            epoch = int(f.split(\".\")[0].split(\"_\")[-1])\n",
    "            nonoise_reprs = np.load(os.path.join(repr_data, f))\n",
    "            reprs = []\n",
    "            for nr in nonoise_reprs:\n",
    "                for i in range(train_repeat):\n",
    "                    epsilon = np.random.randn(nonoise_reprs.shape[1]) * drp_noise + 1\n",
    "                    reprs.append(nr*epsilon)\n",
    "            reprs = np.array(reprs)\n",
    "            train_bin_mi_xz[epoch] = double_bin_calc_information(np.repeat(np.array(train_inputs), train_repeat, axis=0), \n",
    "                                                                 reprs, num_of_bins=3)\n",
    "            train_bin_mi_zy[epoch] = double_bin_calc_information(np.repeat(np.array(train_labels), train_repeat, axis=0), \n",
    "                                                                 reprs, num_of_bins=3)\n",
    "            print(train_bin_mi_xz[epoch], train_bin_mi_zy[epoch])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "adc4da59",
   "metadata": {},
   "outputs": [],
   "source": [
    "if not os.path.exists(os.path.join(repr_data, \"test_bin_mi_xz\")):\n",
    "    pickle.dump(test_bin_mi_xz, open(os.path.join(repr_data, \"test_bin_mi_xz\"), \"wb\"))\n",
    "    pickle.dump(test_bin_mi_zy, open(os.path.join(repr_data, \"test_bin_mi_zy\"), \"wb\"))\n",
    "    pickle.dump(train_bin_mi_xz, open(os.path.join(repr_data, \"train_bin_mi_xz\"), \"wb\"))\n",
    "    pickle.dump(train_bin_mi_zy, open(os.path.join(repr_data, \"train_bin_mi_zy\"), \"wb\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "1a37585a",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\adylo\\AppData\\Local\\Temp\\ipykernel_19284\\1700152655.py:25: MatplotlibDeprecationWarning: Auto-removal of grids by pcolor() and pcolormesh() is deprecated since 3.5 and will be removed two minor releases later; please call grid(False) first.\n",
      "  plt.colorbar(sm, label='Epoch')\n"
     ]
    }
   ],
   "source": [
    "od_test_bin_mi_xz = OrderedDict(sorted(test_bin_mi_xz.items()))\n",
    "od_test_bin_mi_zy = OrderedDict(sorted(test_bin_mi_zy.items()))\n",
    "drawIP(od_test_bin_mi_xz, od_test_bin_mi_zy, 'Gaussian dropout + binning, '+netw+' (validation)', crossentropy_zy=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "d0218f6e",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\adylo\\AppData\\Local\\Temp\\ipykernel_19284\\1700152655.py:25: MatplotlibDeprecationWarning: Auto-removal of grids by pcolor() and pcolormesh() is deprecated since 3.5 and will be removed two minor releases later; please call grid(False) first.\n",
      "  plt.colorbar(sm, label='Epoch')\n"
     ]
    }
   ],
   "source": [
    "od_train_bin_mi_xz = OrderedDict(sorted(train_bin_mi_xz.items()))\n",
    "od_train_bin_mi_zy = OrderedDict(sorted(train_bin_mi_zy.items()))\n",
    "drawIP(od_train_bin_mi_xz, od_train_bin_mi_zy, 'Gaussian dropout + binning, '+netw+' (training)', crossentropy_zy=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c636ecea",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
