{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
     ]
    }
   ],
   "source": [
    "import pandas\n",
    "import numpy as onp\n",
    "import jax.numpy as np\n",
    "\n",
    "path = './data/kin8nm.csv'\n",
    "data = pandas.read_csv(path, header=None).values\n",
    "\n",
    "X_full = data[:, :-1]\n",
    "Y_full = data[:, -1:]\n",
    "\n",
    "\n",
    "N = X_full.shape[0]\n",
    "n = int(N * 0.9)\n",
    "ind = onp.arange(N)\n",
    "\n",
    "onp.random.shuffle(ind)\n",
    "train_ind = ind[:n]\n",
    "test_ind = ind[n:]\n",
    "\n",
    "X = X_full[train_ind]\n",
    "Xs = X_full[test_ind]\n",
    "Y = Y_full[train_ind]\n",
    "Ys = Y_full[test_ind]\n",
    "\n",
    "X_mean = np.mean(X, 0)\n",
    "X_std = np.std(X, 0)\n",
    "X = (X - X_mean) / X_std\n",
    "Xs = (Xs - X_mean) / X_std\n",
    "Y_mean = np.mean(Y, 0)\n",
    "Y_std = np.std(Y, 0)\n",
    "Y = (Y - Y_mean) / Y_std\n",
    "Ys = (Ys - Y_mean) / Y_std\n",
    "\n",
    "class D:\n",
    "    X_train, Y_train, X_test, Y_test, X_mean, Y_mean, X_std, Y_std = X, Y, Xs, Ys, X_mean, Y_mean, X_std, Y_std\n",
    "    D = X_train.shape[0]\n",
    "\n",
    "from utils import bbData\n",
    "d = bbData(D(), minibatch_size=128)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "arg_dict = {'step_size': 0.001, \\\n",
    "            'sgd': False, 'alpha': 0.1, \\\n",
    "            'invsigma': 100., \\\n",
    "            'a': 1.0, 'mu': 10.0}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "import jax.numpy as np\n",
    "import numpy as onp\n",
    "from jax import random\n",
    "\n",
    "\n",
    "from inference import Inference\n",
    "from utils import bbData\n",
    "from jax.scipy.special import logsumexp\n",
    "from bnn_stax import *\n",
    "\n",
    "key = random.PRNGKey(0)\n",
    "k1, k2, k3 = random.split(key, 3)\n",
    "sizes = [d.d.D] + [50] + [d.d.Y_test.shape[1]]\n",
    "from jax.nn.initializers import normal, glorot_normal\n",
    "model = BnnModel_stax_regression(d, n_layers=2, optim='sgd', stein_kernel='se', activation='relu')\n",
    "width = 50\n",
    "fd = model.func_dict\n",
    "fd['LD_stein'] = LD_stein(**arg_dict)\n",
    "fd['NHT_2_stein'] = NHT_2_stein(**arg_dict)\n",
    "fd['HMC_stein'] = HMC_stein(**arg_dict)\n",
    "model.stein_training(n_iter=5000, n_particles=20, key=k1, lr=arg_dict['step_size'], window=10, method=('NHT_2', 'stein'), sgd=arg_dict['sgd'], split=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train log-likelihood: 0.19\n",
      "Test log-likelihood: 0.60\n"
     ]
    }
   ],
   "source": [
    "print('Train log-likelihood: ' + \"%.3f\" % model.test_trace[-1][1])\n",
    "print('Test log-likelihood: ' + \"%.3f\" % model.test_trace[-1][0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# plot trace of test log-likelihood\n",
    "plt.plot(np.linspace(0, 5000, len(model.test_trace)), [x[1] for x in model.test_trace])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:gp]",
   "language": "python",
   "name": "conda-env-gp-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
