{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# External imports \n",
    "import torch\n",
    "from torch.utils.data import DataLoader\n",
    "from torchvision import transforms as tf\n",
    "import random\n",
    "import numpy as np\n",
    "from tqdm import trange\n",
    "import matplotlib.pyplot as plt\n",
    "from IPython.display import display, clear_output\n",
    "import os\n",
    "\n",
    "# Internal imports\n",
    "import sys; sys.path.insert(0, '..')\n",
    "from src import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NUM_INPUT_CHANNELS = 1\n",
    "NUM_CHARS = 9\n",
    "NUM_CLASSES = 300\n",
    "N_DIMS=1\n",
    "DIM_OUTPUT = 10\n",
    "BS = 500\n",
    "NUM_EPOCHS = 2000\n",
    "SEED = 21\n",
    "LR = 1e-3\n",
    "DROPOUT = 0.20\n",
    "DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'\n",
    "# ALPHAS = [0, 0.1, 0.2, 0.4, 0.5, 0.75, 1.0]\n",
    "ALPHAS = get_dim_mix_masks((1, 84, 84))\n",
    "SAVE_INT = 20\n",
    "ANNEAL = False\n",
    "SAME_ENC = True\n",
    "SAVE_DIR = os.path.join(os.getcwd(), '../results/omniglot_ch{}_K{}_lr{}_anneal{}_bs{}_final{}_sameenc{}'.format(NUM_CHARS, DIM_OUTPUT-2, LR, ANNEAL, BS, NUM_CLASSES, SAME_ENC))\n",
    "\n",
    "# Creating save directory\n",
    "if not os.path.exists(SAVE_DIR):\n",
    "    os.makedirs(SAVE_DIR, exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ALPHAS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ALPHAS.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Setting seed for reproducibility\n",
    "random.seed(SEED)\n",
    "torch.manual_seed(SEED)\n",
    "np.random.seed(SEED)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define model\n",
    "model = RatioCriticImageBilinearCoB(dim_input=N_DIMS, dim_output=DIM_OUTPUT, num_input_channels=1, dropout=DROPOUT, alphas=ALPHAS, num_classes=NUM_CLASSES, same_enc=SAME_ENC)\n",
    "# model.apply(weights_init)\n",
    "\n",
    "# Define optimizer\n",
    "optim = torch.optim.Adam(model.parameters(), lr=LR)\n",
    "# optim = torch.optim.SGD(model.parameters(), lr=LR)\n",
    "\n",
    "# Define distributions\n",
    "# p, q, m = get_dists_from_mi(MI, N_DIMS)\n",
    "if ANNEAL: \n",
    "    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, 50000 // BS * NUM_EPOCHS, verbose=True)\n",
    "\n",
    "# # Loading weights\n",
    "# checkpoint = torch.load(os.path.join(SAVE_DIR, 'models/model_epoch260.pth'))\n",
    "# model.load_state\n",
    "\n",
    "# model.load_state_dict(checkpoint['model_state_dict'])\n",
    "# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n",
    "# START_EPOCH = checkpoint['epoch']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define transforms\n",
    "transforms = None # tf.Compose([tf.ToTensor()])\n",
    "\n",
    "# Define dataset & dataloader\n",
    "train_ds1 = SpatialOmniDataset('../data/omniglot/multiomniglot_trn_{}.npz'.format(NUM_CHARS), transforms=transforms)\n",
    "train_ds2 = SpatialOmniDataset('../data/omniglot/multiomniglot_trn_{}.npz'.format(NUM_CHARS), transforms=transforms)\n",
    "train_ds = PairedSpatialOmniDataset(train_ds1, train_ds2)\n",
    "\n",
    "val_ds1 = SpatialOmniDataset('../data/omniglot/multiomniglot_val_{}.npz'.format(NUM_CHARS), transforms=transforms)\n",
    "val_ds2 = SpatialOmniDataset('../data/omniglot/multiomniglot_val_{}.npz'.format(NUM_CHARS), transforms=transforms)\n",
    "val_ds = PairedSpatialOmniDataset(val_ds1, val_ds2)\n",
    "\n",
    "test_ds1 = SpatialOmniDataset('../data/omniglot/multiomniglot_tst_{}.npz'.format(NUM_CHARS), transforms=transforms)\n",
    "test_ds2 = SpatialOmniDataset('../data/omniglot/multiomniglot_tst_{}.npz'.format(NUM_CHARS), transforms=transforms)\n",
    "test_ds = PairedSpatialOmniDataset(test_ds1, test_ds2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NUM_SAMPLES = len(train_ds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define dataloader\n",
    "train_dl = DataLoader(train_ds, batch_size=BS, shuffle=True)\n",
    "test_dl = DataLoader(test_ds, batch_size=BS*2, shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Visualize joint, m, marginal images\n",
    "with torch.no_grad():\n",
    "    for (u, v, l1), (_, v_q, l2) in iter(train_dl):\n",
    "#         if torch.cuda.is_available():\n",
    "#             u, v = u.to(DEVICE), v.to(DEVICE)\n",
    "        \n",
    "        i = 11\n",
    "        plt.figure(1, figsize=(20,20))\n",
    "        plt.subplot(10,2,1)\n",
    "        plt.title('joint - u')\n",
    "        plt.imshow(u[i].cpu().numpy().transpose(1,2,0))\n",
    "\n",
    "        plt.subplot(10,2,2)\n",
    "        plt.title('joint - v')\n",
    "        plt.imshow(v[i].cpu().numpy().transpose(1,2,0))\n",
    "        \n",
    "        plt.subplot(10,2,3)\n",
    "        plt.title('marginal - u')\n",
    "        plt.imshow(u[i].cpu().numpy().transpose(1,2,0))\n",
    "        \n",
    "        plt.subplot(10,2,4)\n",
    "        plt.title('marginal - v')\n",
    "        plt.imshow(v_q[i].cpu().numpy().transpose(1,2,0))\n",
    "        \n",
    "        v_m1, v_m2, v_m3, v_m4, v_m5, v_m6, v_m7, v_m8 = model.test_forward(u, v, u, v_q)\n",
    "        \n",
    "        plt.subplot(10,2,5)\n",
    "        plt.title('m1 - u')\n",
    "        plt.imshow(u[i].cpu().numpy().transpose(1,2,0))\n",
    "        \n",
    "        plt.subplot(10,2,6)\n",
    "        plt.title('m1 - v')\n",
    "        plt.imshow(v_m1[i].cpu().numpy().transpose(1,2,0))\n",
    "        \n",
    "        plt.subplot(10,2,7)\n",
    "        plt.title('m2 - u')\n",
    "        plt.imshow(u[i].cpu().numpy().transpose(1,2,0))\n",
    "        \n",
    "        plt.subplot(10,2,8)\n",
    "        plt.title('m2 - v')\n",
    "        plt.imshow(v_m2[i].cpu().numpy().transpose(1,2,0))\n",
    "        \n",
    "        plt.subplot(10,2,9)\n",
    "        plt.title('m3 - u')\n",
    "        plt.imshow(u[i].cpu().numpy().transpose(1,2,0))\n",
    "        \n",
    "        plt.subplot(10,2,10)\n",
    "        plt.title('m3 - v')\n",
    "        plt.imshow(v_m3[i].cpu().numpy().transpose(1,2,0))\n",
    "        \n",
    "        plt.subplot(10,2,11)\n",
    "        plt.title('m4 - u')\n",
    "        plt.imshow(u[i].cpu().numpy().transpose(1,2,0))\n",
    "        \n",
    "        plt.subplot(10,2,12)\n",
    "        plt.title('m4 - v')\n",
    "        plt.imshow(v_m4[i].cpu().numpy().transpose(1,2,0))\n",
    "        \n",
    "        plt.subplot(10,2,13)\n",
    "        plt.title('m5 - u')\n",
    "        plt.imshow(u[i].cpu().numpy().transpose(1,2,0))\n",
    "        \n",
    "        plt.subplot(10,2,14)\n",
    "        plt.title('m5 - v')\n",
    "        plt.imshow(v_m5[i].cpu().numpy().transpose(1,2,0))\n",
    "        \n",
    "        plt.subplot(10,2,15)\n",
    "        plt.title('m6 - u')\n",
    "        plt.imshow(u[i].cpu().numpy().transpose(1,2,0))\n",
    "        \n",
    "        plt.subplot(10,2,16)\n",
    "        plt.title('m6 - v')\n",
    "        plt.imshow(v_m6[i].cpu().numpy().transpose(1,2,0))\n",
    "        \n",
    "        plt.subplot(10,2,17)\n",
    "        plt.title('m7 - u')\n",
    "        plt.imshow(u[i].cpu().numpy().transpose(1,2,0))\n",
    "        \n",
    "        plt.subplot(10,2,18)\n",
    "        plt.title('m7 - v')\n",
    "        plt.imshow(v_m7[i].cpu().numpy().transpose(1,2,0))\n",
    "        \n",
    "        plt.subplot(10,2,19)\n",
    "        plt.title('m8 - u')\n",
    "        plt.imshow(u[i].cpu().numpy().transpose(1,2,0))\n",
    "        \n",
    "        plt.subplot(10,2,20)\n",
    "        plt.title('m8 - v')\n",
    "        plt.imshow(v_m8[i].cpu().numpy().transpose(1,2,0))\n",
    "        \n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set up viz\n",
    "fig, [[ax1,ax2,ax3,ax4],[ax5,ax6,ax7,ax8], [ax9,ax10,ax11,ax12], [ax13,ax14,ax15,ax16]] = plt.subplots(4, 4,figsize=(20,15))\n",
    "\n",
    "line, = ax1.plot([0,1],[0,1])\n",
    "test_line, = ax4.plot([0,1],[0,1],label='Test Loss')\n",
    "test_line1, = ax4.plot([0,1],[0,1],label='Train Loss')\n",
    "\n",
    "test_line_p, = ax5.plot([0,1],[0,1],label='Test Loss P')\n",
    "test_line_q, = ax6.plot([0,1],[0,1],label='Test Loss Q')\n",
    "test_line_m1, = ax9.plot([0,1],[0,1],label='Test Loss M1')\n",
    "test_line_m2, = ax10.plot([0,1],[0,1],label='Test Loss M2')\n",
    "test_line_m3, = ax11.plot([0,1],[0,1],label='Test Loss M3')\n",
    "test_line_m4, = ax12.plot([0,1],[0,1],label='Test Loss M3')\n",
    "test_line_m5, = ax13.plot([0,1],[0,1],label='Test Loss M3')\n",
    "test_line_m6, = ax14.plot([0,1],[0,1],label='Test Loss M3')\n",
    "test_line_m7, = ax15.plot([0,1],[0,1],label='Test Loss M3')\n",
    "test_line_m8, = ax16.plot([0,1],[0,1],label='Test Loss M3')\n",
    "\n",
    "\n",
    "kld_line, = ax3.plot([0],[0],label='GT KLD: '+str(0))\n",
    "x, y = np.random.random((2, 500))\n",
    "\n",
    "scat = ax2.scatter(x,y,label='GT LR vs CoB ',alpha=0.9,s=10.,c='r')\n",
    "scat8 = ax8.scatter(x,y,label='GT LR vs CoB ',alpha=0.9,s=10.,c='r')\n",
    "\n",
    "ax1.set_xlabel(\"Iteration\")\n",
    "ax1.set_ylabel(\"Train Loss\")\n",
    "ax1.set_xlim([0,NUM_EPOCHS*(NUM_SAMPLES//BS)])\n",
    "ax1.set_ylim([0,10])\n",
    "\n",
    "# ax2.set_xlabel(\"Data Point\")\n",
    "ax2.set_ylabel(\"Log Ratio CoB (p samples)\")\n",
    "ax2.set_xlabel(\"Data Point #\")\n",
    "ax2.legend(loc='best')\n",
    "\n",
    "ax3.set_ylabel(\"KLD\")\n",
    "ax3.legend(loc='best')\n",
    "\n",
    "ax4.set_xlabel(\"Iteration\")\n",
    "ax4.set_ylabel(\"Test Loss\")\n",
    "ax4.set_xlim([0,NUM_EPOCHS*(NUM_SAMPLES//BS)])\n",
    "ax4.set_ylim([0,10])\n",
    "ax4.legend(loc='best')\n",
    "\n",
    "ax5.set_xlabel(\"Iteration\")\n",
    "ax5.set_ylabel(\"Test Loss p\")\n",
    "ax5.set_xlim([0,NUM_EPOCHS*(NUM_SAMPLES//BS)])\n",
    "ax5.set_ylim([0,5])\n",
    "ax5.legend(loc='best')\n",
    "\n",
    "ax6.set_xlabel(\"Iteration\")\n",
    "ax6.set_ylabel(\"Test Loss q\")\n",
    "ax6.set_xlim([0,NUM_EPOCHS*(NUM_SAMPLES//BS)])\n",
    "ax6.set_ylim([0,5])\n",
    "ax6.legend(loc='best')\n",
    "\n",
    "ax9.set_xlabel(\"Iteration\")\n",
    "ax9.set_ylabel(\"Test Loss m1\")\n",
    "ax9.set_xlim([0,NUM_EPOCHS*(NUM_SAMPLES//BS)])\n",
    "ax9.set_ylim([0,5])\n",
    "ax9.legend(loc='best')\n",
    "\n",
    "ax10.set_xlabel(\"Iteration\")\n",
    "ax10.set_ylabel(\"Test Loss m2\")\n",
    "ax10.set_xlim([0,NUM_EPOCHS*(NUM_SAMPLES//BS)])\n",
    "ax10.set_ylim([0,5])\n",
    "ax10.legend(loc='best')\n",
    "\n",
    "ax11.set_xlabel(\"Iteration\")\n",
    "ax11.set_ylabel(\"Test Loss m3\")\n",
    "ax11.set_xlim([0,NUM_EPOCHS*(NUM_SAMPLES//BS)])\n",
    "ax11.set_ylim([0,5])\n",
    "ax11.legend(loc='best')\n",
    "\n",
    "ax12.set_xlabel(\"Iteration\")\n",
    "ax12.set_ylabel(\"Test Loss m4\")\n",
    "ax12.set_xlim([0,NUM_EPOCHS*(NUM_SAMPLES//BS)])\n",
    "ax12.set_ylim([0,5])\n",
    "ax12.legend(loc='best')\n",
    "\n",
    "ax13.set_xlabel(\"Iteration\")\n",
    "ax13.set_ylabel(\"Test Loss m5\")\n",
    "ax13.set_xlim([0,NUM_EPOCHS*(NUM_SAMPLES//BS)])\n",
    "ax13.set_ylim([0,5])\n",
    "ax13.legend(loc='best')\n",
    "\n",
    "ax14.set_xlabel(\"Iteration\")\n",
    "ax14.set_ylabel(\"Test Loss m6\")\n",
    "ax14.set_xlim([0,NUM_EPOCHS*(NUM_SAMPLES//BS)])\n",
    "ax14.set_ylim([0,5])\n",
    "ax14.legend(loc='best')\n",
    "\n",
    "ax15.set_xlabel(\"Iteration\")\n",
    "ax15.set_ylabel(\"Test Loss m7\")\n",
    "ax15.set_xlim([0,NUM_EPOCHS*(NUM_SAMPLES//BS)])\n",
    "ax15.set_ylim([0,5])\n",
    "ax15.legend(loc='best')\n",
    "\n",
    "ax16.set_xlabel(\"Iteration\")\n",
    "ax16.set_ylabel(\"Test Loss m8\")\n",
    "ax16.set_xlim([0,NUM_EPOCHS*(NUM_SAMPLES//BS)])\n",
    "ax16.set_ylim([0,5])\n",
    "ax16.legend(loc='best')\n",
    "\n",
    "ax8.set_ylabel(\"Log Ratio CoB (q samples)\")\n",
    "ax8.set_xlabel(\"Data Point #\")\n",
    "ax8.legend(loc='best')\n",
    "\n",
    "\n",
    "\n",
    "plt.tight_layout()\n",
    "\n",
    "loss_store = []\n",
    "kld_store = []\n",
    "test_loss_store = []\n",
    "test_loss_store1 = []\n",
    "test_loss_store_p = []\n",
    "test_loss_store_q = []\n",
    "test_loss_store_m1 = []\n",
    "test_loss_store_m2 = []\n",
    "test_loss_store_m3 = []\n",
    "test_loss_store_m4 = []\n",
    "test_loss_store_m5 = []\n",
    "test_loss_store_m6 = []\n",
    "test_loss_store_m7 = []\n",
    "test_loss_store_m8 = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.train()\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    model = model.to(DEVICE)\n",
    "    \n",
    "i = 0\n",
    "loss_crit = torch.nn.CrossEntropyLoss()\n",
    "best_test_loss = float('inf')\n",
    "for epoch in trange(0, NUM_EPOCHS):\n",
    "    for (u, v, l1), (_, v_q, l2) in iter(train_dl):\n",
    "        model.train()\n",
    "        i += 1\n",
    "        \n",
    "        optim.zero_grad()\n",
    "        \n",
    "        # CUDA\n",
    "        if torch.cuda.is_available():\n",
    "            u, v, v_q = u.to(DEVICE), v.to(DEVICE), v_q.to(DEVICE)\n",
    "            \n",
    "        logP, logQ, *logMs = model(u, v, None, v_q)\n",
    "        \n",
    "        p_label = torch.empty(u.shape[0], dtype=torch.long, device=DEVICE).fill_(0)\n",
    "        q_label = torch.empty(u.shape[0], dtype=torch.long, device=DEVICE).fill_(1)\n",
    "        m_labels = [torch.empty(u.shape[0], dtype=torch.long, device=DEVICE).fill_(i) for i in range(2, DIM_OUTPUT)]\n",
    "        \n",
    "        loss = loss_crit(logP, p_label) + loss_crit(logQ, q_label) \n",
    "        loss += sum(list(map(loss_crit, logMs, m_labels)))\n",
    "        loss.backward()\n",
    "        optim.step()\n",
    "        loss_store.append(loss.item())\n",
    "        if ANNEAL:\n",
    "            scheduler.step()\n",
    "        \n",
    "        # Validation/Test\n",
    "        if i % 100 == 0:\n",
    "            model.eval()\n",
    "            \n",
    "            with torch.no_grad():\n",
    "                for (u, v, l1), (_, v_q, l2) in iter(test_dl):\n",
    "#                     log_ratio_p_q, _ = get_gt_ratio_kl(p, q, m_batch)\n",
    "#                     _, kl_from_p_q = get_gt_ratio_kl(p, q, p_batch)\n",
    "\n",
    "                    if torch.cuda.is_available():\n",
    "                        u, v, v_q = u.to(DEVICE), v.to(DEVICE), v_q.to(DEVICE)\n",
    "                        \n",
    "                    logP, logQ, *logMs = model(u, v, None, v_q)\n",
    "\n",
    "                    log_ratio_p_q_from_cob_p = logP[:, 0] - logP[:, 1]\n",
    "                    kl_from_cob = torch.mean(log_ratio_p_q_from_cob_p)\n",
    "                    \n",
    "                    log_ratio_p_q_from_cob_q = logQ[:, 0] - logQ[:, 1]\n",
    "\n",
    "                    p_label = torch.empty(u.shape[0], dtype=torch.long, device=DEVICE).fill_(0)\n",
    "                    q_label = torch.empty(u.shape[0], dtype=torch.long, device=DEVICE).fill_(1)\n",
    "                    m_labels = [torch.empty(u.shape[0], dtype=torch.long, device=DEVICE).fill_(i) for i in range(2, DIM_OUTPUT)]\n",
    "                    \n",
    "                    test_loss = loss_crit(logP, p_label) + loss_crit(logQ, q_label) \n",
    "                    test_loss += sum(list(map(loss_crit, logMs, m_labels)))\n",
    "\n",
    "                    # Visualize\n",
    "                    # First plot of loss\n",
    "                    line.set_data(range(len(loss_store)), loss_store)\n",
    "                    ax1.set_xlim( 0, len(loss_store))\n",
    "\n",
    "                    kld_store.append(kl_from_cob.cpu().detach().numpy())\n",
    "                    test_loss_store.append(test_loss.item())\n",
    "                    test_loss_store1.append(loss.item())\n",
    "                    \n",
    "                    test_loss_store_p.append(loss_crit(logP, p_label).item())\n",
    "                    test_loss_store_q.append(loss_crit(logQ, q_label).item())\n",
    "                    test_loss_store_m1.append(loss_crit(logMs[0], m_labels[0]).item())\n",
    "                    test_loss_store_m2.append(loss_crit(logMs[1], m_labels[1]).item())\n",
    "                    test_loss_store_m3.append(loss_crit(logMs[2], m_labels[2]).item())\n",
    "                    test_loss_store_m4.append(loss_crit(logMs[3], m_labels[3]).item())\n",
    "                    test_loss_store_m5.append(loss_crit(logMs[4], m_labels[4]).item())\n",
    "                    test_loss_store_m6.append(loss_crit(logMs[5], m_labels[5]).item())\n",
    "                    test_loss_store_m7.append(loss_crit(logMs[6], m_labels[6]).item())\n",
    "                    test_loss_store_m8.append(loss_crit(logMs[7], m_labels[7]).item())\n",
    "                    \n",
    "\n",
    "                    scat.set_offsets(np.vstack([range(len(log_ratio_p_q_from_cob_p)),log_ratio_p_q_from_cob_p.cpu().detach().numpy()]).T)\n",
    "                    ax2.set_xlim(0, len(log_ratio_p_q_from_cob_p))\n",
    "                    ax2.set_ylim(-5, 20)\n",
    "                    \n",
    "                    scat8.set_offsets(np.vstack([range(len(log_ratio_p_q_from_cob_q)),log_ratio_p_q_from_cob_q.cpu().detach().numpy()]).T)\n",
    "                    ax8.set_xlim(0, len(log_ratio_p_q_from_cob_q))\n",
    "                    ax8.set_ylim(-5, 20)\n",
    "\n",
    "                    kld_line.set_data(range(len(kld_store)),kld_store)\n",
    "                    ax3.set_xlim( 0, len(kld_store))\n",
    "                    ax3.set_ylim( min(kld_store), 50 )\n",
    "\n",
    "                    test_line.set_data(range(len(test_loss_store)), test_loss_store)\n",
    "                    test_line1.set_data(range(len(test_loss_store1)), test_loss_store1)\n",
    "                    ax4.set_xlim( 0, len(test_loss_store) )\n",
    "                    \n",
    "\n",
    "                    test_line_p.set_data(range(len(test_loss_store_p)), test_loss_store_p)\n",
    "                    ax5.set_xlim( 0, len(test_loss_store_p) )\n",
    "                    \n",
    "                    test_line_q.set_data(range(len(test_loss_store_q)), test_loss_store_q)\n",
    "                    ax6.set_xlim( 0, len(test_loss_store_q) )\n",
    "                    \n",
    "                    test_line_m1.set_data(range(len(test_loss_store_m1)), test_loss_store_m1)\n",
    "                    ax9.set_xlim( 0, len(test_loss_store_m1) )\n",
    "                    test_line_m2.set_data(range(len(test_loss_store_m2)), test_loss_store_m2)\n",
    "                    ax10.set_xlim( 0, len(test_loss_store_m2) )\n",
    "                    test_line_m3.set_data(range(len(test_loss_store_m3)), test_loss_store_m3)\n",
    "                    ax11.set_xlim( 0, len(test_loss_store_m3) )\n",
    "                    test_line_m4.set_data(range(len(test_loss_store_m4)), test_loss_store_m4)\n",
    "                    ax12.set_xlim( 0, len(test_loss_store_m4) )\n",
    "                    test_line_m5.set_data(range(len(test_loss_store_m5)), test_loss_store_m5)\n",
    "                    ax13.set_xlim( 0, len(test_loss_store_m5) )\n",
    "                    test_line_m6.set_data(range(len(test_loss_store_m6)), test_loss_store_m6)\n",
    "                    ax14.set_xlim( 0, len(test_loss_store_m6) )\n",
    "                    test_line_m7.set_data(range(len(test_loss_store_m7)), test_loss_store_m7)\n",
    "                    ax15.set_xlim( 0, len(test_loss_store_m7) )\n",
    "                    test_line_m8.set_data(range(len(test_loss_store_m8)), test_loss_store_m8)\n",
    "                    ax16.set_xlim( 0, len(test_loss_store_m8) )\n",
    "                    \n",
    "                    clear_output(wait=True)\n",
    "                    display(fig)\n",
    "                    print('Estimated KL: {}'.format(kl_from_cob.cpu().detach()))\n",
    "                    break\n",
    "\n",
    "            model.train()\n",
    "            \n",
    "    if epoch % SAVE_INT == 0:\n",
    "        print('Saving checkpoint')\n",
    "        model_save_dir = os.path.join(SAVE_DIR, 'models')\n",
    "        if not os.path.exists(model_save_dir):\n",
    "            os.makedirs(model_save_dir, exist_ok=True)\n",
    "        \n",
    "        torch.save({\n",
    "            'epoch': epoch,\n",
    "            'model_state_dict': model.state_dict(),\n",
    "            'optimizer_state_dict': optim.state_dict(),\n",
    "            'kl_est': kl_from_cob.cpu().item(),\n",
    "        }, os.path.join(model_save_dir, 'model_epoch{}.pth'.format(epoch)))\n",
    "        \n",
    "        figs_save_dir = os.path.join(SAVE_DIR, 'figs')\n",
    "        if not os.path.exists(figs_save_dir):\n",
    "            os.makedirs(figs_save_dir, exist_ok=True)\n",
    "            \n",
    "        fig.savefig(os.path.join(figs_save_dir, 'figs_epoch{}.png'.format(epoch)))\n",
    "        \n",
    "    if test_loss_store[-1] < best_test_loss:\n",
    "        best_test_loss = test_loss_store[-1]\n",
    "        print('Saving best checkpoint')\n",
    "        torch.save({\n",
    "            'epoch': epoch,\n",
    "            'model_state_dict': model.state_dict(),\n",
    "            'optimizer_state_dict': optim.state_dict(),\n",
    "            'kl_est': kl_from_cob.cpu().item(),\n",
    "        }, os.path.join(model_save_dir, 'model_best.pth'))\n",
    "    \n",
    "\n",
    "        "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# With p and q mixing, not u and v (fixed dimension-wise mixing & sampling for m)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "torch1.8",
   "language": "python",
   "name": "torch1.8"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
