{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "58eb201a-6e0c-4b8d-ae9a-4416a3405168",
   "metadata": {},
   "outputs": [],
   "source": [
    "# %load_ext autoreload\n",
    "# %autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "e7a51eb2-a916-4481-8917-2b13b8e23097",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Dict\n",
    "\n",
    "import pandas as pd\n",
    "import powerlaw\n",
    "import seaborn as sns\n",
    "import torch\n",
    "import torch.distributions as dist\n",
    "import torch.optim as optim\n",
    "from matplotlib import pyplot as plt\n",
    "from tqdm.auto import tqdm, trange\n",
    "\n",
    "import beanmachine.ppl as bm\n",
    "import beanmachine.ppl.experimental.gg_algebra as gga\n",
    "import flowtorch.bijectors\n",
    "import flowtorch.distributions\n",
    "import flowtorch.parameters\n",
    "from beanmachine.ppl.experimental.vi.variational_world import VariationalWorld\n",
    "from beanmachine.ppl.world import World\n",
    "\n",
    "sns.set_style(\"darkgrid\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a6319749-22d4-488f-8209-ff5aad2eaa68",
   "metadata": {
    "tags": []
   },
   "source": [
    "# Background"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e0d12f96-5fd9-4124-acf2-3769176e0b44",
   "metadata": {},
   "source": [
    "The generalized Gamma family is a three parameter $(\\nu, \\sigma, \\rho)$ equivalence class of distributions whose PDF satisfies\n",
    "\n",
    "$$p(x) \\sim C x^{\\nu} \\exp(-\\sigma x^\\rho)$$\n",
    "\n",
    "Application of the generalized Gamma algebra enables computation of an upper bound on $(\\nu, \\sigma, \\rho)$ for a target random variable of interest. In this section, we show how this auxiliary information can be leveraged to inform the design of inference algorithms.\n",
    "\n",
    "For each equivalence class $(\\nu, \\sigma, \\rho)$,\n",
    "we choose a representative $q_{\\nu,\\sigma,\\rho}(x)$ given by a pushforward of a Gamma distribution for $\\rho > 0$. For $\\rho=0$, we use the StudentT distribution.\n",
    "\n",
    "A normalizing flow is then composed on top to fix the bulk"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "d4153131-db9b-4302-8951-2974c917f985",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Symmetrize(dist.Distribution):\n",
    "    def __init__(self, q):\n",
    "        self.q = q\n",
    "        self._batch_shape = q._batch_shape\n",
    "        self._event_shape = q._event_shape\n",
    "\n",
    "    def log_prob(self, x):\n",
    "        return self.q.log_prob(x.abs()) - torch.log(torch.tensor(2.0))\n",
    "\n",
    "    def sample(self, shape=torch.Size()):\n",
    "        x = self.q.sample(shape)\n",
    "        return (2 * (torch.rand(x.shape) > 0.5) - 1) * x\n",
    "\n",
    "    def rsample(self, shape=torch.Size()):\n",
    "        x = self.q.rsample(shape)\n",
    "        return (2 * (torch.rand(x.shape) > 0.5) - 1) * x\n",
    "\n",
    "    def expand(self, shape):\n",
    "        self.q = q.expand(shape)\n",
    "        self._batch_shape = q._batch_shape\n",
    "        self._event_shape = q._event_shape\n",
    "\n",
    "    def __repr__(self):\n",
    "        return f\"Symmetrize({self.q})\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "fbd1d173-ba5c-41d8-9b33-4a83b2dff17f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_de(flow, target, N=1000, n_iter=1000, lr=5e-2):\n",
    "    optimizer = optim.Adam(flow.parameters(), lr=lr)\n",
    "    lls = torch.zeros(n_iter)\n",
    "    for i in range(n_iter):\n",
    "        optimizer.zero_grad()\n",
    "        x = target.sample((N, 1))\n",
    "        log_q = flow.log_prob(x).squeeze()\n",
    "        loss = -log_q.sum()\n",
    "        lls[i] = -loss.item()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "    return lls\n",
    "\n",
    "\n",
    "def train_vi(flow, target, N=1000, n_iter=1000, lr=5e-2):\n",
    "    optimizer = optim.Adam(flow.parameters(), lr=lr)\n",
    "    elbos = torch.zeros(n_iter)\n",
    "    lmls = torch.zeros(n_iter)\n",
    "    for i in range(n_iter):\n",
    "        optimizer.zero_grad()\n",
    "        x = flow.rsample((N, 1))\n",
    "        log_p = target.log_prob(x).squeeze()\n",
    "        log_q = flow.log_prob(x).squeeze()\n",
    "        loss = -(log_p - log_q).mean()\n",
    "        elbos[i] = -loss.item()\n",
    "        lmls[i] = torch.logsumexp(log_p - log_q, 0).item()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "    fit = powerlaw.Fit((log_p - log_q).exp().detach(), verbose=False)\n",
    "    khat = 1 / (fit.alpha - 1)\n",
    "    return elbos, lmls, khat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "id": "0c517186-7314-46f4-994c-aaece3ee8947",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "96b2469b375144769fb10fb14520fb47",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "target:   0%|          | 0/5 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1626474ef10d48da9b3cb72afc280f54",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "trial:   0%|          | 0/100 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xmin progress: 99%\r"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Less than 2 unique data values left after xmin and xmax options! Cannot fit. Returning nans.\n",
      "Less than 2 unique data values left after xmin and xmax options! Cannot fit. Returning nans.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xmin progress: 99%\r"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Less than 2 unique data values left after xmin and xmax options! Cannot fit. Returning nans.\n",
      "Less than 2 unique data values left after xmin and xmax options! Cannot fit. Returning nans.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xmin progress: 99%\r"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Less than 2 unique data values left after xmin and xmax options! Cannot fit. Returning nans.\n",
      "Less than 2 unique data values left after xmin and xmax options! Cannot fit. Returning nans.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xmin progress: 99%\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "06f2ec3a91c6447fbe0c0e5c27a0b0d3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "trial:   0%|          | 0/100 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xmin progress: 99%\r"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Less than 2 unique data values left after xmin and xmax options! Cannot fit. Returning nans.\n",
      "Less than 2 unique data values left after xmin and xmax options! Cannot fit. Returning nans.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xmin progress: 99%\r"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Less than 2 unique data values left after xmin and xmax options! Cannot fit. Returning nans.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xmin progress: 99%\r"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Less than 2 unique data values left after xmin and xmax options! Cannot fit. Returning nans.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xmin progress: 99%\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2bbbdf22d21f492ab3dde721380bcef5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "trial:   0%|          | 0/100 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xmin progress: 99%\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2e2537ec5896438da15c9a9d02fa1070",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "trial:   0%|          | 0/100 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xmin progress: 99%\r"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "edf89a18080e47bc9767703a72ab262d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "trial:   0%|          | 0/100 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xmin progress: 99%\r"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/envs/gga/lib/python3.10/site-packages/powerlaw.py:1188: RuntimeWarning: overflow encountered in double_scalars\n",
      "  return (self.alpha-1) * self.xmin**(self.alpha-1)\n",
      "/opt/conda/envs/gga/lib/python3.10/site-packages/powerlaw.py:835: RuntimeWarning: invalid value encountered in multiply\n",
      "  likelihoods = f*C\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xmin progress: 99%\r"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Less than 2 unique data values left after xmin and xmax options! Cannot fit. Returning nans.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xmin progress: 99%\r"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Less than 2 unique data values left after xmin and xmax options! Cannot fit. Returning nans.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xmin progress: 99%\r"
     ]
    }
   ],
   "source": [
    "num_trials = 100\n",
    "N = 1000  # sample batch size\n",
    "\n",
    "df = []\n",
    "for target_name, target, gga_tail in tqdm(\n",
    "    [\n",
    "        (\"Normal\", dist.Normal(loc=1.0, scale=1.0), gga.normal(1.0, 1.0)),\n",
    "        (\"Chi2\", Symmetrize(dist.Chi2(df=3.0)), gga.chi2(3)),\n",
    "        (\"StudentT\", dist.StudentT(df=2.0, loc=3.0, scale=1.0), gga.student(2)),\n",
    "        (\"Cauchy\", dist.Cauchy(loc=1.0, scale=1.0), gga.cauchy()),\n",
    "        (\n",
    "            \"InverseGamma\",\n",
    "            gga.make_positive(\n",
    "                dist.TransformedDistribution(\n",
    "                    dist.Exponential(3.0),\n",
    "                    [dist.transforms.PowerTransform(exponent=-1.0)],\n",
    "                )\n",
    "            ),\n",
    "            gga.exponential(3.0) ** -1,\n",
    "        ),\n",
    "    ],\n",
    "    desc=\"target\",\n",
    "    position=0,\n",
    "    leave=True,\n",
    "):\n",
    "    for trial in trange(num_trials, position=1, leave=False, desc=\"trial\"):\n",
    "        for i, (method_name, flow) in enumerate(\n",
    "            [\n",
    "                (\"target\", None),\n",
    "                (\n",
    "                    \"advi\",\n",
    "                    flowtorch.distributions.Flow(\n",
    "                        dist.Independent(dist.Normal(torch.zeros(1), torch.ones(1)), 1),\n",
    "                        flowtorch.bijectors.Affine(),\n",
    "                    ),\n",
    "                ),\n",
    "                (\n",
    "                    \"gga\",\n",
    "                    flowtorch.distributions.Flow(\n",
    "                        dist.Independent(\n",
    "                            Symmetrize(gga.make_ggdist(gga_tail).expand((1,))),\n",
    "                            1,\n",
    "                        ),\n",
    "                        flowtorch.bijectors.Affine(),\n",
    "                    ),\n",
    "                ),\n",
    "                (\n",
    "                    \"advi_flow\",\n",
    "                    flowtorch.distributions.Flow(\n",
    "                        dist.Independent(dist.Normal(torch.zeros(1), torch.ones(1)), 1),\n",
    "                        flowtorch.bijectors.Compose(\n",
    "                            [\n",
    "                                flowtorch.bijectors.Affine(),\n",
    "                                flowtorch.bijectors.Spline(bound=5.0, count_bins=10),\n",
    "                            ]\n",
    "                        ),\n",
    "                    ),\n",
    "                ),\n",
    "                (\n",
    "                    \"gga_flow\",\n",
    "                    flowtorch.distributions.Flow(\n",
    "                        dist.Independent(\n",
    "                            Symmetrize(gga.make_ggdist(gga_tail).expand((1,))),\n",
    "                            1,\n",
    "                        ),\n",
    "                        flowtorch.bijectors.Compose(\n",
    "                            [\n",
    "                                flowtorch.bijectors.Affine(),\n",
    "                                flowtorch.bijectors.Spline(bound=5.0, count_bins=10),\n",
    "                            ]\n",
    "                        ),\n",
    "                    ),\n",
    "                ),\n",
    "            ]\n",
    "        ):\n",
    "            if flow:\n",
    "                lls = train_de(flow, target, N, n_iter=500, lr=5e-3)\n",
    "                q_tail = powerlaw.Fit(flow.sample((N,)).squeeze().abs(), verbose=False)\n",
    "\n",
    "                elbos, lmls, khat = train_vi(flow, target, N, n_iter=500, lr=5e-3)\n",
    "\n",
    "                df.append(\n",
    "                    {\n",
    "                        \"trial\": trial,\n",
    "                        \"target\": target_name,\n",
    "                        \"method\": method_name,\n",
    "                        \"LL\": lls[-1].item(),\n",
    "                        \"ELBO\": elbos[-1].item(),\n",
    "                        \"LML\": lmls[-1].item(),\n",
    "                        \"alpha_q\": q_tail.alpha,\n",
    "                        \"khat\": khat,\n",
    "                    }\n",
    "                )\n",
    "\n",
    "                # sns.relplot(\n",
    "                #     data={\n",
    "                #         \"lls\": lls.detach(),\n",
    "                #         # \"elbos\": elbos.detach(),\n",
    "                #         # \"lmls\": lmls.detach(),\n",
    "                #     },\n",
    "                #     kind=\"line\",\n",
    "                # ).set(yscale=\"symlog\", title=f\"{method_name}\")\n",
    "            else:\n",
    "                q_tail = powerlaw.Fit(\n",
    "                    target.sample((N,)).squeeze().abs(), verbose=False\n",
    "                )\n",
    "                df.append(\n",
    "                    {\n",
    "                        \"trial\": trial,\n",
    "                        \"target\": target_name,\n",
    "                        \"method\": method_name,\n",
    "                        \"alpha_q\": q_tail.alpha,\n",
    "                    }\n",
    "                )\n",
    "\n",
    "df = pd.DataFrame(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "e187e1d2-0498-404b-8f2b-859485667d70",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style type=\"text/css\">\n",
       "#T_9d3de_row2_col0, #T_9d3de_row2_col4, #T_9d3de_row2_col6, #T_9d3de_row2_col8, #T_9d3de_row3_col2, #T_9d3de_row3_col8, #T_9d3de_row4_col0, #T_9d3de_row7_col0, #T_9d3de_row7_col2, #T_9d3de_row7_col4, #T_9d3de_row7_col8, #T_9d3de_row8_col6, #T_9d3de_row8_col8, #T_9d3de_row9_col0, #T_9d3de_row10_col6, #T_9d3de_row12_col0, #T_9d3de_row12_col4, #T_9d3de_row13_col2, #T_9d3de_row14_col0, #T_9d3de_row15_col0, #T_9d3de_row15_col4, #T_9d3de_row15_col8, #T_9d3de_row16_col6, #T_9d3de_row16_col8, #T_9d3de_row17_col8, #T_9d3de_row18_col2, #T_9d3de_row18_col8, #T_9d3de_row19_col0, #T_9d3de_row21_col2, #T_9d3de_row22_col4, #T_9d3de_row22_col8, #T_9d3de_row23_col0, #T_9d3de_row23_col6, #T_9d3de_row23_col8, #T_9d3de_row24_col0 {\n",
       "  background-color: lightgreen;\n",
       "}\n",
       "</style>\n",
       "<table id=\"T_9d3de\">\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th class=\"blank\" >&nbsp;</th>\n",
       "      <th class=\"blank level0\" >&nbsp;</th>\n",
       "      <th id=\"T_9d3de_level0_col0\" class=\"col_heading level0 col0\" colspan=\"2\">alpha_q</th>\n",
       "      <th id=\"T_9d3de_level0_col2\" class=\"col_heading level0 col2\" colspan=\"2\">LL</th>\n",
       "      <th id=\"T_9d3de_level0_col4\" class=\"col_heading level0 col4\" colspan=\"2\">ELBO</th>\n",
       "      <th id=\"T_9d3de_level0_col6\" class=\"col_heading level0 col6\" colspan=\"2\">LML</th>\n",
       "      <th id=\"T_9d3de_level0_col8\" class=\"col_heading level0 col8\" colspan=\"2\">khat</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th class=\"blank\" >&nbsp;</th>\n",
       "      <th class=\"blank level1\" >&nbsp;</th>\n",
       "      <th id=\"T_9d3de_level1_col0\" class=\"col_heading level1 col0\" >mean</th>\n",
       "      <th id=\"T_9d3de_level1_col1\" class=\"col_heading level1 col1\" >std</th>\n",
       "      <th id=\"T_9d3de_level1_col2\" class=\"col_heading level1 col2\" >mean</th>\n",
       "      <th id=\"T_9d3de_level1_col3\" class=\"col_heading level1 col3\" >std</th>\n",
       "      <th id=\"T_9d3de_level1_col4\" class=\"col_heading level1 col4\" >mean</th>\n",
       "      <th id=\"T_9d3de_level1_col5\" class=\"col_heading level1 col5\" >std</th>\n",
       "      <th id=\"T_9d3de_level1_col6\" class=\"col_heading level1 col6\" >mean</th>\n",
       "      <th id=\"T_9d3de_level1_col7\" class=\"col_heading level1 col7\" >std</th>\n",
       "      <th id=\"T_9d3de_level1_col8\" class=\"col_heading level1 col8\" >mean</th>\n",
       "      <th id=\"T_9d3de_level1_col9\" class=\"col_heading level1 col9\" >std</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th class=\"index_name level0\" >target</th>\n",
       "      <th class=\"index_name level1\" >method</th>\n",
       "      <th class=\"blank col0\" >&nbsp;</th>\n",
       "      <th class=\"blank col1\" >&nbsp;</th>\n",
       "      <th class=\"blank col2\" >&nbsp;</th>\n",
       "      <th class=\"blank col3\" >&nbsp;</th>\n",
       "      <th class=\"blank col4\" >&nbsp;</th>\n",
       "      <th class=\"blank col5\" >&nbsp;</th>\n",
       "      <th class=\"blank col6\" >&nbsp;</th>\n",
       "      <th class=\"blank col7\" >&nbsp;</th>\n",
       "      <th class=\"blank col8\" >&nbsp;</th>\n",
       "      <th class=\"blank col9\" >&nbsp;</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th id=\"T_9d3de_level0_row0\" class=\"row_heading level0 row0\" rowspan=\"5\">Cauchy</th>\n",
       "      <th id=\"T_9d3de_level1_row0\" class=\"row_heading level1 row0\" >advi</th>\n",
       "      <td id=\"T_9d3de_row0_col0\" class=\"data row0 col0\" >7.719938</td>\n",
       "      <td id=\"T_9d3de_row0_col1\" class=\"data row0 col1\" >2.461630</td>\n",
       "      <td id=\"T_9d3de_row0_col2\" class=\"data row0 col2\" >-13897743.715625</td>\n",
       "      <td id=\"T_9d3de_row0_col3\" class=\"data row0 col3\" >62249288.106553</td>\n",
       "      <td id=\"T_9d3de_row0_col4\" class=\"data row0 col4\" >-0.187442</td>\n",
       "      <td id=\"T_9d3de_row0_col5\" class=\"data row0 col5\" >0.010712</td>\n",
       "      <td id=\"T_9d3de_row0_col6\" class=\"data row0 col6\" >6.799048</td>\n",
       "      <td id=\"T_9d3de_row0_col7\" class=\"data row0 col7\" >0.030728</td>\n",
       "      <td id=\"T_9d3de_row0_col8\" class=\"data row0 col8\" >0.463318</td>\n",
       "      <td id=\"T_9d3de_row0_col9\" class=\"data row0 col9\" >0.125606</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_9d3de_level1_row1\" class=\"row_heading level1 row1\" >advi_flow</th>\n",
       "      <td id=\"T_9d3de_row1_col0\" class=\"data row1 col0\" >7.070086</td>\n",
       "      <td id=\"T_9d3de_row1_col1\" class=\"data row1 col1\" >6.568098</td>\n",
       "      <td id=\"T_9d3de_row1_col2\" class=\"data row1 col2\" >-53026542479.916092</td>\n",
       "      <td id=\"T_9d3de_row1_col3\" class=\"data row1 col3\" >264830808888.565887</td>\n",
       "      <td id=\"T_9d3de_row1_col4\" class=\"data row1 col4\" >-0.103038</td>\n",
       "      <td id=\"T_9d3de_row1_col5\" class=\"data row1 col5\" >0.028352</td>\n",
       "      <td id=\"T_9d3de_row1_col6\" class=\"data row1 col6\" >6.854181</td>\n",
       "      <td id=\"T_9d3de_row1_col7\" class=\"data row1 col7\" >0.153292</td>\n",
       "      <td id=\"T_9d3de_row1_col8\" class=\"data row1 col8\" >0.346866</td>\n",
       "      <td id=\"T_9d3de_row1_col9\" class=\"data row1 col9\" >0.432249</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_9d3de_level1_row2\" class=\"row_heading level1 row2\" >gga</th>\n",
       "      <td id=\"T_9d3de_row2_col0\" class=\"data row2 col0\" >2.051607</td>\n",
       "      <td id=\"T_9d3de_row2_col1\" class=\"data row2 col1\" >0.064027</td>\n",
       "      <td id=\"T_9d3de_row2_col2\" class=\"data row2 col2\" >-3926.671270</td>\n",
       "      <td id=\"T_9d3de_row2_col3\" class=\"data row2 col3\" >56.261533</td>\n",
       "      <td id=\"T_9d3de_row2_col4\" class=\"data row2 col4\" >1.386296</td>\n",
       "      <td id=\"T_9d3de_row2_col5\" class=\"data row2 col5\" >0.000271</td>\n",
       "      <td id=\"T_9d3de_row2_col6\" class=\"data row2 col6\" >8.294094</td>\n",
       "      <td id=\"T_9d3de_row2_col7\" class=\"data row2 col7\" >0.000283</td>\n",
       "      <td id=\"T_9d3de_row2_col8\" class=\"data row2 col8\" >0.011495</td>\n",
       "      <td id=\"T_9d3de_row2_col9\" class=\"data row2 col9\" >0.006257</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_9d3de_level1_row3\" class=\"row_heading level1 row3\" >gga_flow</th>\n",
       "      <td id=\"T_9d3de_row3_col0\" class=\"data row3 col0\" >2.047628</td>\n",
       "      <td id=\"T_9d3de_row3_col1\" class=\"data row3 col1\" >0.066727</td>\n",
       "      <td id=\"T_9d3de_row3_col2\" class=\"data row3 col2\" >-3918.399521</td>\n",
       "      <td id=\"T_9d3de_row3_col3\" class=\"data row3 col3\" >54.791068</td>\n",
       "      <td id=\"T_9d3de_row3_col4\" class=\"data row3 col4\" >1.385121</td>\n",
       "      <td id=\"T_9d3de_row3_col5\" class=\"data row3 col5\" >0.001526</td>\n",
       "      <td id=\"T_9d3de_row3_col6\" class=\"data row3 col6\" >8.293885</td>\n",
       "      <td id=\"T_9d3de_row3_col7\" class=\"data row3 col7\" >0.001533</td>\n",
       "      <td id=\"T_9d3de_row3_col8\" class=\"data row3 col8\" >0.033853</td>\n",
       "      <td id=\"T_9d3de_row3_col9\" class=\"data row3 col9\" >0.010293</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_9d3de_level1_row4\" class=\"row_heading level1 row4\" >target</th>\n",
       "      <td id=\"T_9d3de_row4_col0\" class=\"data row4 col0\" >2.052146</td>\n",
       "      <td id=\"T_9d3de_row4_col1\" class=\"data row4 col1\" >0.075131</td>\n",
       "      <td id=\"T_9d3de_row4_col2\" class=\"data row4 col2\" >nan</td>\n",
       "      <td id=\"T_9d3de_row4_col3\" class=\"data row4 col3\" >nan</td>\n",
       "      <td id=\"T_9d3de_row4_col4\" class=\"data row4 col4\" >nan</td>\n",
       "      <td id=\"T_9d3de_row4_col5\" class=\"data row4 col5\" >nan</td>\n",
       "      <td id=\"T_9d3de_row4_col6\" class=\"data row4 col6\" >nan</td>\n",
       "      <td id=\"T_9d3de_row4_col7\" class=\"data row4 col7\" >nan</td>\n",
       "      <td id=\"T_9d3de_row4_col8\" class=\"data row4 col8\" >nan</td>\n",
       "      <td id=\"T_9d3de_row4_col9\" class=\"data row4 col9\" >nan</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_9d3de_level0_row5\" class=\"row_heading level0 row5\" rowspan=\"5\">Chi2</th>\n",
       "      <th id=\"T_9d3de_level1_row5\" class=\"row_heading level1 row5\" >advi</th>\n",
       "      <td id=\"T_9d3de_row5_col0\" class=\"data row5 col0\" >6.833164</td>\n",
       "      <td id=\"T_9d3de_row5_col1\" class=\"data row5 col1\" >2.401425</td>\n",
       "      <td id=\"T_9d3de_row5_col2\" class=\"data row5 col2\" >-2844.304668</td>\n",
       "      <td id=\"T_9d3de_row5_col3\" class=\"data row5 col3\" >38.093052</td>\n",
       "      <td id=\"T_9d3de_row5_col4\" class=\"data row5 col4\" >-0.024114</td>\n",
       "      <td id=\"T_9d3de_row5_col5\" class=\"data row5 col5\" >0.007210</td>\n",
       "      <td id=\"T_9d3de_row5_col6\" class=\"data row5 col6\" >6.905364</td>\n",
       "      <td id=\"T_9d3de_row5_col7\" class=\"data row5 col7\" >0.006591</td>\n",
       "      <td id=\"T_9d3de_row5_col8\" class=\"data row5 col8\" >0.255072</td>\n",
       "      <td id=\"T_9d3de_row5_col9\" class=\"data row5 col9\" >0.093515</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_9d3de_level1_row6\" class=\"row_heading level1 row6\" >advi_flow</th>\n",
       "      <td id=\"T_9d3de_row6_col0\" class=\"data row6 col0\" >6.386599</td>\n",
       "      <td id=\"T_9d3de_row6_col1\" class=\"data row6 col1\" >0.884774</td>\n",
       "      <td id=\"T_9d3de_row6_col2\" class=\"data row6 col2\" >-2875.030059</td>\n",
       "      <td id=\"T_9d3de_row6_col3\" class=\"data row6 col3\" >55.151399</td>\n",
       "      <td id=\"T_9d3de_row6_col4\" class=\"data row6 col4\" >-0.046185</td>\n",
       "      <td id=\"T_9d3de_row6_col5\" class=\"data row6 col5\" >0.033645</td>\n",
       "      <td id=\"T_9d3de_row6_col6\" class=\"data row6 col6\" >6.906127</td>\n",
       "      <td id=\"T_9d3de_row6_col7\" class=\"data row6 col7\" >0.009802</td>\n",
       "      <td id=\"T_9d3de_row6_col8\" class=\"data row6 col8\" >0.225182</td>\n",
       "      <td id=\"T_9d3de_row6_col9\" class=\"data row6 col9\" >0.119412</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_9d3de_level1_row7\" class=\"row_heading level1 row7\" >gga</th>\n",
       "      <td id=\"T_9d3de_row7_col0\" class=\"data row7 col0\" >5.509403</td>\n",
       "      <td id=\"T_9d3de_row7_col1\" class=\"data row7 col1\" >1.197632</td>\n",
       "      <td id=\"T_9d3de_row7_col2\" class=\"data row7 col2\" >-2754.785361</td>\n",
       "      <td id=\"T_9d3de_row7_col3\" class=\"data row7 col3\" >26.013949</td>\n",
       "      <td id=\"T_9d3de_row7_col4\" class=\"data row7 col4\" >-0.002228</td>\n",
       "      <td id=\"T_9d3de_row7_col5\" class=\"data row7 col5\" >0.003437</td>\n",
       "      <td id=\"T_9d3de_row7_col6\" class=\"data row7 col6\" >6.907670</td>\n",
       "      <td id=\"T_9d3de_row7_col7\" class=\"data row7 col7\" >0.001647</td>\n",
       "      <td id=\"T_9d3de_row7_col8\" class=\"data row7 col8\" >0.074997</td>\n",
       "      <td id=\"T_9d3de_row7_col9\" class=\"data row7 col9\" >0.070365</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_9d3de_level1_row8\" class=\"row_heading level1 row8\" >gga_flow</th>\n",
       "      <td id=\"T_9d3de_row8_col0\" class=\"data row8 col0\" >5.241649</td>\n",
       "      <td id=\"T_9d3de_row8_col1\" class=\"data row8 col1\" >1.556106</td>\n",
       "      <td id=\"T_9d3de_row8_col2\" class=\"data row8 col2\" >-2777.530225</td>\n",
       "      <td id=\"T_9d3de_row8_col3\" class=\"data row8 col3\" >44.213584</td>\n",
       "      <td id=\"T_9d3de_row8_col4\" class=\"data row8 col4\" >-0.030770</td>\n",
       "      <td id=\"T_9d3de_row8_col5\" class=\"data row8 col5\" >0.030821</td>\n",
       "      <td id=\"T_9d3de_row8_col6\" class=\"data row8 col6\" >6.908253</td>\n",
       "      <td id=\"T_9d3de_row8_col7\" class=\"data row8 col7\" >0.006677</td>\n",
       "      <td id=\"T_9d3de_row8_col8\" class=\"data row8 col8\" >0.136386</td>\n",
       "      <td id=\"T_9d3de_row8_col9\" class=\"data row8 col9\" >0.103602</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_9d3de_level1_row9\" class=\"row_heading level1 row9\" >target</th>\n",
       "      <td id=\"T_9d3de_row9_col0\" class=\"data row9 col0\" >5.458437</td>\n",
       "      <td id=\"T_9d3de_row9_col1\" class=\"data row9 col1\" >1.386474</td>\n",
       "      <td id=\"T_9d3de_row9_col2\" class=\"data row9 col2\" >nan</td>\n",
       "      <td id=\"T_9d3de_row9_col3\" class=\"data row9 col3\" >nan</td>\n",
       "      <td id=\"T_9d3de_row9_col4\" class=\"data row9 col4\" >nan</td>\n",
       "      <td id=\"T_9d3de_row9_col5\" class=\"data row9 col5\" >nan</td>\n",
       "      <td id=\"T_9d3de_row9_col6\" class=\"data row9 col6\" >nan</td>\n",
       "      <td id=\"T_9d3de_row9_col7\" class=\"data row9 col7\" >nan</td>\n",
       "      <td id=\"T_9d3de_row9_col8\" class=\"data row9 col8\" >nan</td>\n",
       "      <td id=\"T_9d3de_row9_col9\" class=\"data row9 col9\" >nan</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_9d3de_level0_row10\" class=\"row_heading level0 row10\" rowspan=\"5\">InverseGamma</th>\n",
       "      <th id=\"T_9d3de_level1_row10\" class=\"row_heading level1 row10\" >advi</th>\n",
       "      <td id=\"T_9d3de_row10_col0\" class=\"data row10 col0\" >7.306839</td>\n",
       "      <td id=\"T_9d3de_row10_col1\" class=\"data row10 col1\" >1.730365</td>\n",
       "      <td id=\"T_9d3de_row10_col2\" class=\"data row10 col2\" >-144297343.462500</td>\n",
       "      <td id=\"T_9d3de_row10_col3\" class=\"data row10 col3\" >622376889.709392</td>\n",
       "      <td id=\"T_9d3de_row10_col4\" class=\"data row10 col4\" >-0.626661</td>\n",
       "      <td id=\"T_9d3de_row10_col5\" class=\"data row10 col5\" >6.528261</td>\n",
       "      <td id=\"T_9d3de_row10_col6\" class=\"data row10 col6\" >1977.128919</td>\n",
       "      <td id=\"T_9d3de_row10_col7\" class=\"data row10 col7\" >3880.206355</td>\n",
       "      <td id=\"T_9d3de_row10_col8\" class=\"data row10 col8\" >12.502074</td>\n",
       "      <td id=\"T_9d3de_row10_col9\" class=\"data row10 col9\" >3.379253</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_9d3de_level1_row11\" class=\"row_heading level1 row11\" >advi_flow</th>\n",
       "      <td id=\"T_9d3de_row11_col0\" class=\"data row11 col0\" >26.654599</td>\n",
       "      <td id=\"T_9d3de_row11_col1\" class=\"data row11 col1\" >39.127147</td>\n",
       "      <td id=\"T_9d3de_row11_col2\" class=\"data row11 col2\" >-4276563562.820000</td>\n",
       "      <td id=\"T_9d3de_row11_col3\" class=\"data row11 col3\" >20533001026.572933</td>\n",
       "      <td id=\"T_9d3de_row11_col4\" class=\"data row11 col4\" >-1.531380</td>\n",
       "      <td id=\"T_9d3de_row11_col5\" class=\"data row11 col5\" >0.100064</td>\n",
       "      <td id=\"T_9d3de_row11_col6\" class=\"data row11 col6\" >10.656070</td>\n",
       "      <td id=\"T_9d3de_row11_col7\" class=\"data row11 col7\" >23.305825</td>\n",
       "      <td id=\"T_9d3de_row11_col8\" class=\"data row11 col8\" >0.627758</td>\n",
       "      <td id=\"T_9d3de_row11_col9\" class=\"data row11 col9\" >0.553000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_9d3de_level1_row12\" class=\"row_heading level1 row12\" >gga</th>\n",
       "      <td id=\"T_9d3de_row12_col0\" class=\"data row12 col0\" >1.930956</td>\n",
       "      <td id=\"T_9d3de_row12_col1\" class=\"data row12 col1\" >0.092449</td>\n",
       "      <td id=\"T_9d3de_row12_col2\" class=\"data row12 col2\" >-3950.807988</td>\n",
       "      <td id=\"T_9d3de_row12_col3\" class=\"data row12 col3\" >53.523370</td>\n",
       "      <td id=\"T_9d3de_row12_col4\" class=\"data row12 col4\" >0.441069</td>\n",
       "      <td id=\"T_9d3de_row12_col5\" class=\"data row12 col5\" >4.222806</td>\n",
       "      <td id=\"T_9d3de_row12_col6\" class=\"data row12 col6\" >953.613516</td>\n",
       "      <td id=\"T_9d3de_row12_col7\" class=\"data row12 col7\" >1610.192384</td>\n",
       "      <td id=\"T_9d3de_row12_col8\" class=\"data row12 col8\" >11.193838</td>\n",
       "      <td id=\"T_9d3de_row12_col9\" class=\"data row12 col9\" >3.170087</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_9d3de_level1_row13\" class=\"row_heading level1 row13\" >gga_flow</th>\n",
       "      <td id=\"T_9d3de_row13_col0\" class=\"data row13 col0\" >1.938580</td>\n",
       "      <td id=\"T_9d3de_row13_col1\" class=\"data row13 col1\" >0.092345</td>\n",
       "      <td id=\"T_9d3de_row13_col2\" class=\"data row13 col2\" >-3932.185518</td>\n",
       "      <td id=\"T_9d3de_row13_col3\" class=\"data row13 col3\" >46.659545</td>\n",
       "      <td id=\"T_9d3de_row13_col4\" class=\"data row13 col4\" >-0.141594</td>\n",
       "      <td id=\"T_9d3de_row13_col5\" class=\"data row13 col5\" >0.897872</td>\n",
       "      <td id=\"T_9d3de_row13_col6\" class=\"data row13 col6\" >157.159852</td>\n",
       "      <td id=\"T_9d3de_row13_col7\" class=\"data row13 col7\" >155.667465</td>\n",
       "      <td id=\"T_9d3de_row13_col8\" class=\"data row13 col8\" >5.670329</td>\n",
       "      <td id=\"T_9d3de_row13_col9\" class=\"data row13 col9\" >5.705646</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_9d3de_level1_row14\" class=\"row_heading level1 row14\" >target</th>\n",
       "      <td id=\"T_9d3de_row14_col0\" class=\"data row14 col0\" >1.916284</td>\n",
       "      <td id=\"T_9d3de_row14_col1\" class=\"data row14 col1\" >0.079041</td>\n",
       "      <td id=\"T_9d3de_row14_col2\" class=\"data row14 col2\" >nan</td>\n",
       "      <td id=\"T_9d3de_row14_col3\" class=\"data row14 col3\" >nan</td>\n",
       "      <td id=\"T_9d3de_row14_col4\" class=\"data row14 col4\" >nan</td>\n",
       "      <td id=\"T_9d3de_row14_col5\" class=\"data row14 col5\" >nan</td>\n",
       "      <td id=\"T_9d3de_row14_col6\" class=\"data row14 col6\" >nan</td>\n",
       "      <td id=\"T_9d3de_row14_col7\" class=\"data row14 col7\" >nan</td>\n",
       "      <td id=\"T_9d3de_row14_col8\" class=\"data row14 col8\" >nan</td>\n",
       "      <td id=\"T_9d3de_row14_col9\" class=\"data row14 col9\" >nan</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_9d3de_level0_row15\" class=\"row_heading level0 row15\" rowspan=\"5\">Normal</th>\n",
       "      <th id=\"T_9d3de_level1_row15\" class=\"row_heading level1 row15\" >advi</th>\n",
       "      <td id=\"T_9d3de_row15_col0\" class=\"data row15 col0\" >8.381567</td>\n",
       "      <td id=\"T_9d3de_row15_col1\" class=\"data row15 col1\" >3.471436</td>\n",
       "      <td id=\"T_9d3de_row15_col2\" class=\"data row15 col2\" >-1424.329668</td>\n",
       "      <td id=\"T_9d3de_row15_col3\" class=\"data row15 col3\" >19.327368</td>\n",
       "      <td id=\"T_9d3de_row15_col4\" class=\"data row15 col4\" >-0.000155</td>\n",
       "      <td id=\"T_9d3de_row15_col5\" class=\"data row15 col5\" >0.000538</td>\n",
       "      <td id=\"T_9d3de_row15_col6\" class=\"data row15 col6\" >6.907674</td>\n",
       "      <td id=\"T_9d3de_row15_col7\" class=\"data row15 col7\" >0.000502</td>\n",
       "      <td id=\"T_9d3de_row15_col8\" class=\"data row15 col8\" >0.005543</td>\n",
       "      <td id=\"T_9d3de_row15_col9\" class=\"data row15 col9\" >0.008168</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_9d3de_level1_row16\" class=\"row_heading level1 row16\" >advi_flow</th>\n",
       "      <td id=\"T_9d3de_row16_col0\" class=\"data row16 col0\" >8.808349</td>\n",
       "      <td id=\"T_9d3de_row16_col1\" class=\"data row16 col1\" >4.638696</td>\n",
       "      <td id=\"T_9d3de_row16_col2\" class=\"data row16 col2\" >-1418.917900</td>\n",
       "      <td id=\"T_9d3de_row16_col3\" class=\"data row16 col3\" >19.030661</td>\n",
       "      <td id=\"T_9d3de_row16_col4\" class=\"data row16 col4\" >-0.000377</td>\n",
       "      <td id=\"T_9d3de_row16_col5\" class=\"data row16 col5\" >0.001300</td>\n",
       "      <td id=\"T_9d3de_row16_col6\" class=\"data row16 col6\" >6.908054</td>\n",
       "      <td id=\"T_9d3de_row16_col7\" class=\"data row16 col7\" >0.001270</td>\n",
       "      <td id=\"T_9d3de_row16_col8\" class=\"data row16 col8\" >0.021956</td>\n",
       "      <td id=\"T_9d3de_row16_col9\" class=\"data row16 col9\" >0.016775</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_9d3de_level1_row17\" class=\"row_heading level1 row17\" >gga</th>\n",
       "      <td id=\"T_9d3de_row17_col0\" class=\"data row17 col0\" >8.772573</td>\n",
       "      <td id=\"T_9d3de_row17_col1\" class=\"data row17 col1\" >2.820500</td>\n",
       "      <td id=\"T_9d3de_row17_col2\" class=\"data row17 col2\" >-1418.871104</td>\n",
       "      <td id=\"T_9d3de_row17_col3\" class=\"data row17 col3\" >21.452965</td>\n",
       "      <td id=\"T_9d3de_row17_col4\" class=\"data row17 col4\" >-0.000168</td>\n",
       "      <td id=\"T_9d3de_row17_col5\" class=\"data row17 col5\" >0.000570</td>\n",
       "      <td id=\"T_9d3de_row17_col6\" class=\"data row17 col6\" >6.907696</td>\n",
       "      <td id=\"T_9d3de_row17_col7\" class=\"data row17 col7\" >0.000549</td>\n",
       "      <td id=\"T_9d3de_row17_col8\" class=\"data row17 col8\" >0.007343</td>\n",
       "      <td id=\"T_9d3de_row17_col9\" class=\"data row17 col9\" >0.007484</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_9d3de_level1_row18\" class=\"row_heading level1 row18\" >gga_flow</th>\n",
       "      <td id=\"T_9d3de_row18_col0\" class=\"data row18 col0\" >8.232020</td>\n",
       "      <td id=\"T_9d3de_row18_col1\" class=\"data row18 col1\" >4.008672</td>\n",
       "      <td id=\"T_9d3de_row18_col2\" class=\"data row18 col2\" >-1414.655181</td>\n",
       "      <td id=\"T_9d3de_row18_col3\" class=\"data row18 col3\" >24.053614</td>\n",
       "      <td id=\"T_9d3de_row18_col4\" class=\"data row18 col4\" >-0.000710</td>\n",
       "      <td id=\"T_9d3de_row18_col5\" class=\"data row18 col5\" >0.001028</td>\n",
       "      <td id=\"T_9d3de_row18_col6\" class=\"data row18 col6\" >6.907587</td>\n",
       "      <td id=\"T_9d3de_row18_col7\" class=\"data row18 col7\" >0.000941</td>\n",
       "      <td id=\"T_9d3de_row18_col8\" class=\"data row18 col8\" >0.016847</td>\n",
       "      <td id=\"T_9d3de_row18_col9\" class=\"data row18 col9\" >0.014443</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_9d3de_level1_row19\" class=\"row_heading level1 row19\" >target</th>\n",
       "      <td id=\"T_9d3de_row19_col0\" class=\"data row19 col0\" >8.475122</td>\n",
       "      <td id=\"T_9d3de_row19_col1\" class=\"data row19 col1\" >3.077439</td>\n",
       "      <td id=\"T_9d3de_row19_col2\" class=\"data row19 col2\" >nan</td>\n",
       "      <td id=\"T_9d3de_row19_col3\" class=\"data row19 col3\" >nan</td>\n",
       "      <td id=\"T_9d3de_row19_col4\" class=\"data row19 col4\" >nan</td>\n",
       "      <td id=\"T_9d3de_row19_col5\" class=\"data row19 col5\" >nan</td>\n",
       "      <td id=\"T_9d3de_row19_col6\" class=\"data row19 col6\" >nan</td>\n",
       "      <td id=\"T_9d3de_row19_col7\" class=\"data row19 col7\" >nan</td>\n",
       "      <td id=\"T_9d3de_row19_col8\" class=\"data row19 col8\" >nan</td>\n",
       "      <td id=\"T_9d3de_row19_col9\" class=\"data row19 col9\" >nan</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_9d3de_level0_row20\" class=\"row_heading level0 row20\" rowspan=\"5\">StudentT</th>\n",
       "      <th id=\"T_9d3de_level1_row20\" class=\"row_heading level1 row20\" >advi</th>\n",
       "      <td id=\"T_9d3de_row20_col0\" class=\"data row20 col0\" >7.670487</td>\n",
       "      <td id=\"T_9d3de_row20_col1\" class=\"data row20 col1\" >2.257001</td>\n",
       "      <td id=\"T_9d3de_row20_col2\" class=\"data row20 col2\" >-2968.608047</td>\n",
       "      <td id=\"T_9d3de_row20_col3\" class=\"data row20 col3\" >469.677964</td>\n",
       "      <td id=\"T_9d3de_row20_col4\" class=\"data row20 col4\" >-0.072461</td>\n",
       "      <td id=\"T_9d3de_row20_col5\" class=\"data row20 col5\" >0.009904</td>\n",
       "      <td id=\"T_9d3de_row20_col6\" class=\"data row20 col6\" >6.895080</td>\n",
       "      <td id=\"T_9d3de_row20_col7\" class=\"data row20 col7\" >0.057844</td>\n",
       "      <td id=\"T_9d3de_row20_col8\" class=\"data row20 col8\" >0.525770</td>\n",
       "      <td id=\"T_9d3de_row20_col9\" class=\"data row20 col9\" >0.172238</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_9d3de_level1_row21\" class=\"row_heading level1 row21\" >advi_flow</th>\n",
       "      <td id=\"T_9d3de_row21_col0\" class=\"data row21 col0\" >12.820791</td>\n",
       "      <td id=\"T_9d3de_row21_col1\" class=\"data row21 col1\" >11.201760</td>\n",
       "      <td id=\"T_9d3de_row21_col2\" class=\"data row21 col2\" >-2686.158994</td>\n",
       "      <td id=\"T_9d3de_row21_col3\" class=\"data row21 col3\" >644.280795</td>\n",
       "      <td id=\"T_9d3de_row21_col4\" class=\"data row21 col4\" >-0.017273</td>\n",
       "      <td id=\"T_9d3de_row21_col5\" class=\"data row21 col5\" >0.002529</td>\n",
       "      <td id=\"T_9d3de_row21_col6\" class=\"data row21 col6\" >6.900762</td>\n",
       "      <td id=\"T_9d3de_row21_col7\" class=\"data row21 col7\" >0.010287</td>\n",
       "      <td id=\"T_9d3de_row21_col8\" class=\"data row21 col8\" >0.210844</td>\n",
       "      <td id=\"T_9d3de_row21_col9\" class=\"data row21 col9\" >0.263434</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_9d3de_level1_row22\" class=\"row_heading level1 row22\" >gga</th>\n",
       "      <td id=\"T_9d3de_row22_col0\" class=\"data row22 col0\" >3.128015</td>\n",
       "      <td id=\"T_9d3de_row22_col1\" class=\"data row22 col1\" >0.159919</td>\n",
       "      <td id=\"T_9d3de_row22_col2\" class=\"data row22 col2\" >-3612.014688</td>\n",
       "      <td id=\"T_9d3de_row22_col3\" class=\"data row22 col3\" >28.071365</td>\n",
       "      <td id=\"T_9d3de_row22_col4\" class=\"data row22 col4\" >1.386293</td>\n",
       "      <td id=\"T_9d3de_row22_col5\" class=\"data row22 col5\" >0.000116</td>\n",
       "      <td id=\"T_9d3de_row22_col6\" class=\"data row22 col6\" >8.294061</td>\n",
       "      <td id=\"T_9d3de_row22_col7\" class=\"data row22 col7\" >0.000117</td>\n",
       "      <td id=\"T_9d3de_row22_col8\" class=\"data row22 col8\" >0.002202</td>\n",
       "      <td id=\"T_9d3de_row22_col9\" class=\"data row22 col9\" >0.003228</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_9d3de_level1_row23\" class=\"row_heading level1 row23\" >gga_flow</th>\n",
       "      <td id=\"T_9d3de_row23_col0\" class=\"data row23 col0\" >3.301435</td>\n",
       "      <td id=\"T_9d3de_row23_col1\" class=\"data row23 col1\" >0.449897</td>\n",
       "      <td id=\"T_9d3de_row23_col2\" class=\"data row23 col2\" >-3380.473779</td>\n",
       "      <td id=\"T_9d3de_row23_col3\" class=\"data row23 col3\" >42.439409</td>\n",
       "      <td id=\"T_9d3de_row23_col4\" class=\"data row23 col4\" >1.378258</td>\n",
       "      <td id=\"T_9d3de_row23_col5\" class=\"data row23 col5\" >0.005217</td>\n",
       "      <td id=\"T_9d3de_row23_col6\" class=\"data row23 col6\" >8.294360</td>\n",
       "      <td id=\"T_9d3de_row23_col7\" class=\"data row23 col7\" >0.005155</td>\n",
       "      <td id=\"T_9d3de_row23_col8\" class=\"data row23 col8\" >0.124299</td>\n",
       "      <td id=\"T_9d3de_row23_col9\" class=\"data row23 col9\" >0.063872</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_9d3de_level1_row24\" class=\"row_heading level1 row24\" >target</th>\n",
       "      <td id=\"T_9d3de_row24_col0\" class=\"data row24 col0\" >4.261231</td>\n",
       "      <td id=\"T_9d3de_row24_col1\" class=\"data row24 col1\" >0.294786</td>\n",
       "      <td id=\"T_9d3de_row24_col2\" class=\"data row24 col2\" >nan</td>\n",
       "      <td id=\"T_9d3de_row24_col3\" class=\"data row24 col3\" >nan</td>\n",
       "      <td id=\"T_9d3de_row24_col4\" class=\"data row24 col4\" >nan</td>\n",
       "      <td id=\"T_9d3de_row24_col5\" class=\"data row24 col5\" >nan</td>\n",
       "      <td id=\"T_9d3de_row24_col6\" class=\"data row24 col6\" >nan</td>\n",
       "      <td id=\"T_9d3de_row24_col7\" class=\"data row24 col7\" >nan</td>\n",
       "      <td id=\"T_9d3de_row24_col8\" class=\"data row24 col8\" >nan</td>\n",
       "      <td id=\"T_9d3de_row24_col9\" class=\"data row24 col9\" >nan</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n"
      ],
      "text/plain": [
       "<pandas.io.formats.style.Styler at 0x7fddf26e3fa0>"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import numpy as np\n",
    "\n",
    "def highlight_max_by_target(s):\n",
    "    if s.name[1] == \"std\":\n",
    "        return\n",
    "    if s.name[0] in [\"LL\", \"ELBO\", \"LML\"]:\n",
    "        is_large = s.groupby(\"target\").max().values\n",
    "        return [\"background-color: lightgreen\" if v in is_large else \"\" for v in s]\n",
    "    elif s.name[0] == \"alpha_q\":\n",
    "        df = pd.DataFrame(s)\n",
    "        df = df.xs(\"target\", level=\"method\").join(df, lsuffix=\"_target\")\n",
    "        is_close = (\n",
    "            (df[\"alpha_q\"] - df[\"alpha_q_target\"])[\"mean\"]\n",
    "            .abs()\n",
    "            .groupby(\"target\")\n",
    "            .nsmallest(2)\n",
    "            .index.droplevel(0)\n",
    "        )\n",
    "        return [\n",
    "            \"background-color: lightgreen\"\n",
    "            if any(map(lambda x: x == df.index[i], is_close.values))\n",
    "            else \"\"\n",
    "            for i, _ in enumerate(s)\n",
    "        ]\n",
    "    elif s.name[0] == \"khat\":\n",
    "        return [\"background-color: lightgreen\" if v < 0.2 else \"\" for v in s]\n",
    "    return [\"\" for v in s]\n",
    "\n",
    "\n",
    "df.loc[:, ~df.columns.isin([\"trial\"])].groupby([\"target\", \"method\"]).aggregate(\n",
    "    [np.mean, np.std]\n",
    ").style.apply(highlight_max_by_target)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 84,
   "id": "fa981dde-331e-408b-96ce-8b270d3fc58b",
   "metadata": {},
   "outputs": [],
   "source": [
    "#pd.set_option(\"display.float_format\", \"{:.5f}\".format)\n",
    "\n",
    "cols = [\"LL\", \"alpha_q\", \"ELBO\", \"LML\", \"khat\"]\n",
    "new_df = (\n",
    "    df.loc[:, ~df.columns.isin([\"trial\"])]\n",
    "    .groupby([\"target\", \"method\"])\n",
    "    .aggregate([np.mean, np.std])[cols]\n",
    "    .dropna()\n",
    ")\n",
    "\n",
    "new_df = pd.DataFrame(\n",
    "    {\n",
    "        c: new_df.xs(c, axis=1).apply(\n",
    "            lambda x: f\"{x['mean']:.2g} ({x['std']:.2g})\",\n",
    "            axis=1,\n",
    "        )\n",
    "        for c in cols\n",
    "    }\n",
    ")\n",
    "new_df = new_df.reset_index().pivot(index=\"method\", columns=[\"target\"])\n",
    "\n",
    "new_df.index = new_df.index.set_names([\"Method\"])\n",
    "new_df.columns = new_df.columns.set_names([\"Metric\", \"Target\"]).set_levels(\n",
    "    [\n",
    "        \"\\\\shortstack[l]{Cauchy \\\\\\\\($\\\\alpha=2$)}\",\n",
    "        \"\\\\shortstack[l]{Chi2 \\\\\\\\($\\\\alpha=\\\\infty$)}\",\n",
    "        \"\\\\shortstack[l]{IG \\\\\\\\($\\\\alpha=2$)}\",\n",
    "        \"\\\\shortstack[l]{Normal \\\\\\\\($\\\\alpha=\\\\infty$)}\",\n",
    "        \"\\\\shortstack[l]{StudentT \\\\\\\\($\\\\alpha=3$)}\",\n",
    "    ],\n",
    "    level=1,\n",
    ")\n",
    "new_df.columns = new_df.columns.set_levels([\"-H(p,q)\", \"$\\\\hat{\\\\alpha}$\", \"ELBO\", \"IWAE\", \"$\\\\hat{k}$\"], level=0)\n",
    "new_df = new_df.rename(\n",
    "    index={\n",
    "        \"advi\": \"Normal Affine\",\n",
    "        \"advi_flow\": \"Normal Flow\",\n",
    "        \"gga\": \"GGA Affine\",\n",
    "        \"gga_flow\": \"GGA Flow\",\n",
    "    }\n",
    ").sort_index(axis=1, level=1).reorder_levels(order=[1, 0], axis=1).T\n",
    "\n",
    "def highlight_max(x):\n",
    "    x_num = x.apply(lambda y: float(y.split(\" \")[0]))\n",
    "    if x.name[1] in ['ELBO', 'IWAE', '-H(p,q)']:\n",
    "        max_idxs = x_num.idxmax()\n",
    "        # return np.where(max_idxs == x.index.values, f\"background-color:red ;\", None)\n",
    "        return np.where(max_idxs == x.index.values, f\"bfseries: ;\", None)\n",
    "\n",
    "    nu = float(\"Inf\")\n",
    "    try:\n",
    "        nu = float(x.name[0].split('=')[1].split('$')[0])\n",
    "    except ValueError:\n",
    "        pass\n",
    "    if x.name[1] == '$\\\\hat{\\\\alpha}$':\n",
    "        # return np.where(abs(x_num - nu).idxmin() == x.index.values, f\"background-color:red ;\", None)\n",
    "        return np.where(abs(x_num - nu).idxmin() == x.index.values, f\"bfseries: ;\", None)\n",
    "    elif x.name[1] == '$\\\\hat{k}$':\n",
    "        khat = 1 / (nu - 1)\n",
    "        # return np.where(x_num < 0.2, f\"background-color:red ;\", None)\n",
    "        return np.where(x_num < 0.2, f\"bfseries: ;\", None)\n",
    "    else:\n",
    "        print(x.name[1])\n",
    "        raise Exception\n",
    "\n",
    "    return np.where(abs(x_num - nu).idxmin() == x.index.values, f\"bfseries: ;\", None)\n",
    "\n",
    "new_df.loc[new_df.index.get_level_values(1).map(lambda x: x in ['-H(p,q)', '$\\\\hat{\\\\alpha}$'])].style.apply(highlight_max, axis=1).to_latex(\n",
    "    buf='de_table.tex',\n",
    "    hrules=True,\n",
    "    # environment='longtable',\n",
    ")\n",
    "new_df.loc[new_df.index.get_level_values(1).map(lambda x: x not in ['-H(p,q)', '$\\\\hat{\\\\alpha}$'])].style.apply(highlight_max, axis=1).to_latex(\n",
    "    buf='vi_table.tex',\n",
    "    hrules=True,\n",
    "    # environment='longtable',\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 324,
   "id": "f61293da-f239-4677-b5e2-6e766c14de18",
   "metadata": {},
   "outputs": [],
   "source": [
    "# fig, ax = plt.subplots(1, 3, figsize=(12, 6))\n",
    "# palette = sns.color_palette()\n",
    "# legends = []\n",
    "#     sns.lineplot(ax=ax[0], data=losses, color=palette[i]).set(\n",
    "#         yscale=\"symlog\"\n",
    "#     )\n",
    "#     sns.histplot(\n",
    "#         ax=ax[1],\n",
    "#         data=flow.sample((N,)).detach().numpy().squeeze(),\n",
    "#         binwidth=0.5,\n",
    "#         stat=\"probability\",\n",
    "#         color=palette[i],\n",
    "#     ).set(xlim=[-20, 20])\n",
    "#     sns.kdeplot(\n",
    "#         ax=ax[2],\n",
    "#         data=flow.sample((N,)).detach().numpy().squeeze(),\n",
    "#         color=palette[i],\n",
    "#     ).set(xscale=\"symlog\", yscale=\"log\", xlim=[-1e2, 1e2], ylim=[1e-5, 1e0])\n",
    "#     legends.append(f\"{method_name}, $\\\\hat{{ \\\\alpha }}$={tail.alpha:.2f}\")\n",
    "# ax[2].legend(legends)"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "83b589253c6ae2165fd99d3b5e434b8a0ff74c98e791d87ced25152a201010fd"
  },
  "kernelspec": {
   "display_name": "Python 3.10.4 ('gga')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
