{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Test mutual information estimators"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Preamble"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import scipy.stats as sps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "font = {'family' : 'DejaVu Sans',\n",
    "        'size'   : 18}\n",
    "\n",
    "matplotlib.rc('font', **font)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "import csv\n",
    "\n",
    "from datetime import datetime"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "tneJ2JaEwztO"
   },
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "path = os.path.abspath(os.path.join(os.path.abspath(os.getcwd()), \"../../data/\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "experiments_path = path + \"/mutual_information/synthetic/\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Importing the module"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import mutinfo.estimators.mutual_information as mi_estimators\n",
    "from mutinfo.utils.dependent_norm import multivariate_normal_from_MI"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "### SETTINGS ###\n",
    "%run ./Settings.ipynb"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Standard tests with arbitrary mapping"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def perform_normal_compressed_test(mi, n_samples, X_dimension, Y_dimension, X_map=None, Y_map=None,\n",
    "                                   X_compressor=None, Y_compressor=None, verbose=0):\n",
    "    # Generation.\n",
    "    random_variable = multivariate_normal_from_MI(X_dimension, Y_dimension, mi)\n",
    "    X_Y = random_variable.rvs(n_samples)\n",
    "    X = X_Y[:, 0:X_dimension]\n",
    "    Y = X_Y[:, X_dimension:X_dimension + Y_dimension]\n",
    "        \n",
    "    # Mapping application.\n",
    "    if not X_map is None:\n",
    "        X = X_map(X)\n",
    "           \n",
    "    if not Y_map is None:\n",
    "        Y = Y_map(Y)\n",
    "        \n",
    "    # Mutual information estimation.\n",
    "    mi_estimator = mi_estimators.MutualInfoEstimator(entropy_estimator_params=entropy_estimator_params)\n",
    "    mi_estimator.fit(X, Y, verbose=verbose)\n",
    "    mi = mi_estimator.estimate(X, Y, verbose=verbose)\n",
    "    \n",
    "    # Mutual information estimation for compressed representation.\n",
    "    mi_estimator = mi_estimators.LossyMutualInfoEstimator(X_compressor, Y_compressor,\n",
    "                                                          entropy_estimator_params=entropy_estimator_params)\n",
    "    mi_estimator.fit(X, Y, verbose=verbose)\n",
    "    mi_compressed = mi_estimator.estimate(X, Y, verbose=verbose)\n",
    "    \n",
    "    return mi, mi_compressed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def perform_normal_compressed_tests_MI(MI, n_samples, X_dimension, Y_dimension, X_map=None, Y_map=None,\n",
    "                                       X_compressor=None, Y_compressor=None, verbose=0):\n",
    "    \"\"\"\n",
    "    Estimate mutual information for different true values\n",
    "    (transformed normal distribution).\n",
    "    \"\"\"\n",
    "    n_exps = len(MI)\n",
    "    \n",
    "    # Mutual information estimates.\n",
    "    estimated_MI = []\n",
    "    estimated_MI_compressed = []\n",
    "\n",
    "    # Conducting the tests.\n",
    "    for n_exp in range(n_exps):\n",
    "        print(\"\\nn_exp = %d/%d\\n------------\\n\" % (n_exp + 1, n_exps))\n",
    "        mi, compressed_mi = perform_normal_compressed_test(MI[n_exp], n_samples, X_dimension, Y_dimension,\n",
    "                                                           X_map, Y_map, X_compressor, Y_compressor, verbose)\n",
    "        estimated_MI.append(mi)\n",
    "        estimated_MI_compressed.append(compressed_mi)\n",
    "        \n",
    "    return estimated_MI, estimated_MI_compressed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_estimated_compressed_MI(MI, estimated_MI, estimated_MI_compressed, title):\n",
    "    estimated_MI_mean = np.array([estimated_MI[index][0] for index in range(len(estimated_MI))])\n",
    "    estimated_MI_std  = np.array([estimated_MI[index][1] for index in range(len(estimated_MI))])\n",
    "    \n",
    "    estimated_MI_compressed_mean = np.array([estimated_MI_compressed[index][0]\n",
    "                                             for index in range(len(estimated_MI_compressed))])\n",
    "    estimated_MI_compressed_std  = np.array([estimated_MI_compressed[index][1]\n",
    "                                             for index in range(len(estimated_MI_compressed))])\n",
    "    \n",
    "    fig_normal, ax_normal = plt.subplots()\n",
    "\n",
    "    fig_normal.set_figheight(11)\n",
    "    fig_normal.set_figwidth(16)\n",
    "\n",
    "    # Grid.\n",
    "    ax_normal.grid(color='#000000', alpha=0.15, linestyle='-', linewidth=1, which='major')\n",
    "    ax_normal.grid(color='#000000', alpha=0.1, linestyle='-', linewidth=0.5, which='minor')\n",
    "\n",
    "    ax_normal.set_title(title)\n",
    "    ax_normal.set_xlabel(\"$I(X,Y)$\")\n",
    "    ax_normal.set_ylabel(\"$\\\\hat I(X,Y)$\")\n",
    "    \n",
    "    ax_normal.minorticks_on()\n",
    "    \n",
    "    #ax_normal.set_yscale('log')\n",
    "    #ax_normal.set_xscale('log')\n",
    "\n",
    "    ax_normal.plot(MI, MI, label=\"$I(X,Y)$\", color='red')\n",
    "    \n",
    "    ax_normal.plot(MI, estimated_MI_mean, label=\"$\\\\hat I(X,Y)$\")\n",
    "    ax_normal.fill_between(MI, estimated_MI_mean + estimated_MI_std, estimated_MI_mean - estimated_MI_std, alpha=0.2)\n",
    "    \n",
    "    ax_normal.plot(MI, estimated_MI_compressed_mean, label=\"$\\\\hat I_{compr}(X,Y)$\")\n",
    "    ax_normal.fill_between(MI, estimated_MI_compressed_mean + estimated_MI_compressed_std,\n",
    "                           estimated_MI_compressed_mean - estimated_MI_compressed_std, alpha=0.2)\n",
    "\n",
    "    ax_normal.legend(loc='upper left')\n",
    "\n",
    "    ax_normal.set_xlim((0.0, None))\n",
    "    ax_normal.set_ylim((0.0, None))\n",
    "\n",
    "    plt.show();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Global parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# The values of mutual information under study.\n",
    "MI = np.linspace(0.0, 10.0, 41)\n",
    "n_exps = len(MI)\n",
    "\n",
    "# Sample size and dimensions of vectors X and Y.\n",
    "n_samples = 5000"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Images of rectangles"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mutinfo.utils.synthetic import normal_to_rectangle_coords\n",
    "from mutinfo.utils.synthetic import rectangle_coords_to_rectangles"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_dimension = 4\n",
    "Y_dimension = 4\n",
    "latent_dimension = 4\n",
    "\n",
    "min_delta = 2\n",
    "img_width = 32\n",
    "img_height = 32\n",
    "\n",
    "experiments_dir = ('rectangles_%dx%d' % (img_width, img_height))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def imshow_array(array):\n",
    "    \"\"\"Display array of pixels.\"\"\"\n",
    "    plt.axis('off')\n",
    "    plt.imshow((255.0 * array).astype(np.uint8), cmap=plt.get_cmap(\"gray\"), vmin=0, vmax=255)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Train the autoencoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.stats import multivariate_normal"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_train_samples = 6000\n",
    "n_test_samples  = 1000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "random_variable = multivariate_normal()\n",
    "X = random_variable.rvs((n_train_samples + n_test_samples, X_dimension))\n",
    "X = rectangle_coords_to_rectangles(normal_to_rectangle_coords(X, min_delta, img_width, min_delta, img_height), img_width, img_height)\n",
    "X_train = X[0:n_train_samples]\n",
    "X_test  = X[n_train_samples:n_train_samples + n_test_samples]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.decomposition import PCA\n",
    "\n",
    "pca = PCA(n_components=latent_dimension).fit(X_train.reshape(X_train.shape[0], -1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def rectangles_mapping(X):\n",
    "    \"\"\" Map Gaussian vector to the coordinates of the rectangle. \"\"\"\n",
    "    return normal_to_rectangle_coords(X, min_delta, img_width, min_delta, img_height)\n",
    "\n",
    "def rectangles_compressor(X):\n",
    "    \"\"\" Parameters to images, then to latent representations. \"\"\"\n",
    "    images = rectangle_coords_to_rectangles(X, img_width, img_height)\n",
    "    return pca.transform(images.reshape(images.shape[0], -1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "estimated_MI, estimated_MI_compressed = perform_normal_compressed_tests_MI(MI,\n",
    "    n_samples, X_dimension, Y_dimension, rectangles_mapping, rectangles_mapping,\n",
    "    rectangles_compressor, rectangles_compressor, verbose=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_estimated_compressed_MI(MI, estimated_MI, estimated_MI_compressed, \"Rectangles\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "save_estimated_MI(MI, estimated_MI, experiments_dir + '/coordinates')\n",
    "save_estimated_MI(MI, estimated_MI_compressed, experiments_dir + '/compressed/PCA')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"OK\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
