{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "695715d0",
   "metadata": {},
   "source": [
    "# Two Gaussians Test\n",
    "\n",
    "Test of Hamiltonian Nested Sampling on a mixture of two Gaussians"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "480d0b81",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/pablo/miniconda3/envs/gns/lib/python3.10/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "# Imports\n",
    "import numpy as np\n",
    "import torch\n",
    "from gradNS import Param, HamiltonianNS\n",
    "from getdist import plots, MCSamples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "b6734adc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Number of dimensions of the problem\n",
    "ndims = 5\n",
    "\n",
    "# Define the two components\n",
    "mvn1 = torch.distributions.MultivariateNormal(loc=-1*torch.ones(ndims),\n",
    "                                             covariance_matrix=torch.diag(\n",
    "                                                 0.2*torch.ones(ndims)))\n",
    "\n",
    "mvn2 = torch.distributions.MultivariateNormal(loc=2*torch.ones(ndims),\n",
    "                                             covariance_matrix=torch.diag(\n",
    "                                                 0.2*torch.ones(ndims)))\n",
    "\n",
    "# Define the likelihood\n",
    "def get_loglike(theta):\n",
    "    return torch.logsumexp(torch.stack([mvn1.log_prob(theta), mvn2.log_prob(theta)]), dim=0, keepdim=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "f6722a2a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define the parameters for GradNS\n",
    "params = []\n",
    "\n",
    "for i in range(ndims):\n",
    "    params.append(\n",
    "        Param(\n",
    "            name=f'p{i}',\n",
    "            prior_type='Uniform',\n",
    "            prior=(-5, 5),\n",
    "            label=f'p_{i}')\n",
    "    )\n",
    "\n",
    "# Create a Nested Sampling object\n",
    "ns = HamiltonianNS(\n",
    "    nlive=25*len(params),\n",
    "    loglike=get_loglike,\n",
    "    params=params,\n",
    "    verbose=True,\n",
    "    clustering=True,\n",
    "    tol=1e-3\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "32e41957",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Active: 2 / 62\n",
      "Active: 1 / 62\n",
      "---------------------------------------------\n",
      "logZ = -48.0647, eps = 5.3434e+06, 0.0000e+00\n",
      "---------------------------------------------\n",
      "Cluster 0 has volume fraction 1.0000 and logZp = -47.9950\n",
      "---------------------------------------------\n",
      "logZ = -30.4804, eps = 8.2980e+03, 0.0000e+00\n",
      "---------------------------------------------\n",
      "Cluster 0 has volume fraction 0.6000 and logZp = -30.9533\n",
      "Cluster 1 has volume fraction 0.4000 and logZp = -31.3588\n",
      "---------------------------------------------\n",
      "logZ = -21.8932, eps = 4.6384e+02, 0.0000e+00\n",
      "---------------------------------------------\n",
      "Cluster 0 has volume fraction 0.5723 and logZp = -22.2840\n",
      "Cluster 1 has volume fraction 0.4277 and logZp = -22.9199\n",
      "---------------------------------------------\n",
      "logZ = -17.3029, eps = 4.7707e+01, 0.0000e+00\n",
      "---------------------------------------------\n",
      "Cluster 0 has volume fraction 0.6257 and logZp = -17.6875\n",
      "Cluster 1 has volume fraction 0.3743 and logZp = -18.3369\n",
      "---------------------------------------------\n",
      "logZ = -15.0196, eps = 1.9152e+01, 0.0000e+00\n",
      "---------------------------------------------\n",
      "Cluster 0 has volume fraction 0.7130 and logZp = -15.4483\n",
      "Cluster 1 has volume fraction 0.2870 and logZp = -15.9776\n",
      "---------------------------------------------\n",
      "logZ = -13.3266, eps = 8.5644e+00, 0.0000e+00\n",
      "---------------------------------------------\n",
      "Cluster 0 has volume fraction 0.7190 and logZp = -13.6731\n",
      "Cluster 1 has volume fraction 0.2810 and logZp = -14.4319\n",
      "---------------------------------------------\n",
      "logZ = -11.9146, eps = 2.3875e+00, 0.0000e+00\n",
      "---------------------------------------------\n",
      "Cluster 0 has volume fraction 0.6862 and logZp = -12.1920\n",
      "Cluster 1 has volume fraction 0.3138 and logZp = -13.1670\n",
      "Decreasing dt to  0.09000000000000001\n",
      "Decreasing dt to  0.08100000000000002\n",
      "---------------------------------------------\n",
      "logZ = -11.4436, eps = 8.9610e-01, -4.0866e-01\n",
      "---------------------------------------------\n",
      "Cluster 0 has volume fraction 0.5000 and logZp = -11.7123\n",
      "Cluster 1 has volume fraction 0.5000 and logZp = -12.7109\n",
      "Decreasing dt to  0.07290000000000002\n",
      "Decreasing dt to  0.06561000000000002\n",
      "Decreasing dt to  0.05904900000000002\n",
      "---------------------------------------------\n",
      "logZ = -11.2483, eps = 2.8980e-01, -1.0456e+00\n",
      "---------------------------------------------\n",
      "Cluster 0 has volume fraction 0.5237 and logZp = -11.5763\n",
      "Cluster 1 has volume fraction 0.4763 and logZp = -12.3678\n",
      "Decreasing dt to  0.05314410000000002\n",
      "Decreasing dt to  0.04782969000000002\n",
      "Decreasing dt to  0.043046721000000024\n",
      "Decreasing dt to  0.03874204890000002\n",
      "---------------------------------------------\n",
      "logZ = -11.1658, eps = 8.6709e-02, -2.0625e+00\n",
      "---------------------------------------------\n",
      "Cluster 0 has volume fraction 0.4939 and logZp = -11.5111\n",
      "Cluster 1 has volume fraction 0.5061 and logZp = -12.2463\n",
      "Decreasing dt to  0.03486784401000002\n",
      "Decreasing dt to  0.03138105960900001\n",
      "---------------------------------------------\n",
      "logZ = -11.1403, eps = 2.6610e-02, -3.2067e+00\n",
      "---------------------------------------------\n",
      "Cluster 0 has volume fraction 0.4743 and logZp = -11.4921\n",
      "Cluster 1 has volume fraction 0.5257 and logZp = -12.2069\n",
      "Decreasing dt to  0.028242953648100012\n",
      "Decreasing dt to  0.025418658283290013\n",
      "Decreasing dt to  0.022876792454961013\n",
      "---------------------------------------------\n",
      "logZ = -11.1325, eps = 8.3258e-03, -4.3482e+00\n",
      "---------------------------------------------\n",
      "Cluster 0 has volume fraction 0.4430 and logZp = -11.4863\n",
      "Cluster 1 has volume fraction 0.5570 and logZp = -12.1946\n",
      "Decreasing dt to  0.020589113209464913\n",
      "Decreasing dt to  0.01853020188851842\n",
      "Decreasing dt to  0.01667718169966658\n",
      "---------------------------------------------\n",
      "logZ = -11.1301, eps = 2.8568e-03, -5.4656e+00\n",
      "---------------------------------------------\n",
      "Cluster 0 has volume fraction 0.3457 and logZp = -11.4846\n",
      "Cluster 1 has volume fraction 0.6543 and logZp = -12.1909\n",
      "Decreasing dt to  0.015009463529699923\n",
      "Decreasing dt to  0.013508517176729932\n",
      "Decreasing dt to  0.01215766545905694\n",
      "---------------------------------------------\n",
      "logZ = -11.1292, eps = 4.0321e-04, -7.2601e+00\n",
      "---------------------------------------------\n",
      "Cluster 0 has volume fraction 0.5760 and logZp = -11.4841\n",
      "Cluster 1 has volume fraction 0.4240 and logZp = -12.1888\n",
      "---------------------------------------------\n",
      "Nested Sampling completed\n",
      "Run time = 11.705910921096802 seconds\n",
      "Acceptance rate = 0.007731298345638283\n",
      "Number of likelihood evaluations = 249764\n",
      "logZ = -11.1290 +/- 0.3055\n",
      "---------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "# Run the sampler\n",
    "ns.run()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "108b24fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# The true logZ is the inverse of the prior volume\n",
    "print('True logZ = ', torch.log(2.) + np.log(1 / 10**len(params)))\n",
    "print('Number of evaluations', ns.get_like_evals())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "40427320",
   "metadata": {},
   "source": [
    "### Plot\n",
    "\n",
    "Generate a plot of the true and sampled posteriors, using GetDist"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4dede2d6",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:root:outlier fraction 0.09321120689655173 \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Removed no burn in\n",
      "Removed no burn in\n"
     ]
    }
   ],
   "source": [
    "samples = ns.convert_to_getdist()\n",
    "true_samples = torch.cat([mvn1.sample((5000,)), mvn2.sample((5000,))], dim=0)\n",
    "true_samples = MCSamples(samples=true_samples.numpy(), names=[f'p{i}' for i in range(ndims)])\n",
    "g = plots.get_subplot_plotter()\n",
    "g.triangle_plot([true_samples, samples], [f'p{i}' for i in range(5)], filled=True, legend_labels=['True', 'GDNest'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ed8c276a",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:gns]",
   "language": "python",
   "name": "conda-env-gns-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
