{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from estimate_entropy import meta_main, get_args\n",
    "from tools import plot_entropy, plot_pdf\n",
    "from os.path import join\n",
    "\n",
    "import sys\n",
    "sys.path.append('./doe')\n",
    "\n",
    "from doe.main import get_args, meta_main\n",
    "from tools import plot_mi_run, plot_mi_bias_var_mse, collect_mi\n",
    "from mi_estimators import main\n",
    "\n",
    "cuda = True"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### MI estimation of Gaussian random variables in Section 3.2.2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# Gaussian\n",
    "main()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# Cubed Y\n",
    "main(True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### MI estimation of uniformly distributed random variables in Section 3.2.2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "args = get_args([])\n",
    "args.dim = 20\n",
    "args.I = 0.1\n",
    "args.I_max = 0.5\n",
    "args.N = 128\n",
    "args.steps = 20000\n",
    "args.nruns = 1\n",
    "args.c = 1\n",
    "args.layers = 1\n",
    "args.pickle = 'model.p'\n",
    "args.cuda = cuda\n",
    "args.run_id = 1\n",
    "args.pdf = 'uniform'\n",
    "args.nruns = 1\n",
    "\n",
    "log_dir = meta_main(args)\n",
    "\n",
    "# Collect the best run for each method\n",
    "collect_mi(join(log_dir, 'summary.json'))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# Plot Traiing Run\n",
    "plot_mi_run(join(log_dir, 'collected.json'), show=True, save=False)\n",
    "# Plot Bias, MSE, Variance\n",
    "plot_mi_bias_var_mse(join(log_dir, 'collected.json'), show=True, save=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
