{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "593fe148-16e1-4da5-91b4-ee8506473b05",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.utils.data import Dataset\n",
    "from torch.distributions import MultivariateNormal\n",
    "import numpy as np\n",
    "from all_estimators import *\n",
    "np.random.seed(42)\n",
    "import random \n",
    "random.seed(42)\n",
    "import argparse\n",
    "from scipy import stats\n",
    "from mi_utils import *\n",
    "from estimator_lib import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "id": "c2a091df-0625-43f3-8235-0d3aa6d41ba7",
   "metadata": {},
   "outputs": [],
   "source": [
    "params_dict = {\n",
    "  \"concat_self\": [20],\n",
    "  \"randmat\": [],\n",
    "   \"cube\": [],\n",
    "   \"concat_self_noisy\": [20,0.8],\n",
    "   \"sigmoid\": [],\n",
    "    \"scale\": [],\n",
    "}\n",
    "# estimators = [KSG_est,KSG_local_est,KSG_global_est_infnorm,revised_KSG_est]\n",
    "estimators = [mine_est, mine_est_global,mine_est_global_nocorrection]\n",
    "names = ['MINE','MINE-Global-Corrected','MINE-Global']\n",
    "# SNR_range = np.linspace(5,0.1,20)\n",
    "# dim_range = np.arange(0,30,3)\n",
    "# names = ['KSG','KSG-Local','KSG-Global-$L_{\\infty}$','BI-KSG']\n",
    "output_list = [[] for x in estimators]\n",
    "true_mi_list = [] \n",
    "dim_range = np.arange(1,9,1)\n",
    "mean_list = [[] for x in dim_range]\n",
    "bias_prop = [[] for x in dim_range]\n",
    "\n",
    "trials = 10\n",
    "N = 1000\n",
    "dim = 2\n",
    "transforms_x = ['none']\n",
    "transforms_y = ['none']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "id": "54fa6801-5b7a-4fca-b23e-bf1df65d84a4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n",
      "1\n",
      "2\n",
      "3\n",
      "4\n",
      "5\n",
      "6\n",
      "7\n",
      "8\n",
      "9\n"
     ]
    }
   ],
   "source": [
    "for k in range(len(dim_range)):\n",
    "    print(k)\n",
    "    output_list = [[] for x in estimators]\n",
    "    true_mi_list = []\n",
    "    # params_dict['concat_self_noisy']= [dim_range[k],0.05]\n",
    "    for i in range(trials):\n",
    "        rho = np.random.rand()*0.8\n",
    "        dataset = MultivariateNormalDataset(N, dim_range[k], rho,params_dict,transforms_x,transforms_y)\n",
    "        # print(\"True MI:\", dataset.true_mi)\n",
    "        true_mi_list.append(dataset.true_mi)\n",
    "        # print('LNC:',estimators[-1](dataset.x,dataset.y))\n",
    "        for temp in range(len(estimators)):\n",
    "            E = estimators[temp](dataset.x,dataset.y)\n",
    "            output_list[temp].append(E)\n",
    "    \n",
    "    for temp in range(len(estimators)):\n",
    "        mean_list[k].append(np.mean(output_list[temp]))\n",
    "        diff_arr = np.array(true_mi_list) - np.mean(output_list[temp])\n",
    "        bias_prop[k].append(np.mean(diff_arr))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "82341d04-6117-4792-9d58-6e83cb3506df",
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "import matplotlib\n",
    "import scipy\n",
    "plt.rcParams[\"font.family\"] = \"Times New Roman\"\n",
    "matplotlib.rcParams.update({'font.size': 17})\n",
    "fig, axs = plt.subplots(1, 1,figsize=(5, 5))\n",
    "plt.rc('axes', labelsize=24)\n",
    "\n",
    "mean_list = np.array(mean_list)\n",
    "bias_prop = np.array(bias_prop)\n",
    "# uppers = np.array(uppers)\n",
    "# lowers = np.array(lowers)\n",
    "\n",
    "color_arr = ['#1f77b4','#ff7f0e']\n",
    "\n",
    "for temp in range(len(estimators)):\n",
    "    X = np.array(dim_range)\n",
    "    Y = bias_prop[:,temp]\n",
    "    filt = scipy.signal.savgol_filter(Y, 8, 2)\n",
    "    # filt = scipy.signal.savgol_filter(Y, 10, 3)\n",
    "    # axs[0].bar(X-0.2+ 0.2*temp,Y,width=0.2,color=color_arr[temp],label=names[temp],align='center')\n",
    "    axs.plot(X,-filt,label=names[temp])\n",
    "    # ups = uppers[:,temp] - mean_list[:,temp]\n",
    "    # downs = mean_list[:,temp] - lowers[:,temp]\n",
    "    # plt.errorbar(np.log10(scale_range),mean_list[:,temp], yerr=[downs,ups], capsize=5,  ecolor = \"black\")\n",
    "    # plt.fill_between(np.log10(scale_range),lowers[:,temp],uppers[:,temp])\n",
    "axs.plot(dim_range,np.zeros_like(dim_range),linestyle='--',color='black',label='Zero Bias')\n",
    "axs.legend(fontsize=14,handlelength=1,framealpha=0)\n",
    "axs.grid(linestyle='--')\n",
    "axs.set( xlabel='$d$ ')\n",
    "axs.set(ylabel='Bias')\n",
    "axs.set_xticks([2,4,6,8])\n",
    "# axs.set_yticks([0,25,50,75,100])\n",
    "\n",
    "plt.tight_layout()\n",
    "\n",
    "# plt.tight_layout()\n",
    "# fig.savefig('MINE_bias_with_local_and_noisecat.png',bbox_inches='tight')\n",
    "plt.show()\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "9675cccd-0b67-4ce9-b710-53b8ba391ce8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[0.32037492, 0.27718593, 0.32457326, 0.26692341],\n",
       "       [0.401359  , 0.39604954, 0.38851777, 0.32949089],\n",
       "       [0.43766331, 0.41995806, 0.42557957, 0.39892336],\n",
       "       [0.4240395 , 0.44522429, 0.44619207, 0.39639274],\n",
       "       [0.37318656, 0.42975107, 0.4014616 , 0.38064523],\n",
       "       [0.38945666, 0.37389135, 0.40625595, 0.39524089],\n",
       "       [0.43888758, 0.42329823, 0.4279525 , 0.4105435 ],\n",
       "       [0.43149513, 0.50539668, 0.37986643, 0.40320094],\n",
       "       [0.47367049, 0.42421194, 0.43044537, 0.34754334],\n",
       "       [0.36991796, 0.45938089, 0.40764806, 0.37687842]])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mean_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "4b5fd4e0-d37b-4cba-ad3f-dc0c8be05e13",
   "metadata": {},
   "outputs": [],
   "source": [
    "names = ['MINE','MINE-Local','MINE-Global-Corrected','MINE-Global']\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
