{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "58eb201a-6e0c-4b8d-ae9a-4416a3405168",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "e7a51eb2-a916-4481-8917-2b13b8e23097",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Dict\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import powerlaw\n",
    "import seaborn as sns\n",
    "import torch\n",
    "import torch.distributions as dist\n",
    "from matplotlib import pyplot as plt\n",
    "from tqdm.auto import tqdm\n",
    "\n",
    "import matplotlib\n",
    "matplotlib.use(\"pgf\")\n",
    "matplotlib.rcParams.update({\n",
    "    \"pgf.texsystem\": \"pdflatex\",\n",
    "    'font.family': 'serif',\n",
    "    'text.usetex': True,\n",
    "    'pgf.rcfonts': False,\n",
    "})\n",
    "\n",
    "import beanmachine.ppl as bm\n",
    "import beanmachine.ppl.experimental.gg_algebra as gga\n",
    "from beanmachine.ppl.experimental.vi.variational_world import VariationalWorld\n",
    "from beanmachine.ppl.world import World\n",
    "\n",
    "sns.set_style(\"darkgrid\")\n",
    "sns.set_context(\"paper\", font_scale = 1.)\n",
    "\n",
    "TEXTWIDTH = 362.12\n",
    "\n",
    "def set_size(width, fraction=1):\n",
    "    \"\"\" Set aesthetic figure dimensions to avoid scaling in latex.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    width: float\n",
    "            Width in pts\n",
    "    fraction: float\n",
    "            Fraction of the width which you wish the figure to occupy\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    fig_dim: tuple\n",
    "            Dimensions of figure in inches\n",
    "    \"\"\"\n",
    "    # Width of figure\n",
    "    fig_width_pt = width * fraction\n",
    "\n",
    "    # Convert from pt to inches\n",
    "    inches_per_pt = 1 / 72.27\n",
    "\n",
    "    # Golden ratio to set aesthetic figure height\n",
    "    golden_ratio = (5 ** 0.5 - 1) / 2\n",
    "\n",
    "    # Figure width in inches\n",
    "    fig_width_in = fig_width_pt * inches_per_pt\n",
    "    # Figure height in inches\n",
    "    fig_height_in = fig_width_in * golden_ratio\n",
    "\n",
    "    return fig_width_in, fig_height_in"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "da9650c9-96cd-4067-92f5-55f35912a1db",
   "metadata": {},
   "source": [
    "The empirical sampling density for $x^\\top A y$ all iid $N(0,1)$\n",
    "with $x \\in \\mathbb{R}^k$. This is a 3rd degree polynomial of normal RVs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "85ca245a-d983-4c2d-a733-6daa37388e01",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "c x^2.333 exp(-1.5 * x^0.6667)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[Text(0.5, 1.0, 'k=10')]"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def make_samples(N):\n",
    "    samples = torch.zeros(N)\n",
    "    for j in range(N):\n",
    "        samples[j] = abs(\n",
    "            np.random.randn(1, k) @ np.random.randn(k, k) @ np.random.randn(k, 1)\n",
    "        ).item()\n",
    "    return samples\n",
    "\n",
    "\n",
    "N = 5000\n",
    "\n",
    "\n",
    "fig, ax = plt.subplots(1, 1, figsize=(10, 3), sharex=False, sharey=True)\n",
    "axs = [ax]\n",
    "i = 0\n",
    "k = 10\n",
    "samples = make_samples(N)\n",
    "sns.histplot(ax=axs[i], data=samples, stat=\"density\")\n",
    "\n",
    "y = gga.gauss_ens(k, 1)\n",
    "z = gga.gauss_ens(k, 1)\n",
    "A = gga.gauss_ens(k, k)\n",
    "tail = (y.T @ A @ z).item()\n",
    "print(tail)\n",
    "q = gga.make_ggdist(tail)\n",
    "xs = torch.logspace(-25, np.log10(samples.max()), steps=1000)\n",
    "sns.lineplot(\n",
    "    ax=axs[i],\n",
    "    data=pd.DataFrame(\n",
    "        {\n",
    "            \"x\": xs,\n",
    "            \"q(x)\": q.log_prob(xs).exp(),\n",
    "        }\n",
    "    ),\n",
    "    x=\"x\",\n",
    "    y=\"q(x)\",\n",
    "    color=\"y\",\n",
    ")\n",
    "axs[i].set(title=f\"k={k}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c51f8442-e23b-4b2b-a147-b3c26f85223c",
   "metadata": {},
   "source": [
    "While our theory guarantees matching tails here, our GGA fit (yellow) is a bit off in the bulk.\n",
    "\n",
    "Let's see if we can fix it with a flow."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "25550c17-7377-4a04-838f-0cf45234053d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<seaborn.axisgrid.FacetGrid at 0x7f9cd1538160>"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "log_q = dist.TransformedDistribution(\n",
    "    q,\n",
    "    [dist.transforms.ExpTransform().inv],\n",
    ")\n",
    "base_dist = log_q\n",
    "sns.displot(data=log_q.sample((1000,)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "a54db829-5f05-4743-b722-2e3e33f6746d",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/envs/gga/lib/python3.10/site-packages/flowtorch/parameters/dense_autoregressive.py:71: UserWarning: DenseAutoregressive input_dim = 1. Consider using an affine transformation instead.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "92aaef1396384bbf8854c6435491858a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/5001 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from tqdm.auto import tqdm\n",
    "\n",
    "import flowtorch.bijectors\n",
    "import flowtorch.distributions\n",
    "import flowtorch.parameters\n",
    "\n",
    "bijectors = flowtorch.bijectors.SplineAutoregressive(bound=10.0, count_bins=10)\n",
    "flow = flowtorch.distributions.Flow(base_dist, bijectors)\n",
    "opt = torch.optim.Adam(flow.parameters(), lr=1e-3)\n",
    "for idx in (pbar := tqdm(range(5001))):\n",
    "    opt.zero_grad()\n",
    "\n",
    "    # train an approximation to log(Y) \\in R\n",
    "    # y = make_samples(1).abs()\n",
    "    y = make_samples(1).log()\n",
    "    loss = -flow.log_prob(y).mean()\n",
    "\n",
    "    if idx % 100 == 0:\n",
    "        pbar.set_description(f\"epoch: {idx}, loss: {loss}\")\n",
    "        # flow_push = dist.TransformedDistribution(flow, [dist.transforms.ExpTransform()])\n",
    "        # flow_push = dist.TransformedDistribution(flow, [gga.AbsTransform()])\n",
    "        # sns.displot(data=flow_push.sample((10,)))\n",
    "        # plt.show()\n",
    "\n",
    "    loss.backward()\n",
    "    opt.step()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "32a64ba9-89eb-4324-b010-d33e2999d466",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axes = plt.subplots(\n",
    "    2,\n",
    "    1,\n",
    "    figsize=set_size(TEXTWIDTH * 0.8),\n",
    ")\n",
    "\n",
    "\n",
    "# xs = torch.linspace(-5, 50, steps=100)\n",
    "\n",
    "# sns.lineplot(\n",
    "#     ax=axes[0],\n",
    "#     data=pd.DataFrame({\"x\": xs, \"y\": flow.bijector.forward(xs).detach()}),\n",
    "#     x=\"x\",\n",
    "#     y=\"y\",\n",
    "# )\n",
    "# axes[0].set(title=\"Learned flow correction\", yscale='log' \n",
    "\n",
    "\n",
    "xs = q.sample((5000,))\n",
    "sns.kdeplot(ax=axes[0], data=xs)\n",
    "sns.kdeplot(ax=axes[1], data=xs)\n",
    "flow_push = dist.TransformedDistribution(flow, [dist.transforms.ExpTransform()])\n",
    "xs = flow_push.sample((5000,))\n",
    "sns.kdeplot(ax=axes[0], data=xs)\n",
    "sns.kdeplot(ax=axes[1], data=xs)\n",
    "xs = make_samples(5000)\n",
    "sns.kdeplot(ax=axes[0], data=xs)\n",
    "sns.kdeplot(ax=axes[1], data=xs)\n",
    "axes[0].legend([\"q(x)\", \"flow(q)(x)\", \"target\"])\n",
    "axes[0].set(xlim=[0, xs.max()], title='Before/after flow correction')\n",
    "#axes[1].legend([\"q(x)\", \"flow(q)(x)\", \"target\"])\n",
    "axes[1].set(xlim=[1, 100], xscale='symlog', yscale='log', title='Tails preserved by Lipschitz mappings', ylim=[1e-5, 1e1])\n",
    "\n",
    "fig.tight_layout()\n",
    "fig.savefig('lip-nf.pgf')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fdb92743",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "83b589253c6ae2165fd99d3b5e434b8a0ff74c98e791d87ced25152a201010fd"
  },
  "kernelspec": {
   "display_name": "Python 3.10.4 ('gga')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
