{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# KL divergence of model with respect to the empirical distribution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Import packages\n",
    "import torch\n",
    "import numpy as np\n",
    "from scipy import stats\n",
    "\n",
    "import boltzgen as bg\n",
    "\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Specify checkpoint root\n",
    "checkpoint_root = 'rnvp'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load config\n",
    "config = bg.utils.get_config(checkpoint_root + 'config/bm.yaml')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load data\n",
    "training_data = bg.utils.load_traj('data/train.h5')\n",
    "test_data = bg.utils.load_traj('data/test.h5')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Setup model\n",
    "model = bg.BoltzmannGenerator(config)\n",
    "\n",
    "# Move model on GPU if available\n",
    "enable_cuda = False\n",
    "device = torch.device('cuda' if torch.cuda.is_available() and enable_cuda else 'cpu')\n",
    "model = model.to(device)\n",
    "model = model.double()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot loss\n",
    "loss = np.loadtxt(checkpoint_root + 'log/loss.csv')\n",
    "plt.figure(figsize=(15, 10))\n",
    "plt.plot(loss, '.')\n",
    "plt.ylim(-190, -160)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load checkpoint\n",
    "model.load(checkpoint_root + 'checkpoints/model_30000.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Draw samples\n",
    "\n",
    "nth = 1\n",
    "\n",
    "model.eval()\n",
    "\n",
    "z_np = np.zeros((0, 60))\n",
    "x_np = np.zeros((0, 66))\n",
    "log_p_np = np.zeros((0,))\n",
    "log_q_np = np.zeros((0,))\n",
    "\n",
    "for i in tqdm(range(1000 // nth)):\n",
    "    z, log_q = model.sample(1000)\n",
    "    x_np = np.concatenate((x_np, z.cpu().data.numpy()))\n",
    "    log_p = model.p.log_prob(z)\n",
    "    z, _ = model.flows[-1].inverse(z)\n",
    "    z_np_ = z.cpu().data.numpy()\n",
    "    log_p_np_ = log_p.cpu().data.numpy()\n",
    "    log_q_np_ = log_q.cpu().data.numpy()\n",
    "    z_np = np.concatenate((z_np, z_np_))\n",
    "    log_p_np = np.concatenate((log_p_np, log_p_np_))\n",
    "    log_q_np = np.concatenate((log_q_np, log_q_np_))\n",
    "\n",
    "\n",
    "#z_d = training_data[::nth].double().to(device)\n",
    "z_d = test_data[::nth].double().to(device)\n",
    "log_p_d = model.p.log_prob(z_d)\n",
    "z_d, _ = model.flows[-1].inverse(z_d)\n",
    "z_d_np = z_d.cpu().data.numpy()\n",
    "\n",
    "log_p_d_np = log_p_d.cpu().data.numpy()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Use histogram to compute KLD"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Estimate density\n",
    "nbins = 200\n",
    "hist_range = [-5, 5]\n",
    "ndims = z_np.shape[1]\n",
    "\n",
    "hists_train = np.zeros((nbins, ndims))\n",
    "hists_gen = np.zeros((nbins, ndims))\n",
    "\n",
    "for i in range(ndims):\n",
    "    htrain, _ = np.histogram(z_d_np[:, i], nbins, range=hist_range, density=True);\n",
    "    hgen, _ = np.histogram(z_np[:, i], nbins, range=hist_range, density=True);\n",
    "    \n",
    "    hists_train[:, i] = htrain\n",
    "    hists_gen[:, i] = hgen"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "for i in range(ndims):\n",
    "    print(i)\n",
    "    plt.plot(np.linspace(-5, 5, nbins), hists_train[:, i])\n",
    "    plt.plot(np.linspace(-5, 5, nbins), hists_gen[:, i])\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compute KLD\n",
    "kld = np.zeros(ndims)\n",
    "eps = 1e-10\n",
    "kld_unscaled = np.sum(hists_train * np.log((hists_train + eps) / (hists_gen + eps)), axis=0)\n",
    "kld = kld_unscaled * (hist_range[1] - hist_range[0]) / nbins"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.median(kld)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Split KLD into groups\n",
    "ncarts = model.flows[-1].mixed_transform.len_cart_inds\n",
    "permute_inv = model.flows[-1].mixed_transform.permute_inv\n",
    "bond_ind = model.flows[-1].mixed_transform.ic_transform.bond_indices\n",
    "angle_ind = model.flows[-1].mixed_transform.ic_transform.angle_indices\n",
    "dih_ind = model.flows[-1].mixed_transform.ic_transform.dih_indices\n",
    "\n",
    "kld_cart = kld[:(3 * ncarts - 6)]\n",
    "kld_ = np.concatenate([kld[:(3 * ncarts - 6)], np.zeros(6), kld[(3 * ncarts - 6):]])\n",
    "kld_ = kld_[permute_inv]\n",
    "kld_bond = kld_[bond_ind]\n",
    "kld_angle = kld_[angle_ind]\n",
    "kld_dih = kld_[dih_ind]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Print resulting KLDs\n",
    "print('Cartesian coorinates')\n",
    "print(np.sort(kld_cart))\n",
    "print('Mean:', np.mean(kld_cart))\n",
    "print('Median:', np.median(kld_cart))\n",
    "\n",
    "print('\\n\\nBond lengths')\n",
    "print(np.sort(kld_bond))\n",
    "print('Mean:', np.mean(kld_bond))\n",
    "print('Median:', np.median(kld_bond))\n",
    "\n",
    "print('\\n\\nBond angles')\n",
    "print(np.sort(kld_angle))\n",
    "print('Mean:', np.mean(kld_angle))\n",
    "print('Median:', np.median(kld_angle))\n",
    "\n",
    "print('\\n\\nDihedral angles')\n",
    "print(np.sort(kld_dih))\n",
    "print('Mean:', np.mean(kld_dih))\n",
    "print('Median:', np.median(kld_dih))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# Histograms of the groups\n",
    "hists_train_cart = hists_train[:, :(3 * ncarts - 6)]\n",
    "hists_train_ = np.concatenate([hists_train[:, :(3 * ncarts - 6)], np.zeros((nbins, 6)),\n",
    "                               hists_train[:, (3 * ncarts - 6):]], axis=1)\n",
    "hists_train_ = hists_train_[:, permute_inv]\n",
    "hists_train_bond = hists_train_[:, bond_ind]\n",
    "hists_train_angle = hists_train_[:, angle_ind]\n",
    "hists_train_dih = hists_train_[:, dih_ind]\n",
    "\n",
    "for hists in [hists_train_cart, hists_train_bond, hists_train_angle, hists_train_dih]:\n",
    "    for i in range(hists.shape[1]):\n",
    "        print(i)\n",
    "        plt.plot(np.linspace(-5, 5, nbins), hists[:, i])\n",
    "        plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Use Gaussian KDE to compute KLD"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Estimate density\n",
    "ndims = z_np.shape[1]\n",
    "\n",
    "kde_train = []\n",
    "kde_gen = []\n",
    "\n",
    "for i in range(ndims):\n",
    "    kernel_train = stats.gaussian_kde(z_d_np[:, i])\n",
    "    kernel_gen = stats.gaussian_kde(z_np[:, i])\n",
    "    \n",
    "    kde_train.append(kernel_train)\n",
    "    kde_gen.append(kernel_gen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "x = np.linspace(-5, 5, 200)\n",
    "for i in range(ndims):\n",
    "    print(i)\n",
    "    plt.plot(x, kde_train[i].pdf(x))\n",
    "    plt.plot(x, kde_gen[i].pdf(x))\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compute KLD\n",
    "eps = 1e-10\n",
    "int_range = [-5, 5]\n",
    "npoints = 1000\n",
    "\n",
    "kld = np.zeros(ndims)\n",
    "x = np.linspace(int_range[0], int_range[1], npoints)\n",
    "\n",
    "for i in tqdm(range(ndims)):\n",
    "    kld_unscaled = np.sum(kde_train[i].pdf(x) * np.log((kde_train[i].pdf(x) + eps) / (kde_gen[i].pdf(x) + eps)))\n",
    "    kld[i] = kld_unscaled * (int_range[1] - int_range[0]) / npoints"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Split KLD into groups\n",
    "ncarts = model.flows[-1].mixed_transform.len_cart_inds\n",
    "permute_inv = model.flows[-1].mixed_transform.permute_inv\n",
    "bond_ind = model.flows[-1].mixed_transform.ic_transform.bond_indices\n",
    "angle_ind = model.flows[-1].mixed_transform.ic_transform.angle_indices\n",
    "dih_ind = model.flows[-1].mixed_transform.ic_transform.dih_indices\n",
    "\n",
    "kld_cart = kld[:(3 * ncarts - 6)]\n",
    "kld_ = np.concatenate([kld[:(3 * ncarts - 6)], np.zeros(6), kld[(3 * ncarts - 6):]])\n",
    "kld_ = kld_[permute_inv]\n",
    "kld_bond = kld_[bond_ind]\n",
    "kld_angle = kld_[angle_ind]\n",
    "kld_dih = kld_[dih_ind]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Print resulting KLDs\n",
    "print('Cartesian coorinates')\n",
    "print(np.sort(kld_cart))\n",
    "print('Mean:', np.mean(kld_cart))\n",
    "print('Median:', np.median(kld_cart))\n",
    "\n",
    "print('\\n\\nBond lengths')\n",
    "print(np.sort(kld_bond))\n",
    "print('Mean:', np.mean(kld_bond))\n",
    "print('Median:', np.median(kld_bond))\n",
    "\n",
    "print('\\n\\nBond angles')\n",
    "print(np.sort(kld_angle))\n",
    "print('Mean:', np.mean(kld_angle))\n",
    "print('Median:', np.median(kld_angle))\n",
    "\n",
    "print('\\n\\nDihedral angles')\n",
    "print(np.sort(kld_dih))\n",
    "print('Mean:', np.mean(kld_dih))\n",
    "print('Median:', np.median(kld_dih))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}