{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Test mutual information estimators"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Preamble"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import scipy.stats as sps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "font = {'family' : 'DejaVu Sans',\n",
    "        'size'   : 18}\n",
    "\n",
    "matplotlib.rc('font', **font)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "import csv\n",
    "\n",
    "from datetime import datetime"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "tneJ2JaEwztO"
   },
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "path = os.path.abspath(os.path.join(os.path.abspath(os.getcwd()), \"../../data/\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "experiments_path = path + \"/mutual_information/synthetic/\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Importing the module"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import mutinfo.estimators.mutual_information as mi_estimators\n",
    "from mutinfo.utils.dependent_norm import multivariate_normal_from_MI"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "### SETTINGS ###\n",
    "%run ./Settings.ipynb"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Standard tests with arbitrary mapping"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def perform_normal_compressed_test(mi, n_samples, X_dimension, Y_dimension, X_map=None, Y_map=None,\n",
    "                                   X_compressor=None, Y_compressor=None, verbose=0):\n",
    "    # Generation.\n",
    "    random_variable = multivariate_normal_from_MI(X_dimension, Y_dimension, mi)\n",
    "    X_Y = random_variable.rvs(n_samples)\n",
    "    X = X_Y[:, 0:X_dimension]\n",
    "    Y = X_Y[:, X_dimension:X_dimension + Y_dimension]\n",
    "        \n",
    "    # Mapping application.\n",
    "    if not X_map is None:\n",
    "        X = X_map(X)\n",
    "           \n",
    "    if not Y_map is None:\n",
    "        Y = Y_map(Y)\n",
    "        \n",
    "    # Mutual information estimation.\n",
    "    mi_estimator = mi_estimators.MutualInfoEstimator(entropy_estimator_params=entropy_estimator_params)\n",
    "    mi_estimator.fit(X, Y, verbose=verbose)\n",
    "    mi = mi_estimator.estimate(X, Y, verbose=verbose)\n",
    "    \n",
    "    # Mutual information estimation for compressed representation.\n",
    "    mi_estimator = mi_estimators.LossyMutualInfoEstimator(X_compressor, Y_compressor,\n",
    "                                                          entropy_estimator_params=entropy_estimator_params)\n",
    "    mi_estimator.fit(X, Y, verbose=verbose)\n",
    "    mi_compressed = mi_estimator.estimate(X, Y, verbose=verbose)\n",
    "    \n",
    "    return mi, mi_compressed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def perform_normal_compressed_tests_MI(MI, n_samples, X_dimension, Y_dimension, X_map=None, Y_map=None,\n",
    "                                       X_compressor=None, Y_compressor=None, verbose=0):\n",
    "    \"\"\"\n",
    "    Estimate mutual information for different true values\n",
    "    (transformed normal distribution).\n",
    "    \"\"\"\n",
    "    n_exps = len(MI)\n",
    "    \n",
    "    # Mutual information estimates.\n",
    "    estimated_MI = []\n",
    "    estimated_MI_compressed = []\n",
    "\n",
    "    # Conducting the tests.\n",
    "    for n_exp in range(n_exps):\n",
    "        print(\"\\nn_exp = %d/%d\\n------------\\n\" % (n_exp + 1, n_exps))\n",
    "        mi, compressed_mi = perform_normal_compressed_test(MI[n_exp], n_samples, X_dimension, Y_dimension,\n",
    "                                                           X_map, Y_map, X_compressor, Y_compressor, verbose)\n",
    "        estimated_MI.append(mi)\n",
    "        estimated_MI_compressed.append(compressed_mi)\n",
    "        \n",
    "    return estimated_MI, estimated_MI_compressed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_estimated_compressed_MI(MI, estimated_MI, estimated_MI_compressed, title):\n",
    "    estimated_MI_mean = np.array([estimated_MI[index][0] for index in range(len(estimated_MI))])\n",
    "    estimated_MI_std  = np.array([estimated_MI[index][1] for index in range(len(estimated_MI))])\n",
    "    \n",
    "    estimated_MI_compressed_mean = np.array([estimated_MI_compressed[index][0]\n",
    "                                             for index in range(len(estimated_MI_compressed))])\n",
    "    estimated_MI_compressed_std  = np.array([estimated_MI_compressed[index][1]\n",
    "                                             for index in range(len(estimated_MI_compressed))])\n",
    "    \n",
    "    fig_normal, ax_normal = plt.subplots()\n",
    "\n",
    "    fig_normal.set_figheight(11)\n",
    "    fig_normal.set_figwidth(16)\n",
    "\n",
    "    # Grid.\n",
    "    ax_normal.grid(color='#000000', alpha=0.15, linestyle='-', linewidth=1, which='major')\n",
    "    ax_normal.grid(color='#000000', alpha=0.1, linestyle='-', linewidth=0.5, which='minor')\n",
    "\n",
    "    ax_normal.set_title(title)\n",
    "    ax_normal.set_xlabel(\"$I(X,Y)$\")\n",
    "    ax_normal.set_ylabel(\"$\\\\hat I(X,Y)$\")\n",
    "    \n",
    "    ax_normal.minorticks_on()\n",
    "    \n",
    "    #ax_normal.set_yscale('log')\n",
    "    #ax_normal.set_xscale('log')\n",
    "\n",
    "    ax_normal.plot(MI, MI, label=\"$I(X,Y)$\", color='red')\n",
    "    \n",
    "    ax_normal.plot(MI, estimated_MI_mean, label=\"$\\\\hat I(X,Y)$\")\n",
    "    ax_normal.fill_between(MI, estimated_MI_mean + estimated_MI_std, estimated_MI_mean - estimated_MI_std, alpha=0.2)\n",
    "    \n",
    "    ax_normal.plot(MI, estimated_MI_compressed_mean, label=\"$\\\\hat I_{compr}(X,Y)$\")\n",
    "    ax_normal.fill_between(MI, estimated_MI_compressed_mean + estimated_MI_compressed_std,\n",
    "                           estimated_MI_compressed_mean - estimated_MI_compressed_std, alpha=0.2)\n",
    "\n",
    "    ax_normal.legend(loc='upper left')\n",
    "\n",
    "    ax_normal.set_xlim((0.0, None))\n",
    "    ax_normal.set_ylim((0.0, None))\n",
    "\n",
    "    plt.show();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Global parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# The values of mutual information under study.\n",
    "MI = np.linspace(0.0, 4.0, 41)\n",
    "n_exps = len(MI)\n",
    "\n",
    "# Sample size and dimensions of vectors X and Y.\n",
    "n_samples = 5000"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Images of rectangles"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mutinfo.utils.synthetic import normal_to_uniform"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_dimension = 1\n",
    "Y_dimension = 1\n",
    "latent_dimension = 2\n",
    "\n",
    "embedding_dimension = 32\n",
    "\n",
    "experiments_dir = ('anti_PCA_%d' % (embedding_dimension))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Train the autoencoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.stats import multivariate_normal"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_train_samples = 6000\n",
    "n_test_samples  = 1000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def latent_transform(xi):\n",
    "    Z = 2.0 * normal_to_uniform(multivariate_normal().rvs((xi.shape[0], 1)))[:,None] - 1.0\n",
    "    X = 2.0 * normal_to_uniform(xi) - 1.0\n",
    "\n",
    "    return np.concatenate((X, Z), axis=1)\n",
    "\n",
    "\n",
    "koeffs = np.arange(0, embedding_dimension)[None,:]\n",
    "\n",
    "def embedding_transform(latent_X):\n",
    "    a = latent_X[:,0,None] + latent_X[:,1,None] * koeffs\n",
    "    result = np.maximum(a, 0.1 * a * koeffs)\n",
    "    result[:,0] = latent_X[:,0]\n",
    "    result[:,1] = latent_X[:,1]\n",
    "    \n",
    "    return result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "random_variable = multivariate_normal()\n",
    "\n",
    "latent_X = latent_transform(random_variable.rvs((n_train_samples + n_test_samples, X_dimension))[:,None])\n",
    "X = embedding_transform(latent_X)\n",
    "X_train = X[0:n_train_samples]\n",
    "X_test  = X[n_train_samples:n_train_samples + n_test_samples]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## PCA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.decomposition import PCA\n",
    "\n",
    "pca = PCA(n_components=latent_dimension).fit(X_train.reshape(X_train.shape[0], -1))\n",
    "np.mean((pca.inverse_transform(pca.transform(X_test)) - X_test)**2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "PCA_latent = pca.transform(X_test)\n",
    "plt.scatter(PCA_latent[:,0], PCA_latent[:,1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "estimated_MI, estimated_MI_compressed = perform_normal_compressed_tests_MI(MI,\n",
    "    n_samples, X_dimension, Y_dimension, latent_transform, latent_transform,\n",
    "    lambda x : pca.transform(embedding_transform(x)), lambda x : pca.transform(embedding_transform(x)), verbose=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_estimated_compressed_MI(MI, estimated_MI, estimated_MI_compressed, \"Anti-PCA\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "save_estimated_MI(MI, estimated_MI, experiments_dir + '/coordinates')\n",
    "save_estimated_MI(MI, estimated_MI_compressed, experiments_dir + '/compressed/PCA')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Autoencoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "#device = \"cpu\"\n",
    "print(\"Device: \" + device)\n",
    "print(f\"Devices count: {torch.cuda.device_count()}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mutinfo.torch.layers import AdditiveGaussianNoise"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "encoder = torch.nn.Sequential(\n",
    "    #torch.nn.Linear(embedding_dimension, embedding_dimension),\n",
    "    #torch.nn.LeakyReLU(),\n",
    "    torch.nn.Linear(embedding_dimension, latent_dimension),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "decoder = torch.nn.Sequential(\n",
    "    AdditiveGaussianNoise(0.01, relative_scale=True, enabled_on_inference=False),\n",
    "    torch.nn.Linear(latent_dimension, embedding_dimension),\n",
    "    torch.nn.LeakyReLU(),\n",
    "    torch.nn.Linear(embedding_dimension, embedding_dimension),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "autoencoder = torch.nn.Sequential(encoder, decoder).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "optim = torch.optim.Adam(autoencoder.parameters(), lr=1e-3)\n",
    "loss = torch.nn.MSELoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = torch.utils.data.TensorDataset(torch.tensor(X_train, dtype=torch.float32, device=device))\n",
    "dataloader = torch.utils.data.DataLoader(dataset, batch_size=1000, shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import trange\n",
    "\n",
    "n_epochs = 2000\n",
    "\n",
    "for epoch in trange(n_epochs):\n",
    "    for batch, in dataloader:\n",
    "        optim.zero_grad()\n",
    "\n",
    "        y = autoencoder(batch)\n",
    "        loss_value = loss(batch, y)\n",
    "        loss_value.backward()\n",
    "\n",
    "        optim.step()\n",
    "    \n",
    "    if epoch % 100 == 0:\n",
    "        print(loss_value.item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "encoder = encoder.eval().double()\n",
    "decoder = decoder.eval().double()\n",
    "autoencoder = autoencoder.eval().double()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_train = torch.tensor(X_train, dtype=torch.float64, device=device)\n",
    "x_test = torch.tensor(X_test, dtype=torch.float64, device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.mean((autoencoder(x_test) - x_test)**2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.mean((autoencoder(x_test)[:,0] - x_test[:,0])**2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "AE_latent = encoder(x_test).detach().cpu().numpy()\n",
    "plt.scatter(AE_latent[:,0], AE_latent[:,1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(X_test[:,0], AE_latent[:,0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [],
   "source": [
    "estimated_MI, estimated_MI_compressed = perform_normal_compressed_tests_MI(\n",
    "    MI,\n",
    "    n_samples, X_dimension, Y_dimension, latent_transform, latent_transform,\n",
    "    lambda x : encoder(torch.tensor(embedding_transform(x), dtype=torch.float64, device=device)).detach().cpu().numpy(),\n",
    "    lambda x : encoder(torch.tensor(embedding_transform(x), dtype=torch.float64, device=device)).detach().cpu().numpy(),\n",
    "    #lambda x : autoencoder(torch.tensor(embedding_transform(x), dtype=torch.float64, device=device)).detach().cpu().numpy()[:,:2],\n",
    "    #lambda x : autoencoder(torch.tensor(embedding_transform(x), dtype=torch.float64, device=device)).detach().cpu().numpy()[:,:2],\n",
    "    verbose=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_estimated_compressed_MI(MI, estimated_MI, estimated_MI_compressed, \"Anti-PCA\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "save_estimated_MI(MI, estimated_MI, experiments_dir + '/coordinates')\n",
    "save_estimated_MI(MI, estimated_MI_compressed, experiments_dir + '/compressed')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
