{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Visualize selected marginals for different models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Import packages\n",
    "import torch\n",
    "import numpy as np\n",
    "from scipy import stats\n",
    "\n",
    "import boltzgen as bg\n",
    "\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create model for transform\n",
    "\n",
    "# Specify checkpoint root\n",
    "checkpoint_root = 'rnvp/'\n",
    "# Load config\n",
    "config = bg.utils.get_config(checkpoint_root + 'config/bm.yaml')\n",
    "# Setup model\n",
    "model = bg.BoltzmannGenerator(config)\n",
    "# Move model on GPU if available\n",
    "enable_cuda = False\n",
    "device = torch.device('cuda' if torch.cuda.is_available() and enable_cuda else 'cpu')\n",
    "model = model.to(device)\n",
    "model = model.double()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get test data\n",
    "test_data = bg.utils.load_traj('data/test.h5')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load model samples\n",
    "prefix = ['samples/alpha_0_no_scale_01/samples_batch_num_0_processID_',\n",
    "          'samples/alpha_0_scale_02/samples_batch_num_0_processID_',\n",
    "          'samples/alpha_0_grid_search_samples_batch_num_0_processID_',\n",
    "          'samples/alpha_0_train_acc_prob_samples_batch_num_0_processID_']\n",
    "\n",
    "z_np = np.zeros((len(prefix) + 1, 1024 * 1024, 60))\n",
    "for j in range(len(prefix)):\n",
    "    for i in tqdm(range(1024)):\n",
    "        x_np = np.load(prefix[j] + str(i) + '.npy')\n",
    "        x = torch.tensor(x_np)\n",
    "        z, _ = model.flows[-1].inverse(x)\n",
    "        z_np_ = z.numpy()\n",
    "        z_np[j + 1, (i * 1024):((i + 1) * 1024), :] = z_np_\n",
    "z_np = z_np[:, :1000000, :]\n",
    "z, _ = model.flows[-1].inverse(test_data)\n",
    "z_np[0, :, :] = z.numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get marginals via KDE\n",
    "int_range = [-np.pi, np.pi]\n",
    "npoints = 150\n",
    "x = np.linspace(int_range[0], int_range[1], npoints)\n",
    "kde_marg = np.zeros((len(z_np), npoints, 60))\n",
    "for i in range(len(z_np)):\n",
    "    for j in tqdm(range(60)):\n",
    "        kde = stats.gaussian_kde(z_np[i, np.logical_not(np.isnan(z_np[i, :, j])), j])\n",
    "        kde_marg[i, :, j] = kde.pdf(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ind_marg = np.array([[22, 43, 58], [9, 33, 45], [32, 53, 11], [1, 2, 7]])\n",
    "ylabel = ['Bond angles', 'Bond lengths', 'Dihedral angles', 'Cartesian coordinates']\n",
    "f, ax = plt.subplots(4, 3, figsize=(15, 20), sharex=True)\n",
    "lines = [None] * len(kde_marg)\n",
    "for i in range(ind_marg.shape[0]):\n",
    "    for j in range(ind_marg.shape[1]):\n",
    "        for k in range(len(kde_marg)):\n",
    "            lines[k], = ax[i, j].plot(x, kde_marg[k, :, ind_marg[i, j]])\n",
    "        ax[i, j].set_yticks([])\n",
    "        ax[i, j].tick_params(axis='x', which='both', labelsize=18)\n",
    "        if j == 0:\n",
    "            ax[i, j].set_ylabel(ylabel[i], fontsize=22)\n",
    "f.legend(lines, ['Ground truth', 'maxELT', 'maxELT & SKSD', 'Grid search', '$\\overline{p}_a=0.65$'], \n",
    "         bbox_to_anchor=(0.905, 0.885), fontsize=16)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# Get indices of the groups\n",
    "ncarts = model.flows[-1].mixed_transform.len_cart_inds\n",
    "permute_inv = model.flows[-1].mixed_transform.permute_inv\n",
    "bond_ind = model.flows[-1].mixed_transform.ic_transform.bond_indices\n",
    "angle_ind = model.flows[-1].mixed_transform.ic_transform.angle_indices\n",
    "dih_ind = model.flows[-1].mixed_transform.ic_transform.dih_indices\n",
    "\n",
    "ind_perm = np.concatenate([np.arange(3 * ncarts - 6), np.arange(60, 66), np.arange(3 * ncarts - 6, 60)])\n",
    "ind = ind_perm[permute_inv]\n",
    "\n",
    "print(ind[bond_ind])\n",
    "print(ind[angle_ind])\n",
    "print(ind[dih_ind])\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}