{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "593fe148-16e1-4da5-91b4-ee8506473b05",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.utils.data import Dataset\n",
    "from torch.distributions import MultivariateNormal\n",
    "import numpy as np\n",
    "from all_estimators import *\n",
    "np.random.seed(42)\n",
    "import random \n",
    "random.seed(42)\n",
    "import argparse\n",
    "from scipy import stats\n",
    "from mi_utils import *\n",
    "from estimator_lib import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "c2a091df-0625-43f3-8235-0d3aa6d41ba7",
   "metadata": {},
   "outputs": [],
   "source": [
    "params_dict = {\n",
    "  \"concat_self\": [20],\n",
    "  \"randmat\": [],\n",
    "   \"cube\": [],\n",
    "   \"concat_self_noisy\": [20,0.2],\n",
    "   \"sigmoid\": [],\n",
    "    \"scale\": [],\n",
    "}\n",
    "# estimators = [KSG_est,KSG_local_est,KSG_global_est_infnorm,revised_KSG_est]\n",
    "estimators = [mine_est,mine_est_local,mine_est_global]\n",
    "names = ['MINE','MINE-Local','MINE-Global-Corrected']\n",
    "# names = ['KSG','KSG-Local','KSG-Global-$L_{\\infty}$','KSG-Revised']\n",
    "output_list = [[] for x in estimators]\n",
    "true_mi_list = [] \n",
    "scale_range = np.concatenate((np.logspace(-2,0.0,10),np.logspace(0.0,1.0,10)))\n",
    "# scale_range = np.logspace(-2,3.0,10)\n",
    "\n",
    "mean_list = [[] for x in scale_range]\n",
    "uppers = [[] for x in scale_range]\n",
    "lowers = [[] for x in scale_range]\n",
    "\n",
    "trials = 20\n",
    "rho = 0.5\n",
    "N = 1000\n",
    "dim = 2\n",
    "transforms_x = ['scale']\n",
    "transforms_y = ['none']\n",
    "\n",
    "#dim = 2, 10, 50\n",
    "# N = 1000, 10000, 50000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "6922976e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.4384104013442993\n",
      "0\n",
      "[0.0005108829936943948, 0.6974136888980865, 0.6787741392850876]\n",
      "1\n",
      "[0.004036318126600236, 0.7076906263828278, 0.7126722633838654]\n",
      "2\n",
      "[0.006566408067010343, 0.6947802364826202, 0.7032470852136612]\n",
      "3\n",
      "[0.014469075109809637, 0.7097485423088074, 0.6929738789796829]\n",
      "4\n",
      "[0.04109436050057411, 0.7231288105249405, 0.7006772667169571]\n",
      "5\n",
      "[0.09014369174838066, 0.7025577664375305, 0.6826566725969314]\n",
      "6\n",
      "[0.20969093292951585, 0.7237301647663117, 0.6757097482681275]\n",
      "7\n",
      "[0.3587577499449253, 0.6678581446409225, 0.6769532352685929]\n",
      "8\n",
      "[0.5665666177868843, 0.6887319475412369, 0.6952549159526825]\n",
      "9\n",
      "[0.6982424020767212, 0.6753963276743888, 0.701740962266922]\n",
      "10\n",
      "[0.749360366165638, 0.7111811548471451, 0.649298021197319]\n",
      "11\n",
      "[0.7371199056506157, 0.7068823546171188, 0.7160366386175155]\n",
      "12\n",
      "[0.7640770733356476, 0.7202428728342056, 0.743590134382248]\n",
      "13\n",
      "[0.700137397646904, 0.72283253967762, 0.6664838999509811]\n",
      "14\n",
      "[0.7004671230912208, 0.7132000029087067, 0.673131063580513]\n",
      "15\n",
      "[0.4688950747251511, 0.702847683429718, 0.6970310509204865]\n",
      "16\n",
      "[0.4404496982693672, 0.7004514068365097, 0.6799500316381455]\n",
      "17\n",
      "[0.27446614652872087, 0.6786351591348648, 0.71733009070158]\n",
      "18\n",
      "[0.141736201941967, 0.7193383753299714, 0.6707206130027771]\n",
      "19\n",
      "[-0.012490976601839066, 0.6625185638666153, 0.7207889437675477]\n",
      "1.4384104013442993\n",
      "0\n",
      "[0.0005108829936943948, 0.6974136888980865, 0.6787741392850876, 0.19478082284331322, 1.3420642375946046, 1.3054374873638153]\n",
      "1\n",
      "[0.004036318126600236, 0.7076906263828278, 0.7126722633838654, 0.4351948484778404, 1.3404280722141266, 1.3607611179351806]\n",
      "2\n",
      "[0.006566408067010343, 0.6947802364826202, 0.7032470852136612, 0.7387288421392441, 1.3215947687625884, 1.3077039420604706]\n",
      "3\n",
      "[0.014469075109809637, 0.7097485423088074, 0.6929738789796829, 0.937425148487091, 1.3399163782596588, 1.3609737694263457]\n",
      "4\n",
      "[0.04109436050057411, 0.7231288105249405, 0.7006772667169571, 1.1029184311628342, 1.360612839460373, 1.3799600422382354]\n",
      "5\n",
      "[0.09014369174838066, 0.7025577664375305, 0.6826566725969314, 1.1860462307929993, 1.335409152507782, 1.3692770063877107]\n",
      "6\n",
      "[0.20969093292951585, 0.7237301647663117, 0.6757097482681275, 1.2317445158958436, 1.3320486426353455, 1.3363616049289704]\n",
      "7\n",
      "[0.3587577499449253, 0.6678581446409225, 0.6769532352685929, 1.3366783142089844, 1.4168724179267884, 1.3146124184131622]\n",
      "8\n",
      "[0.5665666177868843, 0.6887319475412369, 0.6952549159526825, 1.381117480993271, 1.2980871796607971, 1.3641252517700195]\n",
      "9\n",
      "[0.6982424020767212, 0.6753963276743888, 0.701740962266922, 1.3006316304206849, 1.3004609763622283, 1.381750762462616]\n",
      "10\n",
      "[0.749360366165638, 0.7111811548471451, 0.649298021197319, 1.3655699610710144, 1.3306861758232116, 1.4135067641735077]\n",
      "11\n",
      "[0.7371199056506157, 0.7068823546171188, 0.7160366386175155, 1.3514846563339233, 1.2947657972574234, 1.3218959748744965]\n",
      "12\n",
      "[0.7640770733356476, 0.7202428728342056, 0.743590134382248, 1.3530965387821197, 1.3292499244213105, 1.2011914074420929]\n",
      "13\n",
      "[0.700137397646904, 0.72283253967762, 0.6664838999509811, 1.302545577287674, 1.3900613248348237, 1.336340320110321]\n",
      "14\n",
      "[0.7004671230912208, 0.7132000029087067, 0.673131063580513, 1.3094679236412048, 1.3564161002635955, 1.2988875389099122]\n",
      "15\n",
      "[0.4688950747251511, 0.702847683429718, 0.6970310509204865, 1.3063791275024415, 1.3113515317440032, 1.3374859631061553]\n",
      "16\n",
      "[0.4404496982693672, 0.7004514068365097, 0.6799500316381455, 1.3515120506286622, 1.318716937303543, 1.342572909593582]\n",
      "17\n",
      "[0.27446614652872087, 0.6786351591348648, 0.71733009070158, 1.2559903740882874, 1.2893251180648804, 1.4018982529640198]\n",
      "18\n",
      "[0.141736201941967, 0.7193383753299714, 0.6707206130027771, 1.2060055881738663, 1.2980189859867095, 1.3094409704208374]\n",
      "19\n",
      "[-0.012490976601839066, 0.6625185638666153, 0.7207889437675477, 1.1193231165409088, 1.2474555969238281, 1.3204926669597625]\n",
      "1.4384104013442993\n",
      "0\n",
      "[0.0005108829936943948, 0.6974136888980865, 0.6787741392850876, 0.19478082284331322, 1.3420642375946046, 1.3054374873638153, 1.0433630704879762, 1.1206342697143554, 1.1967750549316407]\n",
      "1\n",
      "[0.004036318126600236, 0.7076906263828278, 0.7126722633838654, 0.4351948484778404, 1.3404280722141266, 1.3607611179351806, 1.1553805679082871, 1.2655469238758088, 1.199320062994957]\n",
      "2\n",
      "[0.006566408067010343, 0.6947802364826202, 0.7032470852136612, 0.7387288421392441, 1.3215947687625884, 1.3077039420604706, 1.1851089954376222, 1.233668076992035, 1.1499457895755767]\n",
      "3\n",
      "[0.014469075109809637, 0.7097485423088074, 0.6929738789796829, 0.937425148487091, 1.3399163782596588, 1.3609737694263457, 1.1922241866588592, 1.2314191311597824, 1.197758686542511]\n",
      "4\n",
      "[0.04109436050057411, 0.7231288105249405, 0.7006772667169571, 1.1029184311628342, 1.360612839460373, 1.3799600422382354, 1.1326319813728332, 1.185894101858139, 1.2311300098896027]\n",
      "5\n",
      "[0.09014369174838066, 0.7025577664375305, 0.6826566725969314, 1.1860462307929993, 1.335409152507782, 1.3692770063877107, 1.1523614048957824, 1.2285716772079467, 1.239572387933731]\n",
      "6\n",
      "[0.20969093292951585, 0.7237301647663117, 0.6757097482681275, 1.2317445158958436, 1.3320486426353455, 1.3363616049289704, 1.1224298894405365, 1.2391408294439317, 1.2185503870248795]\n",
      "7\n",
      "[0.3587577499449253, 0.6678581446409225, 0.6769532352685929, 1.3366783142089844, 1.4168724179267884, 1.3146124184131622, 1.2000299334526061, 1.217231035232544, 1.133901944756508]\n",
      "8\n",
      "[0.5665666177868843, 0.6887319475412369, 0.6952549159526825, 1.381117480993271, 1.2980871796607971, 1.3641252517700195, 1.245841532945633, 1.2451247483491898, 1.1128733038902283]\n",
      "9\n",
      "[0.6982424020767212, 0.6753963276743888, 0.701740962266922, 1.3006316304206849, 1.3004609763622283, 1.381750762462616, 1.1838146418333053, 1.1854516744613648, 1.243287044763565]\n",
      "10\n",
      "[0.749360366165638, 0.7111811548471451, 0.649298021197319, 1.3655699610710144, 1.3306861758232116, 1.4135067641735077, 1.1793749690055848, 1.2152680307626724, 1.2852642118930817]\n",
      "11\n",
      "[0.7371199056506157, 0.7068823546171188, 0.7160366386175155, 1.3514846563339233, 1.2947657972574234, 1.3218959748744965, 1.2022022902965546, 1.1450546324253081, 1.1342185407876968]\n",
      "12\n",
      "[0.7640770733356476, 0.7202428728342056, 0.743590134382248, 1.3530965387821197, 1.3292499244213105, 1.2011914074420929, 1.1512509375810622, 1.2511638820171356, 1.232151210308075]\n",
      "13\n",
      "[0.700137397646904, 0.72283253967762, 0.6664838999509811, 1.302545577287674, 1.3900613248348237, 1.336340320110321, 1.1993412375450134, 1.213584566116333, 1.237987232208252]\n",
      "14\n",
      "[0.7004671230912208, 0.7132000029087067, 0.673131063580513, 1.3094679236412048, 1.3564161002635955, 1.2988875389099122, 1.229577249288559, 1.1830426424741745, 1.2278582513332368]\n",
      "15\n",
      "[0.4688950747251511, 0.702847683429718, 0.6970310509204865, 1.3063791275024415, 1.3113515317440032, 1.3374859631061553, 1.2677002847194672, 1.1983258813619613, 1.2021334648132325]\n",
      "16\n",
      "[0.4404496982693672, 0.7004514068365097, 0.6799500316381455, 1.3515120506286622, 1.318716937303543, 1.342572909593582, 1.1657847255468368, 1.2400318682193756, 1.174932387471199]\n",
      "17\n",
      "[0.27446614652872087, 0.6786351591348648, 0.71733009070158, 1.2559903740882874, 1.2893251180648804, 1.4018982529640198, 1.151725023984909, 1.1741406917572021, 1.1861270725727082]\n",
      "18\n",
      "[0.141736201941967, 0.7193383753299714, 0.6707206130027771, 1.2060055881738663, 1.2980189859867095, 1.3094409704208374, 1.1711544692516327, 1.2086597740650178, 1.1670764356851577]\n",
      "19\n",
      "[-0.012490976601839066, 0.6625185638666153, 0.7207889437675477, 1.1193231165409088, 1.2474555969238281, 1.3204926669597625, 1.1907365560531615, 1.1861796259880066, 1.196921855211258]\n"
     ]
    }
   ],
   "source": [
    "### loop \n",
    "\n",
    "mean_list_10 = [[] for x in scale_range]\n",
    "std_list_10 = [[] for x in scale_range]\n",
    "\n",
    "for dim in [10]:\n",
    "    \n",
    "    for N in [1000,10000,50000]:\n",
    "        dataset_list = []\n",
    "        params_dict['scale']= 1.0\n",
    "\n",
    "        for i in range(trials):    \n",
    "                dataset_list.append(MultivariateNormalDataset(N, dim, rho,params_dict,transforms_x,transforms_y))\n",
    "\n",
    "        print(dataset_list[0].true_mi)\n",
    "        \n",
    "        for k in range(len(scale_range)):\n",
    "            print(k)\n",
    "            params_dict['scale']= scale_range[k]\n",
    "            output_list = [[] for x in estimators]\n",
    "            for i in range(trials):    \n",
    "                # dataset = MultivariateNormalDataset(N, dim, rho,params_dict,transforms_x,transforms_y)\n",
    "                # print(\"True MI:\", dataset.true_mi)\n",
    "                # print('LNC:',estimators[-1](dataset.x,dataset.y))\n",
    "                for temp in range(len(estimators)):\n",
    "                    E = estimators[temp](dataset_list[i].x*scale_range[k],dataset_list[i].y)\n",
    "                    output_list[temp].append(E)\n",
    "\n",
    "            for temp in range(len(estimators)):\n",
    "                mean_list_10[k].append(np.mean(output_list[temp]))\n",
    "                std_list_10[k].append(np.std(output_list[temp]))\n",
    "                diff_arr = np.array(output_list[temp])-np.mean(output_list[temp])\n",
    "            print(mean_list_10[k])\n",
    "                # uppers[k].append(np.average(np.array(output_list[temp]),weights=(diff_arr>=0).astype('float')))\n",
    "                # lowers[k].append(np.average(np.array(output_list[temp]),weights=(diff_arr<=0).astype('float')))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f5e17d9-e441-4cbb-9bc8-cb99d850da93",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "7.192051887512207\n",
      "0\n",
      "[0.0005524483160115779, 2.1994708359241484, 2.198544454574585]\n",
      "1\n",
      "[0.006663919717539102, 2.213234269618988, 2.1553693056106566]\n",
      "2\n",
      "[0.01738930030260235, 2.1947508335113524, 2.194417119026184]\n",
      "3\n",
      "[0.05409082630649209, 2.1426228284835815, 2.2010505378246306]\n",
      "4\n",
      "[0.13593760803341864, 2.1589997351169585, 2.191972005367279]\n",
      "5\n",
      "[0.3247593820095062, 2.194976818561554, 2.2069728851318358]\n",
      "6\n",
      "[0.6647567093372345, 2.2324655055999756, 2.171830475330353]\n",
      "7\n",
      "[1.132021853327751, 2.217390888929367, 2.1930198669433594]\n",
      "8\n",
      "[1.6182075202465058, 2.1143530547618865, 2.1005638360977175]\n",
      "9\n",
      "[2.1947466313838957, 2.185455822944641, 2.196187174320221]\n",
      "10\n",
      "[2.1470451772212984, 2.186776477098465, 2.1175334095954894]\n",
      "11\n",
      "[2.3170649290084837, 2.113155537843704, 2.1967615365982054]\n",
      "12\n",
      "[2.5572704553604124, 2.166416335105896, 2.197433817386627]\n",
      "13\n",
      "[2.573277533054352, 2.0695204973220824, 2.133054810762405]\n",
      "14\n",
      "[2.4522926449775695, 2.1983224749565125, 2.176388603448868]\n",
      "15\n",
      "[2.2699213862419128, 2.1138658046722414, 2.175078499317169]\n",
      "16\n",
      "[1.8826300263404847, 2.146153801679611, 2.0667574882507322]\n",
      "17\n",
      "[1.1328931629657746, 2.19552076458931, 2.2068256199359895]\n",
      "18\n",
      "[0.781866067647934, 2.1585026681423187, 2.190373086929321]\n",
      "19\n",
      "[0.3050607919692993, 2.2113704144954682, 2.177988660335541]\n",
      "7.192051887512207\n",
      "0\n",
      "[0.0005524483160115779, 2.1994708359241484, 2.198544454574585, 0.42723288834095, 4.254185903072357, 4.2124763369560245]\n",
      "1\n",
      "[0.006663919717539102, 2.213234269618988, 2.1553693056106566, 0.78475062251091, 4.201100540161133, 4.066039836406707]\n",
      "2\n",
      "[0.01738930030260235, 2.1947508335113524, 2.194417119026184, 1.3294026255607605, 4.110822403430939, 4.202534341812134]\n",
      "3\n",
      "[0.05409082630649209, 2.1426228284835815, 2.2010505378246306, 1.7153466105461121, 4.466654455661773, 4.283091533184051]\n",
      "4\n",
      "[0.13593760803341864, 2.1589997351169585, 2.191972005367279, 2.112933838367462, 4.42188413143158, 4.258798861503601]\n",
      "5\n",
      "[0.3247593820095062, 2.194976818561554, 2.2069728851318358, 2.5408183336257935, 4.347014796733856, 4.252500104904175]\n",
      "6\n",
      "[0.6647567093372345, 2.2324655055999756, 2.171830475330353, 3.0618698835372924, 4.229629254341125, 4.405560803413391]\n",
      "7\n",
      "[1.132021853327751, 2.217390888929367, 2.1930198669433594, 3.4977553606033327, 4.324663710594177, 4.239643931388855]\n",
      "8\n",
      "[1.6182075202465058, 2.1143530547618865, 2.1005638360977175, 3.983257067203522, 4.484812498092651, 4.154430425167083]\n",
      "9\n",
      "[2.1947466313838957, 2.185455822944641, 2.196187174320221, 4.106134223937988, 4.442376124858856, 4.458394002914429]\n",
      "10\n",
      "[2.1470451772212984, 2.186776477098465, 2.1175334095954894, 4.218863141536713, 4.2146224021911625, 3.9830248236656187]\n",
      "11\n",
      "[2.3170649290084837, 2.113155537843704, 2.1967615365982054, 4.451828563213349, 4.398105716705322, 4.3463167428970335]\n",
      "12\n",
      "[2.5572704553604124, 2.166416335105896, 2.197433817386627, 4.550199365615844, 4.331343007087708, 4.389489924907684]\n",
      "13\n",
      "[2.573277533054352, 2.0695204973220824, 2.133054810762405, 4.490142726898194, 4.303083324432373, 4.093412637710571]\n",
      "14\n",
      "[2.4522926449775695, 2.1983224749565125, 2.176388603448868, 4.580692970752716, 4.4838922142982485, 4.481097316741943]\n",
      "15\n",
      "[2.2699213862419128, 2.1138658046722414, 2.175078499317169, 4.253705155849457, 4.364210593700409, 4.205152642726898]\n",
      "16\n",
      "[1.8826300263404847, 2.146153801679611, 2.0667574882507322, 3.5008045345544816, 4.395767998695374, 4.299373066425323]\n",
      "17\n",
      "[1.1328931629657746, 2.19552076458931, 2.2068256199359895, 2.4505982518196108, 4.210524392127991, 4.509518480300903]\n",
      "18\n",
      "[0.781866067647934, 2.1585026681423187, 2.190373086929321, 2.4233878195285796, 4.352754306793213, 4.400678539276123]\n",
      "19\n",
      "[0.3050607919692993, 2.2113704144954682, 2.177988660335541, 1.0220942378044129, 4.300051534175873, 4.378615927696228]\n",
      "7.192051887512207\n",
      "0\n",
      "[0.0005524483160115779, 2.1994708359241484, 2.198544454574585, 0.42723288834095, 4.254185903072357, 4.2124763369560245, 1.2768879055976867, 1.7606780886650086, 1.8344384670257567]\n",
      "1\n",
      "[0.006663919717539102, 2.213234269618988, 2.1553693056106566, 0.78475062251091, 4.201100540161133, 4.066039836406707, 1.3317480206489563, 1.7160346448421477, 1.770302176475525]\n",
      "2\n",
      "[0.01738930030260235, 2.1947508335113524, 2.194417119026184, 1.3294026255607605, 4.110822403430939, 4.202534341812134, 1.492863792181015, 1.7379432380199433, 1.8663526356220246]\n",
      "3\n"
     ]
    }
   ],
   "source": [
    "\n",
    "mean_list_50 = [[] for x in scale_range]\n",
    "std_list_50 = [[] for x in scale_range]\n",
    "\n",
    "for dim in [50]:\n",
    "    \n",
    "    for N in [1000,10000,50000]:\n",
    "        dataset_list = []\n",
    "        params_dict['scale']= 1.0\n",
    "\n",
    "        for i in range(trials):    \n",
    "                dataset_list.append(MultivariateNormalDataset(N, dim, rho,params_dict,transforms_x,transforms_y))\n",
    "\n",
    "        print(dataset_list[0].true_mi)\n",
    "        \n",
    "        for k in range(len(scale_range)):\n",
    "            print(k)\n",
    "            params_dict['scale']= scale_range[k]\n",
    "            output_list = [[] for x in estimators]\n",
    "            for i in range(trials):    \n",
    "                # dataset = MultivariateNormalDataset(N, dim, rho,params_dict,transforms_x,transforms_y)\n",
    "                # print(\"True MI:\", dataset.true_mi)\n",
    "                # print('LNC:',estimators[-1](dataset.x,dataset.y))\n",
    "                for temp in range(len(estimators)):\n",
    "                    E = estimators[temp](dataset_list[i].x*scale_range[k],dataset_list[i].y)\n",
    "                    output_list[temp].append(E)\n",
    "\n",
    "            for temp in range(len(estimators)):\n",
    "                mean_list_50[k].append(np.mean(output_list[temp]))\n",
    "                std_list_50[k].append(np.std(output_list[temp]))\n",
    "                diff_arr = np.array(output_list[temp])-np.mean(output_list[temp])\n",
    "            print(mean_list_50[k])\n",
    "                # uppers[k].append(np.average(np.array(output_list[temp]),weights=(diff_arr>=0).astype('float')))\n",
    "                # lowers[k].append(np.average(np.array(output_list[temp]),weights=(diff_arr<=0).astype('float')))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b0b3b1e8-6545-4b07-a62c-d4cf040f5a2e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "7.192051887512207\n",
      "0\n",
      "[1.252772444486618, 1.811663407087326, 1.9263504922389985]\n",
      "1\n",
      "[1.2994191467761993, 1.7982805013656615, 1.8302257716655732]\n",
      "2\n"
     ]
    }
   ],
   "source": [
    "\n",
    "mean_list_50 = [[] for x in scale_range]\n",
    "std_list_50 = [[] for x in scale_range]\n",
    "\n",
    "for dim in [50]:\n",
    "    \n",
    "    for N in [50000]:\n",
    "        dataset_list = []\n",
    "        params_dict['scale']= 1.0\n",
    "\n",
    "        for i in range(trials):    \n",
    "                dataset_list.append(MultivariateNormalDataset(N, dim, rho,params_dict,transforms_x,transforms_y))\n",
    "\n",
    "        print(dataset_list[0].true_mi)\n",
    "        \n",
    "        for k in range(len(scale_range)):\n",
    "            print(k)\n",
    "            params_dict['scale']= scale_range[k]\n",
    "            output_list = [[] for x in estimators]\n",
    "            for i in range(trials):    \n",
    "                # dataset = MultivariateNormalDataset(N, dim, rho,params_dict,transforms_x,transforms_y)\n",
    "                # print(\"True MI:\", dataset.true_mi)\n",
    "                # print('LNC:',estimators[-1](dataset.x,dataset.y))\n",
    "                for temp in range(len(estimators)):\n",
    "                    E = estimators[temp](dataset_list[i].x*scale_range[k],dataset_list[i].y)\n",
    "                    output_list[temp].append(E)\n",
    "\n",
    "            for temp in range(len(estimators)):\n",
    "                mean_list_50[k].append(np.mean(output_list[temp]))\n",
    "                std_list_50[k].append(np.std(output_list[temp]))\n",
    "                diff_arr = np.array(output_list[temp])-np.mean(output_list[temp])\n",
    "            print(mean_list_50[k])\n",
    "                # uppers[k].append(np.average(np.array(output_list[temp]),weights=(diff_arr>=0).astype('float')))\n",
    "                # lowers[k].append(np.average(np.array(output_list[temp]),weights=(diff_arr<=0).astype('float')))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01843f5c",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5799ac5e-636c-4153-8e0a-d0c8f766f80f",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_list = []\n",
    "params_dict['scale']= 1.0\n",
    "\n",
    "for i in range(trials):    \n",
    "        dataset_list.append(MultivariateNormalDataset(N, dim, rho,params_dict,transforms_x,transforms_y))\n",
    "\n",
    "print(dataset_list[0].true_mi)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea6bd437-acc7-4ceb-a76f-3942f6531225",
   "metadata": {},
   "outputs": [],
   "source": [
    "for k in range(len(scale_range)):\n",
    "    print(k)\n",
    "    params_dict['scale']= scale_range[k]\n",
    "    output_list = [[] for x in estimators]\n",
    "    for i in range(trials):    \n",
    "        # dataset = MultivariateNormalDataset(N, dim, rho,params_dict,transforms_x,transforms_y)\n",
    "        # print(\"True MI:\", dataset.true_mi)\n",
    "        # print('LNC:',estimators[-1](dataset.x,dataset.y))\n",
    "        for temp in range(len(estimators)):\n",
    "            E = estimators[temp](dataset_list[i].x*scale_range[k],dataset_list[i].y)\n",
    "            output_list[temp].append(E)\n",
    "    \n",
    "    for temp in range(len(estimators)):\n",
    "        mean_list[k].append(np.mean(output_list[temp]))\n",
    "        diff_arr = np.array(output_list[temp])-np.mean(output_list[temp])\n",
    "    print(mean_list[k])\n",
    "        # uppers[k].append(np.average(np.array(output_list[temp]),weights=(diff_arr>=0).astype('float')))\n",
    "        # lowers[k].append(np.average(np.array(output_list[temp]),weights=(diff_arr<=0).astype('float')))\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9de9d631-b57c-450c-b763-1d6d9ff7c1a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "import matplotlib\n",
    "import scipy\n",
    "plt.rcParams[\"font.family\"] = \"Times New Roman\"\n",
    "matplotlib.rcParams.update({'font.size': 17})\n",
    "\n",
    "mean_list = np.array(mean_list)\n",
    "uppers = np.array(uppers)\n",
    "lowers = np.array(lowers)\n",
    "for temp in range(len(estimators)):\n",
    "    X = np.log10(scale_range)\n",
    "    Y = mean_list[:,temp]\n",
    "    # X = np.delete(X,10)\n",
    "    # Y = np.delete(Y,10)\n",
    "    \n",
    "    filt = scipy.signal.savgol_filter(Y, 10, 3)\n",
    "    plt.plot(X,filt,label=names[temp])\n",
    "    # ups = uppers[:,temp] - mean_list[:,temp]\n",
    "    # downs = mean_list[:,temp] - lowers[:,temp]\n",
    "    # plt.errorbar(np.log10(scale_range),mean_list[:,temp], yerr=[downs,ups], capsize=5,  ecolor = \"black\")\n",
    "    # plt.fill_between(np.log10(scale_range),lowers[:,temp],uppers[:,temp])\n",
    "plt.plot(np.log10(scale_range),dataset_list[0].true_mi*np.ones_like(scale_range),linestyle='--',color='black',label='True MI')\n",
    "plt.legend(fontsize=14,handlelength=1,framealpha=0)\n",
    "plt.grid(linestyle='--')\n",
    "plt.xlabel('$\\log_{10}(\\eta)$ (Scaling Factor)',fontsize=24)\n",
    "plt.ylabel('Average MI Estimates',fontsize=24)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig('KSG_Scale_10000m_d10.png',bbox_inches='tight')\n",
    "plt.show()\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9675cccd-0b67-4ce9-b710-53b8ba391ce8",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset.true_mi"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "4b5fd4e0-d37b-4cba-ad3f-dc0c8be05e13",
   "metadata": {},
   "outputs": [],
   "source": [
    "names = ['KSG','KSG-Local','KSG-Global-$L_{\\infty}$','BI-KSG']\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
