{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Test mutual information estimators"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Preamble"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'\n",
    "#os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"-1\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow.compat.v2 as tf\n",
    "import tensorflow_datasets as tfds\n",
    "import tensorflow_addons as tfa\n",
    "\n",
    "tfds.disable_progress_bar()\n",
    "tf.enable_v2_behavior()\n",
    "\n",
    "import logging\n",
    "tf.get_logger().setLevel(logging.ERROR)\n",
    "\n",
    "print(tf.__version__)\n",
    "print(\"Num GPUs Available: \", len(tf.config.experimental.list_physical_devices('GPU')))\n",
    "tf.config.experimental.list_physical_devices()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import scipy.stats as sps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "font = {'family' : 'DejaVu Sans',\n",
    "        'size'   : 18}\n",
    "\n",
    "matplotlib.rc('font', **font)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "import csv\n",
    "\n",
    "from datetime import datetime"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "tneJ2JaEwztO"
   },
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "path = os.path.abspath(os.path.join(os.path.abspath(os.getcwd()), \"../../data/\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "experiments_path = path + \"/mutual_information/synthetic/\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Importing the module"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import mutinfo.estimators.mutual_information as mi_estimators\n",
    "from mutinfo.utils.dependent_norm import multivariate_normal_from_MI"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "### SETTINGS ###\n",
    "%run ./Settings.ipynb"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Standard tests with arbitrary mapping"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def perform_normal_compressed_test(mi, n_samples, X_dimension, Y_dimension, X_map=None, Y_map=None,\n",
    "                                   X_compressor=None, Y_compressor=None, verbose=0):\n",
    "    # Generation.\n",
    "    random_variable = multivariate_normal_from_MI(X_dimension, Y_dimension, mi)\n",
    "    X_Y = random_variable.rvs(n_samples)\n",
    "    X = X_Y[:, 0:X_dimension]\n",
    "    Y = X_Y[:, X_dimension:X_dimension + Y_dimension]\n",
    "        \n",
    "    # Mapping application.\n",
    "    if not X_map is None:\n",
    "        X = X_map(X)\n",
    "           \n",
    "    if not Y_map is None:\n",
    "        Y = Y_map(Y)\n",
    "        \n",
    "    # Mutual information estimation.\n",
    "    mi_estimator = mi_estimators.MutualInfoEstimator(entropy_estimator_params=entropy_estimator_params)\n",
    "    mi_estimator.fit(X, Y, verbose=verbose)\n",
    "    mi = mi_estimator.estimate(X, Y, verbose=verbose)\n",
    "    \n",
    "    # Mutual information estimation for compressed representation.\n",
    "    mi_estimator = mi_estimators.LossyMutualInfoEstimator(X_compressor, Y_compressor,\n",
    "                                                          entropy_estimator_params=entropy_estimator_params)\n",
    "    mi_estimator.fit(X, Y, verbose=verbose)\n",
    "    mi_compressed = mi_estimator.estimate(X, Y, verbose=verbose)\n",
    "    \n",
    "    return mi, mi_compressed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def perform_normal_compressed_tests_MI(MI, n_samples, X_dimension, Y_dimension, X_map=None, Y_map=None,\n",
    "                                       X_compressor=None, Y_compressor=None, verbose=0):\n",
    "    \"\"\"\n",
    "    Estimate mutual information for different true values\n",
    "    (transformed normal distribution).\n",
    "    \"\"\"\n",
    "    n_exps = len(MI)\n",
    "    \n",
    "    # Mutual information estimates.\n",
    "    estimated_MI = []\n",
    "    estimated_MI_compressed = []\n",
    "\n",
    "    # Conducting the tests.\n",
    "    for n_exp in range(n_exps):\n",
    "        print(\"\\nn_exp = %d/%d\\n------------\\n\" % (n_exp + 1, n_exps))\n",
    "        mi, compressed_mi = perform_normal_compressed_test(MI[n_exp], n_samples, X_dimension, Y_dimension,\n",
    "                                                           X_map, Y_map, X_compressor, Y_compressor, verbose)\n",
    "        estimated_MI.append(mi)\n",
    "        estimated_MI_compressed.append(compressed_mi)\n",
    "        \n",
    "    return estimated_MI, estimated_MI_compressed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_estimated_compressed_MI(MI, estimated_MI, estimated_MI_compressed, title):\n",
    "    estimated_MI_mean = np.array([estimated_MI[index][0] for index in range(len(estimated_MI))])\n",
    "    estimated_MI_std  = np.array([estimated_MI[index][1] for index in range(len(estimated_MI))])\n",
    "    \n",
    "    estimated_MI_compressed_mean = np.array([estimated_MI_compressed[index][0]\n",
    "                                             for index in range(len(estimated_MI_compressed))])\n",
    "    estimated_MI_compressed_std  = np.array([estimated_MI_compressed[index][1]\n",
    "                                             for index in range(len(estimated_MI_compressed))])\n",
    "    \n",
    "    fig_normal, ax_normal = plt.subplots()\n",
    "\n",
    "    fig_normal.set_figheight(11)\n",
    "    fig_normal.set_figwidth(16)\n",
    "\n",
    "    # Grid.\n",
    "    ax_normal.grid(color='#000000', alpha=0.15, linestyle='-', linewidth=1, which='major')\n",
    "    ax_normal.grid(color='#000000', alpha=0.1, linestyle='-', linewidth=0.5, which='minor')\n",
    "\n",
    "    ax_normal.set_title(title)\n",
    "    ax_normal.set_xlabel(\"$I(X,Y)$\")\n",
    "    ax_normal.set_ylabel(\"$\\\\hat I(X,Y)$\")\n",
    "    \n",
    "    ax_normal.minorticks_on()\n",
    "    \n",
    "    #ax_normal.set_yscale('log')\n",
    "    #ax_normal.set_xscale('log')\n",
    "\n",
    "    ax_normal.plot(MI, MI, label=\"$I(X,Y)$\", color='red')\n",
    "    \n",
    "    ax_normal.plot(MI, estimated_MI_mean, label=\"$\\\\hat I(X,Y)$\")\n",
    "    ax_normal.fill_between(MI, estimated_MI_mean + estimated_MI_std, estimated_MI_mean - estimated_MI_std, alpha=0.2)\n",
    "    \n",
    "    ax_normal.plot(MI, estimated_MI_compressed_mean, label=\"$\\\\hat I_{compr}(X,Y)$\")\n",
    "    ax_normal.fill_between(MI, estimated_MI_compressed_mean + estimated_MI_compressed_std,\n",
    "                           estimated_MI_compressed_mean - estimated_MI_compressed_std, alpha=0.2)\n",
    "\n",
    "    ax_normal.legend(loc='upper left')\n",
    "\n",
    "    ax_normal.set_xlim((0.0, None))\n",
    "    ax_normal.set_ylim((0.0, None))\n",
    "\n",
    "    plt.show();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Global parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# The values of mutual information under study.\n",
    "MI = np.linspace(0.0, 10.0, 41)\n",
    "n_exps = len(MI)\n",
    "\n",
    "# Sample size and dimensions of vectors X and Y.\n",
    "n_samples = 5000"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Images of correlated gaussians"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mutinfo.utils.synthetic import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_dimension = 5\n",
    "Y_dimension = 5\n",
    "latent_dimension = 5\n",
    "\n",
    "img_width = 32\n",
    "img_height = 32\n",
    "\n",
    "experiments_dir = ('gaussian_correlated_%dx%d' % (img_width, img_height))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def imshow_array(array):\n",
    "    \"\"\"Display array of pixels.\"\"\"\n",
    "    plt.axis('off')\n",
    "    plt.imshow((255.0 * array).astype(np.uint8), cmap=plt.get_cmap(\"gray\"), vmin=0, vmax=255)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Train the autoencoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.stats import multivariate_normal"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_train_samples = 6000\n",
    "n_test_samples  = 1000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "random_variable = multivariate_normal()\n",
    "X = random_variable.rvs((n_train_samples + n_test_samples, X_dimension))\n",
    "X = normal_to_uniform(X)\n",
    "\n",
    "distribution = lambda X, Y, params : np.exp(\n",
    "    -10.0 * (\n",
    "        (1.0 + params[:,2,None,None])**2 * (X - params[:,0,None,None])**2 +\n",
    "        (1.0 + params[:,3,None,None])**2 * (Y - params[:,1,None,None])**2 +\n",
    "        (-1.0 + 2.0 * params[:,4,None,None])*(1.0 + params[:,2,None,None])*(1.0 + params[:,3,None,None]) * \n",
    "        (X - params[:,0,None,None])*(Y - params[:,1,None,None])\n",
    "    )\n",
    ")\n",
    "\n",
    "X = params_to_2d_distribution(X, distribution, img_width, img_height)\n",
    "X = np.expand_dims(X, axis=-1)\n",
    "X_train = X[0:n_train_samples]\n",
    "X_test  = X[n_train_samples:n_train_samples + n_test_samples]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_dataset = tf.data.Dataset.from_tensor_slices(X_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "augmentator = tf.keras.Sequential([\n",
    "    tf.keras.layers.Input((img_width, img_height, 1)),\n",
    "    #tf.keras.layers.RandomTranslation(\n",
    "    #    height_factor=(-0.2, 0.2), width_factor=(-0.2, 0.2), fill_mode=\"constant\"\n",
    "    #),\n",
    "    tf.keras.layers.RandomZoom(\n",
    "        height_factor=(-0.2, 0.0), width_factor=(-0.2, 0.0), fill_mode=\"constant\"\n",
    "    )\n",
    "])\n",
    "augmentator.compile()\n",
    "\n",
    "def augment(sample):\n",
    "    sample = augmentator(sample, training=True)\n",
    "    return sample, sample"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "imshow_array(augment(X[0][None,])[0].numpy()[0,:,:,0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_augmented_dataset = X_dataset.shuffle(10000).batch(5000).map(augment, num_parallel_calls=tf.data.AUTOTUNE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def cnn_autoencoder(shape_input, dimension):\n",
    "    # Weight initialization.\n",
    "    init = tf.keras.initializers.RandomNormal(stddev=1e-1)\n",
    "\n",
    "    # Input data.\n",
    "    input_layer = tf.keras.layers.Input(shape_input)\n",
    "    next_layer = input_layer\n",
    "    \n",
    "    # 1 block of layers.\n",
    "    next_layer = tf.keras.layers.GaussianNoise(0.1)(next_layer)\n",
    "    next_layer = tf.keras.layers.Conv2D(filters=4, kernel_size=(3, 3), strides=(1, 1), padding='same', kernel_initializer=init)(next_layer)\n",
    "    next_layer = tf.keras.layers.BatchNormalization()(next_layer)\n",
    "    next_layer = tf.keras.layers.LeakyReLU(alpha=0.2)(next_layer)\n",
    "    next_layer = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), padding='same')(next_layer)\n",
    "    #next_layer = tf.keras.layers.Dropout(rate=0.2)(next_layer)\n",
    "\n",
    "    # 2 block of layers.\n",
    "    next_layer = tf.keras.layers.Conv2D(filters=8, kernel_size=(3, 3), strides=(1, 1), padding='same', kernel_initializer=init)(next_layer)\n",
    "    next_layer = tf.keras.layers.BatchNormalization()(next_layer)\n",
    "    next_layer = tf.keras.layers.LeakyReLU(alpha=0.2)(next_layer)\n",
    "    next_layer = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), padding='same')(next_layer)\n",
    "    #next_layer = tf.keras.layers.Dropout(rate=0.1)(next_layer)\n",
    "    \n",
    "    # 3 block of layers.\n",
    "    next_layer = tf.keras.layers.Conv2D(filters=8, kernel_size=(3, 3), strides=(1, 1), padding='same', kernel_initializer=init)(next_layer)\n",
    "    next_layer = tf.keras.layers.BatchNormalization()(next_layer)\n",
    "    next_layer = tf.keras.layers.LeakyReLU(alpha=0.2)(next_layer)\n",
    "    next_layer = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), padding='same')(next_layer)\n",
    "    #next_layer = tf.keras.layers.Dropout(rate=0.2)(next_layer)\n",
    "    \n",
    "    # 4 block of layers.\n",
    "    next_layer = tf.keras.layers.Conv2D(filters=8, kernel_size=(3, 3), strides=(1, 1), padding='same', kernel_initializer=init)(next_layer)\n",
    "    next_layer = tf.keras.layers.BatchNormalization()(next_layer)\n",
    "    next_layer = tf.keras.layers.LeakyReLU(alpha=0.2)(next_layer)\n",
    "    next_layer = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), padding='same')(next_layer)\n",
    "    #next_layer = tf.keras.layers.Dropout(rate=0.2)(next_layer)\n",
    "    \n",
    "    # 5 block of layers.\n",
    "    next_layer = tf.keras.layers.Conv2D(filters=8, kernel_size=(3, 3), strides=(1, 1), padding='same', kernel_initializer=init)(next_layer)\n",
    "    next_layer = tf.keras.layers.BatchNormalization()(next_layer)\n",
    "    next_layer = tf.keras.layers.LeakyReLU(alpha=0.2)(next_layer)\n",
    "    next_layer = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), padding='same')(next_layer)\n",
    "    #next_layer = tf.keras.layers.Dropout(rate=0.2)(next_layer)\n",
    "    \n",
    "    # 6 block of layers.\n",
    "    next_layer = tf.keras.layers.Conv2D(filters=8, kernel_size=(3, 3), strides=(1, 1), padding='same', kernel_initializer=init)(next_layer)\n",
    "    next_layer = tf.keras.layers.BatchNormalization()(next_layer)\n",
    "    next_layer = tf.keras.layers.LeakyReLU(alpha=0.2)(next_layer)\n",
    "    next_layer = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), padding='same')(next_layer)\n",
    "    #next_layer = tf.keras.layers.Dropout(rate=0.2)(next_layer)\n",
    "\n",
    "    # Bottleneck.\n",
    "    next_layer = tf.keras.layers.Flatten()(next_layer)\n",
    "    next_layer = tf.keras.layers.Dense(dimension, kernel_initializer=init)(next_layer)\n",
    "    #next_layer = tf.keras.layers.BatchNormalization()(next_layer)\n",
    "    bottleneck = tf.keras.layers.Activation('tanh')(next_layer)\n",
    "\n",
    "    # Encoder model.\n",
    "    encoder = tf.keras.Model(input_layer, bottleneck)\n",
    "\n",
    "    # Decoder model begins.\n",
    "    input_code_layer = tf.keras.layers.Input((dimension))\n",
    "    next_layer = input_code_layer\n",
    "    next_layer = tf.keras.layers.GaussianNoise(0.02)(next_layer)\n",
    "    \n",
    "    # 6 block of layers.\n",
    "    #tfa.layers.SpectralNormalization()\n",
    "    next_layer = tf.keras.layers.Dense(1*1*8, kernel_initializer=init)(next_layer)\n",
    "    next_layer = tf.keras.layers.Reshape((1, 1, 8))(next_layer)\n",
    "    next_layer = tf.keras.layers.BatchNormalization()(next_layer)\n",
    "    next_layer = tf.keras.layers.LeakyReLU(alpha=0.2)(next_layer)\n",
    "    \n",
    "    # 5 block of layers.\n",
    "    next_layer = tf.keras.layers.UpSampling2D(size=(2, 2))(next_layer)\n",
    "    next_layer = tf.keras.layers.Conv2D(filters=8, kernel_size=(3, 3), strides=(1, 1), padding='same', kernel_initializer=init)(next_layer)\n",
    "    next_layer = tf.keras.layers.BatchNormalization()(next_layer)\n",
    "    next_layer = tf.keras.layers.LeakyReLU(alpha=0.2)(next_layer)\n",
    "    \n",
    "    # 4 block of layers.\n",
    "    next_layer = tf.keras.layers.UpSampling2D(size=(2, 2))(next_layer)\n",
    "    next_layer = tf.keras.layers.Conv2D(filters=8, kernel_size=(3, 3), strides=(1, 1), padding='same', kernel_initializer=init)(next_layer)\n",
    "    next_layer = tf.keras.layers.BatchNormalization()(next_layer)\n",
    "    next_layer = tf.keras.layers.LeakyReLU(alpha=0.2)(next_layer)\n",
    "    \n",
    "    # 3 block of layers.\n",
    "    next_layer = tf.keras.layers.UpSampling2D(size=(2, 2))(next_layer)\n",
    "    next_layer = tf.keras.layers.Conv2D(filters=8, kernel_size=(3, 3), strides=(1, 1), padding='same', kernel_initializer=init)(next_layer)\n",
    "    next_layer = tf.keras.layers.BatchNormalization()(next_layer)\n",
    "    next_layer = tf.keras.layers.LeakyReLU(alpha=0.2)(next_layer)\n",
    "\n",
    "    # 2 block of layers.\n",
    "    next_layer = tf.keras.layers.UpSampling2D(size=(2, 2))(next_layer)\n",
    "    next_layer = tf.keras.layers.Conv2D(filters=8, kernel_size=(3, 3), strides=(1, 1), padding='same', kernel_initializer=init)(next_layer)\n",
    "    next_layer = tf.keras.layers.BatchNormalization()(next_layer)\n",
    "    next_layer = tf.keras.layers.LeakyReLU(alpha=0.2)(next_layer)\n",
    "\n",
    "    # 1 block of layers.\n",
    "    next_layer = tf.keras.layers.UpSampling2D(size=(2, 2))(next_layer)\n",
    "    next_layer = tf.keras.layers.Conv2D(filters=4, kernel_size=(3, 3), strides=(1, 1), padding='same', kernel_initializer=init)(next_layer)\n",
    "    next_layer = tf.keras.layers.BatchNormalization()(next_layer)\n",
    "    next_layer = tf.keras.layers.LeakyReLU(alpha=0.2)(next_layer)\n",
    "\n",
    "    # 0 block of layers.\n",
    "    next_layer = tf.keras.layers.Conv2D(filters=1, kernel_size=(3, 3), strides=(1, 1), padding='same', kernel_initializer=init)(next_layer)\n",
    "    #next_layer = tf.keras.layers.BatchNormalization()(next_layer)\n",
    "    next_layer = tf.keras.layers.Activation('sigmoid')(next_layer)\n",
    "\n",
    "    output_layer = next_layer\n",
    "\n",
    "    # Model.\n",
    "    decoder = tf.keras.models.Model(input_code_layer, output_layer) # Decoder.\n",
    "    autoencoder = tf.keras.Sequential([encoder, decoder])\n",
    "\n",
    "    # Compile the model.\n",
    "    opt = tf.keras.optimizers.Adam(learning_rate=1e-3)\n",
    "    autoencoder.compile(loss='mae', optimizer=opt)\n",
    "    return encoder, decoder, autoencoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "load_autoencoder = True\n",
    "models_path_ = experiments_path + experiments_dir + \"/models/autoencoder/\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if load_autoencoder:\n",
    "    encoder = tf.keras.models.load_model(models_path_ + \"encoder.h5\")\n",
    "    decoder = tf.keras.models.load_model(models_path_ + \"decoder.h5\")\n",
    "    autoencoder = tf.keras.Sequential([encoder, decoder])\n",
    "    autoencoder.compile(loss='mae', optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "if not load_autoencoder:\n",
    "    encoder, decoder, autoencoder = cnn_autoencoder((img_width, img_height, 1), latent_dimension)\n",
    "    \n",
    "    class CustomCallback(tf.keras.callbacks.Callback):\n",
    "        def on_epoch_end(self, epoch, logs=None):\n",
    "            fig, ax = plt.subplots(2, 2)\n",
    "            fig.set_figheight(8)\n",
    "            fig.set_figwidth(8)\n",
    "            \n",
    "            ax[0][0].axis('off')\n",
    "            ax[0][1].axis('off')\n",
    "            ax[1][0].axis('off')\n",
    "            ax[1][1].axis('off')\n",
    "            \n",
    "            ax[0][0].imshow(X_test[0], cmap=plt.get_cmap(\"gray\"), vmin=0.0, vmax=1.0)\n",
    "            ax[0][1].imshow(autoencoder(X_test[0:1]).numpy()[0,:,:,0],\n",
    "                         cmap=plt.get_cmap(\"gray\"), vmin=0.0, vmax=1.0)\n",
    "            \n",
    "            sample = next(iter(X_augmented_dataset))[0]\n",
    "            \n",
    "            ax[1][0].imshow(sample.numpy()[0,:], cmap=plt.get_cmap(\"gray\"), vmin=0.0, vmax=1.0)\n",
    "            ax[1][1].imshow(autoencoder(sample).numpy()[0,:,:,0],\n",
    "                         cmap=plt.get_cmap(\"gray\"), vmin=0.0, vmax=1.0)\n",
    "            plt.show();\n",
    "    \n",
    "    autoencoder.fit(\n",
    "        X_augmented_dataset,\n",
    "        epochs=2,\n",
    "        validation_data=(X_test, X_test),\n",
    "        callbacks=[CustomCallback()],\n",
    "    )\n",
    "    \n",
    "    # Save the models.\n",
    "    os.makedirs(models_path_, exist_ok=True)\n",
    "    autoencoder.save(models_path_ + \"autoencoder.h5\")\n",
    "    encoder.save(models_path_ + \"encoder.h5\")\n",
    "    decoder.save(models_path_ + \"decoder.h5\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "autoencoder.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def gaussian_mapping(X):\n",
    "    \"\"\" Map Gaussian vector to the coordiantes of the mode (center of the plot) and covariance matrix. \"\"\"\n",
    "    return normal_to_uniform(X)\n",
    "\n",
    "def gaussian_compressor(X):\n",
    "    \"\"\" Parameters to images, then to latent representations. \"\"\"\n",
    "    return encoder(np.expand_dims(params_to_2d_distribution(X, distribution, img_width, img_height),\n",
    "                                  axis=-1)).numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "estimated_MI, estimated_MI_compressed = perform_normal_compressed_tests_MI(MI,\n",
    "    n_samples, X_dimension, Y_dimension, gaussian_mapping, gaussian_mapping,\n",
    "    gaussian_compressor, gaussian_compressor, verbose=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_estimated_compressed_MI(MI, estimated_MI, estimated_MI_compressed, \"Correlated 2D Gaussians\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "save_estimated_MI(MI, estimated_MI, experiments_dir + '/parameters')\n",
    "save_estimated_MI(MI, estimated_MI_compressed, experiments_dir + '/compressed')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"OK\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
