{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Internal imports\n",
    "import sys; sys.path.insert(0, '..')\n",
    "from src import *\n",
    "\n",
    "import torch\n",
    "import torch.distributions as D\n",
    "from scipy.linalg import block_diag\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "p1 = D.normal.Normal(-5, 1)\n",
    "q1 = D.normal.Normal(5, 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(50.)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "D.kl.kl_divergence(p1,q1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "p2 = D.normal.Normal(0, 1e-22)\n",
    "q2 = D.normal.Normal(0, 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(50.1665)"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "D.kl.kl_divergence(p2,q2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "p3 = D.normal.Normal(-1.5, 0.5)\n",
    "q3 = D.normal.Normal(1, 0.25)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(50.8069)"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "D.kl.kl_divergence(p3,q3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQgAAAEYCAYAAACgIGhkAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAASg0lEQVR4nO3df5RcZX3H8fcHjKxIYoRNC7JLJhxBEEVMGhQhlDaCyI8IBgU54aDVoqhUbaBFsRV65EitClRATg9ogwTwRyCIBxH5EUs8IZgVhShGrA1hTdBkMZKQBBL59o+5ezod5slmdmfuvTP5vM6Zc7L3x/N8J7vz2ed57sxdRQRmZo3sUnQBZlZeDggzS3JAmFmSA8LMkhwQZpbkgDCzJAeEtZWk/5T02aLrsNFxQHQISSslbZa0sebxqmzfbpI+J2lVdszjki6QpJrzF0naImmDpGckDUi6UNJuI/T7Fkn3Zef9UdIdkl7b7udr5eCA6CwnR8QeNY/V2fZvATOBE4DxwFnAOcCVded/NCLGA/sAc4EzgDtrg6SWpCOAu4HbgVcBU4CfAT+StH9rn5qVkQOiw0maCRwHzI6I5RGxLSIeBOYAH5H06vpzIuLZiFgEzAKOAE5MNP954IaIuDIiNkTE0xHxaeBB4OKs/2MkDUqaK+n3ktZIel+i1uWSTq75epykdZIOG+3zt/ZyQHS+Y4GlEfFk7caIWAoMUh1ZNBQRq4BlwIz6fZJ2B95CdXRS75tZv8P2Bl4B7Au8H7ha0isbnHcD1eAadgKwJiJ+mqrRiuWA6CwLJa3PHguzbb3AmsTxa7L927Ma2LPB9j2p/nw0aru+3a3Av0TE1oi4E9gIvKbBeTcCJ0iakH19FvD1EeqzAjkgOsspETExe5ySbVtHdU2hkX2y/duzL/B0g+1/AF5ItF3f7lBEbKv5ehOwR/1J2ZrJj4DZkiYCbwfmj1CfFcgB0fnuAd4kqb92o6TDgX7gvtSJ2TnTgAfq90XEs8AS4F0NTn03cO8o651HdZrxLmBJRPx2lO1YDl5SdAE2NhFxj6R7gQXZ4uAvgelUh+5fiYjH68/J1hemA5cDDwF3Jpq/EPi+pF8CX6P68zKX6sLm9FGWvBC4BvhzqougVmIeQXSH2cD9wF1U5/83AtcD59Udd5WkDcDvgCuABcDxEfFCo0YjYjHwNuCdVNcdngDeCBzVKHh2RERszvqdAtw6mjYsP/INYyxvkv4ZODAi5ox4sBXKUwzLlaQ9qV4KPavoWmxknmJYbiT9LfAk8L2I+K+i67GReYphZkkeQZhZUlNrEL29vVGpVNpUipkVYWBgYF1ETGq0r6mAqFQqLFu2rDVVmVkpSHoitc9TDDNLckCYWZIDwsyS/EYpsxbYunUrg4ODbNmypehSknp6eujr62PcuHE7fI4DwqwFBgcHGT9+PJVKhcQd/AoVEQwNDTE4OMiUKVN2+DxPMcxaYMuWLey1116lDAcASey1115Nj3AcEGYtUtZwGDaa+hwQZpbkNQizNtAlrR1NxGeK+cyURxBmluSAMOsCK1eu5KCDDuLss8/m0EMP5bTTTmPTpk1jbtcBYdYlVqxYwTnnnMMjjzzChAkTuOaaa8bcpgPCrEv09/dz5JFHAjBnzhwWL1485jYdEGZdov4yZisuuzogzLrEqlWrWLJkCQA333wzRx111Jjb9GVOszYo4rLkwQcfzLx58/jgBz/IAQccwLnnnjvmNh0QZl1il1124dprr21tmy1tzcy6igPCrAtUKhWWL1/e8nYdEGaW5IAwsyQHhJklOSDMLMmXOc3a4aYW3zzmTH/c28xKxiMIsy5x6aWXcsMNN9Df38+kSZOYNm0a559//pjadECYdYGBgQFuueUWHn74YbZt28bUqVOZNm3amNt1QJh1gQceeIBTTz2V3XffHYBZs2a1pF2vQZh1iXbcVdsBYdYFjj76aG677TY2b97Mhg0buOOOO1rSrqcYZu2Q82XJqVOncvrpp3PYYYcxefJkZsyY0ZJ2PYIw6xIXXXQRK1as4O6772a//fZrSZsOCDNL8hTDrAtdfPHFLWnHIwizFoko5u3QO2o09TkgzFqgp6eHoaGh0oZERDA0NERPT09T53mKYdYCfX19DA4Osnbt2qJLSerp6aGvr6+pcxwQZi0wbtw4pkyZUnQZLecphpklOSDMLMkBYWZJDggzS3JAmFmSA8LMkhwQZpbkgDCzJAeEmSV1zTspdUn6dlvxmXK+P96s7DyCMLMkB4SZJTkgzCypa9YgzArTzN/hLOhvbI5WRwXE9hYizVqu1X+AtwN1VECM1liCxVdArKV2NHRKMtLwGoSZJZVuBOFphFl5eARhZkmlG0GYGaW5MtK2gPBUwUrJVyaa4hHECPwZD9uZeQ3CzJI8grDusDNPHdr43ooxBcTOvs7g6Yd1OzXztwQlrQWeaF85ZlaAyRExqdGOpgLCzHYuXqQ0syQHhJklOSCs7SS9V9Liouuw5jkgOoCklZLeWrftRS+6bNujkjZJekrSVyRNrNl/saStkjZkj19JukrSPiP03ydpvqQhSc9KekjSSS19klZKDoguIWku8K/ABcArgDcDk4EfSHppzaHfiIjxwJ7AqcDewEAqJCTtCSwGngcOAXqBy4GbJJ3WpqdjJeGA6AKSJgCXAOdFxF0RsTUiVgLvphoSc+rPyY75OXA6sBaYm2j+E8BG4P0R8VREbI6Im4FLgS9KUlZDSPqQpMcl/UHS1cP76mq9WtIX67bdIenjo3z61kYOiO7wFqAHuLV2Y0RsBL4HHJs6MSL+BNwOzEgcciywICJeqNv+TWA/4MCabScB04E3UA2ntzVobx7wHkm7AEjqBWYCN6dqtOI4IDrHQknrhx/ANTX7eoF1EbGtwXlrsv3bs5rqlKOR3qyNRu0O7x92WUSsj4hVwP3AYfUnRcRDwB+phgLAGcCiiPjdCDVaARwQneOUiJg4/AA+XLNvHdArqdFb5/fJ9m/PvsDTiX3rsjYatTu8f9hTNf/eBOyRaHMe/zftmQN8fYT6rCAOiO6wBHgOeGftRkkvB94O3Js6MRvqnww8kDjkHmD28JSgxruBJ4FfjaLeG4F3SHoDcDCwcBRtWA4cEF0gIv5IdZHyy5KOlzROUgX4FjBIg9/Q2TEHU5377w18KdH85cAE4HpJe0vqkfQe4CLgghjFe/UjYhD4cVbXgojY3Gwblg8HRJeIiM8DnwK+ADwDLKX6G35mRDxXc+jpkjYC64HvAEPAtIhYnWh3CDiK6iLoL7Lj/x44KyK+MYaS5wGvx9OLUvOHtawQko6mOtWoNLhCYiXhEYTlTtI44GPAdQ6HcnNAWK6ydY/1VK+CXFFoMTYiTzHMLMkjCDNLckCYWVJTN63t7e2NSqXSplLMrAgDAwPrUvekbCogKpUKy5Yta01VZlYKkpI3ovYUw8ySHBBmluSAMLMkB4SZJTkgzCzJAWFmSQ4IM0tyQJhZkgPCzJIcEGaW5IAwsyQHhJklOSDMLMkBYWZJDggzS3JAmFmSA8LMkhwQZpbkgDCzJAeEmSU5IMwsyQFhZkkOCDNLckCYWZIDwsySHBBmluSAMLMkB4SZJTkgzCzJAWFmSQ4IM0tyQJhZkgPCzJIcEGaW5IAwsyQHhJklOSDMLMkBYWZJDggzS3JAmFmSA8LMkhwQZpbkgDCzJAeEmSU5IMwsyQFhZkkOCDNLckCYWZIDwsySHBBmlvSSoguwErlJzR1/ZrSnDisNjyDMLMkjiG7X7KjArIZHEGaW5BGEjZ7XLLqeRxBmluSAMLMkTzE6jRcdLUcOCMtPM+Hm9YpS8BTDzJI8gigDTxuspBwQVk6+hFoKO29AtHM+7BGBdQlF7PgPv6S1wBPtK8fMCjA5IiY12tFUQJjZzsVXMcwsyQFhZkkOCDNLckCUiKSVkjZL2ljzeFW2bzdJn5O0KjvmcUkXSFLN+YskbZG0QdIzkgYkXShptzbUukjSB+q2HSNpsG7bSZIekvSspCFJ8yX11ex/r6Q/1Tzf/5H0NUkHtrpma54DonxOjog9ah6rs+3fAmYCJwDjgbOAc4Ar687/aESMB/YB5gJnAHfWBklK9gJf1KLngaTTgJuyGnuBQ4DngMWSXllz6JKI2AN4BfBWYDMwIOl1rarFRscB0QEkzQSOA2ZHxPKI2BYRDwJzgI9IenX9ORHxbEQsAmYBRwAn5lyzgC8Cn42I+RGxOSKeAj4AbAQ+UX9ORPwpIv47Ij4M/BC4OM+a7cUcEJ3hWGBpRDxZuzEilgKDVEcWDUXEKmAZMKOtFb7Ya4D9qI58aut5AVhA9Tltz63kX7PVcUCUz0JJ67PHwmxbL7AmcfyabP/2rAb2bFF9tf69ptb1wHdr9g3X1KjuImu2JjggyueUiJiYPU7Jtq2juqbQyD7Z/u3ZF3i60Y5sEbP2BX5U3Yt+e/6uptaJwEk1+4ZralT3mGq2/DggOsM9wJsk9ddulHQ40A/clzoxO2ca8ECj/RFxWd0LfHHdi360VlCd/ryrrp5dgNnAvSOcf2qqZsuPA6IDRMQ9VF9QCyQdImlXSW8G5gNfiYjH68+RtLukvwRuBx4C7sy55gDOBz4t6UxJL5O0N3AdMAG4vEHNu0qaIunLwDHAJXnWbC/mgOgcs4H7gbuoXgW4EbgeOK/uuKskbQB+B1xBdUHw+GxxMFcR8Q2ql2M/QXVK8QvgZcCRETFUc+gRkjYCzwCLqAbI9Ih4NN+KrZ4/rGVmSR5BmFmSA8LMkhwQZpbkgDCzpKbuSdnb2xuVSqVNpZhZEQYGBtalbjnXVEBUKhWWLVvWmqrMrBQkJe8z6ymGmSU5IMwsyQFhZkkOCDNLckCYWZIDwsySHBBmluSAMLMkB4SZJTkgzCzJAWFmSQ4IM0tq6sNaZjvsppq/9Hemb2vYqTyCMLMkB4SZJTkgzCzJAWFmSQ4IM0vyVQxrndorF9YVPIIwsyQHhJklOSDMLMkBYWZJDggzS3JAmFmSA8LMkhwQZpbkgDCzJAeEmSU5IMwsyQFhZkn+sJa1n28/17E8gjCzJAeEmSU5IMwsyQFhZkkOCDNLckCYWZIvc3YoXfLi+z/GZ3wJ0VrLAWFj4xvVdjUHRAdoNFowy4PXIMwsyQFhZkkOCDNLckCYWZIXKUvGC5JWJh5BmFmSA8LMkjzF6CL10xO/s9LGyiMIM0tyQJhZkqcYBfIVCys7jyDMLMkjCMuX73DdURwQXcz3jLCx8hTDzJI8grDm+AYxOxWPIMwsyQFhZkmeYuTE73mwTuSA2Mn4yoY1wwFhI2vXwqTfE1F6Dog26LTphEcVlqKIHf9BkLQWeKJ95dALrGtj+67BNXRS/3nVMDkiJjXa0VRAtJukZRHxF67BNZShhqL7L0MNvsxpZkkOCDNLKltA/EfRBeAahrmG4vuHgmso1RqEmZVL2UYQZlYiDggzSyplQEg6T9IKST+X9PkC6zhfUkjqLaDvf5P0S0mPSLpN0sSc+j0++7//taQL8+izrv9+SfdLeiz7/n8s7xpqatlV0sOSvltQ/xMlfTv7OXhM0hF511C6gJD0V8A7gEMj4hDgCwXV0Q8cC6wqon/gB8DrIuJQ4FfAJ9vdoaRdgauBtwOvBd4j6bXt7rfONmBuRBwMvBn4SAE1DPsY8FhBfQNcCdwVEQcBbyiiltIFBHAucFlEPAcQEb8vqI7LgX8AClnFjYi7I2Jb9uWDQF8O3R4O/DoifhMRzwO3UA3r3ETEmoj4SfbvDVRfFPvmWQOApD7gROC6vPvO+p8AHA1cDxARz0fE+rzrKGNAHAjMkLRU0g8lTc+7AEmzgN9GxM/y7jvhb4Dv5dDPvsCTNV8PUsCLc5ikCvBGYGkB3V9B9RfECwX0DbA/sBb4WjbNuU7Sy/MuopAPa0m6B9i7wa6LqNb0SqrDy+nANyXtHy2+HjtCDZ8Cjmtlf83WEBG3Z8dcRHXYPb/d9QCNPmVWyAhK0h7AAuDjEfFMzn2fBPw+IgYkHZNn3zVeAkwFzouIpZKuBC4E/invInIXEW9N7ZN0LnBrFggPSXqB6gdW1uZRg6TXA1OAn0mC6tD+J5IOj4in8qihppazgZOAma0OyIRBoL/m6z5gdQ79/j+SxlENh/kRcWve/QNHArMknQD0ABMk3RgRc3KsYRAYjIjh0dO3qQZErso4xVgI/DWApAOBl5LjJ+oi4tGI+LOIqEREheo3amqrw2Ekko4H/hGYFRGbcur2x8ABkqZIeilwBvCdnPoGQNVUvh54LCK+lGffwyLikxHRl33/zwDuyzkcyH7enpT0mmzTTOAXedYA5bwfxFeBr0paDjwPnJ3Tb8+yuQrYDfhBNpJ5MCI+1M4OI2KbpI8C3wd2Bb4aET9vZ58NHAmcBTwq6afZtk9FxJ0511EG5wHzs7D+DfC+vAvwW63NLKmMUwwzKwkHhJklOSDMLMkBYWZJDggzS3JAmFmSA8LMkv4X5EUnw5YGPx4AAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 288x288 with 3 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Set up viz\n",
    "fig, [ax1,ax2,ax3] = plt.subplots(3, 1,figsize=(4,4))\n",
    "\n",
    "p1_samples = p1.sample([100])\n",
    "# p1_pdf = p1.log_pdf(p1)\n",
    "q1_samples = q1.sample([1000])\n",
    "# q1_pdf = p1.log_pdf(q1)\n",
    "\n",
    "p2_samples = p2.sample([1000])\n",
    "# p2_pdf = p2.log_pdf(p2)\n",
    "q2_samples = q2.sample([1000])\n",
    "# q2_pdf = p2.log_pdf(q2)\n",
    "\n",
    "p3_samples = p3.sample([1000])\n",
    "# p3_pdf = p3.log_pdf(p3)\n",
    "q3_samples = q3.sample([1000])\n",
    "# q3_pdf = p3.log_pdf(q3)\n",
    "\n",
    "ax1.hist(p1_samples.numpy().T, density=True, histtype='stepfilled', color='green', label='p')\n",
    "ax1.hist(q1_samples.numpy().T, density=True, histtype='stepfilled', color='orange', label='q')\n",
    "ax1.set_xlim(-7.5, 7.5)\n",
    "ax1.set_ylim(0, 2)\n",
    "ax1.set_title('FOD Only')\n",
    "\n",
    "ax2.hist(q2_samples.numpy().T, density=True, histtype='stepfilled', color='orange', label='q')\n",
    "ax2.hist(p2_samples.numpy().T, density=True, histtype='stepfilled', color='green', label='p')\n",
    "\n",
    "ax2.set_xlim(-7.5, 7.5)\n",
    "ax2.set_ylim(0, 2)\n",
    "ax2.set_title('HOD Only')\n",
    "\n",
    "ax3.hist(q3_samples.numpy().T, density=True, histtype='stepfilled', color='orange', label='q')\n",
    "ax3.hist(p3_samples.numpy().T, density=True, histtype='stepfilled', color='green', label='p')\n",
    "ax3.set_xlim(-7.5, 7.5)\n",
    "ax3.set_ylim(0, 2)\n",
    "ax3.set_title('FOD + HOD')\n",
    "# ax3.hist(p3_samples)\n",
    "\n",
    "\n",
    "ax1.get_xaxis().set_visible(False)\n",
    "ax2.get_xaxis().set_visible(False)\n",
    "ax1.get_yaxis().set_visible(False)\n",
    "ax2.get_yaxis().set_visible(False)\n",
    "ax3.get_yaxis().set_visible(False)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.subplots_adjust(wspace=0, hspace=0.3)\n",
    "ax1.legend()\n",
    "plt.draw()\n",
    "plt.savefig('fod_hod.png')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "torch1.8",
   "language": "python",
   "name": "torch1.8"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
