{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# External imports \n",
    "import torch\n",
    "from torch.utils.data import DataLoader\n",
    "from torchvision import transforms as tf\n",
    "import random\n",
    "import numpy as np\n",
    "from tqdm import trange\n",
    "import matplotlib.pyplot as plt\n",
    "from IPython.display import display, clear_output\n",
    "\n",
    "# Internal imports\n",
    "import sys; sys.path.insert(0, '..')\n",
    "from src import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "n_img 3\n"
     ]
    }
   ],
   "source": [
    "# Hyperparameters for backbone/encoder\n",
    "NUM_INPUT_CHANNELS = 1\n",
    "NUM_CHARS = 9\n",
    "NUM_CLASSES = 300\n",
    "N_DIMS=1\n",
    "DIM_OUTPUT = 10\n",
    "DROPOUT = 0.20\n",
    "# Hyperparameters for linear classifier training\n",
    "BS = 500\n",
    "NUM_EPOCHS = 100\n",
    "SEED = 21\n",
    "LR = 1e-3\n",
    "DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'\n",
    "MAX_NUM_CLASSES = 55\n",
    "ALPHAS = get_dim_mix_masks((1, 84, 84))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Setting seed for reproducibility\n",
    "random.seed(SEED)\n",
    "torch.manual_seed(SEED)\n",
    "np.random.seed(SEED)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Alphas set to tensor([[[[[0., 0., 0.,  ..., 1., 1., 1.],\n",
      "           [0., 0., 0.,  ..., 1., 1., 1.],\n",
      "           [0., 0., 0.,  ..., 1., 1., 1.],\n",
      "           ...,\n",
      "           [1., 1., 1.,  ..., 1., 1., 1.],\n",
      "           [1., 1., 1.,  ..., 1., 1., 1.],\n",
      "           [1., 1., 1.,  ..., 1., 1., 1.]]],\n",
      "\n",
      "\n",
      "         [[[1., 1., 1.,  ..., 0., 0., 0.],\n",
      "           [1., 1., 1.,  ..., 0., 0., 0.],\n",
      "           [1., 1., 1.,  ..., 0., 0., 0.],\n",
      "           ...,\n",
      "           [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "           [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "           [0., 0., 0.,  ..., 0., 0., 0.]]]],\n",
      "\n",
      "\n",
      "\n",
      "        [[[[0., 0., 0.,  ..., 1., 1., 1.],\n",
      "           [0., 0., 0.,  ..., 1., 1., 1.],\n",
      "           [0., 0., 0.,  ..., 1., 1., 1.],\n",
      "           ...,\n",
      "           [1., 1., 1.,  ..., 1., 1., 1.],\n",
      "           [1., 1., 1.,  ..., 1., 1., 1.],\n",
      "           [1., 1., 1.,  ..., 1., 1., 1.]]],\n",
      "\n",
      "\n",
      "         [[[1., 1., 1.,  ..., 0., 0., 0.],\n",
      "           [1., 1., 1.,  ..., 0., 0., 0.],\n",
      "           [1., 1., 1.,  ..., 0., 0., 0.],\n",
      "           ...,\n",
      "           [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "           [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "           [0., 0., 0.,  ..., 0., 0., 0.]]]],\n",
      "\n",
      "\n",
      "\n",
      "        [[[[0., 0., 0.,  ..., 0., 0., 0.],\n",
      "           [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "           [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "           ...,\n",
      "           [1., 1., 1.,  ..., 1., 1., 1.],\n",
      "           [1., 1., 1.,  ..., 1., 1., 1.],\n",
      "           [1., 1., 1.,  ..., 1., 1., 1.]]],\n",
      "\n",
      "\n",
      "         [[[1., 1., 1.,  ..., 1., 1., 1.],\n",
      "           [1., 1., 1.,  ..., 1., 1., 1.],\n",
      "           [1., 1., 1.,  ..., 1., 1., 1.],\n",
      "           ...,\n",
      "           [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "           [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "           [0., 0., 0.,  ..., 0., 0., 0.]]]],\n",
      "\n",
      "\n",
      "\n",
      "        ...,\n",
      "\n",
      "\n",
      "\n",
      "        [[[[0., 0., 0.,  ..., 0., 0., 0.],\n",
      "           [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "           [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "           ...,\n",
      "           [1., 1., 1.,  ..., 1., 1., 1.],\n",
      "           [1., 1., 1.,  ..., 1., 1., 1.],\n",
      "           [1., 1., 1.,  ..., 1., 1., 1.]]],\n",
      "\n",
      "\n",
      "         [[[1., 1., 1.,  ..., 1., 1., 1.],\n",
      "           [1., 1., 1.,  ..., 1., 1., 1.],\n",
      "           [1., 1., 1.,  ..., 1., 1., 1.],\n",
      "           ...,\n",
      "           [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "           [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "           [0., 0., 0.,  ..., 0., 0., 0.]]]],\n",
      "\n",
      "\n",
      "\n",
      "        [[[[0., 0., 0.,  ..., 0., 0., 0.],\n",
      "           [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "           [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "           ...,\n",
      "           [0., 0., 0.,  ..., 1., 1., 1.],\n",
      "           [0., 0., 0.,  ..., 1., 1., 1.],\n",
      "           [0., 0., 0.,  ..., 1., 1., 1.]]],\n",
      "\n",
      "\n",
      "         [[[1., 1., 1.,  ..., 1., 1., 1.],\n",
      "           [1., 1., 1.,  ..., 1., 1., 1.],\n",
      "           [1., 1., 1.,  ..., 1., 1., 1.],\n",
      "           ...,\n",
      "           [1., 1., 1.,  ..., 0., 0., 0.],\n",
      "           [1., 1., 1.,  ..., 0., 0., 0.],\n",
      "           [1., 1., 1.,  ..., 0., 0., 0.]]]],\n",
      "\n",
      "\n",
      "\n",
      "        [[[[0., 0., 0.,  ..., 0., 0., 0.],\n",
      "           [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "           [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "           ...,\n",
      "           [0., 0., 0.,  ..., 1., 1., 1.],\n",
      "           [0., 0., 0.,  ..., 1., 1., 1.],\n",
      "           [0., 0., 0.,  ..., 1., 1., 1.]]],\n",
      "\n",
      "\n",
      "         [[[1., 1., 1.,  ..., 1., 1., 1.],\n",
      "           [1., 1., 1.,  ..., 1., 1., 1.],\n",
      "           [1., 1., 1.,  ..., 1., 1., 1.],\n",
      "           ...,\n",
      "           [1., 1., 1.,  ..., 0., 0., 0.],\n",
      "           [1., 1., 1.,  ..., 0., 0., 0.],\n",
      "           [1., 1., 1.,  ..., 0., 0., 0.]]]]])\n"
     ]
    },
    {
     "ename": "FileNotFoundError",
     "evalue": "[Errno 2] No such file or directory: '../results/omniglot_ch9_lr0.001_annealFalse_bs500_final300_sameencTrue/models/model_best.pth'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mFileNotFoundError\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-5-8398e3f7d276>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      3\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      4\u001b[0m \u001b[0;31m# Load model weights\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0mcheckpoint\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'../results/omniglot_ch9_lr0.001_annealFalse_bs500_final300_sameencTrue/models/model_best.pth'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      6\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload_state_dict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcheckpoint\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'model_state_dict'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      7\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcheckpoint\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'kl_est'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/disk_c/han/ananconda3/envs/torch1.8/lib/python3.8/site-packages/torch/serialization.py\u001b[0m in \u001b[0;36mload\u001b[0;34m(f, map_location, pickle_module, **pickle_load_args)\u001b[0m\n\u001b[1;32m    577\u001b[0m         \u001b[0mpickle_load_args\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'encoding'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'utf-8'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    578\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 579\u001b[0;31m     \u001b[0;32mwith\u001b[0m \u001b[0m_open_file_like\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'rb'\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mopened_file\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    580\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0m_is_zipfile\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mopened_file\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    581\u001b[0m             \u001b[0;31m# The zipfile reader is going to advance the current file position.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/disk_c/han/ananconda3/envs/torch1.8/lib/python3.8/site-packages/torch/serialization.py\u001b[0m in \u001b[0;36m_open_file_like\u001b[0;34m(name_or_buffer, mode)\u001b[0m\n\u001b[1;32m    228\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_open_file_like\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname_or_buffer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    229\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0m_is_path\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname_or_buffer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 230\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0m_open_file\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname_or_buffer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    231\u001b[0m     \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    232\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0;34m'w'\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/disk_c/han/ananconda3/envs/torch1.8/lib/python3.8/site-packages/torch/serialization.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, name, mode)\u001b[0m\n\u001b[1;32m    209\u001b[0m \u001b[0;32mclass\u001b[0m \u001b[0m_open_file\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_opener\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    210\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 211\u001b[0;31m         \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_open_file\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    212\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    213\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m__exit__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: '../results/omniglot_ch9_lr0.001_annealFalse_bs500_final300_sameencTrue/models/model_best.pth'"
     ]
    }
   ],
   "source": [
    "# Define model\n",
    "model = RatioCriticImageBilinearCoB(dim_input=N_DIMS, dim_output=DIM_OUTPUT, num_input_channels=1, dropout=DROPOUT, alphas=ALPHAS, num_classes=NUM_CLASSES)\n",
    "\n",
    "# Load model weights\n",
    "checkpoint = torch.load('../results/omniglot_ch9_K8_lr0.001_annealFalse_bs500_final300_sameencTrue/models/model_best.pth')\n",
    "model.load_state_dict(checkpoint['model_state_dict'])\n",
    "print(checkpoint['kl_est'])\n",
    "# Define backbone from model\n",
    "backbone = model.g"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define transforms\n",
    "transforms = None # tf.Compose([tf.ToTensor()])\n",
    "\n",
    "# Define dataset & dataloader\n",
    "train_ds = SpatialOmniDataset('../data/omniglot/multiomniglot_trn_{}.npz'.format(NUM_CHARS), transforms=transforms)\n",
    "\n",
    "val_ds = SpatialOmniDataset('../data/omniglot/multiomniglot_val_{}.npz'.format(NUM_CHARS), transforms=transforms)\n",
    "\n",
    "test_ds = SpatialOmniDataset('../data/omniglot/multiomniglot_tst_{}.npz'.format(NUM_CHARS), transforms=transforms)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NUM_SAMPLES = len(train_ds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define dataloader\n",
    "train_dl = DataLoader(train_ds, batch_size=BS, shuffle=True)\n",
    "val_dl = DataLoader(val_ds, batch_size=BS, shuffle=False)\n",
    "test_dl = DataLoader(test_ds, batch_size=BS*2, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# For masking out logits / over-parametrization\n",
    "labels = train_ds.labels\n",
    "label_shape = labels.shape\n",
    "num_classes_per_problem = np.array([len(np.unique(labels[:, i])) for i in range(label_shape[1])])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_classes_per_problem"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define linear classifier\n",
    "clfs = [torch.nn.Linear(model.g.fc_out_features, out_dim) for out_dim in num_classes_per_problem]\n",
    "\n",
    "# Define optimizer\n",
    "# optim = torch.optim.LBFGS(clf.parameters(), lr=LR)\n",
    "optims = [torch.optim.Adam(clfs[j].parameters(), lr=LR) for j in range(len(clfs))]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set up viz\n",
    "fig, [ax1,ax2,ax3,ax4] = plt.subplots(1, 4,figsize=(20,5))\n",
    "\n",
    "train_loss_plt, = ax1.plot([0,1],[0,1],label='Test Loss')\n",
    "test_loss_plt, = ax2.plot([0,1],[0,1],label='Train Loss')\n",
    "train_acc_plt, = ax3.plot([0,1],[0,1],label='Train Accuracy')\n",
    "test_acc_plt, = ax4.plot([0,1],[0,1],label='Test Accuracy')\n",
    "\n",
    "ax1.set_xlabel(\"Iteration\")\n",
    "ax1.set_ylabel(\"Train Loss\")\n",
    "ax1.set_xlim([0,NUM_EPOCHS*(NUM_SAMPLES//BS)])\n",
    "ax1.set_ylim([0,20])\n",
    "\n",
    "ax2.set_xlabel(\"Iteration\")\n",
    "ax2.set_ylabel(\"Test Loss\")\n",
    "ax2.set_xlim([0,NUM_EPOCHS*(NUM_SAMPLES//BS)])\n",
    "ax2.set_ylim([0,20])\n",
    "\n",
    "ax3.set_xlabel(\"Iteration\")\n",
    "ax3.set_ylabel(\"Train Accuracy\")\n",
    "ax3.set_xlim([0,NUM_EPOCHS*(NUM_SAMPLES//BS)])\n",
    "ax3.set_ylim([0,1])\n",
    "\n",
    "ax4.set_xlabel(\"Iteration\")\n",
    "ax4.set_ylabel(\"Test Accuracy\")\n",
    "ax4.set_xlim([0,NUM_EPOCHS*(NUM_SAMPLES//BS)])\n",
    "ax4.set_ylim([0,1])\n",
    "\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def models_train(models):\n",
    "    for m in models:\n",
    "        m.train()\n",
    "        \n",
    "def models_eval(models):\n",
    "    for m in models:\n",
    "        m.eval()\n",
    "        \n",
    "# Training functions\n",
    "def accuracy_fn1(log_probs, true_labels):\n",
    "    accuracy = 0.0\n",
    "    \n",
    "    for i in range(len(log_probs)):\n",
    "        log_prob = log_probs[i]\n",
    "        labels = true_labels[i]\n",
    "        predictions = torch.argmax(log_prob, dim=-1)\n",
    "        accuracy += (torch.sum(predictions == labels) / float(labels.numel()))\n",
    "    \n",
    "    return accuracy / float(len(log_probs))\n",
    "\n",
    "def accuracy_fn(log_probs, true_labels):\n",
    "    masked_log_probs = log_probs.clone()\n",
    "    if NUM_CHARS > 1:\n",
    "        masked_log_probs = mask_extra_dims(masked_log_probs, zero_mask_output=False)\n",
    "\n",
    "    predictions = torch.argmax(masked_log_probs, dim=-1) # (n, num_classification_problems)\n",
    "    accuracy = torch.sum(predictions == true_labels) / torch.numel(true_labels)\n",
    "    \n",
    "    \n",
    "    return accuracy\n",
    "\n",
    "def nll_loss(log_probs, one_hot_labels):\n",
    "    masked_log_probs = log_probs.clone()\n",
    "     \n",
    "    if NUM_CHARS > 1:\n",
    "        masked_log_probs = mask_extra_dims(masked_log_probs, zero_mask_output=True)\n",
    "    \n",
    "    \n",
    "    log_likelihood = torch.sum(one_hot_labels * masked_log_probs, axis=-1)\n",
    "    av_nll = -torch.mean(log_likelihood)\n",
    "    \n",
    "    return av_nll\n",
    "\n",
    "def mask_extra_dims(log_probs, zero_mask_output):\n",
    "    \"\"\"For each classification problem, we are outputting the same number of logits, despite the fact\n",
    "    that different problems have different number of classes. Thus, for each problem,\n",
    "    we mask out unnecessary logits (and renormalize). This leads to (increased) over-parameterisation\n",
    "    but given that the usual softmax cross-entropy loss is overparameterised to start with, it's\n",
    "    probably not a big deal.\n",
    "    \"\"\"\n",
    "    \n",
    "    class_mask = torch.zeros((NUM_CHARS, MAX_NUM_CLASSES), dtype=torch.float32, device=log_probs.device)\n",
    "    neg_infs = torch.zeros_like(class_mask, device=log_probs.device)\n",
    "    \n",
    "    for i in range(NUM_CHARS):\n",
    "        class_mask[i, :num_classes_per_problem[i]] = 1.0\n",
    "        neg_infs[i, num_classes_per_problem[i]:] = -np.inf\n",
    "    \n",
    "    # renormalise the non-masked entries\n",
    "    masked_log_probs = (class_mask * log_probs.clone().unsqueeze(1)) + neg_infs  # (n, num_classification_problems, max_num_classes)\n",
    "    masked_log_probs = masked_log_probs - (torch.logsumexp(masked_log_probs, dim=-1, keepdim=True) * class_mask)\n",
    "    \n",
    "    if zero_mask_output:\n",
    "        masked_log_probs = torch.where(torch.isinf(masked_log_probs), torch.zeros_like(masked_log_probs), masked_log_probs)\n",
    "    \n",
    "#     return log_probs\n",
    "    return masked_log_probs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "backbone.eval()\n",
    "\n",
    "train_loss_store = []\n",
    "test_loss_store = []\n",
    "train_acc_store = []\n",
    "test_acc_store = []\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    backbone = backbone.to(DEVICE)\n",
    "    clfs = [clf.to(DEVICE) for clf in clfs]\n",
    "    \n",
    "i = 0\n",
    "# Define custom NLL loss\n",
    "loss_crit = torch.nn.CrossEntropyLoss()\n",
    "\n",
    "\n",
    "for epoch in trange(NUM_EPOCHS):\n",
    "    models_train(clfs)\n",
    "    for (u, _, labels) in iter(train_dl):\n",
    "        i += 1\n",
    "        if torch.cuda.is_available():\n",
    "            u, labels = u.to(DEVICE), labels.to(DEVICE)\n",
    "#         one_hot_labels = torch.nn.functional.one_hot(labels, MAX_NUM_CLASSES)\n",
    "        \n",
    "        # Splitting labels for the grid\n",
    "        labels = [labels[:,i] for i in range(labels.shape[-1])]\n",
    "#         labels = labels.unbind(dim=-1)\n",
    "        \n",
    "        for optim in optims:\n",
    "            optim.zero_grad()\n",
    "        \n",
    "        with torch.no_grad():\n",
    "            g_u = backbone(u)\n",
    "            \n",
    "        outputs = [clf(g_u) for clf in clfs]\n",
    "        loss = sum(list(map(loss_crit, outputs, labels)))\n",
    "        loss.backward()\n",
    "        \n",
    "        for optim in optims:\n",
    "            optim.step()\n",
    "        \n",
    "        accuracy = accuracy_fn1(outputs, labels)\n",
    "#         loss = nll_loss(outputs, one_hot_labels)\n",
    "        train_loss_store.append(loss.item())\n",
    "        train_acc_store.append(accuracy.item())\n",
    "    \n",
    "    # Validation\n",
    "    with torch.no_grad():\n",
    "        models_eval(clfs)\n",
    "        test_epoch_loss = []\n",
    "        test_epoch_acc = []\n",
    "        \n",
    "        for (u, _, labels) in iter(val_dl):\n",
    "            # CUDA\n",
    "            if torch.cuda.is_available():\n",
    "                u, labels = u.to(DEVICE), labels.to(DEVICE)\n",
    "\n",
    "            # Splitting labels for the grid\n",
    "            labels = labels.unbind(dim=-1)\n",
    "            g_u = backbone(u)\n",
    "            outputs = [clf(g_u) for clf in clfs]\n",
    "            loss = sum(list(map(loss_crit, outputs, labels)))\n",
    "\n",
    "            accuracy = accuracy_fn1(outputs, labels)\n",
    "            \n",
    "            test_epoch_loss.append(loss.item())\n",
    "            test_epoch_acc.append(accuracy.item())\n",
    "        \n",
    "        test_loss_store.append(np.mean(test_epoch_loss))\n",
    "        test_acc_store.append(np.mean(test_epoch_acc))\n",
    "        \n",
    "    train_loss_plt.set_data(range(len(train_loss_store)), train_loss_store)\n",
    "    ax1.set_xlim(0, len(train_loss_store))\n",
    "    train_acc_plt.set_data(range(len(train_acc_store)), train_acc_store)\n",
    "    ax3.set_xlim(0, len(train_acc_store))\n",
    "    ax3.set_ylim(0, 1)\n",
    "    test_loss_plt.set_data(range(len(test_loss_store)), test_loss_store)\n",
    "    ax2.set_xlim(0, len(test_loss_store))\n",
    "    test_acc_plt.set_data(range(len(test_acc_store)), test_acc_store)\n",
    "    ax4.set_xlim(0, len(test_acc_store))\n",
    "    ax4.set_ylim(0, 1)\n",
    "    \n",
    "    clear_output(wait=True)\n",
    "    display(fig)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(train_acc_store[-1], test_acc_store[-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    models_eval(clfs)\n",
    "    test_epoch_loss = []\n",
    "    test_epoch_acc = []\n",
    "\n",
    "    for (u, _, labels) in iter(test_dl):\n",
    "        # CUDA\n",
    "        if torch.cuda.is_available():\n",
    "            u, labels = u.to(DEVICE), labels.to(DEVICE)\n",
    "\n",
    "        # Splitting labels for the grid\n",
    "        labels = labels.unbind(dim=-1)\n",
    "        g_u = backbone(u)\n",
    "        outputs = [clf(g_u) for clf in clfs]\n",
    "        loss = sum(list(map(loss_crit, outputs, labels)))\n",
    "\n",
    "        accuracy = accuracy_fn1(outputs, labels)\n",
    "\n",
    "        test_epoch_loss.append(loss.item())\n",
    "        test_epoch_acc.append(accuracy.item())\n",
    "\n",
    "    print(np.mean(test_epoch_loss))\n",
    "    print(np.mean(test_epoch_acc))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#### LBFGS #####\n",
    "\n",
    "backbone.eval()\n",
    "\n",
    "train_loss_store = []\n",
    "test_loss_store = []\n",
    "train_acc_store = []\n",
    "test_acc_store = []\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    backbone = backbone.to(DEVICE)\n",
    "    clfs = [clf.to(DEVICE) for clf in clfs]\n",
    "    \n",
    "i = 0\n",
    "\n",
    "for epoch in trange(NUM_EPOCHS):\n",
    "    clf.train()\n",
    "    for (u, _, labels) in iter(train_dl):\n",
    "        i += 1\n",
    "        if torch.cuda.is_available():\n",
    "            u, labels = u.to(DEVICE), labels.to(DEVICE)\n",
    "        one_hot_labels = torch.nn.functional.one_hot(labels, MAX_NUM_CLASSES)        \n",
    "        \n",
    "        def closure():\n",
    "            optim.zero_grad()\n",
    "            with torch.no_grad():\n",
    "                g_u = backbone(u)\n",
    "            outputs = clf(g_u)\n",
    "            loss = nll_loss(outputs, one_hot_labels)\n",
    "            loss.backward()\n",
    "            \n",
    "            return loss\n",
    "        \n",
    "        optim.step(closure)\n",
    "        \n",
    "        with torch.no_grad():\n",
    "            g_u = backbone(u)\n",
    "            outputs = clf(g_u)\n",
    "            loss = nll_loss(outputs, one_hot_labels)\n",
    "            accuracy = accuracy_fn(outputs, labels)\n",
    "            \n",
    "        train_loss_store.append(loss.item())\n",
    "        train_acc_store.append(accuracy.item())\n",
    "    \n",
    "    # Validation\n",
    "    with torch.no_grad():\n",
    "        models_eval(clfs)\n",
    "        test_epoch_loss = []\n",
    "        test_epoch_acc = []\n",
    "        \n",
    "        for (u, _, labels) in iter(val_dl):\n",
    "            # CUDA\n",
    "            if torch.cuda.is_available():\n",
    "                u, labels = u.to(DEVICE), labels.to(DEVICE)\n",
    "\n",
    "            g_u = backbone(u)\n",
    "            outputs = clf(g_u)\n",
    "            loss = nll_loss(outputs, one_hot_labels)\n",
    "            accuracy = accuracy_fn1(outputs, labels)\n",
    "            \n",
    "            test_epoch_loss.append(loss.item())\n",
    "            test_epoch_acc.append(accuracy.item())\n",
    "        \n",
    "        test_loss_store.append(np.mean(test_epoch_loss))\n",
    "        test_acc_store.append(np.mean(test_epoch_acc))\n",
    "        \n",
    "    train_loss_plt.set_data(range(len(train_loss_store)), train_loss_store)\n",
    "    ax1.set_xlim(0, len(train_loss_store))\n",
    "    train_acc_plt.set_data(range(len(train_acc_store)), train_acc_store)\n",
    "    ax3.set_xlim(0, len(train_acc_store))\n",
    "    ax3.set_ylim(0, 1)\n",
    "    test_loss_plt.set_data(range(len(test_loss_store)), test_loss_store)\n",
    "    ax2.set_xlim(0, len(test_loss_store))\n",
    "    test_acc_plt.set_data(range(len(test_acc_store)), test_acc_store)\n",
    "    ax4.set_xlim(0, len(test_acc_store))\n",
    "    ax4.set_ylim(0, 1)\n",
    "    \n",
    "    clear_output(wait=True)\n",
    "    display(fig)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "torch1.8",
   "language": "python",
   "name": "torch1.8"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
