{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Python Path non-sense\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import sys\n",
    "import os\n",
    "\n",
    "sys.path = [x for x in sys.path if 'bayes_gsl' not in x]\n",
    "new_path = '/Users/maxw/projects/gsl-bnn/' ## change this to the path of the src directory!\n",
    "if new_path not in sys.path:\n",
    "    sys.path.append(new_path)\n",
    "\n",
    "# Now try importing your module using the absolute path as a check\n",
    "from src.models import dpg_bnn\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import networkx as nx\n",
    "import time\n",
    "import os\n",
    "import matplotlib.pyplot as plt\n",
    "import pickle\n",
    "from scipy.linalg import block_diag\n",
    "\n",
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import jax.random\n",
    "from jax.random import PRNGKey\n",
    "from jax import random as jax_random\n",
    "\n",
    "import numpyro\n",
    "from numpyro.infer import MCMC, NUTS\n",
    "from numpyro.infer import init_to_value\n",
    "plt.style.use(\"bmh\")\n",
    "#from IPython.display import set_matplotlib_formats\n",
    "import matplotlib_inline\n",
    "if \"NUMPYRO_SPHINXBUILD\" in os.environ:\n",
    "    #set_matplotlib_formats(\"svg\")\n",
    "    matplotlib_inline.backend_inline.set_matplotlib_formats(\"svg\")\n",
    "from numpyro.infer import MCMC, NUTS, SVI, Predictive, Trace_ELBO, autoguide\n",
    "\n",
    "from src.utils import degrees_from_upper_tri\n",
    "from src.iterates_and_unroll import unroll_dpg\n",
    "from src.utils import adj2vec, vec2adj, edge_density\n",
    "from src.models import dpg_bnn\n",
    "\n",
    "from src.config import w_init_scale, lam_init_scale, altered_prior\n",
    "\n",
    "from src.metrics import compute_metrics"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load Synthetic Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading Data: /Users/maxw/projects/gsl-bnn/data/synthetic/RG_N=20_r=0.5_dim=2.pt\n"
     ]
    }
   ],
   "source": [
    "from src import SYNTHETIC_DATA_ROOT\n",
    "\n",
    "graph_distribution = SYNTHETIC_DATA_ROOT + f\"RG_N={20}_r={0.5}_dim={2}.pt\"\n",
    "num_signals = float('inf') # analytic Euclidean distance matrix\n",
    "print(f\"Loading Data: {graph_distribution}\")\n",
    "data_dict = pickle.load(open(graph_distribution, \"rb\"))\n",
    "\n",
    "# data is pairs of adjacency matrices and euclidean distance matrices\n",
    "adjacencies = data_dict['adjacencies'].astype(np.float32)\n",
    "dataset_key = 'expected' if num_signals == float('inf') else str(num_signals)\n",
    "euclidean_distance_matrices = data_dict[dataset_key]\n",
    "\n",
    "# convert to vectors\n",
    "adjacencies = adj2vec(adjacencies)\n",
    "euclidean_distance_matrices = adj2vec(euclidean_distance_matrices)\n",
    "num_edges = adjacencies.shape[-1]\n",
    "\n",
    "# concatenate for easier processing\n",
    "data = np.concatenate([euclidean_distance_matrices, adjacencies], axis=1)\n",
    "\n",
    "# predetermine train/val/test split\n",
    "num_train, num_val, num_test = 50, 50, 100\n",
    "train, val, test = data[:num_train], data[num_train:num_train + num_val], data[num_train + num_val:]\n",
    "data =  {\"train\": (train[:, :num_edges], train[:, num_edges:]), \n",
    "         \"val\": (val[:, :num_edges], val[:, num_edges:]),\n",
    "         \"test\": (test[:, :num_edges], test[:, num_edges:])}\n",
    "\n",
    "# unpack data\n",
    "x_total, y_total = data['train']\n",
    "n, num_edges = int(0.5*(np.sqrt(8 * x_total.shape[-1] + 1) + 1)), x_total.shape[-1]\n",
    "\n",
    "num_train_samples_ = 50\n",
    "x_total, y_total = jnp.array(x_total), jnp.array(y_total)\n",
    "x, y = x_total[:num_train_samples_], y_total[:num_train_samples_]"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Configure DPG"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "depth = 30\n",
    "\n",
    "num_train_samples_ = 50\n",
    "\n",
    "w_init, lam_init = w_init_scale * jnp.ones((num_train_samples_, num_edges)), lam_init_scale * jnp.ones((num_train_samples_, n))\n",
    "\n",
    "S = jnp.array(degrees_from_upper_tri(n))\n",
    "model = dpg_bnn.model\n",
    "model_args = {'x': x, 'y': y,\n",
    "              'depth': depth,\n",
    "              'w_init': w_init, 'lam_init': lam_init,\n",
    "              'S': S,\n",
    "              'dummy': False,\n",
    "              'prior_settings': altered_prior} # priors for model parameters"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Run Inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import multiprocessing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of CPU cores: 1\n"
     ]
    }
   ],
   "source": [
    "# for parallelization for multiple chains. Must be done before jax import:\n",
    "# Blackjax tutorial: https://blackjax-devs.github.io/blackjax/examples/howto_sample_multiple_chains.html\n",
    "os.environ[\"XLA_FLAGS\"] = \"--xla_force_host_platform_device_count={}\".format(\n",
    "    multiprocessing.cpu_count()\n",
    ")\n",
    "\n",
    "print(\"Number of CPU cores:\", jax.local_device_count()) # often will not allow multi-threading in Jupyter notebooks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "********** Running Inference: depth=30, num_train_samples=50 **********\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/var/folders/lk/qxmfz0hx2ggd8m2__sd2ks9h0000gn/T/ipykernel_51396/4019214288.py:8: UserWarning: There are not enough devices to run parallel chains: expected 2 but got 1. Chains will be drawn sequentially. If you are running MCMC in CPU, consider using `numpyro.set_host_device_count(2)` at the beginning of your program. You can double-check how many devices are available in your system using `jax.local_device_count()`.\n",
      "  mcmc = MCMC(kernel,\n",
      "sample: 100%|██████████| 200/200 [00:14<00:00, 13.90it/s, 15 steps of size 4.99e-03. acc. prob=0.98] \n",
      "sample: 100%|██████████| 200/200 [00:11<00:00, 16.85it/s, 31 steps of size 3.24e-03. acc. prob=0.97]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Time taken for inference using 2 with 100 warmup samples and 100 samples: 27.661314249038696\n",
      "\n",
      "                mean       std    median      5.0%     95.0%     n_eff     r_hat\n",
      "         b     11.93      0.68     11.93     10.86     13.13     55.33      0.99\n",
      "     delta    128.03      5.25    127.77    117.39    135.51     53.56      1.00\n",
      "     theta      0.14      0.00      0.14      0.14      0.15     74.00      1.00\n",
      "\n",
      "Number of divergences: 0\n",
      "^^********** Finished  **********^^\n",
      "\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(f'\\n\\n********** Running Inference: depth={depth}, num_train_samples={num_train_samples_} **********')\n",
    "\n",
    "# Start from this source of randomness. We will split keys for subsequent operations.\n",
    "rng_key = jax_random.PRNGKey(0)\n",
    "rng_key, rng_key_ = jax_random.split(rng_key)\n",
    "kernel = NUTS(model, forward_mode_differentiation=True)\n",
    "num_chains, num_warmup_samples, num_samples = 2, 100, 100 #jax.local_device_count()\n",
    "mcmc = MCMC(kernel,\n",
    "            num_warmup=num_warmup_samples, num_samples=num_samples,\n",
    "            progress_bar=True,\n",
    "            num_chains=num_chains, chain_method='parallel')\n",
    "start_time = time.time()\n",
    "mcmc.run(rng_key_, **model_args)\n",
    "end_time = time.time()\n",
    "print(f\"Time taken for inference using {num_chains} with {num_warmup_samples} warmup samples and {num_samples} samples: {end_time - start_time}\")\n",
    "mcmc.print_summary()\n",
    "print(f\"^^********** Finished  **********^^\\n\\n\")\n",
    "samples = mcmc.get_samples()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "samples[theta].shape: (200,)\n",
      "samples[delta].shape: (200,)\n",
      "samples[b].shape: (200,)\n",
      "x_test.shape: (100, 190)\n",
      "w_test.shape: (100, 190)\n",
      "lam_test.shape: (100, 20)\n",
      "S.shape: (20, 190)\n"
     ]
    }
   ],
   "source": [
    "num_test_samples = 100\n",
    "x_total_test, y_total_test = data['test']\n",
    "n, num_edges = int(0.5*(np.sqrt(8 * x_total_test.shape[-1] + 1) + 1)), x_total_test.shape[-1]\n",
    "x_total_test, y_total_test = jnp.array(x_total_test), jnp.array(y_total_test)\n",
    "x_test, y_test = x_total_test[:num_test_samples], y_total_test[:num_test_samples]\n",
    "w_test, lam_test = w_init_scale * jnp.ones((num_test_samples, num_edges)), lam_init_scale * jnp.ones((num_test_samples, n))\n",
    "\n",
    "\n",
    "# print the shapes of all the inputs to the model\n",
    "print(f'samples[theta].shape: {samples[\"theta\"].shape}')\n",
    "print(f'samples[delta].shape: {samples[\"delta\"].shape}')\n",
    "print(f'samples[b].shape: {samples[\"b\"].shape}')\n",
    "print(f'x_test.shape: {x_test.shape}')\n",
    "print(f'w_test.shape: {w_test.shape}')\n",
    "print(f'lam_test.shape: {lam_test.shape}')\n",
    "print(f'S.shape: {S.shape}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "dpg_bnn_forward_pass = dpg_bnn.forward_pass_vmap()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "edge_logits = dpg_bnn_forward_pass(\n",
    "    samples['theta'],\n",
    "    samples['delta'],\n",
    "    samples['b'],\n",
    "    x_test,\n",
    "    w_test,\n",
    "    lam_test,\n",
    "    depth,\n",
    "    S)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test Error: 0.01363 \\pm 0.01274\n",
      "Test NLL: 10.161 \\pm 9.809\n",
      "Test BS: 0.01129 \\pm 0.00940\n",
      "Test ECE:0.00246\n"
     ]
    }
   ],
   "source": [
    "from src import NUM_BINS\n",
    "\n",
    "metrics_dict = compute_metrics(edge_logits, y_test, NUM_BINS)\n",
    "calibration_dict = metrics_dict['calibration_dict']\n",
    "\n",
    "print(f'Test Error: {1 - metrics_dict[\"accuracies\"].mean():.5f} \\pm {metrics_dict[\"accuracies\"].std():.5f}')\n",
    "print(f'Test NLL: {-1 * metrics_dict[\"log_likelihoods\"].mean():.3f} \\pm {metrics_dict[\"log_likelihoods\"].std():.3f}')\n",
    "print(f'Test BS: {metrics_dict[\"brier_scores\"].mean():.5f} \\pm {metrics_dict[\"brier_scores\"].std():.5f}')\n",
    "print(f'Test ECE:{calibration_dict[\"ece\"]:.5f}')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "gsl-bnn-mac-m2",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
