{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# External imports \n",
    "import torch\n",
    "from torch.utils.data import DataLoader\n",
    "import random\n",
    "import numpy as np\n",
    "from tqdm import trange\n",
    "import matplotlib.pyplot as plt\n",
    "from IPython.display import display, clear_output\n",
    "\n",
    "# Internal imports\n",
    "import sys; sys.path.insert(0, '..')\n",
    "from src import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "N_DIMS = 1\n",
    "NUM_SAMPLES = 100000\n",
    "BS = 500\n",
    "NUM_EPOCHS = 400\n",
    "SEED = 10\n",
    "LR = 1e-4\n",
    "DROPOUT = 0.20\n",
    "DEVICE = 'cuda:1' if torch.cuda.is_available() else 'cpu'\n",
    "\n",
    "\n",
    "# Break by changing num datapoints, scales, means, or to 2D"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Setting seed for reproducibility\n",
    "random.seed(SEED)\n",
    "torch.manual_seed(SEED)\n",
    "np.random.seed(SEED)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define model\n",
    "# model_tre1 = RatioCritic1D(dim_input=N_DIMS, dim_output=1, dropout=DROPOUT, tre=True) # CHANGE 1D MODEL FOR TRE\n",
    "model_tre1 = RatioCriticNN1D(dim_input=N_DIMS, dim_output=1, dropout=DROPOUT)\n",
    "# model_tre2 = RatioCritic1D(dim_input=N_DIMS, dim_output=1, dropout=DROPOUT, tre=True)\n",
    "# model_tre3 = RatioCritic1D(dim_input=N_DIMS, dim_output=1, dropout=DROPOUT, tre=True)\n",
    "\n",
    "\n",
    "# model.apply(weights_init)\n",
    "\n",
    "# Define optimizer\n",
    "optim_tre1 = torch.optim.Adam(model_tre1.parameters(), lr=LR)\n",
    "# optim_tre2 = torch.optim.Adam(model_tre2.parameters(), lr=LR)\n",
    "# optim_tre3 = torch.optim.Adam(model_tre3.parameters(), lr=LR)\n",
    "\n",
    "# Define distributions\n",
    "p, q, m = get_dists_1d_tre(mu1=-2., mu2=2., mu3=0, scale_p=0.9, scale_q=1.1, alphas=[0.33, 0.66])\n",
    "\n",
    "# -5, 5, m_var=3.0\n",
    "# -10, 10, m_var=3.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "p_samples = p.sample([1000])\n",
    "q_samples = q.sample([1000])\n",
    "print(p_samples.shape)\n",
    "\n",
    "plt.hist(p_samples.numpy(), density=True, histtype='stepfilled')\n",
    "plt.hist(q_samples.numpy(), density=True, histtype='stepfilled')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "m"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.distributions.kl.kl_divergence(p, q)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define dataset & dataloader\n",
    "train_ds = DistDataset2Waymark(p, q, m, num_samples=NUM_SAMPLES)\n",
    "test_ds = DistDataset2Waymark(p, q, m, num_samples=NUM_SAMPLES) # Test dataset is only of size batch "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define dataloader\n",
    "train_dl = DataLoader(train_ds, batch_size=BS, shuffle=True)\n",
    "test_dl = DataLoader(test_ds, batch_size=BS, shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set up viz\n",
    "fig, [[ax1,ax2,ax3], [ax4, ax5, ax6], [ax7, ax8, ax9]] = plt.subplots(3, 3,figsize=(15,12))\n",
    "\n",
    "line, = ax1.plot([0,1],[0,1])\n",
    "x, y = np.random.random((2, 500))\n",
    "scat1 = ax2.scatter(x,y,label='True p/q',alpha=0.9,s=10.,c='b')\n",
    "scat2 = ax2.scatter(x,y,label='CoB p/q',alpha=0.9,s=10.,c='r')\n",
    "test_line, = ax3.plot([0,1],[0,1])\n",
    "\n",
    "ax1.set_xlabel(\"Iteration\")\n",
    "ax1.set_ylabel(\"Train Loss\")\n",
    "ax1.set_xlim([0,NUM_EPOCHS*NUM_SAMPLES//BS])\n",
    "ax1.set_ylim([0,10])\n",
    "\n",
    "ax2.set_xlabel(\"Samples\")\n",
    "ax2.set_ylabel(\"Log Ratio\")\n",
    "ax2.legend(loc='best')\n",
    "ax2.set_xlim([-6,10])\n",
    "ax2.set_ylim([-1500,5000])\n",
    "\n",
    "ax3.set_xlabel(\"Iteration\")\n",
    "ax3.set_ylabel(\"Test Loss\")\n",
    "ax3.set_xlim([0,NUM_EPOCHS*NUM_SAMPLES//BS])\n",
    "ax3.set_ylim([0,10])\n",
    "\n",
    "line2, = ax4.plot([0,1],[0,1])\n",
    "x, y = np.random.random((2, 500))\n",
    "scat3 = ax5.scatter(x,y,label='True p/q',alpha=0.9,s=10.,c='b')\n",
    "scat4 = ax5.scatter(x,y,label='TRE p/q',alpha=0.9,s=10.,c='r')\n",
    "test_line2, = ax6.plot([0,1],[0,1])\n",
    "\n",
    "ax4.set_xlabel(\"Iteration\")\n",
    "ax4.set_ylabel(\"Train Loss\")\n",
    "ax4.set_xlim([0,NUM_EPOCHS*NUM_SAMPLES//BS])\n",
    "ax4.set_ylim([0,10])\n",
    "\n",
    "ax5.set_xlabel(\"Samples\")\n",
    "ax5.set_ylabel(\"Log Ratio\")\n",
    "ax5.legend(loc='best')\n",
    "ax5.set_xlim([-6,10])\n",
    "ax5.set_ylim([-1500,5000])\n",
    "\n",
    "ax6.set_xlabel(\"Iteration\")\n",
    "ax6.set_ylabel(\"Test Loss\")\n",
    "ax6.set_xlim([0,NUM_EPOCHS*NUM_SAMPLES//BS])\n",
    "ax6.set_ylim([0,10])\n",
    "\n",
    "\n",
    "line3, = ax7.plot([0,1],[0,1])\n",
    "scat5 = ax8.scatter(x,y,label='True p/q',alpha=0.9,s=10.,c='b')\n",
    "scat6 = ax8.scatter(x,y,label='TRE p/q',alpha=0.9,s=10.,c='r')\n",
    "test_line3, = ax9.plot([0,1],[0,1])\n",
    "\n",
    "ax7.set_xlabel(\"Iteration\")\n",
    "ax7.set_ylabel(\"Train Loss\")\n",
    "ax7.set_xlim([0,NUM_EPOCHS*NUM_SAMPLES//BS])\n",
    "ax7.set_ylim([0,10])\n",
    "\n",
    "ax8.set_xlabel(\"Samples\")\n",
    "ax8.set_ylabel(\"Log Ratio\")\n",
    "ax8.legend(loc='best')\n",
    "ax8.set_xlim([-6,10])\n",
    "ax8.set_ylim([-1500,5000])\n",
    "\n",
    "ax9.set_xlabel(\"Iteration\")\n",
    "ax9.set_ylabel(\"Test Loss\")\n",
    "ax9.set_xlim([0,NUM_EPOCHS*NUM_SAMPLES//BS])\n",
    "ax9.set_ylim([0,10])\n",
    "\n",
    "plt.tight_layout()\n",
    "\n",
    "loss_store = []\n",
    "test_loss_store = []\n",
    "\n",
    "loss_store_cob_sep = []\n",
    "test_loss_store_cob_sep = []\n",
    "\n",
    "loss_store_tre = []\n",
    "test_loss_store_tre = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## CONFIRM q_list_test in validation/visualization in Akash's code\n",
    "\n",
    "model_tre1.train()\n",
    "# model_tre2.train()\n",
    "# model_tre3.train()\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    model_tre1 = model_tre1.to(DEVICE)\n",
    "#     model_tre2 = model_tre2.to(DEVICE)\n",
    "#     model_tre3 = model_tre3.to(DEVICE)\n",
    "    \n",
    "i = 0\n",
    "loss_crit_tre = torch.nn.functional.binary_cross_entropy_with_logits\n",
    "\n",
    "\n",
    "for epoch in trange(NUM_EPOCHS):\n",
    "    for p_batch, q_batch, m_batch1, m_batch2 in iter(train_dl):\n",
    "        model_tre1.train()\n",
    "        \n",
    "        i += 1\n",
    "        optim_tre1.zero_grad()\n",
    "        \n",
    "        # CUDA\n",
    "        if torch.cuda.is_available():\n",
    "            p_batch, q_batch, m_batch1, m_batch2 = p_batch.unsqueeze(1).to(DEVICE), q_batch.unsqueeze(1).to(DEVICE), m_batch1.unsqueeze(1).to(DEVICE), m_batch2.unsqueeze(1).to(DEVICE)\n",
    "        \n",
    "        # TRE\n",
    "        logP1 = model_tre1(p_batch).squeeze()\n",
    "        logQ1 = model_tre1(q_batch).squeeze()\n",
    "        \n",
    "        zero_labels = torch.empty(p_batch.shape[0], device=DEVICE).fill_(1)\n",
    "        one_labels = torch.empty(p_batch.shape[0], device=DEVICE).fill_(0)\n",
    "        \n",
    "        loss_tre1 = loss_crit_tre(logP1, zero_labels.clone()) + loss_crit_tre(logQ1, one_labels.clone())\n",
    "        loss_tre1.backward()\n",
    "        optim_tre1.step()\n",
    "        loss_store.append(loss_tre1.item())\n",
    "        \n",
    "\n",
    "        # Validation/Test\n",
    "        if i % 50 == 0:\n",
    "            model_tre1.eval()\n",
    "#             model_tre2.eval()\n",
    "#             model_tre3.eval()\n",
    "            \n",
    "            with torch.no_grad():\n",
    "                for p_batch, q_batch, m_batch1, m_batch2 in iter(test_dl):\n",
    "                    gt_log_ratio_p_q, _, true_kl_p_q = get_gt_ratio_kl(p, q, p_batch, calc_true_kl=True)\n",
    "\n",
    "                    if torch.cuda.is_available():\n",
    "                        p_batch, q_batch, m_batch1, m_batch2 = p_batch.unsqueeze(1).to(DEVICE), q_batch.unsqueeze(1).to(DEVICE), m_batch1.unsqueeze(1).to(DEVICE), m_batch2.unsqueeze(1).to(DEVICE)\n",
    "                    \n",
    "                    # TRE \n",
    "                    logP1 = model_tre1(p_batch).squeeze()\n",
    "                    logQ1 = model_tre1(q_batch).squeeze()\n",
    "\n",
    "                    zero_labels = torch.empty(p_batch.shape[0], device=DEVICE).fill_(0)\n",
    "                    one_labels = torch.empty(p_batch.shape[0], device=DEVICE).fill_(1)\n",
    "\n",
    "                    test_loss_tre = loss_crit_tre(logP1, zero_labels.clone()) + loss_crit_tre(logQ1, one_labels.clone())\n",
    "                    \n",
    "                    log_ratio_p_q_from_tre = logP1 \n",
    "                    \n",
    "                    # Visualize\n",
    "                    line.set_data(range(len(loss_store)), loss_store)\n",
    "                    ax1.set_xlim( 0, len(loss_store) )\n",
    "                    \n",
    "                    scat1.set_offsets(np.vstack([p_batch.cpu().squeeze(), gt_log_ratio_p_q.cpu().detach()]).T)\n",
    "                    scat2.set_offsets(np.vstack([p_batch.cpu().squeeze(), log_ratio_p_q_from_tre.cpu().detach()]).T)\n",
    "\n",
    "                    ax2.set_xlim( -5., 5. )\n",
    "                    ax2.set_ylim( -50, 50)\n",
    "            \n",
    "                    test_loss_store.append(test_loss_tre.item())\n",
    "                    test_line.set_data(range(len(test_loss_store)), test_loss_store)\n",
    "                    ax3.set_xlim( 0, len(test_loss_store) )\n",
    "                    \n",
    "                    clear_output(wait=True)\n",
    "                    display(fig)\n",
    "                    break\n",
    "\n",
    "            model_tre1.train()\n",
    "#             model_tre2.train()\n",
    "#             model_tre3.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "                    print('iteration: ',i)\n",
    "                    print('KLD: ', true_kl_p_q)\n",
    "                    print('CoB: ', kl_from_cob)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# fig, ax2 = plt.subplots(1, 1,figsize=(6,4))\n",
    "\n",
    "# x, y = np.random.random((2, 500))\n",
    "# scat1 = ax2.scatter(x,y,label='True Log p/q, KL = '+str(np.around(true_kl_p_q.item(),2)),alpha=0.9,s=10.,c='b')\n",
    "# scat2 = ax2.scatter(x,y,label='CoB Log p/q, KL = '+str(np.around(kl_from_cob.item(),2)),alpha=0.9,s=10.,c='r')\n",
    "\n",
    "# scat1.set_offsets(np.vstack([m_batch.cpu().squeeze(), log_ratio_p_q.cpu().detach()]).T)\n",
    "# scat2.set_offsets(np.vstack([m_batch.cpu().squeeze(), log_ratio_p_q_from_cob.cpu().detach()]).T)                    \n",
    "\n",
    "# ax2.set_xlabel(\"Samples\")\n",
    "# ax2.set_ylabel(\"Log Ratio\")\n",
    "# ax2.legend(loc='best')\n",
    "# ax2.set_xlim([-25,25])\n",
    "# ax2.set_ylim([-100000,100000])\n",
    "\n",
    "# plt.tight_layout()\n",
    "# plt.savefig('../plots/cob_mu20.png')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_gt_log_ratios(p, q, m1, m2, samples):\n",
    "    p_lp = p.log_prob(samples.cpu())\n",
    "    q_lp = q.log_prob(samples.cpu())\n",
    "    m1_lp = m1.log_prob(samples.cpu())\n",
    "    m2_lp = m2.log_prob(samples.cpu())\n",
    "    \n",
    "    return p_lp - q_lp, p_lp - m1_lp, m1_lp - m2_lp, m2_lp - q_lp # log p/q, log p/m1, log m1/m2, log m2/q\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#######\n",
    "with torch.no_grad():\n",
    "    model_tre1.eval()\n",
    "    model_tre2.eval()\n",
    "    model_tre3.eval()\n",
    "    for p_batch, q_batch, m_batch1, m_batch2 in iter(test_dl):\n",
    "        if torch.cuda.is_available():\n",
    "            p_batch, q_batch, m_batch1, m_batch2 = p_batch.unsqueeze(1).to(DEVICE), q_batch.unsqueeze(1).to(DEVICE), m_batch1.unsqueeze(1).to(DEVICE), m_batch2.unsqueeze(1).to(DEVICE)\n",
    "\n",
    "        p_batch = m_batch2\n",
    "\n",
    "        logP_M1 = model_tre1(p_batch).cpu()\n",
    "        logM1_M2 = model_tre1(p_batch).cpu()\n",
    "        logM2_Q = model_tre3(p_batch).cpu()\n",
    "\n",
    "        log_ratio_p_q_from_tre = logP_M1 + logM1_M2 + logM2_Q\n",
    "        log_ratio_p_m1_from_tre = logP_M1\n",
    "        log_ratio_m1_m2_from_tre = logM1_M2\n",
    "        log_ratio_m2_q_from_tre = logM2_Q  \n",
    "        \n",
    "        true_log_ratio_p_q, true_log_ratio_p_m1, true_log_ratio_m1_m2, true_log_ratio_m2_q = get_gt_log_ratios(p, q, m[0], m[1], p_batch)\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set up viz\n",
    "fig, [ax1,ax2,ax3, ax4] = plt.subplots(1, 4,figsize=(8,3))\n",
    "p_batch = p_batch.cpu()\n",
    "\n",
    "scat1 = ax1.scatter(p_batch,true_log_ratio_p_q,label='True log p/q',alpha=0.9,s=1.,c='b')\n",
    "scat2 = ax1.scatter(p_batch,log_ratio_p_q_from_tre,label='Binary log p/q',alpha=0.9,s=1.,c='r')\n",
    "\n",
    "scat3 = ax2.scatter(p_batch,true_log_ratio_p_m1,label='True log p/m1',alpha=0.9,s=1.,c='b')\n",
    "scat4 = ax2.scatter(p_batch,log_ratio_p_m1_from_tre,label='Binary log p/m1',alpha=0.9,s=1.,c='r')\n",
    "\n",
    "scat5 = ax3.scatter(p_batch,true_log_ratio_m1_m2,label='True log m1/m2',alpha=0.9,s=1.,c='b')\n",
    "scat6 = ax3.scatter(p_batch,log_ratio_m1_m2_from_tre,label='Binary log m1/m2',alpha=0.9,s=1.,c='r')\n",
    "\n",
    "scat7 = ax4.scatter(p_batch,true_log_ratio_m2_q,label='True log m2/q',alpha=0.9,s=1.,c='b')\n",
    "scat8 = ax4.scatter(p_batch,log_ratio_m2_q_from_tre,label='Binary log m2/q',alpha=0.9,s=1.,c='r')\n",
    "\n",
    "ylim = [-10, 10]\n",
    "xlim = [-5, 5]\n",
    "ax1.set_ylabel(\"Log Ratio\")\n",
    "ax1.legend(loc='best')\n",
    "ax1.set_xlim(xlim)\n",
    "ax1.set_ylim(ylim)\n",
    "\n",
    "ax2.set_ylabel(\"Log Ratio\")\n",
    "ax2.legend(loc='best')\n",
    "ax2.set_xlim(xlim)\n",
    "ax2.set_ylim(ylim)\n",
    "\n",
    "ax3.set_ylabel(\"Log Ratio\")\n",
    "ax3.legend(loc='best')\n",
    "ax3.set_xlim(xlim)\n",
    "ax3.set_ylim(ylim)\n",
    "\n",
    "ax4.set_ylabel(\"Log Ratio\")\n",
    "ax4.legend(loc='upper right')\n",
    "ax4.set_xlim(xlim)\n",
    "ax4.set_ylim(ylim)\n",
    "\n",
    "\n",
    "# ax1.get_xaxis().set_visible(False)\n",
    "# ax2.get_xaxis().set_visible(False)\n",
    "# ax3.get_xaxis().set_visible(False)\n",
    "# ax4.get_xaxis().set_visible(False)\n",
    "ax2.get_yaxis().set_visible(False)\n",
    "ax3.get_yaxis().set_visible(False)\n",
    "ax4.get_yaxis().set_visible(False)\n",
    "plt.tight_layout()\n",
    "plt.subplots_adjust(wspace=0, hspace=0.0)\n",
    "\n",
    "plt.savefig('cob_vs_tre_2waymark_from_m2.png')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "m"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "sr",
   "language": "python",
   "name": "sr"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
