{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "695715d0",
   "metadata": {},
   "source": [
    "# Two Gaussians Test\n",
    "\n",
    "Test of Hamiltonian Nested Sampling on a mixture of two Gaussians"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "480d0b81",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/pablo/miniconda3/envs/gns/lib/python3.10/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "# Imports\n",
    "import numpy as np\n",
    "import torch\n",
    "from gradNS import Param, NestedSampler, EllipsoidalNS, SliceNS, DynamicNestedSampler, HamiltonianNS\n",
    "from getdist import plots, MCSamples\n",
    "\n",
    "# --- plotting --- \n",
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "#mpl.rcParams['text.usetex'] = True\n",
    "mpl.rcParams['font.family'] = 'serif'\n",
    "mpl.rcParams['font.size'] = 10\n",
    "mpl.rcParams['axes.linewidth'] = 1.5\n",
    "#mpl.rcParams['axes.xmargin'] = 1\n",
    "mpl.rcParams['xtick.labelsize'] = 'x-large'\n",
    "mpl.rcParams['xtick.major.size'] = 5\n",
    "mpl.rcParams['xtick.major.width'] = 1.5\n",
    "mpl.rcParams['ytick.labelsize'] = 'x-large'\n",
    "mpl.rcParams['ytick.major.size'] = 5\n",
    "mpl.rcParams['ytick.major.width'] = 1.5\n",
    "mpl.rcParams['legend.frameon'] = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "b6734adc",
   "metadata": {},
   "outputs": [],
   "source": [
    "def sample_gaussian(ndims = 2, sampler = 'nested'):\n",
    "    #assert sampler in ['nested', 'multinest', 'polychord', 'gannest']\n",
    "    mvn = torch.distributions.MultivariateNormal(loc=torch.zeros(ndims),\n",
    "                                             scale_tril=torch.diag(\n",
    "                                                 torch.ones(ndims)))\n",
    "\n",
    "    \n",
    "    params = []\n",
    "    for i in range(ndims):\n",
    "        p = Param(\n",
    "                name=f'p{i}',\n",
    "                prior_type='Uniform',\n",
    "                prior=(-5, 5),\n",
    "                label=f'p_{i}')\n",
    "        params.append(p)\n",
    "        \n",
    "    if sampler == 'base':\n",
    "        ns = NestedSampler(\n",
    "            nlive=25 * ndims,\n",
    "            loglike=mvn.log_prob,\n",
    "            params=params,\n",
    "            clustering=False,\n",
    "            verbose=False)\n",
    "    elif sampler == 'ellipsoidal':\n",
    "        ns = EllipsoidalNS(\n",
    "            nlive=25 * ndims,\n",
    "            loglike=mvn.log_prob,\n",
    "            params=params,\n",
    "            eff=0.1,\n",
    "            clustering=False,\n",
    "            verbose=False)\n",
    "    elif sampler == 'slice':\n",
    "        ns = SliceNS(\n",
    "            nlive=25 * ndims,\n",
    "            loglike=mvn.log_prob,\n",
    "            params=params,\n",
    "            clustering=False,\n",
    "            verbose=False)\n",
    "    elif sampler == 'dynamic':\n",
    "        ns = DynamicNestedSampler(\n",
    "                    nlive=25 * ndims,\n",
    "                    loglike=mvn.log_prob,\n",
    "                    params=params,\n",
    "                    clustering=True,\n",
    "                    verbose=False)\n",
    "    elif sampler == 'hamiltonian':\n",
    "        ns = HamiltonianNS(\n",
    "            nlive=25 * ndims,\n",
    "            loglike=mvn.log_prob,\n",
    "            params=params,\n",
    "            clustering=False,\n",
    "            verbose=False)\n",
    "        \n",
    "    ns.run()\n",
    "    return ns.get_mean_logZ(), ns.get_var_logZ()**0.5, ns.get_like_evals()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "f6722a2a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2 -4.6051702\n",
      "slice -4.518281 0.18777265 6979.0\n",
      "ellipsoidal -4.174598 0.17821755 1252.0\n",
      "hamiltonian -4.897506 0.22904915 197150.0\n",
      "4 -9.2103405\n",
      "slice -9.278921 0.19322771 44524.0\n",
      "ellipsoidal -9.130712 0.1896862 18849.0\n",
      "hamiltonian -8.799771 0.22000718 616700.0\n",
      "8 -18.420681\n",
      "slice -18.980032 0.19513647 305475.0\n",
      "ellipsoidal -18.135433 0.18963271 1169792.0\n",
      "hamiltonian -17.441557 0.22068945 1859706.0\n",
      "16 -36.841362\n",
      "slice -37.908226 0.1837465 2022049.0\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Input \u001b[0;32mIn [3]\u001b[0m, in \u001b[0;36m<cell line: 14>\u001b[0;34m()\u001b[0m\n\u001b[1;32m     17\u001b[0m logZ_slice[i], err_logZ_slice[i], like_evals_slice[i] \u001b[38;5;241m=\u001b[39m sample_gaussian(ndims \u001b[38;5;241m=\u001b[39m d, sampler \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mslice\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m     18\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mslice\u001b[39m\u001b[38;5;124m\"\u001b[39m, logZ_slice[i], err_logZ_slice[i], like_evals_slice[i])\n\u001b[0;32m---> 19\u001b[0m logZ_ellipsoidal[i], err_logZ_ellipsoidal[i], like_evals_ellipsoidal[i] \u001b[38;5;241m=\u001b[39m \u001b[43msample_gaussian\u001b[49m\u001b[43m(\u001b[49m\u001b[43mndims\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[43md\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msampler\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mellipsoidal\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m     20\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mellipsoidal\u001b[39m\u001b[38;5;124m\"\u001b[39m, logZ_ellipsoidal[i], err_logZ_ellipsoidal[i], like_evals_ellipsoidal[i])\n\u001b[1;32m     21\u001b[0m logZ_hamiltonian[i], err_logZ_hamiltonian[i], like_evals_hamiltonian[i] \u001b[38;5;241m=\u001b[39m sample_gaussian(ndims \u001b[38;5;241m=\u001b[39m d, sampler \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mhamiltonian\u001b[39m\u001b[38;5;124m'\u001b[39m)\n",
      "Input \u001b[0;32mIn [2]\u001b[0m, in \u001b[0;36msample_gaussian\u001b[0;34m(ndims, sampler)\u001b[0m\n\u001b[1;32m     46\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m sampler \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mhamiltonian\u001b[39m\u001b[38;5;124m'\u001b[39m:\n\u001b[1;32m     47\u001b[0m     ns \u001b[38;5;241m=\u001b[39m HamiltonianNS(\n\u001b[1;32m     48\u001b[0m         nlive\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m25\u001b[39m \u001b[38;5;241m*\u001b[39m ndims,\n\u001b[1;32m     49\u001b[0m         loglike\u001b[38;5;241m=\u001b[39mmvn\u001b[38;5;241m.\u001b[39mlog_prob,\n\u001b[1;32m     50\u001b[0m         params\u001b[38;5;241m=\u001b[39mparams,\n\u001b[1;32m     51\u001b[0m         clustering\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m,\n\u001b[1;32m     52\u001b[0m         verbose\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[0;32m---> 54\u001b[0m \u001b[43mns\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     55\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m ns\u001b[38;5;241m.\u001b[39mget_mean_logZ(), ns\u001b[38;5;241m.\u001b[39mget_var_logZ()\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m0.5\u001b[39m, ns\u001b[38;5;241m.\u001b[39mget_like_evals()\n",
      "File \u001b[0;32m~/Code/GradNS/gradNS/nested_sampling.py:424\u001b[0m, in \u001b[0;36mNestedSampler.run\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    422\u001b[0m nsteps \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[1;32m    423\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mn_clusters \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m max_epsilon \u001b[38;5;241m>\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtol):\n\u001b[0;32m--> 424\u001b[0m     \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmove_one_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    425\u001b[0m     epsilon \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_get_epsilon()\n\u001b[1;32m    426\u001b[0m     max_epsilon \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39msum(epsilon) \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mclustering \u001b[38;5;28;01melse\u001b[39;00m epsilon\n",
      "File \u001b[0;32m~/Code/GradNS/gradNS/nested_sampling.py:357\u001b[0m, in \u001b[0;36mNestedSampler.move_one_step\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    355\u001b[0m \u001b[38;5;124;03m\"\"\" Find highest log like, get rid of that point, and sample a new one \"\"\"\u001b[39;00m\n\u001b[1;32m    356\u001b[0m sample \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mkill_point()\n\u001b[0;32m--> 357\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madd_point\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmin_logL\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msample\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_logL\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/Code/GradNS/gradNS/nested_sampling.py:348\u001b[0m, in \u001b[0;36mNestedSampler.add_point\u001b[0;34m(self, min_logL)\u001b[0m\n\u001b[1;32m    336\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m    337\u001b[0m \u001b[38;5;124;03mAdd a new point to the live points, sampling from the prior until finding a point with higher likelihood\u001b[39;00m\n\u001b[1;32m    338\u001b[0m \u001b[38;5;124;03mthan a given value\u001b[39;00m\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    345\u001b[0m \u001b[38;5;124;03m-------\u001b[39;00m\n\u001b[1;32m    346\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m    347\u001b[0m \u001b[38;5;66;03m# Add a new sample\u001b[39;00m\n\u001b[0;32m--> 348\u001b[0m newsample \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfind_new_sample\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmin_logL\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    349\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m newsample\u001b[38;5;241m.\u001b[39mget_logL() \u001b[38;5;241m>\u001b[39m min_logL, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mNew sample has lower likelihood than old one\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    351\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlive_points\u001b[38;5;241m.\u001b[39madd_nspoint(newsample)\n",
      "File \u001b[0;32m~/Code/GradNS/gradNS/ellipsoidal.py:71\u001b[0m, in \u001b[0;36mEllipsoidalNS.find_new_sample\u001b[0;34m(self, min_like)\u001b[0m\n\u001b[1;32m     69\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m newlike \u001b[38;5;241m<\u001b[39m min_like:\n\u001b[1;32m     70\u001b[0m     values \u001b[38;5;241m=\u001b[39m mvn\u001b[38;5;241m.\u001b[39msample()\n\u001b[0;32m---> 71\u001b[0m     newlike \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mloglike\u001b[49m\u001b[43m(\u001b[49m\u001b[43mvalues\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     72\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlike_evals \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m     74\u001b[0m sample \u001b[38;5;241m=\u001b[39m NSPoints(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnparams)\n",
      "File \u001b[0;32m~/miniconda3/envs/gns/lib/python3.10/site-packages/torch/distributions/multivariate_normal.py:216\u001b[0m, in \u001b[0;36mMultivariateNormal.log_prob\u001b[0;34m(self, value)\u001b[0m\n\u001b[1;32m    214\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_validate_sample(value)\n\u001b[1;32m    215\u001b[0m diff \u001b[38;5;241m=\u001b[39m value \u001b[38;5;241m-\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mloc\n\u001b[0;32m--> 216\u001b[0m M \u001b[38;5;241m=\u001b[39m \u001b[43m_batch_mahalanobis\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_unbroadcasted_scale_tril\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdiff\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    217\u001b[0m half_log_det \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_unbroadcasted_scale_tril\u001b[38;5;241m.\u001b[39mdiagonal(dim1\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m2\u001b[39m, dim2\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\u001b[38;5;241m.\u001b[39mlog()\u001b[38;5;241m.\u001b[39msum(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m    218\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m0.5\u001b[39m \u001b[38;5;241m*\u001b[39m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_event_shape[\u001b[38;5;241m0\u001b[39m] \u001b[38;5;241m*\u001b[39m math\u001b[38;5;241m.\u001b[39mlog(\u001b[38;5;241m2\u001b[39m \u001b[38;5;241m*\u001b[39m math\u001b[38;5;241m.\u001b[39mpi) \u001b[38;5;241m+\u001b[39m M) \u001b[38;5;241m-\u001b[39m half_log_det\n",
      "File \u001b[0;32m~/miniconda3/envs/gns/lib/python3.10/site-packages/torch/distributions/multivariate_normal.py:54\u001b[0m, in \u001b[0;36m_batch_mahalanobis\u001b[0;34m(bL, bx)\u001b[0m\n\u001b[1;32m     49\u001b[0m \u001b[38;5;66;03m# Permute bx to make it have shape (..., 1, j, i, 1, n)\u001b[39;00m\n\u001b[1;32m     50\u001b[0m permute_dims \u001b[38;5;241m=\u001b[39m (\u001b[38;5;28mlist\u001b[39m(\u001b[38;5;28mrange\u001b[39m(outer_batch_dims)) \u001b[38;5;241m+\u001b[39m\n\u001b[1;32m     51\u001b[0m                 \u001b[38;5;28mlist\u001b[39m(\u001b[38;5;28mrange\u001b[39m(outer_batch_dims, new_batch_dims, \u001b[38;5;241m2\u001b[39m)) \u001b[38;5;241m+\u001b[39m\n\u001b[1;32m     52\u001b[0m                 \u001b[38;5;28mlist\u001b[39m(\u001b[38;5;28mrange\u001b[39m(outer_batch_dims \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m1\u001b[39m, new_batch_dims, \u001b[38;5;241m2\u001b[39m)) \u001b[38;5;241m+\u001b[39m\n\u001b[1;32m     53\u001b[0m                 [new_batch_dims])\n\u001b[0;32m---> 54\u001b[0m bx \u001b[38;5;241m=\u001b[39m \u001b[43mbx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpermute\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpermute_dims\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     56\u001b[0m flat_L \u001b[38;5;241m=\u001b[39m bL\u001b[38;5;241m.\u001b[39mreshape(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, n, n)  \u001b[38;5;66;03m# shape = b x n x n\u001b[39;00m\n\u001b[1;32m     57\u001b[0m flat_x \u001b[38;5;241m=\u001b[39m bx\u001b[38;5;241m.\u001b[39mreshape(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, flat_L\u001b[38;5;241m.\u001b[39msize(\u001b[38;5;241m0\u001b[39m), n)  \u001b[38;5;66;03m# shape = c x b x n\u001b[39;00m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "dims = [2, 4, 8, 16]\n",
    "logZ_true = np.zeros_like(dims, dtype=np.float32)\n",
    "\n",
    "logZ_ellipsoidal = np.zeros_like(dims, dtype=np.float32)\n",
    "logZ_slice = np.zeros_like(dims, dtype=np.float32)\n",
    "logZ_hamiltonian = np.zeros_like(dims, dtype=np.float32)\n",
    "err_logZ_ellipsoidal = np.zeros_like(dims, dtype=np.float32)\n",
    "err_logZ_slice = np.zeros_like(dims, dtype=np.float32)\n",
    "err_logZ_hamiltonian = np.zeros_like(dims, dtype=np.float32)\n",
    "like_evals_ellipsoidal = np.zeros_like(dims, dtype=np.float32)\n",
    "like_evals_slice = np.zeros_like(dims, dtype=np.float32)\n",
    "like_evals_hamiltonian = np.zeros_like(dims, dtype=np.float32)\n",
    "\n",
    "for i, d in enumerate(dims): \n",
    "    logZ_true[i] = np.log(1 / 10.**d)\n",
    "    print(d, logZ_true[i])\n",
    "    logZ_slice[i], err_logZ_slice[i], like_evals_slice[i] = sample_gaussian(ndims = d, sampler = 'slice')\n",
    "    print(\"slice\", logZ_slice[i], err_logZ_slice[i], like_evals_slice[i])\n",
    "    logZ_ellipsoidal[i], err_logZ_ellipsoidal[i], like_evals_ellipsoidal[i] = sample_gaussian(ndims = d, sampler = 'ellipsoidal')\n",
    "    print(\"ellipsoidal\", logZ_ellipsoidal[i], err_logZ_ellipsoidal[i], like_evals_ellipsoidal[i])\n",
    "    logZ_hamiltonian[i], err_logZ_hamiltonian[i], like_evals_hamiltonian[i] = sample_gaussian(ndims = d, sampler = 'hamiltonian')\n",
    "    print(\"hamiltonian\", logZ_hamiltonian[i], err_logZ_hamiltonian[i], like_evals_hamiltonian[i])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d1c9773",
   "metadata": {},
   "outputs": [],
   "source": [
    "x_size=12\n",
    "y_size=6\n",
    "fontsize = 8\n",
    "fig, ax = plt.subplots(figsize=(x_size/2.54,y_size/2.54))\n",
    "ax.plot(dims, like_evals_hamiltonian, '-x', label='Hamiltonian')\n",
    "ax.plot(dims, like_evals_ellipsoidal, '-o', label='Ellipsoidal')\n",
    "ax.plot(dims, like_evals_slice, '-*', label='Slice')\n",
    "ax.set_yscale('log')\n",
    "ax.set_xscale('log')\n",
    "ax.set_xlabel(\"# of Dimensions\", fontsize=fontsize)\n",
    "ax.set_ylabel(r\"# of Likelihood Evaluations\", fontsize=fontsize)\n",
    "ax.set_xticks(dims)\n",
    "ax.set_xticklabels(np.array(dims, dtype='str'), fontsize=fontsize)\n",
    "#ax.set_yticks([1e4, 1e5, 1e6, 1e7, 1e8])\n",
    "#ax.set_yticklabels(['$10^4$', '$10^5$', '$10^6$', '$10^7$', '$10^8$'], fontsize=fontsize)\n",
    "plt.legend(fontsize=fontsize)\n",
    "#plt.savefig(\"/Users/pablo/Desktop/ns_comparison\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1cd429b2",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.array(dims, dtype='str')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca6a249c",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:gns]",
   "language": "python",
   "name": "conda-env-gns-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
