{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# I. Validating Class"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Random verification of n-dimensional Gaussian models in IMQ kernel and KGM kernel."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from stein_pi_thinning.target import PiTargetAuto, PiTargetIMQ, PiTargetKGM, PiTargetCentKGM\n",
    "\n",
    "import numpy as np\n",
    "from scipy.stats import wishart\n",
    "from jax import numpy as jnp\n",
    "\n",
    "np.random.seed(1234)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x_test: [ 0.47143516 -1.19097569  1.43270697 -0.3126519 ]\n",
      "mean  : [-0.72058873  0.88716294  0.85958841 -0.6365235 ]\n",
      "cov   : [[11.68150068  0.05364739 -7.66509583  3.39029399]\n",
      " [ 0.05364739 12.34410368  4.00531349  3.36496236]\n",
      " [-7.66509583  4.00531349 25.21091916 -9.90588931]\n",
      " [ 3.39029399  3.36496236 -9.90588931 10.75391052]]\n"
     ]
    }
   ],
   "source": [
    "dim = 4\n",
    "\n",
    "x_test = np.random.normal(size=dim)\n",
    "mean = np.random.normal(size=dim)\n",
    "cov = wishart.rvs(dim + 10, np.eye(dim), size=1)\n",
    "\n",
    "print(\"x_test:\", x_test)\n",
    "print(\"mean  :\", mean)\n",
    "print(\"cov   :\", cov)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "log-pdf   : -7.153352187026759\n",
      "glog-pdf  : [-0.17921809  0.33648779 -0.25339807 -0.31232095]\n",
      "hlog-pdf  : [[-0.10860438  0.013307   -0.03654333 -0.0035867 ]\n",
      " [ 0.013307   -0.12428333  0.05865073  0.08871951]\n",
      " [-0.03654333  0.05865073 -0.09838846 -0.09746129]\n",
      " [-0.0035867   0.08871951 -0.09746129 -0.20939533]]\n",
      "log-q-pdf : -7.235925371128592\n",
      "glog-q-pdf: [-0.13873336  0.23413024 -0.15708047 -0.17007607]\n"
     ]
    }
   ],
   "source": [
    "# Test manual IMQ class\n",
    "log_p = lambda x:-np.log(2*np.pi) - 0.5*np.log(np.linalg.det(cov)) - 0.5*(x-mean)@np.linalg.inv(cov)@(x-mean)\n",
    "grad_log_p = lambda x: -np.linalg.inv(cov)@(x-mean)\n",
    "hess_log_p = lambda x: -np.linalg.inv(cov)\n",
    "linv = np.linalg.inv(cov)\n",
    "\n",
    "test_manual_imq = PiTargetIMQ(log_p=log_p, grad_log_p=grad_log_p, hess_log_p=hess_log_p, linv=linv)\n",
    "\n",
    "print(\"log-pdf   :\", test_manual_imq.log_p(x_test))\n",
    "print(\"glog-pdf  :\", test_manual_imq.grad_log_p(x_test))\n",
    "print(\"hlog-pdf  :\", test_manual_imq.hess_log_p(x_test))\n",
    "print(\"log-q-pdf :\", test_manual_imq.log_q(x_test))\n",
    "print(\"glog-q-pdf:\", test_manual_imq.grad_log_q(x_test))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "log-pdf   : -7.1533523\n",
      "glog-pdf  : [-0.17921811  0.33648777 -0.2533981  -0.31232098]\n",
      "hlog-pdf  : [[-0.10860439  0.013307   -0.03654334 -0.0035867 ]\n",
      " [ 0.013307   -0.12428331  0.05865073  0.08871952]\n",
      " [-0.03654334  0.05865073 -0.09838847 -0.09746131]\n",
      " [-0.0035867   0.08871952 -0.09746131 -0.20939535]]\n",
      "log-q-pdf : -7.235925\n",
      "glog-q-pdf: [-0.13873339  0.23413023 -0.15708049 -0.17007609]\n"
     ]
    }
   ],
   "source": [
    "# Test auto IMQ class\n",
    "log_p = lambda x:-jnp.log(2*jnp.pi) - 0.5*jnp.log(jnp.linalg.det(cov)) - 0.5*(x-mean)@jnp.linalg.inv(cov)@(x-mean)\n",
    "linv = jnp.linalg.inv(cov)\n",
    "imq_kernel = lambda x, y: (1 + (x - y)@linv@(x - y))**(-0.5)\n",
    "\n",
    "test_auto_imq = PiTargetAuto(log_p=log_p, base_kernel=imq_kernel)\n",
    "\n",
    "print(\"log-pdf   :\", test_auto_imq.log_p(x_test))\n",
    "print(\"glog-pdf  :\", test_auto_imq.grad_log_p(x_test))\n",
    "print(\"hlog-pdf  :\", test_auto_imq.hess_log_p(x_test))\n",
    "print(\"log-q-pdf :\", test_auto_imq.log_q(x_test))\n",
    "print(\"glog-q-pdf:\", test_auto_imq.grad_log_q(x_test))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "log-pdf   : -7.153352187026759\n",
      "glog-pdf  : [-0.17921809  0.33648779 -0.25339807 -0.31232095]\n",
      "hlog-pdf  : [[-0.10860438  0.013307   -0.03654333 -0.0035867 ]\n",
      " [ 0.013307   -0.12428333  0.05865073  0.08871951]\n",
      " [-0.03654333  0.05865073 -0.09838846 -0.09746129]\n",
      " [-0.0035867   0.08871951 -0.09746129 -0.20939533]]\n",
      "log-q-pdf : -6.832792089900934\n",
      "glog-q-pdf: [-0.06523543  0.12608668 -0.05548275 -0.1164513 ]\n"
     ]
    }
   ],
   "source": [
    "# Test manual KGM class\n",
    "log_p = lambda x:-np.log(2*np.pi) - 0.5*np.log(np.linalg.det(cov)) - 0.5*(x-mean)@np.linalg.inv(cov)@(x-mean)\n",
    "grad_log_p = lambda x: -np.linalg.inv(cov)@(x-mean)\n",
    "hess_log_p = lambda x: -np.linalg.inv(cov)\n",
    "linv = np.linalg.inv(cov)\n",
    "\n",
    "test_manual_kgm = PiTargetKGM(log_p=log_p, grad_log_p=grad_log_p, hess_log_p=hess_log_p, linv=linv, s=3.0)\n",
    "\n",
    "print(\"log-pdf   :\", test_manual_kgm.log_p(x_test))\n",
    "print(\"glog-pdf  :\", test_manual_kgm.grad_log_p(x_test))\n",
    "print(\"hlog-pdf  :\", test_manual_kgm.hess_log_p(x_test))\n",
    "print(\"log-q-pdf :\", test_manual_kgm.log_q(x_test))\n",
    "print(\"glog-q-pdf:\", test_manual_kgm.grad_log_q(x_test))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "log-pdf   : -7.1533523\n",
      "glog-pdf  : [-0.17921811  0.33648777 -0.2533981  -0.31232098]\n",
      "hlog-pdf  : [[-0.10860439  0.013307   -0.03654334 -0.0035867 ]\n",
      " [ 0.013307   -0.12428331  0.05865073  0.08871952]\n",
      " [-0.03654334  0.05865073 -0.09838847 -0.09746131]\n",
      " [-0.0035867   0.08871952 -0.09746131 -0.20939535]]\n",
      "log-q-pdf : -6.8327923\n",
      "glog-q-pdf: [-0.06523542  0.12608667 -0.05548273 -0.11645128]\n"
     ]
    }
   ],
   "source": [
    "# Test auto KGM class\n",
    "s = 3.0\n",
    "linv = jnp.linalg.inv(cov)\n",
    "log_p = lambda x:-jnp.log(2*jnp.pi) - 0.5*jnp.log(jnp.linalg.det(cov)) - 0.5*(x-mean)@jnp.linalg.inv(cov)@(x-mean)\n",
    "\n",
    "kgm_kernel = lambda x, y: (1 + x@linv@x)**((s-1)/2) *\\\n",
    "        (1 + y@linv@y)**((s-1)/2) *\\\n",
    "        (1 + (x-y)@linv@(x-y))**(-0.5) +\\\n",
    "        (1 + x@linv@y )/( jnp.sqrt(1+x@linv@x) * jnp.sqrt(1+y@linv@y) )\n",
    "\n",
    "test_auto_kgm = PiTargetAuto(log_p=log_p, base_kernel=kgm_kernel)\n",
    "\n",
    "print(\"log-pdf   :\", test_auto_kgm.log_p(x_test))\n",
    "print(\"glog-pdf  :\", test_auto_kgm.grad_log_p(x_test))\n",
    "print(\"hlog-pdf  :\", test_auto_kgm.hess_log_p(x_test))\n",
    "print(\"log-q-pdf :\", test_auto_kgm.log_q(x_test))\n",
    "print(\"glog-q-pdf:\", test_auto_kgm.grad_log_q(x_test))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "log-pdf   : -7.153352187026759\n",
      "glog-pdf  : [-0.17921809  0.33648779 -0.25339807 -0.31232095]\n",
      "hlog-pdf  : [[-0.10860438  0.013307   -0.03654333 -0.0035867 ]\n",
      " [ 0.013307   -0.12428333  0.05865073  0.08871951]\n",
      " [-0.03654333  0.05865073 -0.09838846 -0.09746129]\n",
      " [-0.0035867   0.08871951 -0.09746129 -0.20939533]]\n",
      "log-q-pdf : -6.600675379440154\n",
      "glog-q-pdf: [-0.02890971  0.04828226 -0.03198163 -0.03406257]\n"
     ]
    }
   ],
   "source": [
    "# Test manual Centralized KGM class\n",
    "log_p = lambda x:-np.log(2*np.pi) - 0.5*np.log(np.linalg.det(cov)) - 0.5*(x-mean)@np.linalg.inv(cov)@(x-mean)\n",
    "grad_log_p = lambda x: -np.linalg.inv(cov)@(x-mean)\n",
    "hess_log_p = lambda x: -np.linalg.inv(cov)\n",
    "linv = np.linalg.inv(cov)\n",
    "\n",
    "test_manual_centkgm = PiTargetCentKGM(log_p=log_p, grad_log_p=grad_log_p, hess_log_p=hess_log_p, linv=linv, s=3.0, x_map=mean)\n",
    "\n",
    "print(\"log-pdf   :\", test_manual_centkgm.log_p(x_test))\n",
    "print(\"glog-pdf  :\", test_manual_centkgm.grad_log_p(x_test))\n",
    "print(\"hlog-pdf  :\", test_manual_centkgm.hess_log_p(x_test))\n",
    "print(\"log-q-pdf :\", test_manual_centkgm.log_q(x_test))\n",
    "print(\"glog-q-pdf:\", test_manual_centkgm.grad_log_q(x_test))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "log-pdf   : -7.1533523\n",
      "glog-pdf  : [-0.17921811  0.33648777 -0.2533981  -0.31232098]\n",
      "hlog-pdf  : [[-0.10860439  0.013307   -0.03654334 -0.0035867 ]\n",
      " [ 0.013307   -0.12428331  0.05865073  0.08871952]\n",
      " [-0.03654334  0.05865073 -0.09838847 -0.09746131]\n",
      " [-0.0035867   0.08871952 -0.09746131 -0.20939535]]\n",
      "log-q-pdf : -6.6006756\n",
      "glog-q-pdf: [-0.02890968  0.04828227 -0.03198163 -0.03406259]\n"
     ]
    }
   ],
   "source": [
    "# Test auto Centralized KGM class\n",
    "s = 3.0\n",
    "x_map = mean\n",
    "linv = jnp.linalg.inv(cov)\n",
    "log_p = lambda x:-jnp.log(2*jnp.pi) - 0.5*jnp.log(jnp.linalg.det(cov)) - 0.5*(x-mean)@jnp.linalg.inv(cov)@(x-mean)\n",
    "\n",
    "centkgm_kernel = lambda x, y: (1 + (x-x_map)@linv@(x-x_map))**((s-1)/2) * (1 + (y-x_map)@linv@(y-x_map))**((s-1)/2)\\\n",
    "      * ((1 + (x-y)@linv@(x-y).T )**(-0.5) +\\\n",
    "            (1 + (x-x_map)@linv@(y-x_map).T)/( (1+(x-x_map)@linv@(x-x_map).T)**(s/2) * (1+(y-x_map)@linv@(y-x_map).T)**(s/2) ))\n",
    "\n",
    "test_auto_centkgm = PiTargetAuto(log_p=log_p, base_kernel=centkgm_kernel)\n",
    "\n",
    "print(\"log-pdf   :\", test_auto_centkgm.log_p(x_test))\n",
    "print(\"glog-pdf  :\", test_auto_centkgm.grad_log_p(x_test))\n",
    "print(\"hlog-pdf  :\", test_auto_centkgm.hess_log_p(x_test))\n",
    "print(\"log-q-pdf :\", test_auto_centkgm.log_q(x_test))\n",
    "print(\"glog-q-pdf:\", test_auto_centkgm.grad_log_q(x_test))"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# II. Verifying Non-Python Objects"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Verifying the Availability of Black-Box Log-Posterior Distributions and their Gradients Compiled by Stan"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "import pandas as pd\n",
    "from stein_pi_thinning.util import flat\n",
    "import bridgestan as bs\n",
    "from posteriordb import PosteriorDatabase"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load DataBase Locally\n",
    "pdb_path = os.path.join(os.getcwd(), \"posteriordb/posterior_database\")\n",
    "my_pdb = PosteriorDatabase(pdb_path)\n",
    "\n",
    "# Load Dataset\n",
    "posterior = my_pdb.posterior(\"eight_schools-eight_schools_centered\")\n",
    "stan = posterior.model.stan_code_file_path()\n",
    "data = json.dumps(posterior.data.values())\n",
    "model = bs.StanModel.from_stan_file(stan, data)\n",
    "\n",
    "# Gold Standard\n",
    "gs = posterior.reference_draws()\n",
    "df = pd.DataFrame(gs)\n",
    "gs_chains = np.zeros((sum(flat(posterior.information['dimensions'].values())),\\\n",
    "                       posterior.reference_draws_info()['diagnostics']['ndraws']))\n",
    "for i in range(len(df.keys())):\n",
    "    s = []\n",
    "    for j in range(len(df[df.keys()[i]])):\n",
    "        s += df[df.keys()[i]][j]\n",
    "    gs_chains[i, :] = s\n",
    "linv = np.linalg.inv(np.cov(gs_chains))\n",
    "\n",
    "# Extract log-P-pdf and its gradient\n",
    "log_p = model.log_density\n",
    "grad_log_p = lambda x: model.log_density_gradient(x)[1]\n",
    "hess_log_p = lambda x: model.log_density_hessian(x)[2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "log-pdf   : -9.01998466982852\n",
      "glog-pdf  : [-0.17749739 -0.40994467 -0.89068497  0.09812123 -0.56727961  0.7459595\n",
      "  0.07248543 -0.70806775  2.31392932 -2.28824916]\n",
      "hlog-pdf  : [[-0.51366436  0.          0.          0.          0.          0.\n",
      "   0.          0.          0.50921991  0.60216437]\n",
      " [ 0.         -0.51921991  0.          0.          0.          0.\n",
      "   0.          0.          0.50921991  0.96882056]\n",
      " [ 0.          0.         -0.51312616  0.          0.          0.\n",
      "   0.          0.          0.50921991  1.74763437]\n",
      " [ 0.          0.          0.         -0.51748438  0.          0.\n",
      "   0.          0.          0.50921991 -0.07278286]\n",
      " [ 0.          0.          0.          0.         -0.52156559  0.\n",
      "   0.          0.          0.50921991  1.09318752]\n",
      " [ 0.          0.          0.          0.          0.         -0.51748438\n",
      "   0.          0.          0.50921991 -1.44535656]\n",
      " [ 0.          0.          0.          0.          0.          0.\n",
      "  -0.51921991  0.          0.50921991  0.2186913 ]\n",
      " [ 0.          0.          0.          0.          0.          0.\n",
      "   0.         -0.51230633  0.50921991  1.48367273]\n",
      " [ 0.50921991  0.50921991  0.50921991  0.50921991  0.50921991  0.50921991\n",
      "   0.50921991  0.50921991 -4.1137593  -4.59603143]\n",
      " [ 0.60216437  0.96882056  1.74763437 -0.07278286  1.09318752 -1.44535656\n",
      "   0.2186913   1.48367273 -4.59603143 -9.98492908]]\n",
      "log-q-pdf : -7.694975864096433\n",
      "glog-q-pdf: [-0.1851579  -0.46828406 -1.05767961  0.18954709 -0.63985924  1.03559699\n",
      "  0.11771857 -0.83905042  2.31834682 -1.76456858]\n"
     ]
    }
   ],
   "source": [
    "# Test manual IMQ class\n",
    "dim = model.param_num()\n",
    "x_test = np.random.normal(size=dim)\n",
    "\n",
    "test_manual_imq = PiTargetIMQ(log_p=log_p, grad_log_p=grad_log_p, hess_log_p=hess_log_p, linv=linv)\n",
    "\n",
    "print(\"log-pdf   :\", test_manual_imq.log_p(x_test))\n",
    "print(\"glog-pdf  :\", test_manual_imq.grad_log_p(x_test))\n",
    "print(\"hlog-pdf  :\", test_manual_imq.hess_log_p(x_test))\n",
    "print(\"log-q-pdf :\", test_manual_imq.log_q(x_test))\n",
    "print(\"glog-q-pdf:\", test_manual_imq.grad_log_q(x_test))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "log-pdf   : -13.198300114983885\n",
      "glog-pdf  : [-2.8959757  -2.94113106 -2.46003717  0.6744793  -0.16450099  1.24965883\n",
      " -2.2049274  -7.1567994  16.32757001 19.82672475]\n",
      "hlog-pdf  : [[ -3.10906611   0.           0.           0.           0.\n",
      "    0.           0.           0.           3.10462166   6.03152848]\n",
      " [  0.          -3.11462166   0.           0.           0.\n",
      "    0.           0.           0.           3.10462166   6.02134335]\n",
      " [  0.           0.          -3.10852791   0.           0.\n",
      "    0.           0.           0.           3.10462166   4.88988905]\n",
      " [  0.           0.           0.          -3.11288613   0.\n",
      "    0.           0.           0.           3.10462166  -1.23123808]\n",
      " [  0.           0.           0.           0.          -3.11696734\n",
      "    0.           0.           0.           3.10462166   0.30123129]\n",
      " [  0.           0.           0.           0.           0.\n",
      "   -3.11288613   0.           0.           3.10462166  -2.47745328]\n",
      " [  0.           0.           0.           0.           0.\n",
      "    0.          -3.11462166   0.           3.10462166   4.75302132]\n",
      " [  0.           0.           0.           0.           0.\n",
      "    0.           0.          -3.10770808   3.10462166  14.37291385]\n",
      " [  3.10462166   3.10462166   3.10462166   3.10462166   3.10462166\n",
      "    3.10462166   3.10462166   3.10462166 -24.87697332 -32.66123599]\n",
      " [  6.03152848   6.02134335   4.88988905  -1.23123808   0.30123129\n",
      "   -2.47745328   4.75302132  14.37291385 -32.66123599 -53.75456334]]\n",
      "log-q-pdf : -9.23378165032704\n",
      "glog-q-pdf: [-2.58160335 -2.62896216 -2.22107511  0.68994554 -0.09988177  1.21637149\n",
      " -1.92417505 -6.55483477 14.56349312 17.22832009]\n"
     ]
    }
   ],
   "source": [
    "# Test manual KGM class\n",
    "dim = model.param_num()\n",
    "x_test = np.random.normal(size=dim)\n",
    "\n",
    "test_manual_kgm = PiTargetKGM(log_p=log_p, grad_log_p=grad_log_p, hess_log_p=hess_log_p, linv=linv, s=3.0)\n",
    "\n",
    "print(\"log-pdf   :\", test_manual_kgm.log_p(x_test))\n",
    "print(\"glog-pdf  :\", test_manual_kgm.grad_log_p(x_test))\n",
    "print(\"hlog-pdf  :\", test_manual_kgm.hess_log_p(x_test))\n",
    "print(\"log-q-pdf :\", test_manual_kgm.log_q(x_test))\n",
    "print(\"glog-q-pdf:\", test_manual_kgm.grad_log_q(x_test))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "log-pdf   : -125.29449874971975\n",
      "glog-pdf  : [  14.62998483   52.65407133   10.6785506    31.39132912   17.61179453\n",
      "   14.83556516    1.7150771    11.31376635 -154.37989176  255.32649628]\n",
      "hlog-pdf  : [[ -18.02692961    0.            0.            0.            0.\n",
      "     0.            0.            0.           18.02248516  -29.01140202]\n",
      " [   0.          -18.03248516    0.            0.            0.\n",
      "     0.            0.            0.           18.02248516 -105.10664311]\n",
      " [   0.            0.          -18.02639141    0.            0.\n",
      "     0.            0.            0.           18.02248516  -21.38247457]\n",
      " [   0.            0.            0.          -18.03074963    0.\n",
      "     0.            0.            0.           18.02248516  -62.65212674]\n",
      " [   0.            0.            0.            0.          -18.03483084\n",
      "     0.            0.            0.           18.02248516  -35.24490276]\n",
      " [   0.            0.            0.            0.            0.\n",
      "   -18.03074963    0.            0.           18.02248516  -29.6549037 ]\n",
      " [   0.            0.            0.            0.            0.\n",
      "     0.          -18.03248516    0.           18.02248516   -3.08526248]\n",
      " [   0.            0.            0.            0.            0.\n",
      "     0.            0.          -18.02557158   18.02248516  -22.55478745]\n",
      " [  18.02248516   18.02248516   18.02248516   18.02248516   18.02248516\n",
      "    18.02248516   18.02248516   18.02248516 -144.21988131  308.69250283]\n",
      " [ -29.01140202 -105.10664311  -21.38247457  -62.65212674  -35.24490276\n",
      "   -29.6549037    -3.08526248  -22.55478745  308.69250283 -524.67068925]]\n",
      "log-q-pdf : -118.0406392250471\n",
      "glog-q-pdf: [  14.50730839   52.22811112   10.57431753   31.12636937   17.4382916\n",
      "   14.69027827    1.65631925   11.20697352 -153.16534745  253.20043008]\n"
     ]
    }
   ],
   "source": [
    "# Test manual Centralized KGM class\n",
    "dim = model.param_num()\n",
    "x_test = np.random.normal(size=dim)\n",
    "x_map = model.param_unconstrain(np.mean(gs_chains, axis=1))\n",
    "\n",
    "test_manual_centkgm = PiTargetCentKGM(log_p=log_p, grad_log_p=grad_log_p, hess_log_p=hess_log_p, linv=linv, s=3.0, x_map=x_map)\n",
    "\n",
    "print(\"log-pdf   :\", test_manual_centkgm.log_p(x_test))\n",
    "print(\"glog-pdf  :\", test_manual_centkgm.grad_log_p(x_test))\n",
    "print(\"hlog-pdf  :\", test_manual_centkgm.hess_log_p(x_test))\n",
    "print(\"log-q-pdf :\", test_manual_centkgm.log_q(x_test))\n",
    "print(\"glog-q-pdf:\", test_manual_centkgm.grad_log_q(x_test))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "stein-q-thinning",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
