{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "695715d0",
   "metadata": {},
   "source": [
    "# Test of the algorithm on different dimensions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "480d0b81",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Imports\n",
    "import numpy as np\n",
    "import torch\n",
    "from gradNS import Param, NestedSampler, EllipsoidalNS, SliceNS, DynamicNestedSampler, HamiltonianNS\n",
    "from getdist import plots, MCSamples\n",
    "import time\n",
    "\n",
    "# --- plotting --- \n",
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "#mpl.rcParams['text.usetex'] = True\n",
    "mpl.rcParams['font.family'] = 'serif'\n",
    "mpl.rcParams['font.size'] = 10\n",
    "mpl.rcParams['axes.linewidth'] = 1.5\n",
    "#mpl.rcParams['axes.xmargin'] = 1\n",
    "mpl.rcParams['xtick.labelsize'] = 'x-large'\n",
    "mpl.rcParams['xtick.major.size'] = 5\n",
    "mpl.rcParams['xtick.major.width'] = 1.5\n",
    "mpl.rcParams['ytick.labelsize'] = 'x-large'\n",
    "mpl.rcParams['ytick.major.size'] = 5\n",
    "mpl.rcParams['ytick.major.width'] = 1.5\n",
    "mpl.rcParams['legend.frameon'] = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b6734adc",
   "metadata": {},
   "outputs": [],
   "source": [
    "def sample_gaussian(ndims = 2, sampler = 'nested', sigma=0.):\n",
    "    #assert sampler in ['nested', 'multinest', 'polychord', 'gannest']\n",
    "    mvn = torch.distributions.MultivariateNormal(loc=torch.zeros(ndims),\n",
    "                                             scale_tril=torch.diag(\n",
    "                                                 torch.ones(ndims)))\n",
    "\n",
    "    \n",
    "    params = []\n",
    "    for i in range(ndims):\n",
    "        p = Param(\n",
    "                name=f'p{i}',\n",
    "                prior_type='Uniform',\n",
    "                prior=(-5, 5),\n",
    "                label=f'p_{i}')\n",
    "        params.append(p)\n",
    "        \n",
    "    if sampler == 'base':\n",
    "        ns = NestedSampler(\n",
    "            nlive=25 * ndims,\n",
    "            loglike=mvn.log_prob,\n",
    "            params=params,\n",
    "            clustering=False,\n",
    "            verbose=False)\n",
    "    elif sampler == 'ellipsoidal':\n",
    "        ns = EllipsoidalNS(\n",
    "            nlive=25 * ndims,\n",
    "            loglike=mvn.log_prob,\n",
    "            params=params,\n",
    "            eff=0.1,\n",
    "            clustering=False,\n",
    "            verbose=False)\n",
    "    elif sampler == 'slice':\n",
    "        ns = SliceNS(\n",
    "            nlive=25 * ndims,\n",
    "            loglike=mvn.log_prob,\n",
    "            params=params,\n",
    "            clustering=False,\n",
    "            verbose=False)\n",
    "    elif sampler == 'dynamic':\n",
    "        ns = DynamicNestedSampler(\n",
    "                    nlive=25 * ndims,\n",
    "                    loglike=mvn.log_prob,\n",
    "                    params=params,\n",
    "                    clustering=True,\n",
    "                    verbose=False)\n",
    "    elif sampler == 'hamiltonian':\n",
    "        ns = HamiltonianNS(\n",
    "            nlive=200, \n",
    "            loglike=mvn.log_prob,\n",
    "            params=params,\n",
    "            tol=1e-2,\n",
    "            clustering=False,\n",
    "            verbose=False)\n",
    "        \n",
    "    ns.run()\n",
    "    an_samples = ns.convert_to_anesthetic()\n",
    "    lZ = an_samples.logZ(nsamples=100)\n",
    "    return ns.get_mean_logZ(), ns.get_var_logZ()**0.5, ns.get_like_evals()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f6722a2a",
   "metadata": {},
   "outputs": [],
   "source": [
    "dims = [4, 8, 16, 32, 64, 128]\n",
    "logZ_true = np.zeros_like(dims, dtype=np.float32)\n",
    "\n",
    "logZ = np.zeros_like(dims, dtype=np.float32)\n",
    "err_logZ = np.zeros_like(dims, dtype=np.float32)\n",
    "like_evals = np.zeros_like(dims, dtype=np.int32)\n",
    "\n",
    "times = np.zeros_like(dims, dtype=np.float32)\n",
    "\n",
    "num_repeats = 10\n",
    "\n",
    "for i, d in enumerate(dims): \n",
    "    logZ_true[i] = np.log(1 / 10.**d)\n",
    "    print(f\"D = {d}, logZ = {logZ_true[i]:.4f}\")\n",
    "    \n",
    "    logZ_local = np.zeros(num_repeats, dtype=np.float32)\n",
    "    err_logZ_local = np.zeros(num_repeats, dtype=np.float32)\n",
    "    like_evals_local = np.zeros(num_repeats, dtype=np.int32)\n",
    "    times_local = np.zeros(num_repeats, dtype=np.float32)\n",
    "    for j in range(num_repeats):\n",
    "        start_time = time.time()\n",
    "        logZ_local[j], err_logZ_local[j], like_evals_local[j] = sample_gaussian(ndims = d, sampler = 'hamiltonian', sigma=0.05)\n",
    "        times_local[j] = time.time() - start_time\n",
    "        \n",
    "    logZ[i] = np.mean(logZ_local)\n",
    "    err_propagated = np.sqrt(np.sum(err_logZ_local**2))/len(err_logZ_local)\n",
    "    err_logZ[i] = np.sqrt(err_propagated**2 + np.std(logZ_local)**2)\n",
    "    like_evals[i] = np.mean(like_evals_local)\n",
    "    times[i] = np.mean(times_local)\n",
    "    \n",
    "    print(f\"logZ = {logZ[i]:.4f} +/- {err_logZ[i]:.4f}, like_evals = {like_evals[i]}, time = {times[i]:.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd5ce9e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "x_size=12\n",
    "y_size=6\n",
    "fontsize = 14\n",
    "fig, ax = plt.subplots(figsize=(x_size/2.54,y_size/2.54))\n",
    "ax.errorbar(dims, logZ - logZ_true, yerr=err_logZ, fmt='x', label='$\\sigma = 0$')\n",
    "ax.axhline(0, ls='--', color='k')\n",
    "ax.set_xscale('log')\n",
    "ax.set_xlabel(\"# of Dimensions\", fontsize=fontsize)\n",
    "ax.set_ylabel(r\"$ \\log Z$\", fontsize=fontsize)\n",
    "ax.set_xticks(dims)\n",
    "ax.set_xticklabels(np.array(dims, dtype='str'), fontsize=fontsize)\n",
    "plt.legend(fontsize=fontsize)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7f9adc2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# lz = np.stack([dims, logZ, err_logZ, like_evals, times], axis=1)\n",
    "# np.save('./data/hamiltonian_logZ_v4', lz)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:gns]",
   "language": "python",
   "name": "conda-env-gns-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
