{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "593fe148-16e1-4da5-91b4-ee8506473b05",
   "metadata": {},
   "outputs": [
    {
     "ename": "ModuleNotFoundError",
     "evalue": "No module named 'all_estimators'",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mModuleNotFoundError\u001b[0m                       Traceback (most recent call last)",
      "Cell \u001b[1;32mIn[1], line 5\u001b[0m\n\u001b[0;32m      3\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdistributions\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m MultivariateNormal\n\u001b[0;32m      4\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mnumpy\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mnp\u001b[39;00m\n\u001b[1;32m----> 5\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mall_estimators\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;241m*\u001b[39m\n\u001b[0;32m      6\u001b[0m np\u001b[38;5;241m.\u001b[39mrandom\u001b[38;5;241m.\u001b[39mseed(\u001b[38;5;241m42\u001b[39m)\n\u001b[0;32m      7\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mrandom\u001b[39;00m \n",
      "\u001b[1;31mModuleNotFoundError\u001b[0m: No module named 'all_estimators'"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from torch.utils.data import Dataset\n",
    "from torch.distributions import MultivariateNormal\n",
    "import numpy as np\n",
    "from all_estimators import *\n",
    "np.random.seed(42)\n",
    "import random \n",
    "random.seed(42)\n",
    "import argparse\n",
    "from scipy import stats\n",
    "from mi_utils import *\n",
    "from estimator_lib import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "c2a091df-0625-43f3-8235-0d3aa6d41ba7",
   "metadata": {},
   "outputs": [],
   "source": [
    "params_dict = {\n",
    "  \"concat_self\": [20],\n",
    "  \"randmat\": [],\n",
    "   \"cube\": [],\n",
    "   \"concat_self_noisy\": [20,0.2],\n",
    "   \"sigmoid\": [],\n",
    "    \"scale\": [],\n",
    "}\n",
    "estimators = [KSG_est,KSG_local_est,KSG_global_est_infnorm,revised_KSG_est]\n",
    "# estimators = [mine_est,mine_est_local,mine_est_global]\n",
    "# names = ['MINE','MINE-Local','MINE-Global-Corrected']\n",
    "names = ['KSG','KSG-Local','KSG-Global-$L_{\\infty}$','KSG-Revised']\n",
    "output_list = [[] for x in estimators]\n",
    "true_mi_list = [] \n",
    "# scale_range = np.concatenate((np.logspace(-2,0.0,10),np.logspace(0.0,1.0,10)))\n",
    "scale_range = np.logspace(-2,3.0,10)\n",
    "\n",
    "mean_list = [[] for x in scale_range]\n",
    "uppers = [[] for x in scale_range]\n",
    "lowers = [[] for x in scale_range]\n",
    "\n",
    "trials = 20\n",
    "rho = 0.5\n",
    "N = 10000\n",
    "dim = 10\n",
    "transforms_x = ['scale']\n",
    "transforms_y = ['none']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "5799ac5e-636c-4153-8e0a-d0c8f766f80f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.4384104013442993\n"
     ]
    }
   ],
   "source": [
    "dataset_list = []\n",
    "params_dict['scale']= 1.0\n",
    "\n",
    "for i in range(trials):    \n",
    "        dataset_list.append(MultivariateNormalDataset(N, dim, rho,params_dict,transforms_x,transforms_y))\n",
    "\n",
    "print(dataset_list[0].true_mi)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "ea6bd437-acc7-4ceb-a76f-3942f6531225",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n",
      "[1.5987211554602255e-15, 0.78917489512348, 0.7904170836006545, -0.1757279485693221]\n",
      "1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "KeyboardInterrupt\n",
      "\n"
     ]
    }
   ],
   "source": [
    "for k in range(len(scale_range)):\n",
    "    print(k)\n",
    "    params_dict['scale']= scale_range[k]\n",
    "    output_list = [[] for x in estimators]\n",
    "    for i in range(trials):    \n",
    "        # dataset = MultivariateNormalDataset(N, dim, rho,params_dict,transforms_x,transforms_y)\n",
    "        # print(\"True MI:\", dataset.true_mi)\n",
    "        # print('LNC:',estimators[-1](dataset.x,dataset.y))\n",
    "        for temp in range(len(estimators)):\n",
    "            E = estimators[temp](dataset_list[i].x*scale_range[k],dataset_list[i].y)\n",
    "            output_list[temp].append(E)\n",
    "    \n",
    "    for temp in range(len(estimators)):\n",
    "        mean_list[k].append(np.mean(output_list[temp]))\n",
    "        diff_arr = np.array(output_list[temp])-np.mean(output_list[temp])\n",
    "    print(mean_list[k])\n",
    "        # uppers[k].append(np.average(np.array(output_list[temp]),weights=(diff_arr>=0).astype('float')))\n",
    "        # lowers[k].append(np.average(np.array(output_list[temp]),weights=(diff_arr<=0).astype('float')))\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9de9d631-b57c-450c-b763-1d6d9ff7c1a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "import matplotlib\n",
    "import scipy\n",
    "plt.rcParams[\"font.family\"] = \"Times New Roman\"\n",
    "matplotlib.rcParams.update({'font.size': 17})\n",
    "\n",
    "mean_list = np.array(mean_list)\n",
    "uppers = np.array(uppers)\n",
    "lowers = np.array(lowers)\n",
    "for temp in range(len(estimators)):\n",
    "    X = np.log10(scale_range)\n",
    "    Y = mean_list[:,temp]\n",
    "    # X = np.delete(X,10)\n",
    "    # Y = np.delete(Y,10)\n",
    "    \n",
    "    filt = scipy.signal.savgol_filter(Y, 10, 3)\n",
    "    plt.plot(X,filt,label=names[temp])\n",
    "    # ups = uppers[:,temp] - mean_list[:,temp]\n",
    "    # downs = mean_list[:,temp] - lowers[:,temp]\n",
    "    # plt.errorbar(np.log10(scale_range),mean_list[:,temp], yerr=[downs,ups], capsize=5,  ecolor = \"black\")\n",
    "    # plt.fill_between(np.log10(scale_range),lowers[:,temp],uppers[:,temp])\n",
    "plt.plot(np.log10(scale_range),dataset_list[0].true_mi*np.ones_like(scale_range),linestyle='--',color='black',label='True MI')\n",
    "plt.legend(fontsize=14,handlelength=1,framealpha=0)\n",
    "plt.grid(linestyle='--')\n",
    "plt.xlabel('$\\log_{10}(\\eta)$ (Scaling Factor)',fontsize=24)\n",
    "plt.ylabel('Average MI Estimates',fontsize=24)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig('KSG_Scale_10000m_d10.png',bbox_inches='tight')\n",
    "plt.show()\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9675cccd-0b67-4ce9-b710-53b8ba391ce8",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset.true_mi"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "4b5fd4e0-d37b-4cba-ad3f-dc0c8be05e13",
   "metadata": {},
   "outputs": [],
   "source": [
    "names = ['KSG','KSG-Local','KSG-Global-$L_{\\infty}$','BI-KSG']\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
