{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Test mutual information estimators"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Preamble"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import scipy.stats as sps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import seaborn as sns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "font = {'family' : 'DejaVu Sans',\n",
    "        'size'   : 18}\n",
    "\n",
    "matplotlib.rc('font', **font)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "import csv\n",
    "\n",
    "from datetime import datetime"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "path = os.path.abspath(os.path.join(os.path.abspath(os.getcwd()), \"../../data/\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "experiments_path = path + \"/mutual_information/synthetic/\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Importing the module"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import mutinfo.estimators.mutual_information as mi_estimators\n",
    "from mutinfo.utils.dependent_norm import multivariate_normal_from_MI"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "### SETTINGS ###\n",
    "%run ./Settings.ipynb"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Standard tests with arbitrary mapping"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def perform_normal_test(mi, n_samples, X_dimension, Y_dimension, X_map=None, Y_map=None, verbose=0):\n",
    "    # Generation.\n",
    "    random_variable = multivariate_normal_from_MI(X_dimension, Y_dimension, mi)\n",
    "    X_Y = random_variable.rvs(n_samples)\n",
    "    X = X_Y[:, 0:X_dimension]\n",
    "    Y = X_Y[:, X_dimension:X_dimension + Y_dimension]\n",
    "        \n",
    "    # Mapping application.\n",
    "    if not X_map is None:\n",
    "        X = X_map(X)\n",
    "           \n",
    "    if not Y_map is None:\n",
    "        Y = Y_map(Y)\n",
    "\n",
    "    # Mutual information estimation.\n",
    "    mi_estimator = mi_estimators.MutualInfoEstimator(entropy_estimator_params=entropy_estimator_params)\n",
    "    mi_estimator.fit(X, Y, verbose=verbose)\n",
    "    \n",
    "    return mi_estimator.estimate(X, Y, verbose=verbose)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Continuous case"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def perform_normal_tests_MI(MI, n_samples, X_dimension, Y_dimension, X_map=None, Y_map=None, verbose=0):\n",
    "    \"\"\"\n",
    "    Estimate mutual information for different true values\n",
    "    (transformed normal distribution).\n",
    "    \"\"\"\n",
    "    n_exps = len(MI)\n",
    "    \n",
    "    # Mutual information estimates.\n",
    "    estimated_MI = []\n",
    "\n",
    "    # Conducting the tests.\n",
    "    for n_exp in range(n_exps):\n",
    "        print(\"\\nn_exp = %d/%d\\n------------\\n\" % (n_exp + 1, n_exps))\n",
    "        mi = perform_normal_test(MI[n_exp], n_samples, X_dimension, Y_dimension,\n",
    "                                 X_map, Y_map, verbose)\n",
    "        estimated_MI.append(mi)\n",
    "        \n",
    "    return np.asarray(estimated_MI)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_estimated_MI(MI, estimated_MI, title, Bandwidth=None, bandwidth_scale=10.0):\n",
    "    estimated_MI_mean = estimated_MI[:,0]\n",
    "    estimated_MI_std  = estimated_MI[:,1]\n",
    "    \n",
    "    fig_normal, ax_normal = plt.subplots()\n",
    "\n",
    "    fig_normal.set_figheight(11)\n",
    "    fig_normal.set_figwidth(16)\n",
    "\n",
    "    # Grid.\n",
    "    ax_normal.grid(color='#000000', alpha=0.15, linestyle='-', linewidth=1, which='major')\n",
    "    ax_normal.grid(color='#000000', alpha=0.1, linestyle='-', linewidth=0.5, which='minor')\n",
    "\n",
    "    ax_normal.set_title(title)\n",
    "    ax_normal.set_xlabel(\"$I(X,Y)$\")\n",
    "    ax_normal.set_ylabel(\"$\\\\hat I(X,Y)$\")\n",
    "    \n",
    "    ax_normal.minorticks_on()\n",
    "    \n",
    "    #ax_normal.set_yscale('log')\n",
    "    #ax_normal.set_xscale('log')\n",
    "\n",
    "    ax_normal.plot(MI, MI, label=\"$I(X,Y)$\", color='red')\n",
    "    ax_normal.plot(MI, estimated_MI_mean, label=\"$\\\\hat I(X,Y)$\")        \n",
    "    ax_normal.fill_between(MI, estimated_MI_mean + estimated_MI_std, estimated_MI_mean - estimated_MI_std, alpha=0.2)\n",
    "    \n",
    "    if not Bandwidth is None:\n",
    "        ax_normal.plot(MI, Bandwidth * bandwidth_scale, label=\"bandwidth\")\n",
    "\n",
    "    ax_normal.legend(loc='upper left')\n",
    "\n",
    "    ax_normal.set_xlim((0.0, None))\n",
    "    ax_normal.set_ylim((0.0, None))\n",
    "\n",
    "    plt.show();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Global parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# The values of mutual information under study.\n",
    "#MI = [0.0, 0.1, 0.2, 0.3, 0.5, 0.7, 1.0, 1.5, 2.0, 3.0, 5.0, 6.0, 8.0, 10.0]\n",
    "#MI = [0.0, 0.5, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 8.0, 10.0]\n",
    "MI = np.linspace(0.0, 10.0, 5)\n",
    "#MI = [0.0, 2.0, 5.0]\n",
    "n_exps = len(MI)\n",
    "\n",
    "# Sample size and dimensions of vectors X and Y.\n",
    "n_samples = 60*1024#20000\n",
    "X_dimension = 10\n",
    "Y_dimension = 4"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Gaussian random vector"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Mutual information estimate.\n",
    "estimated_MI = perform_normal_tests_MI(MI, n_samples, X_dimension, Y_dimension, verbose=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_estimated_MI(MI, estimated_MI, \"Gaussian vectors\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "save_estimated_MI(MI, estimated_MI, 'normal')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Uniformly distributed vectors\n",
    "\n",
    "Apply to the components of a normal random vector Gaussian cumulative distribution function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mutinfo.utils.synthetic import normal_to_uniform"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mutinfo.utils.matrices import get_scaling_matrix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _uniform_pp():\n",
    "    _X_Y = multivariate_normal_from_MI(1, 1, mutual_information=2.0).rvs(1000)\n",
    "    _X = _X_Y[:, 0:1]\n",
    "    _Y = _X_Y[:, 1:2]\n",
    "    _X = normal_to_uniform(_X)\n",
    "    _Y = normal_to_uniform(_Y)\n",
    "    _X_Y = np.concatenate([_X, _Y], axis=1)\n",
    "    #M = get_scaling_matrix(np.cov(_X_Y, rowvar=False))\n",
    "    #_X_Y = _X_Y @ M\n",
    "    print(np.cov(_X_Y, rowvar=False))\n",
    "\n",
    "    pp = sns.pairplot(pd.DataFrame(_X_Y), height = 2.0, aspect=1.6,\n",
    "                      plot_kws=dict(edgecolor=\"k\", linewidth=0.0, alpha=0.05, size=0.01, s=0.01),\n",
    "                      diag_kind=\"kde\", diag_kws=dict(shade=True))\n",
    "\n",
    "    fig = pp.fig\n",
    "    fig.subplots_adjust(top=0.93, wspace=0.3)\n",
    "    t = fig.suptitle(\"Pairplot\", fontsize=14)\n",
    "    \n",
    "_uniform_pp()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.stats import norm\n",
    "\n",
    "def _test(X):\n",
    "    n, dim = X.shape\n",
    "    X = normal_to_uniform(X)\n",
    "    \n",
    "    for index in range(dim):\n",
    "        X[:,index] = (np.argsort(np.argsort(X[:,index])) + 1) / (n + 2)\n",
    "        X[:,index] = norm.ppf(X[:,index])\n",
    "        \n",
    "    return X"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Mutual information estimate.\n",
    "estimated_MI = perform_normal_tests_MI(MI, n_samples, X_dimension, Y_dimension,\n",
    "                                       X_map=normal_to_uniform, Y_map=normal_to_uniform, verbose=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_estimated_MI(MI, estimated_MI, \"Uniform distribution\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "save_estimated_MI(MI, estimated_MI, 'uniform')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Rings\n",
    "\n",
    "We obtain uniform distributions according to the previous section.\n",
    "Then we apply the following transformation:\n",
    "\n",
    "$$\n",
    "\\begin{cases}\n",
    "x' = [R \\cdot x + r \\cdot (1 - x)] \\cdot \\cos(2 \\pi y) \\\\\n",
    "y' = [R \\cdot x + r \\cdot (1 - x)] \\cdot \\sin(2 \\pi y) \\\\\n",
    "\\end{cases}\n",
    "$$\n",
    "\n",
    "It is required to have dimension of both vectors equals $ 2 $."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "r = 1.0\n",
    "R = 2.0\n",
    "\n",
    "def ring_mapping(X):\n",
    "    \"\"\"\n",
    "    Gaussian vector to a ring.\n",
    "    \"\"\"\n",
    "    \n",
    "    assert len(X.shape) == 2\n",
    "    assert X.shape[1] == 2\n",
    "    \n",
    "    X = normal_to_uniform(X)\n",
    "    new_X = np.zeros_like(X)\n",
    "    for index in range(X.shape[0]):\n",
    "        rho = R * X[index][0] + r * (1.0 - X[index][0])\n",
    "        phi = 2.0 * np.pi * X[index][1]\n",
    "        \n",
    "        new_X[index][0] = rho * np.cos(phi)\n",
    "        new_X[index][1] = rho * np.sin(phi)\n",
    "    \n",
    "    return new_X"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _rings_pp():\n",
    "    _X_Y = multivariate_normal_from_MI(2, 2, mutual_information=10.0).rvs(10000)\n",
    "    _X = _X_Y[:, 0:2]\n",
    "    _Y = _X_Y[:, 2:4]\n",
    "    _X = ring_mapping(_X)\n",
    "    _Y = ring_mapping(_Y)\n",
    "    _X_Y = np.concatenate([_X, _Y], axis=1)\n",
    "\n",
    "    pp = sns.pairplot(pd.DataFrame(_X_Y), height = 2.0, aspect=1.6,\n",
    "                      plot_kws=dict(edgecolor=\"k\", linewidth=0.0, alpha=0.05, size=0.01, s=0.01),\n",
    "                      diag_kind=\"kde\", diag_kws=dict(shade=True))\n",
    "\n",
    "    fig = pp.fig\n",
    "    fig.subplots_adjust(top=0.93, wspace=0.3)\n",
    "    t = fig.suptitle(\"Pairplot\", fontsize=14)\n",
    "    \n",
    "_rings_pp()"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {},
   "source": [
    "# Mutual information estimation.\n",
    "estimated_MI = perform_normal_tests_MI(MI, n_samples, 2, 2,\n",
    "                                       X_map=ring_mapping, Y_map=ring_mapping, verbose=10)"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {},
   "source": [
    "plot_estimated_MI(MI, estimated_MI, \"Rings\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Continuous-discrete case"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def perform_uniform_discrete_test(n_labels, n_samples, X_dimension, X_map=None, verbose=0):\n",
    "    # Generation.  \n",
    "    X_random_variable = sps.uniform(scale=1.0)\n",
    "    X = np.zeros(shape=(n_samples, X_dimension))\n",
    "    for dim in range(X_dimension):\n",
    "        X[:,dim] = X_random_variable.rvs(size=n_samples)\n",
    "    \n",
    "    # The discrete RV is obtained from the first component of the continuous RV.\n",
    "    Y = (np.floor(X[:,0] * n_labels)).astype(int)\n",
    "        \n",
    "    # Mapping application.\n",
    "    if not X_map is None:\n",
    "        X = X_map(X)\n",
    "        #X_Y = np.concatenate([X, Y], axis=1)\n",
    "\n",
    "    # Entropy estimation.\n",
    "    mi_estimator = mi_estimators.MutualInfoEstimator(\n",
    "        Y_is_discrete=True,\n",
    "        entropy_estimator_params=entropy_estimator_params\n",
    "    )\n",
    "    mi_estimator.fit(X, Y, verbose = verbose)\n",
    "    \n",
    "    return mi_estimator.estimate(X, Y, verbose=verbose)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def perform_uniform_discrete_test_MI(N_labels, n_samples, X_dimension, X_map=None, verbose=0):\n",
    "    \"\"\"\n",
    "    Estimate mutual information for different true values\n",
    "    (uniform distribution).\n",
    "    \"\"\"\n",
    "    n_exps = len(N_labels)\n",
    "    MI = np.array([np.log(N_labels[index]) for index in range(n_exps)])\n",
    "    \n",
    "    # Mutual information estimates.\n",
    "    estimated_MI = []\n",
    "\n",
    "    # Conducting the tests.\n",
    "    for n_exp in range(n_exps):\n",
    "        print(\"\\nn_exp = %d/%d\\n------------\\n\" % (n_exp + 1, n_exps))\n",
    "        estimated_MI.append(perform_uniform_discrete_test(N_labels[n_exp], n_samples, X_dimension, X_map, verbose))\n",
    "        \n",
    "    return np.asarray(MI), np.asarray(estimated_MI)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Global parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Number of classes.\n",
    "if method == 'KDE':\n",
    "    min_in_group = 10\n",
    "elif method == 'KL':\n",
    "    min_in_group = 5 * k_neighbours\n",
    "\n",
    "N_labels = [2**n for n in range(int(np.floor(np.log2(n_samples / min_in_group))))]\n",
    "print(N_labels[-1])\n",
    "n_exps = len(N_labels)\n",
    "\n",
    "# Sample size and dimensions of vectors X and Y.\n",
    "#X_dimension = 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "MI, estimated_MI = perform_uniform_discrete_test_MI(N_labels, n_samples, X_dimension, verbose=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_estimated_MI(MI, estimated_MI, \"Uniform distribution with discrete label\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "save_estimated_MI(MI, estimated_MI, 'uniform_discrete')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Varying dimensionality"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def perform_normal_tests_dim(mi, n_samples, dimensions, X_map=None, Y_map=None, verbose=0):\n",
    "    \"\"\"\n",
    "    Estimate mutual information for different true values.\n",
    "    \"\"\"\n",
    "    n_exps = len(dimensions)\n",
    "    \n",
    "    # Mutual information estimates.\n",
    "    estimated_MI = []\n",
    "    \n",
    "    # Cunducting the tests.\n",
    "    for n_exp in range(n_exps):\n",
    "        print(\"\\nn_exp = %d/%d\\n------------\\n\" % (n_exp + 1, n_exps))\n",
    "        est_mi = perform_normal_test(mi, n_samples, dimensions[n_exp], dimensions[n_exp],\n",
    "                                     X_map, Y_map, verbose)\n",
    "        estimated_MI.append(est_mi)\n",
    "        \n",
    "    return estimated_MI"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_estimated_dim(dimensions, mi, estimated_MI, title):\n",
    "    estimated_MI_mean = np.array([estimated_MI[index][0] for index in range(len(estimated_MI))])\n",
    "    estimated_MI_std  = np.array([estimated_MI[index][1] for index in range(len(estimated_MI))])\n",
    "    \n",
    "    fig_normal, ax_normal = plt.subplots()\n",
    "\n",
    "    fig_normal.set_figheight(11)\n",
    "    fig_normal.set_figwidth(16)\n",
    "\n",
    "    # Grid.\n",
    "    ax_normal.grid(color='#000000', alpha=0.15, linestyle='-', linewidth=1, which='major')\n",
    "    ax_normal.grid(color='#000000', alpha=0.1, linestyle='-', linewidth=0.5, which='minor')\n",
    "\n",
    "    ax_normal.set_title(title)\n",
    "    ax_normal.set_xlabel(\"Dimensionality of $ X $ and $ Y $\")\n",
    "    ax_normal.set_ylabel(\"$\\\\hat I(X,Y)$\")\n",
    "    \n",
    "    ax_normal.minorticks_on()\n",
    "    \n",
    "    #ax_normal.set_yscale('log')\n",
    "    #ax_normal.set_xscale('log')\n",
    "\n",
    "    ax_normal.plot(dimensions, np.ones_like(dimensions) * mi, label=\"$I(X,Y)$\", color='red')\n",
    "    ax_normal.plot(dimensions, estimated_MI_mean, label=\"$\\\\hat I(X,Y)$\")\n",
    "    ax_normal.fill_between(dimensions, estimated_MI_mean + estimated_MI_std, estimated_MI_mean - estimated_MI_std, alpha=0.2)\n",
    "\n",
    "    ax_normal.legend(loc='upper left')\n",
    "\n",
    "    ax_normal.set_xlim((0.0, None))\n",
    "    ax_normal.set_ylim((0.0, None))\n",
    "\n",
    "    plt.show();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Global parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dimensions = [1, 2, 3, 4, 5, 6, 8, 10, 12, 14, 16, 20, 25, 30, 40]\n",
    "#dimensions = [1, 2, 4, 6, 8, 12, 16, 20, 30, 40]\n",
    "mi = 2.0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Gaussian random vector"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Mutual information estimation.\n",
    "#estimated_MI = perform_normal_tests_dim(mi, n_samples, dimensions, verbose=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#plot_estimated_dim(dimensions, mi, estimated_MI, \"Gaussian vectors\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
