{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "a6e37950",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\mdeep2\\anaconda3\\envs\\py39\\lib\\site-packages\\h5py\\__init__.py:36: UserWarning: h5py is running against HDF5 1.12.2 when it was built against 1.12.1, this may cause problems\n",
      "  _warn((\"h5py is running against HDF5 {0} when it was built against {1}, \"\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "def theoretic_mutual_information_gaussian(power,noise):\n",
    "    return dim_n*0.5*np.log2(1 + power/noise)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "8411b9e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "### variables\n",
    "\n",
    "SAMPLE_SIZE = 3000\n",
    "SIGNAL_POWER = 3\n",
    "SIGNAL_NOISE = 0.2\n",
    "dim_n = 3\n",
    "\n",
    "x_sample = np.random.normal(0., np.sqrt(SIGNAL_POWER), [SAMPLE_SIZE, dim_n])\n",
    "y_sample = x_sample + np.random.normal(0., np.sqrt(SIGNAL_NOISE), [SAMPLE_SIZE, dim_n])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c8e6ace1",
   "metadata": {},
   "source": [
    "### MINE estimator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "182b3c97",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "true_mi= 6.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\mdeep2\\anaconda3\\envs\\py39\\lib\\site-packages\\keras\\initializers\\initializers_v2.py:120: UserWarning: The initializer RandomNormal is unseeded and being called multiple times, which will return identical values  each time (even if the initializer is unseeded). Please update your code to provide a seed to the initializer, or avoid using the same initalizer instance more than once.\n",
      "  warnings.warn(\n",
      "C:\\Users\\mdeep2\\anaconda3\\envs\\py39\\lib\\site-packages\\keras\\optimizers\\optimizer_v2\\nadam.py:86: UserWarning: The `lr` argument is deprecated, use `learning_rate` instead.\n",
      "  super().__init__(name, **kwargs)\n",
      "C:\\Users\\mdeep2\\anaconda3\\envs\\py39\\lib\\site-packages\\keras\\optimizers\\optimizer_v2\\adam.py:114: UserWarning: The `lr` argument is deprecated, use `learning_rate` instead.\n",
      "  super().__init__(name, **kwargs)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MINE:  5.097211143347569\n"
     ]
    }
   ],
   "source": [
    "# MINE - no normalization\n",
    "\n",
    "# X = Gaussian variable with mean = 0 and sd = SIGNAL_POWER\n",
    "# Y = Gaussian variable with mean = 0 and sd = SIGNAL_NOISE\n",
    "\n",
    "from MINE_estimate import MINE_MI\n",
    "\n",
    "\n",
    "\n",
    "true_mi = theoretic_mutual_information_gaussian(SIGNAL_POWER, SIGNAL_NOISE)\n",
    "print('true_mi=',true_mi)\n",
    "\n",
    "#original MINE\n",
    "MI,history_mi = MINE_MI(x_sample,y_sample,total_epochs=50)\n",
    "print(\"MINE: \", MI)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1c6a6999",
   "metadata": {},
   "source": [
    "### KSG estimator"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "73ff9653",
   "metadata": {},
   "source": [
    "### ksg installation guide:\n",
    "\n",
    "GPU version:https://github.com/pwollstadt/IDTxl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "3859892f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "KSG-GPU:  [3.69469974]\n"
     ]
    }
   ],
   "source": [
    "# KSG - GPU\n",
    "\n",
    "import idtxl.estimators_opencl as est\n",
    "settings = {'kraskov_k': 3}\n",
    "gpu_est = est.OpenCLKraskovMI(settings = settings)\n",
    "\n",
    "#original KSG\n",
    "MI = gpu_est.estimate(x_sample,y_sample)\n",
    "print(\"KSG-GPU: \", MI)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "bafb2ee2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "KSG-CPU:  3.56018993324321\n",
      "revised_mi-KSG:  3.55878933446243\n"
     ]
    }
   ],
   "source": [
    "# ksg - cpu\n",
    "import knnie\n",
    "\n",
    "#original KSG\n",
    "MI = knnie.kraskov_mi(x_sample,y_sample,k=5)\n",
    "print(\"KSG-CPU: \", MI)\n",
    "\n",
    "# revised KSG\n",
    "MI = knnie.revised_mi(x_sample,y_sample,k=5)\n",
    "print(\"revised_mi-KSG: \", MI)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c6b4f0ba",
   "metadata": {},
   "source": [
    "### binning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "d733bd1c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Binning MI:  0.5849625007211561\n"
     ]
    }
   ],
   "source": [
    "#binning\n",
    "import simplebinmi\n",
    "\n",
    "binxm = simplebinmi.bin_calc_information_new_mod(x_sample, y_sample, 5,5)\n",
    "\n",
    "print( \"Binning MI: \",binxm)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2751432c",
   "metadata": {},
   "source": [
    "## normalized"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "23d7d462",
   "metadata": {},
   "outputs": [],
   "source": [
    "def global_normalize(data, C):\n",
    "    # Reshape data as before\n",
    "    data = np.reshape(data, (data.shape[0], int(data.size / data.shape[0])))\n",
    "    \n",
    "    # Subtract the mean\n",
    "    means = np.mean(data, axis=0)\n",
    "    data = data - means\n",
    "    \n",
    "    # Compute the normalization factor\n",
    "    norm = np.sqrt(np.mean(np.sum(data ** 2, axis=1)))\n",
    "    \n",
    "    # Normalize the data\n",
    "    data *= C / norm\n",
    "    \n",
    "    return data, norm ** 2\n",
    "\n",
    "def local_normalize(data,C):\n",
    "    data = np.reshape(data, (data.shape[0],int(data.size/data.shape[0])))\n",
    "    \n",
    "    means =  np.mean(data, axis=0) # find the mean for each dimension \n",
    "    data = data - means # data - means for each dimension\n",
    "    \n",
    "    norm = np.tile(np.sqrt(np.mean(data ** 2 ,axis=0)),(data.shape[0],1))\n",
    "#     norm =  np.sqrt(np.mean(np.sum(sqz,axis=1)))\n",
    "    normalized_data = C*data / (norm+(0.0000001))\n",
    "    \n",
    "    return normalized_data,norm**2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "cd8f24a4",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\mdeep2\\anaconda3\\envs\\py39\\lib\\site-packages\\keras\\initializers\\initializers_v2.py:120: UserWarning: The initializer RandomNormal is unseeded and being called multiple times, which will return identical values  each time (even if the initializer is unseeded). Please update your code to provide a seed to the initializer, or avoid using the same initalizer instance more than once.\n",
      "  warnings.warn(\n",
      "C:\\Users\\mdeep2\\anaconda3\\envs\\py39\\lib\\site-packages\\keras\\optimizers\\optimizer_v2\\nadam.py:86: UserWarning: The `lr` argument is deprecated, use `learning_rate` instead.\n",
      "  super().__init__(name, **kwargs)\n",
      "C:\\Users\\mdeep2\\anaconda3\\envs\\py39\\lib\\site-packages\\keras\\optimizers\\optimizer_v2\\adam.py:114: UserWarning: The `lr` argument is deprecated, use `learning_rate` instead.\n",
      "  super().__init__(name, **kwargs)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4.877881668702229\n",
      "5.001485943842631\n"
     ]
    }
   ],
   "source": [
    "#MINE_global\n",
    "normz,energy_z_hat = global_normalize(x_sample,C=1)\n",
    "normx,x_selected = global_normalize(y_sample,C=1)\n",
    "MI,history_mi = MINE_MI(normz,normx,total_epochs=50)\n",
    "print(MI)\n",
    "\n",
    "#MINE_local\n",
    "normz,energy_z_hat = local_normalize(x_sample,C=1)\n",
    "normx,x_selected = local_normalize(y_sample,C=1)\n",
    "MI,history_mi = MINE_MI(normz,normx,total_epochs=50)\n",
    "print(MI)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "03f9dcf9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "KSG_global\n",
      "3.697744526426307\n",
      "KSG_local\n",
      "[3.69828618]\n"
     ]
    }
   ],
   "source": [
    "print('KSG_global')\n",
    "C_z=np.arange(0.1,2.0,0.1)\n",
    "mi_val_C_z=0\n",
    "for C_zi in C_z:\n",
    "    normz,energy_z_hat = global_normalize(x_sample,C=C_zi)\n",
    "    normx = global_normalize(y_sample,C=1)[0]\n",
    "    mi_val = gpu_est.estimate(normz,normx)\n",
    "    mi_val_C_z=max(mi_val_C_z,mi_val[0])\n",
    "print(mi_val_C_z)\n",
    "\n",
    "print('KSG_local')\n",
    "normz,energy_z_hat = local_normalize(x_sample,C=1)\n",
    "normx,x_selected = local_normalize(y_sample,C=1)\n",
    "mi_val = gpu_est.estimate(normz,normx)\n",
    "print(mi_val)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "e85b8c3b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "alpha is:  0.001 MI is:  0.5849625007211561\n",
      "alpha is:  0.00206913808111479 MI is:  0.5849625007211561\n",
      "alpha is:  0.004281332398719396 MI is:  0.5849625007211561\n",
      "alpha is:  0.008858667904100823 MI is:  0.5849625007211561\n",
      "alpha is:  0.018329807108324356 MI is:  0.5849625007211561\n",
      "alpha is:  0.0379269019073225 MI is:  0.5849625007211561\n",
      "alpha is:  0.07847599703514611 MI is:  0.5849625007211561\n",
      "alpha is:  0.1623776739188721 MI is:  0.5849625007211561\n",
      "alpha is:  0.3359818286283781 MI is:  0.5849625007211561\n",
      "alpha is:  0.6951927961775606 MI is:  0.5849625007211561\n",
      "alpha is:  1.438449888287663 MI is:  0.5849625007211561\n",
      "alpha is:  2.976351441631316 MI is:  0.5849625007211561\n",
      "alpha is:  6.158482110660261 MI is:  0.5849625007211561\n",
      "alpha is:  12.742749857031322 MI is:  0.5849625007211561\n",
      "alpha is:  26.366508987303554 MI is:  0.5849625007211561\n",
      "alpha is:  54.555947811685144 MI is:  0.5849625007211561\n",
      "alpha is:  112.88378916846884 MI is:  0.5849625007211561\n",
      "alpha is:  233.57214690901213 MI is:  0.5849625007211561\n",
      "alpha is:  483.2930238571752 MI is:  0.5849625007211561\n",
      "alpha is:  1000.0 MI is:  0.5849625007211561\n"
     ]
    }
   ],
   "source": [
    "#normalization binning\n",
    "### test playground \n",
    "start = 10**(-3)\n",
    "stop = 10**3\n",
    "\n",
    "alpha_array=np.logspace(np.log10(start), np.log10(stop), num=20)\n",
    "\n",
    "binning_normalization_array=[]\n",
    "C_z=[0.6,0.7,0.8,1.0,1.2]\n",
    "for alpha in alpha_array:\n",
    "    mi_val_C_z=0\n",
    "    for C_zi in C_z:\n",
    "        normx,energy_y_hat = global_normalize(x_sample,1)\n",
    "        normy,energy_y_hat = global_normalize(y_sample*alpha,C=C_zi)\n",
    "        binxm = simplebinmi.bin_calc_information_new_mod(normx, normy, 5,5)\n",
    "        mi_val_C_z=max(mi_val_C_z,binxm)\n",
    "    binning_normalization_array.append(mi_val_C_z)\n",
    "    print(\"alpha is: \",alpha, \"MI is: \",binxm)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
