{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "695715d0",
   "metadata": {},
   "source": [
    "# Two Gaussians Test\n",
    "\n",
    "Test of Hamiltonian Nested Sampling on a mixture of two Gaussians"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "480d0b81",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/pablo/miniconda3/envs/gns/lib/python3.10/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "# Imports\n",
    "import numpy as np\n",
    "import torch\n",
    "from gradNS import Param, NestedSampler, EllipsoidalNS, SliceNS, DynamicNestedSampler, HamiltonianNS\n",
    "from getdist import plots, MCSamples\n",
    "\n",
    "# --- plotting --- \n",
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "#mpl.rcParams['text.usetex'] = True\n",
    "mpl.rcParams['font.family'] = 'serif'\n",
    "mpl.rcParams['font.size'] = 10\n",
    "mpl.rcParams['axes.linewidth'] = 1.5\n",
    "#mpl.rcParams['axes.xmargin'] = 1\n",
    "mpl.rcParams['xtick.labelsize'] = 'x-large'\n",
    "mpl.rcParams['xtick.major.size'] = 5\n",
    "mpl.rcParams['xtick.major.width'] = 1.5\n",
    "mpl.rcParams['ytick.labelsize'] = 'x-large'\n",
    "mpl.rcParams['ytick.major.size'] = 5\n",
    "mpl.rcParams['ytick.major.width'] = 1.5\n",
    "mpl.rcParams['legend.frameon'] = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "b6734adc",
   "metadata": {},
   "outputs": [],
   "source": [
    "def sample_gaussian(ndims = 2, sampler = 'nested'):\n",
    "    #assert sampler in ['nested', 'multinest', 'polychord', 'gannest']\n",
    "    mvn = torch.distributions.MultivariateNormal(loc=torch.zeros(ndims),\n",
    "                                             scale_tril=torch.diag(\n",
    "                                                 torch.ones(ndims)))\n",
    "\n",
    "    \n",
    "    params = []\n",
    "    for i in range(ndims):\n",
    "        p = Param(\n",
    "                name=f'p{i}',\n",
    "                prior_type='Uniform',\n",
    "                prior=(-5, 5),\n",
    "                label=f'p_{i}')\n",
    "        params.append(p)\n",
    "        \n",
    "    if sampler == 'base':\n",
    "        ns = NestedSampler(\n",
    "            nlive=25 * ndims,\n",
    "            loglike=mvn.log_prob,\n",
    "            params=params,\n",
    "            clustering=False,\n",
    "            verbose=False)\n",
    "    elif sampler == 'ellipsoidal':\n",
    "        ns = EllipsoidalNS(\n",
    "            nlive=25 * ndims,\n",
    "            loglike=mvn.log_prob,\n",
    "            params=params,\n",
    "            eff=0.1,\n",
    "            clustering=False,\n",
    "            verbose=False)\n",
    "    elif sampler == 'slice':\n",
    "        ns = SliceNS(\n",
    "            nlive=25 * ndims,\n",
    "            loglike=mvn.log_prob,\n",
    "            params=params,\n",
    "            clustering=False,\n",
    "            verbose=False)\n",
    "    elif sampler == 'dynamic':\n",
    "        ns = DynamicNestedSampler(\n",
    "                    nlive=25 * ndims,\n",
    "                    loglike=mvn.log_prob,\n",
    "                    params=params,\n",
    "                    clustering=True,\n",
    "                    verbose=False)\n",
    "    elif sampler == 'hamiltonian':\n",
    "        ns = HamiltonianNS(\n",
    "            nlive=25 * ndims,\n",
    "            loglike=mvn.log_prob,\n",
    "            params=params,\n",
    "            clustering=False,\n",
    "            verbose=False)\n",
    "        \n",
    "    ns.run()\n",
    "    return ns.get_mean_logZ(), ns.get_var_logZ()**0.5, ns.get_like_evals()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f6722a2a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2 -4.6051702\n",
      "slice -4.518281 0.18777265 6979.0\n",
      "ellipsoidal -4.174598 0.17821755 1252.0\n",
      "hamiltonian -4.897506 0.22904915 197150.0\n",
      "4 -9.2103405\n",
      "slice -9.278921 0.19322771 44524.0\n",
      "ellipsoidal -9.130712 0.1896862 18849.0\n",
      "hamiltonian -8.799771 0.22000718 616700.0\n",
      "8 -18.420681\n",
      "slice -18.980032 0.19513647 305475.0\n",
      "ellipsoidal -18.135433 0.18963271 1169792.0\n",
      "hamiltonian -17.441557 0.22068945 1859706.0\n",
      "16 -36.841362\n",
      "slice -37.908226 0.1837465 2022049.0\n"
     ]
    }
   ],
   "source": [
    "dims = [2, 4, 8, 16]\n",
    "logZ_true = np.zeros_like(dims, dtype=np.float32)\n",
    "\n",
    "logZ_ellipsoidal = np.zeros_like(dims, dtype=np.float32)\n",
    "logZ_slice = np.zeros_like(dims, dtype=np.float32)\n",
    "logZ_hamiltonian = np.zeros_like(dims, dtype=np.float32)\n",
    "err_logZ_ellipsoidal = np.zeros_like(dims, dtype=np.float32)\n",
    "err_logZ_slice = np.zeros_like(dims, dtype=np.float32)\n",
    "err_logZ_hamiltonian = np.zeros_like(dims, dtype=np.float32)\n",
    "like_evals_ellipsoidal = np.zeros_like(dims, dtype=np.float32)\n",
    "like_evals_slice = np.zeros_like(dims, dtype=np.float32)\n",
    "like_evals_hamiltonian = np.zeros_like(dims, dtype=np.float32)\n",
    "\n",
    "for i, d in enumerate(dims): \n",
    "    logZ_true[i] = np.log(1 / 10.**d)\n",
    "    print(d, logZ_true[i])\n",
    "    logZ_slice[i], err_logZ_slice[i], like_evals_slice[i] = sample_gaussian(ndims = d, sampler = 'slice')\n",
    "    print(\"slice\", logZ_slice[i], err_logZ_slice[i], like_evals_slice[i])\n",
    "    logZ_ellipsoidal[i], err_logZ_ellipsoidal[i], like_evals_ellipsoidal[i] = sample_gaussian(ndims = d, sampler = 'ellipsoidal')\n",
    "    print(\"ellipsoidal\", logZ_ellipsoidal[i], err_logZ_ellipsoidal[i], like_evals_ellipsoidal[i])\n",
    "    logZ_hamiltonian[i], err_logZ_hamiltonian[i], like_evals_hamiltonian[i] = sample_gaussian(ndims = d, sampler = 'hamiltonian')\n",
    "    print(\"hamiltonian\", logZ_hamiltonian[i], err_logZ_hamiltonian[i], like_evals_hamiltonian[i])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71fc2d93",
   "metadata": {},
   "outputs": [],
   "source": [
    "x_size=12\n",
    "y_size=6\n",
    "fontsize = 8\n",
    "fig, ax = plt.subplots(figsize=(x_size/2.54,y_size/2.54))\n",
    "ax.plot(dims, like_evals_hamiltonian, '-x', label='Hamiltonian')\n",
    "ax.plot(dims, like_evals_ellipsoidal, '-o', label='Ellipsoidal')\n",
    "ax.plot(dims, like_evals_slice, '-*', label='Slice')\n",
    "ax.set_yscale('log')\n",
    "ax.set_xscale('log')\n",
    "ax.set_xlabel(\"# of Dimensions\", fontsize=fontsize)\n",
    "ax.set_ylabel(r\"# of Likelihood Evaluations\", fontsize=fontsize)\n",
    "ax.set_xticks(dims)\n",
    "ax.set_xticklabels(np.array(dims, dtype='str'), fontsize=fontsize)\n",
    "#ax.set_yticks([1e4, 1e5, 1e6, 1e7, 1e8])\n",
    "#ax.set_yticklabels(['$10^4$', '$10^5$', '$10^6$', '$10^7$', '$10^8$'], fontsize=fontsize)\n",
    "plt.legend(fontsize=fontsize)\n",
    "#plt.savefig(\"/Users/pablo/Desktop/ns_comparison\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd5ce9e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.array(dims, dtype='str')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "345ee0db",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:gns]",
   "language": "python",
   "name": "conda-env-gns-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
