{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "9f2382f0-45ed-4bd2-a5f5-5fd0ac22ffa4",
   "metadata": {},
   "source": [
    "# Information plane experiments (DNN classifier, CIFAR10)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "930c0071-8e91-4dd8-8770-f0410a4abc9e",
   "metadata": {},
   "source": [
    "## Preamble"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "66cfe11a-c587-4a03-9b39-ece161a4eec6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9474809c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3d17453c",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "#device = \"cpu\"\n",
    "print(\"Device: \" + device)\n",
    "print(f\"Devices count: {torch.cuda.device_count()}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7ff6767-2f72-4d40-ac04-13d9d7805b5f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from pathlib import Path\n",
    "\n",
    "path = Path(\"../../data/\").resolve()\n",
    "experiments_path = path / \"mutual_information/CIFAR10/\"\n",
    "models_path = experiments_path / \"models/\"\n",
    "results_path = experiments_path / \"resuts/\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2dca8d6a-e730-477c-a685-36c5f605657a",
   "metadata": {},
   "source": [
    "### Global settings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7cbab7db-8246-407f-b49d-b567f5c2f5b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Autoencoder for inputs.\n",
    "X_latent_dim = 10             # Input dimension after compression.\n",
    "X_autoencoder_n_epochs = 1500 # Number of epochs to train the autoencoder.\n",
    "load_X_autoencoder = True     # Reload weights of the autoencoder.\n",
    "\n",
    "# Autoencoder for layers.\n",
    "L_latent_dim = 4              # Layer dimension after compression.\n",
    "L_autoencoder_n_epochs = 100  # Number of epochs to train the autoencoder.\n",
    "\n",
    "# Classifier.\n",
    "classifier_lr = 1e-4      # Classifier learning rate.\n",
    "classifier_n_epochs = 50 # Number of epochs to train the classifier.\n",
    "sigma = 1e-3              # Noise-to-signal ratio."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a3d13b8b-47c5-4717-a88b-486a7774dfac",
   "metadata": {},
   "source": [
    "## Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "92085d19",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchvision.datasets import CIFAR10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d2b9cef-60aa-4908-a6f3-6d9159ae3575",
   "metadata": {},
   "outputs": [],
   "source": [
    "image_transform = torchvision.transforms.Compose([\n",
    "    torchvision.transforms.ToTensor(),\n",
    "    torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))\n",
    "])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2c302959",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset = CIFAR10(root=\"./.cache\", download=True, transform=image_transform)\n",
    "test_dataset = CIFAR10(root=\"./.cache\", download=True, transform=image_transform, train=False)\n",
    "eval_dataset = CIFAR10(root=\"./.cache\", download=True, transform=image_transform, train=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fadba18e-0127-4a0c-8b85-44210fb87beb",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size_train = 1024\n",
    "batch_size_test  = 2048"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "57289577",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size_train, shuffle=True)\n",
    "test_dataloader  = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size_test, shuffle=False)\n",
    "eval_dataloader  = torch.utils.data.DataLoader(eval_dataset, batch_size=batch_size_train, shuffle=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e12c498e-aef3-4e76-9dbc-c6e3fa0a9286",
   "metadata": {},
   "source": [
    "### Visualisation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dfd26651-06de-47da-baea-2fa2ab29b923",
   "metadata": {},
   "outputs": [],
   "source": [
    "from misc.utils import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b8209603-91f0-4901-a14b-bddbe597d487",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_images(*split_lists([(train_dataset[index][0], f\"label: {train_dataset[index][1]}\") for index in range(6)]))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bcbb8830-49c2-4b4b-a260-8ea89dbed238",
   "metadata": {},
   "source": [
    "## Autoencoder for inputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8467ca1a-fdc5-46de-b445-b09654dc53d6",
   "metadata": {},
   "outputs": [],
   "source": [
    "from mutinfo.torch.datasets import AutoencoderDataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e878632-32ab-4aca-b0fe-a44de3a15b3d",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset_autoencoder = AutoencoderDataset(train_dataset)\n",
    "test_dataset_autoencoder = AutoencoderDataset(test_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fdc2a5b6-eef6-4c64-9102-f1dc44d0f017",
   "metadata": {},
   "outputs": [],
   "source": [
    "autoencoder_batch_size_train = 1024\n",
    "autoencoder_batch_size_test  = 2048"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f53fd8f-b2ed-42d4-9758-af7ed2de56dc",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataloader_autoencoder = torch.utils.data.DataLoader(train_dataset_autoencoder, batch_size=autoencoder_batch_size_train, shuffle=True)\n",
    "test_dataloader_autoencoder  = torch.utils.data.DataLoader(test_dataset_autoencoder, batch_size=autoencoder_batch_size_test, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "669fde90-9c9e-4e60-b6c5-e54be14cd320",
   "metadata": {},
   "outputs": [],
   "source": [
    "from misc.autoencoder import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a2275175-a09d-42a8-b5a6-e12c76b7e647",
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "from IPython.display import clear_output\n",
    "\n",
    "def autoencoder_callback(autoencoder, autoencoder_metrics=None):\n",
    "    clear_output(True)\n",
    "    \n",
    "    was_in_training = autoencoder.training\n",
    "    autoencoder.eval()\n",
    "    \n",
    "    # Display some images..\n",
    "    with torch.no_grad():\n",
    "        samples = [sample[0] for sample in random.choices(test_dataset_autoencoder, k=3)]\n",
    "        samples += [autoencoder(sample[None,:].to(device)).cpu().detach()[0] for sample in samples]\n",
    "        show_images(samples)\n",
    "        \n",
    "    # Display loss/metrics plots.\n",
    "    if not (autoencoder_metrics is None):\n",
    "        plt.figure(figsize=(12,4))\n",
    "        for index, (name, history) in enumerate(sorted(autoencoder_metrics.items())):\n",
    "            plt.subplot(1, len(autoencoder_metrics), index + 1)\n",
    "            plt.title(name)\n",
    "            plt.plot(range(1, len(history) + 1), history)\n",
    "            plt.grid()\n",
    "\n",
    "        plt.show();\n",
    "        \n",
    "    autoencoder.train(was_in_training)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "55a5c4d4-22a3-4fc5-8704-eafe1ae62f08",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_autoencoder = Autoencoder(CIFAR10_ConvEncoder(latent_dim=X_latent_dim), CIFAR10_ConvDecoder(latent_dim=X_latent_dim)).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c431489-4b38-4c78-8ff1-13dd1dea372f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "X_autoencoder_path = models_path / \"autoencoders/\"\n",
    "encoder_path = X_autoencoder_path / f\"X_encoder_{X_latent_dim}.pt\"\n",
    "decoder_path = X_autoencoder_path / f\"X_decoder_{X_latent_dim}.pt\"\n",
    "\n",
    "if load_X_autoencoder:\n",
    "    try:\n",
    "        X_autoencoder.encoder.load_state_dict(torch.load(encoder_path))\n",
    "        X_autoencoder.decoder.load_state_dict(torch.load(decoder_path))\n",
    "        autoencoder_callback(X_autoencoder)\n",
    "    except:\n",
    "        print(\"The autoencoder is not found or cannot be loaded.\")\n",
    "        load_X_autoencoder = False\n",
    "        \n",
    "if not load_X_autoencoder:\n",
    "    results = train_autoencoder(X_autoencoder, train_dataloader_autoencoder, test_dataloader_autoencoder, torch.nn.L1Loss(),\n",
    "                                device, n_epochs=X_autoencoder_n_epochs, callback=autoencoder_callback, lr=1e-3)\n",
    "    \n",
    "    os.makedirs(X_autoencoder_path, exist_ok=True)\n",
    "    torch.save(X_autoencoder.encoder.state_dict(), encoder_path)\n",
    "    torch.save(X_autoencoder.decoder.state_dict(), decoder_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69fac603-5414-4ec6-bc79-89c388939e33",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_autoencoder.agn.enabled_on_inference = False\n",
    "#X_compressed = get_outputs(X_autoencoder.encoder, eval_dataloader, device).numpy()\n",
    "X_compressed = get_outputs(X_autoencoder.encoder, train_dataloader, device).numpy()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3a361159-c7a6-4543-8b28-83e36ff729c3",
   "metadata": {},
   "source": [
    "## Classifier"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b471984e-55bf-479d-8154-ad4d253f2161",
   "metadata": {},
   "source": [
    "### Filter for plots"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f0d23090-9bd3-424c-8b45-05cd48552958",
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.signal import butter, filtfilt, savgol_filter\n",
    "from misc.nonuniform_savgol_filter import *\n",
    "\n",
    "def filter_data(x: np.array, errorbars: bool=True) -> np.array:\n",
    "    \"\"\"\n",
    "    Filter the data.\n",
    "    \n",
    "    Parameters\n",
    "    ----------\n",
    "    x : np.array\n",
    "        Input data.\n",
    "    errorbars : bool\n",
    "        Process errorbars.\n",
    "        \n",
    "    Returns\n",
    "    -------\n",
    "    np.array\n",
    "        Filtered data.\n",
    "    \"\"\"\n",
    "    \n",
    "    if errorbars:\n",
    "        x = np.array([item[0] for item in x])\n",
    "    else:\n",
    "        if type(x) is not np.array:\n",
    "            x = np.array(x)\n",
    "    \n",
    "    # Savitzky-Golay filter.\n",
    "    window_length = min(10, len(x))\n",
    "    polyorder = min(4, window_length-1)\n",
    "    \n",
    "    y = savgol_filter(x, window_length, polyorder)\n",
    "    \n",
    "    #window_length = 0.5\n",
    "    #polyorder = 4\n",
    "    #y = nonuniform_savgol_filter(np.sort(-np.array(results[\"metrics\"][\"test_loss\"])), x, window_length, polyorder)\n",
    "    \n",
    "    # scipy.signal.filtfilt.\n",
    "    b, a = butter(8, 0.125)\n",
    "    padlen = min(5, len(x)-1)\n",
    "    \n",
    "    y = filtfilt(b, a, y, padlen=padlen)\n",
    "    \n",
    "    return y"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "14514a2d-c57b-4f89-8d20-51899645cd3a",
   "metadata": {},
   "source": [
    "### Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad3a5971-2367-4010-b416-f967a9c6aae0",
   "metadata": {},
   "outputs": [],
   "source": [
    "from misc.classifier import *\n",
    "from tqdm import tqdm, trange\n",
    "from sklearn.decomposition import PCA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9deec626-9d22-495c-89d5-786b2b5d694f",
   "metadata": {},
   "outputs": [],
   "source": [
    "classifier = CIFAR10_Classifier(sigma=sigma).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c5e9831c-8734-4a63-b8b3-3446896fdc64",
   "metadata": {},
   "outputs": [],
   "source": [
    "import mutinfo.estimators.mutual_information as mi_estimators"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01c86c49-449a-4148-b50e-d967c8559530",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Training options.\n",
    "classifier_loss = torch.nn.NLLLoss()\n",
    "classifier_opt = torch.optim.Adam(classifier.parameters(), lr=classifier_lr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5eb45bcf-0cc0-464d-9dd3-db09b1806d61",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Mutual information estimator options.\n",
    "\n",
    "entropy_estimator_params = \\\n",
    "{\n",
    "    'method': \"KL\",\n",
    "    'functional_params': {'n_jobs': 16, \"k_neighbours\": 5}\n",
    "}\n",
    "\n",
    "compression = 'pca' # 'autoencoders', 'first_coords'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f872b3b3-cdf3-4a55-9f6a-4057cc4695bc",
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import defaultdict\n",
    "\n",
    "def train_classifier(classifier, classifier_loss, classifier_opt,\n",
    "                     train_dataloader, test_dataloader, eval_dataloader,\n",
    "                     X_compressed, entropy_estimator_params,\n",
    "                     compression='pca', n_epochs: int=10,\n",
    "                     filter_data: callable=None):\n",
    "    \n",
    "    classifier_metrics = {\n",
    "        \"train_loss\" : [],\n",
    "        \"test_loss\" : [],\n",
    "        \"train_roc_auc\" : [],\n",
    "        \"test_roc_auc\" : []\n",
    "    }\n",
    "    \n",
    "    # Autoencoders.\n",
    "    L_autoencoders = dict()\n",
    "    \n",
    "    # Mutual information.\n",
    "    MI_X_L = defaultdict(list)\n",
    "    MI_L_Y = defaultdict(list)\n",
    "    filtered_MI_X_L = None\n",
    "    filtered_MI_L_Y = None\n",
    "    \n",
    "    # Targets.\n",
    "    targets = np.array(eval_dataloader.dataset.targets)\n",
    "    \n",
    "    for epoch in range(1, n_epochs + 1):       \n",
    "        # Training step.\n",
    "        print(f\"Epoch №{epoch}\")        \n",
    "        for index, batch in tqdm(enumerate(train_dataloader)):\n",
    "            x, y = batch\n",
    "            batch_size = x.shape[0]\n",
    "            \n",
    "            classifier_opt.zero_grad()\n",
    "            y_pred = classifier(x.to(device))\n",
    "            _loss = classifier_loss(y_pred, y.to(device))\n",
    "            _loss.backward()\n",
    "            classifier_opt.step()\n",
    "            \n",
    "        # Metrics.\n",
    "        print(\"Calculating metrics\")\n",
    "        train_loss, train_roc_auc = evaluate_classifier(classifier, train_dataloader, classifier_loss, device)\n",
    "        classifier_metrics[\"train_loss\"].append(train_loss)\n",
    "        classifier_metrics[\"train_roc_auc\"].append(train_roc_auc)\n",
    "        \n",
    "        test_loss, test_roc_auc = evaluate_classifier(classifier, test_dataloader, classifier_loss, device)\n",
    "        classifier_metrics[\"test_loss\"].append(test_loss)\n",
    "        classifier_metrics[\"test_roc_auc\"].append(test_roc_auc)\n",
    "        \n",
    "        # Layers.\n",
    "        print(\"Aquiring outputs of the layers\")\n",
    "        #train_outputs = get_layers(classifier, train_dataloader, device)\n",
    "        #test_outputs = get_layers(classifier, test_dataloader, device)\n",
    "        eval_outputs = get_layers(classifier, eval_dataloader, device)\n",
    "        \n",
    "        # Mutual information.\n",
    "        for layer_name in eval_outputs.keys():\n",
    "            this_L_latent_dim = min(L_latent_dim, torch.numel(eval_outputs[layer_name]) / eval_outputs[layer_name].shape[0])\n",
    "            \n",
    "            if compression == 'first_coords':\n",
    "                L_compressed = eval_outputs[layer_name].numpy()\n",
    "                L_compressed = np.reshape(L_compressed, (L_compressed.shape[0], -1))\n",
    "                L_compressed = L_compressed[:,:this_L_latent_dim]\n",
    "                \n",
    "            elif compression == 'pca':\n",
    "                L_compressed = eval_outputs[layer_name].numpy()\n",
    "                L_compressed = np.reshape(L_compressed, (L_compressed.shape[0], -1))\n",
    "                L_compressed = PCA(n_components=this_L_latent_dim).fit_transform(L_compressed)\n",
    "                \n",
    "            elif compression == 'autoencoders':\n",
    "                print(f\"Training an autoencoder for the layer {layer_name}\")\n",
    "                # Datasets.\n",
    "                train_layer = train_outputs[layer_name]\n",
    "                test_layer  = test_outputs[layer_name]\n",
    "                eval_layer  = eval_outputs[layer_name]\n",
    "\n",
    "                L_train_dataset = torch.utils.data.TensorDataset(train_layer, train_layer)\n",
    "                L_test_dataset  = torch.utils.data.TensorDataset(test_layer, test_layer)\n",
    "                L_eval_dataset  = torch.utils.data.TensorDataset(eval_layer, eval_layer)\n",
    "\n",
    "                L_train_dataloader = torch.utils.data.DataLoader(L_train_dataset, batch_size=batch_size_train,\n",
    "                                                                 shuffle=True)\n",
    "                L_test_dataloader  = torch.utils.data.DataLoader(L_test_dataset, batch_size=batch_size_test,\n",
    "                                                                 shuffle=False)\n",
    "                L_eval_dataloader  = torch.utils.data.DataLoader(L_eval_dataset, batch_size=batch_size_test,\n",
    "                                                                 shuffle=False)\n",
    "\n",
    "                # Autoencoder.\n",
    "                if layer_name in L_autoencoders.keys():\n",
    "                    L_autoencoder = L_autoencoders[layer_name]\n",
    "                else:\n",
    "                    print(f\"Could not find an autoencoder for the layer {layer_name}.\")\n",
    "                    L_dim = train_layer.shape[1]\n",
    "                    L_autoencoder = Autoencoder(DenseEncoder(input_dim=L_dim, latent_dim=this_L_latent_dim),\n",
    "                                                DenseDecoder(latent_dim=this_L_latent_dim, output_dim=L_dim)).to(device)\n",
    "\n",
    "                # Training.\n",
    "                L_results = train_autoencoder(L_autoencoder, L_train_dataloader, L_test_dataloader, torch.nn.MSELoss(),\n",
    "                                    device, n_epochs=L_autoencoder_n_epochs)\n",
    "                L_autoencoders[layer_name] = L_autoencoder\n",
    "\n",
    "                _baseline_PCA = PCA(n_components=this_L_latent_dim).fit(np.reshape(train_layer, (train_layer.shape[0], -1)))\n",
    "                _baseline_layer = _baseline_PCA.inverse_transform(_baseline_PCA.transform(test_layer))\n",
    "                baseline_loss = float(torch.nn.functional.mse_loss(test_layer, torch.tensor(_baseline_layer)))\n",
    "\n",
    "                print(f\"Train loss: {L_results['train_loss'][-1]:.2e}; test loss: {L_results['test_loss'][-1]:.2e}\")\n",
    "                print(f\"Better then PCA: {baseline_loss:.2e} / {L_results['test_loss'][-1]:.2e} = {baseline_loss / L_results['test_loss'][-1]:.2f}\")\n",
    "\n",
    "                L_compressed = get_outputs(L_autoencoder.encoder, L_eval_dataloader, device).numpy()\n",
    "                #L_compressed = PCA(n_components=L_latent_dim).fit_transform(np.reshape(layer, (layer.shape[0], -1)))\n",
    "            \n",
    "            print(f\"Estimating MI for the layer {layer_name}\")            \n",
    "            # (X,L)\n",
    "            print(\"I(X;L)\")\n",
    "            X_L_mi_estimator = mi_estimators.MutualInfoEstimator(entropy_estimator_params=entropy_estimator_params)\n",
    "            X_L_mi_estimator.fit(X_compressed, L_compressed, verbose=0)\n",
    "            MI_X_L[layer_name].append(X_L_mi_estimator.estimate(X_compressed, L_compressed, verbose=0))\n",
    "            \n",
    "            # (L,Y)\n",
    "            print(\"I(L;Y)\")\n",
    "            L_Y_mi_estimator = mi_estimators.MutualInfoEstimator(Y_is_discrete=True,\n",
    "                                                                 entropy_estimator_params=entropy_estimator_params)\n",
    "            L_Y_mi_estimator.fit(L_compressed, targets, verbose=0)\n",
    "            MI_L_Y[layer_name].append(L_Y_mi_estimator.estimate(L_compressed, targets, verbose=0))\n",
    "\n",
    "        \n",
    "        # Plots.\n",
    "        ## Metrics.\n",
    "        clear_output(True)\n",
    "        plt.figure(figsize=(18,4))\n",
    "        for index, (name, history) in enumerate(sorted(classifier_metrics.items())):\n",
    "            plt.subplot(1, len(classifier_metrics), index + 1)\n",
    "            plt.title(name)\n",
    "            plt.plot(range(1, len(history) + 1), history)\n",
    "            plt.grid()\n",
    "\n",
    "        plt.show();\n",
    "        \n",
    "        ## MI plane.\n",
    "        if not filter_data is None:\n",
    "            filtered_MI_X_L = {layer_name: filter_data(values) for layer_name, values in MI_X_L.items()}\n",
    "            filtered_MI_L_Y = {layer_name: filter_data(values) for layer_name, values in MI_L_Y.items()}\n",
    "            \n",
    "        plot_MI_planes(MI_X_L, MI_L_Y, filtered_MI_X_L, filtered_MI_L_Y)\n",
    "        \n",
    "    return {\"metrics\": classifier_metrics, \"MI_X_L\": MI_X_L, \"MI_L_Y\": MI_L_Y, \"filtered_MI_X_L\": filtered_MI_X_L, \"filtered_MI_L_Y\": filtered_MI_L_Y}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e69833bf-d95d-4493-af36-36683d603cf4",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "results = train_classifier(classifier, classifier_loss, classifier_opt,\n",
    "                           train_dataloader, test_dataloader, eval_dataloader,\n",
    "                           X_compressed, entropy_estimator_params,\n",
    "                           compression, n_epochs=classifier_n_epochs,\n",
    "                           filter_data=filter_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "069fe77c-c493-4e08-9ea9-cab719a9ef0e",
   "metadata": {},
   "outputs": [],
   "source": [
    "results[\"filtered_MI_X_L\"] = {layer_name: filter_data(values) for layer_name, values in results[\"MI_X_L\"].items()}\n",
    "results[\"filtered_MI_L_Y\"] = {layer_name: filter_data(values) for layer_name, values in results[\"MI_L_Y\"].items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "26375890-04fa-4782-8664-c0375525e83e",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_MI_planes(results[\"MI_X_L\"], results[\"MI_L_Y\"], results[\"filtered_MI_X_L\"], results[\"filtered_MI_L_Y\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef7f9af6-3016-480b-a172-094dbf11d916",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Saving all the results and settings.\n",
    "\n",
    "settings = {\n",
    "    # Autoencoder for inputs.\n",
    "    \"X_latent_dim\": X_latent_dim,\n",
    "    \"X_autoencoder_n_epochs\": X_autoencoder_n_epochs,\n",
    "    \"load_X_autoencoder\": load_X_autoencoder,\n",
    "    \n",
    "    # Autoencoder for layers.\n",
    "    \"L_latent_dim\": L_latent_dim,\n",
    "    \"L_autoencoder_n_epochs\": L_autoencoder_n_epochs,\n",
    "    \n",
    "    # Classifier.\n",
    "    \"classifier_lr\": classifier_lr,\n",
    "    \"classifier_n_epochs\": classifier_n_epochs,\n",
    "    \"sigma\": sigma,\n",
    "    \n",
    "    # Batch size.\n",
    "    \"batch_size_train\": batch_size_train,\n",
    "    \"batch_size_test\": batch_size_test,\n",
    "    \n",
    "    # Mutual information estimator.\n",
    "    \"entropy_estimator_params\": entropy_estimator_params,\n",
    "    \"compression\": compression,\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2cd2612d-4c55-434f-9ef7-27e94f8b0734",
   "metadata": {},
   "outputs": [],
   "source": [
    "save_results(results, settings, results_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2cd91813-2d4d-45f5-b80f-fc43689b615e",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
