{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 256,
   "metadata": {
    "id": "TCP7Y2QklcrQ"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import pathlib\n",
    "import time\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import random\n",
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "from tensorflow.keras import layers\n",
    "import tensorflow_datasets as tfds\n",
    "from tensorflow.keras.layers import Input\n",
    "from LWTA.base import *\n",
    "from LWTA.base_conv2d import LwtaClassifier as lwta_clf\n",
    "from LWTA.bit_precision import compute_reduced_weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "id": "v8x00gbSlcrR"
   },
   "outputs": [],
   "source": [
    "learning_rate = 0.003\n",
    "optimizer = keras.optimizers.SGD(learning_rate=learning_rate)  \n",
    "\n",
    "meta_step_size = 0.25\n",
    "inner_batch_size = 25\n",
    "eval_batch_size = 25\n",
    "\n",
    "meta_iters = 60000\n",
    "eval_iters = 5\n",
    "inner_iters = 4\n",
    "\n",
    "eval_interval = 50\n",
    "report_frequency = 50\n",
    "checkpoint_freq = 1000\n",
    "train_shots = 20\n",
    "shots = 5 # 1 for 1-shot 5-way\n",
    "classes = 5 # 5 for 1-shot 5-way\n",
    "BMA = False\n",
    "DETERMINISTIC = False"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "aUycHsQ7lcrR"
   },
   "source": [
    "## Prepare the data\n",
    "\n",
    "The [Omniglot dataset](https://github.com/brendenlake/omniglot/) is a dataset of 1,623\n",
    "characters taken from 50 different alphabets, with 20 examples for each character.\n",
    "The 20 samples for each character were drawn online via Amazon's Mechanical Turk. For the\n",
    "few-shot learning task, `k` samples (or \"shots\") are drawn randomly from `n` randomly-chosen\n",
    "classes. These `n` numerical values are used to create a new set of temporary labels to use\n",
    "to test the model's ability to learn a new task given few examples. In other words, if you\n",
    "are training on 5 classes, your new class labels will be either 0, 1, 2, 3, or 4.\n",
    "Omniglot is a great dataset for this task since there are many different classes to draw\n",
    "from, with a reasonable number of samples for each class."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 478,
     "referenced_widgets": [
      "17e640fbeb5f458287117311c12b2bbd",
      "9e6af9e40f1f489fb6e1514f6e9ad8f6",
      "b4540f3e948047b7a18a428860e55652",
      "fbc46482c9eb442a9458f0b619443ef1",
      "e66c0534282e4c55a21bd149fb32aab8",
      "c770b8fe48a44039a1a8486180df589d",
      "04da630f62a040acb30c3ca55b589f5d",
      "b707ee96ad21479fbd0608e8158c5276",
      "203034d88830486c92067b44ea451890",
      "45f90aa79c1d46158f58f5d2682450fd",
      "983c38b96f314039aae4e3be8a452f9e",
      "d8167d1417bd415c8d7c66194cf5f6b7",
      "7c7b964330724bb2b222233f13cb1839",
      "89e979baca37455496ded4f36586ff03",
      "28a53bac50b14a1caf72690a9c2a0a93",
      "3dbbde2390f045eda9f428746ebaa795",
      "550f8801d97c417e8d519f1a6b07146e",
      "0bdba0100ae34522a621c53f28cc3e8d",
      "78e64e6e95f6447d859ae1170365f94c",
      "5118178ff1e14036969c18ad2faf722a",
      "d5e7d7197f57418e82e19a2bee67f3bd",
      "b171d93947a14520907b59f66697574e",
      "7b87243a89ab4f2fa596a09188ebeadf",
      "b7d109ba579741ff8f74e38ce66cda0a",
      "20a7a110d571405da19829ad1727643b",
      "5749fe77e4f948e18033d46beea5718e",
      "5793ada4367f46708da317de1652752f",
      "520eeb9985fb46929f90de6d2f24e999",
      "31ef7325a71b47eab92e3a4dcc0b5318",
      "4505d9dfa62c4826ad96da9ab397e352",
      "469d3773a7114477bbe21e5086c06aec",
      "a4f380d9686946308d483d78a121a3ea",
      "6bf7532e29754f839963c7eb7a58e200",
      "cd945007c0be4a568e77b96944676717",
      "0fd1b28baa684abd8c8e6ff7e24ab19f",
      "8793819ee22c431789fe05792b71ef50",
      "d45c15799fca46239f6797f894a852cd",
      "00eb811349f14c848f0bb2c8e4fab65f",
      "32edaffc9d8d4dfb90b9d99ed640f0e5",
      "d55437c1a0f346b7b0494eaea8988236",
      "a7f926e82ae346928221be0cdc83cb00",
      "86f877a4124c4439a7d4551a488ef4e3",
      "92dbd1a4c73848c3bbeba68ff0c6f31d",
      "9e1f957b1e344d428ed19e26dc3b2f38",
      "3789fd81836e435691fcac8a429c900d",
      "756c0aa0fe9e4fd88c8bc6752c7d5bf1",
      "0622f3b6ac3b4c1a93b8e2e3616e8b00",
      "ff3f698a52ce4422aa0666a67ec97738",
      "c62679f2de2e4e0c954082d6f752160c",
      "cfb9c99287c24d5fb3ff5fca1b3c2f37",
      "79d648274ec14a348109fcd6a66dace2",
      "c1e4612e2d234f3299069f2c292e3ed3",
      "dd248e0814a64ee8a63090fee78e962d",
      "8be07adc4c3c4f1796c6c0d2192e625a",
      "5ad1704befa648a3874ef7174ef3e5ed",
      "bf9a2d6ff3b042d98749cdb2d8f00948",
      "dc0c030fd626423f815f22e91e591030",
      "086d178b0dbe48dd897ee9b125dcd464",
      "de03e722d1a449cb99bab16e69f1e1c3",
      "e932a46875654a3ebfe14802be27418a",
      "51b83ae12ca74a2abde28cdc977b8ed2",
      "3c167a3dde704c79979311b872948e76",
      "622ddc39ade64c77b160cbd66e6752b2",
      "b08991fb355f444f9f8d6be2ea099142",
      "2606945c3d5744408320c95aad7eeb88",
      "12636d6f48804f8d93b21f3ab54bcf6c",
      "56869a904c514f5e8835c517e9245fbf",
      "6df570404ba44b9daad834e269c59c49",
      "5856779e14464d9b867c51eceb5d0dfc",
      "61e50e415fd64b6abb13926f0747aff2",
      "c9e2d236024e4644b3c6607ca127ab21",
      "c61f26881c75493f9d9cdc06628a867d",
      "bf1f44068c034278b6be5e12393898e1",
      "0ca1b2b2206b43e1af3d03da4fa64503",
      "e5479b3849034671a1be5bad552bb299",
      "d62201089e8446a1ae0dd483b9a31686",
      "471de552d73f48f7a244cb2e298f62bc",
      "57bdadfb0c4247b49371f24941acf784",
      "cad7f13a33b84e78b76336792fe46cea",
      "0fc193a8863e479fb87055e664de8b9d",
      "6a4750d06897462eadbdcd673bf21e43",
      "8ccf110725f14ee5b8c58b84f99c39f9",
      "71345d17bc68497f8f4f5b3a3ff125f1",
      "e0b0681af49c463aa93f2ec7526a7b90",
      "ebb18b0a97254ad4857872f885f43d80",
      "1a08f76b599747a1933c758ac0bf397e",
      "06b95880542941b8a819aa28baa2e0e8",
      "a2deeecce6d3496d83a0f42c62aa3f3b"
     ]
    },
    "id": "_cLhxSASlcrR",
    "outputId": "909150ee-074a-448f-943e-ebb7d9900616"
   },
   "outputs": [],
   "source": [
    "class Dataset:\n",
    "    # This class will facilitate the creation of a few-shot dataset\n",
    "    # from the Omniglot dataset that can be sampled from quickly while also\n",
    "    # allowing to create new labels at the same time.\n",
    "    def __init__(self, training):\n",
    "        # Download the tfrecord files containing the omniglot data and convert to a\n",
    "        # dataset.\n",
    "        split = \"train\" if training else \"test\"\n",
    "        ds = tfds.load(\"omniglot\", split=split, as_supervised=True, shuffle_files=False)\n",
    "\n",
    "        # Iterate over the dataset to get each individual image and its class,\n",
    "        # and put that data into a dictionary.\n",
    "        self.data = {}\n",
    "            \n",
    "        def extraction(image, label):\n",
    "            # This function will shrink the Omniglot images to the desired size,\n",
    "            # scale pixel values and convert the RGB image to grayscale\n",
    "            image = tf.image.convert_image_dtype(image, tf.float32)\n",
    "            image = tf.image.rgb_to_grayscale(image)\n",
    "            image = tf.image.resize(image, [28, 28])\n",
    "            return image, label\n",
    "\n",
    "        for image, label in ds.map(extraction):\n",
    "            image = image.numpy()\n",
    "            label = str(label.numpy())\n",
    "            if label not in self.data:\n",
    "                self.data[label] = []\n",
    "            self.data[label].append(image)\n",
    "            self.labels = list(self.data.keys())\n",
    "\n",
    "    def get_mini_dataset(\n",
    "        self, batch_size, repetitions, shots, num_classes, split=False\n",
    "    ):\n",
    "        temp_labels = np.zeros(shape=(num_classes * shots))\n",
    "        temp_images = np.zeros(shape=(num_classes * shots, 28, 28, 1))\n",
    "        if split:\n",
    "            test_labels = np.zeros(shape=(num_classes))\n",
    "            test_images = np.zeros(shape=(num_classes, 28, 28, 1))\n",
    "\n",
    "        # Get a random subset of labels from the entire label set.\n",
    "        label_subset = random.choices(self.labels, k=num_classes)\n",
    "        for class_idx, class_obj in enumerate(label_subset):\n",
    "            # Use enumerated index value as a temporary label for mini-batch in\n",
    "            # few shot learning.\n",
    "            temp_labels[class_idx * shots : (class_idx + 1) * shots] = class_idx\n",
    "            # If creating a split dataset for testing, select an extra sample from each\n",
    "            # label to create the test dataset.\n",
    "            if split:\n",
    "                test_labels[class_idx] = class_idx\n",
    "                images_to_split = random.choices(\n",
    "                    self.data[label_subset[class_idx]], k=shots + 1\n",
    "                )\n",
    "                test_images[class_idx] = images_to_split[-1]\n",
    "                temp_images[\n",
    "                    class_idx * shots : (class_idx + 1) * shots\n",
    "                ] = images_to_split[:-1]\n",
    "            else:\n",
    "                # For each index in the randomly selected label_subset, sample the\n",
    "                # necessary number of images.\n",
    "                temp_images[\n",
    "                    class_idx * shots : (class_idx + 1) * shots\n",
    "                ] = random.choices(self.data[label_subset[class_idx]], k=shots)\n",
    "\n",
    "        dataset = tf.data.Dataset.from_tensor_slices(\n",
    "            (temp_images.astype(np.float32), temp_labels.astype(np.int32))\n",
    "        )\n",
    "        dataset = dataset.shuffle(100).batch(batch_size).repeat(repetitions)\n",
    "        if split:\n",
    "            return dataset, test_images, test_labels\n",
    "        return dataset\n",
    "\n",
    "import urllib3\n",
    "\n",
    "urllib3.disable_warnings()  # Disable SSL warnings that may happen during download.\n",
    "train_dataset = Dataset(training=True)\n",
    "test_dataset = Dataset(training=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train the model (with dense_LWTA)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "#try deterministic=False for point estimates and not distribution estimates\n",
    "sb_class = LwtaClassifier(original_dim = [classes,1], tau = 5e-2, bma=BMA,\n",
    "                          deterministic=DETERMINISTIC) \n",
    "                          \n",
    "def train_func(x, train=True, activation=\"lwta\"):\n",
    "    return sb_class(x, train=train, activation=activation)\n",
    "\n",
    "train = tf.function(train_func)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create checkpoint for resuming training \n",
    "ckpt = tf.train.Checkpoint(step=tf.Variable(0), optimizer=optimizer, net=sb_class)\n",
    "ckpt_path = f\"./checkpoints_omniglot_lwta_{shots}_shot_{classes}_way\"\n",
    "manager = tf.train.CheckpointManager(ckpt, ckpt_path, max_to_keep=3)\n",
    "\n",
    "def checkpoint_manager(meta_iter):  \n",
    "    if meta_iter == 0:\n",
    "        if os.path.isdir(ckpt_path):\n",
    "            ckpt.restore(manager.latest_checkpoint)\n",
    "            print(\"\\nRestored from {}\\n\".format(manager.latest_checkpoint))\n",
    "        else:\n",
    "            print(\"\\nNone checkpoints found => Initializing from scratch.\\n\")\n",
    "    elif meta_iter % checkpoint_freq == 0:\n",
    "        save_path = manager.save()\n",
    "        print(\"\\nSaved checkpoint for step {}: {}\\n\".format(meta_iter, save_path))\n",
    "        \n",
    "    ckpt.step.assign_add(1)\n",
    "    return ckpt\n",
    "\n",
    "#save to extenrla file integer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "rm: cannot remove 'checkpoints_omniglot_lwta_1_shot_5_way': No such file or directory\r\n"
     ]
    }
   ],
   "source": [
    "!rm -r checkpoints_omniglot_lwta_1_shot_5_way\n",
    "!rm -r checkpoints_omniglot_lwta_5_shot_5_way"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "collapsed": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "None checkpoints found => Initializing from scratch.\n",
      "\n",
      "tf.Tensor(54549, shape=(), dtype=int32)\n",
      "Iter = 0 => train_acc = 20.00% / test_acc = 60.00%\n",
      "\n",
      "#### The first 50 iterations took 19.49 secs ####\n",
      "tf.Tensor(54549, shape=(), dtype=int32)\n",
      "Iter = 50 => train_acc = 20.00% / test_acc = 40.00%\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-18-2512cb53cd08>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     41\u001b[0m             \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mce\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     42\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 43\u001b[0;31m         \u001b[0mgrads\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtape\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgradient\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mloss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msb_class\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainable_weights\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     44\u001b[0m         \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapply_gradients\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgrads\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msb_class\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainable_weights\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     45\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/backprop.py\u001b[0m in \u001b[0;36mgradient\u001b[0;34m(self, target, sources, output_gradients, unconnected_gradients)\u001b[0m\n\u001b[1;32m   1084\u001b[0m         \u001b[0moutput_gradients\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0moutput_gradients\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1085\u001b[0m         \u001b[0msources_raw\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mflat_sources_raw\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1086\u001b[0;31m         unconnected_gradients=unconnected_gradients)\n\u001b[0m\u001b[1;32m   1087\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1088\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_persistent\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/imperative_grad.py\u001b[0m in \u001b[0;36mimperative_grad\u001b[0;34m(tape, target, sources, output_gradients, sources_raw, unconnected_gradients)\u001b[0m\n\u001b[1;32m     75\u001b[0m       \u001b[0moutput_gradients\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     76\u001b[0m       \u001b[0msources_raw\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 77\u001b[0;31m       compat.as_str(unconnected_gradients.value))\n\u001b[0m",
      "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/backprop.py\u001b[0m in \u001b[0;36m_gradient_function\u001b[0;34m(op_name, attr_tuple, num_inputs, inputs, outputs, out_grads, skip_input_indices, forward_pass_name_scope)\u001b[0m\n\u001b[1;32m    160\u001b[0m       \u001b[0mgradient_name_scope\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mforward_pass_name_scope\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m\"/\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    161\u001b[0m     \u001b[0;32mwith\u001b[0m \u001b[0mops\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mname_scope\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgradient_name_scope\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 162\u001b[0;31m       \u001b[0;32mreturn\u001b[0m \u001b[0mgrad_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmock_op\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mout_grads\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    163\u001b[0m   \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    164\u001b[0m     \u001b[0;32mreturn\u001b[0m \u001b[0mgrad_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmock_op\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mout_grads\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/math_grad.py\u001b[0m in \u001b[0;36m_MinimumGrad\u001b[0;34m(op, grad)\u001b[0m\n\u001b[1;32m   1552\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_MinimumGrad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mop\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1553\u001b[0m   \u001b[0;34m\"\"\"Returns grad*(x < y, x >= y) with type of grad.\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1554\u001b[0;31m   \u001b[0;32mreturn\u001b[0m \u001b[0m_MaximumMinimumGrad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mop\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmath_ops\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mless_equal\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1555\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1556\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/math_grad.py\u001b[0m in \u001b[0;36m_MaximumMinimumGrad\u001b[0;34m(op, grad, selector_op)\u001b[0m\n\u001b[1;32m   1516\u001b[0m       \u001b[0;31m# When we want to get gradients for the first input only, and the second\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1517\u001b[0m       \u001b[0;31m# input tensor is a scalar, we can do a much simpler calculation\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1518\u001b[0;31m       \u001b[0;32mreturn\u001b[0m \u001b[0m_MaximumMinimumGradInputOnly\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mop\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mselector_op\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1519\u001b[0m   \u001b[0;32mexcept\u001b[0m \u001b[0mAttributeError\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1520\u001b[0m     \u001b[0;31m# No gradient skipping, so do the full gradient computation\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/math_grad.py\u001b[0m in \u001b[0;36m_MaximumMinimumGradInputOnly\u001b[0;34m(op, grad, selector_op)\u001b[0m\n\u001b[1;32m   1500\u001b[0m   \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mop\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1501\u001b[0m   \u001b[0mzeros\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0marray_ops\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzeros_like\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgrad\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1502\u001b[0;31m   \u001b[0mxmask\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mselector_op\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1503\u001b[0m   \u001b[0mxgrad\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0marray_ops\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwhere_v2\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mxmask\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mzeros\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1504\u001b[0m   \u001b[0mygrad\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m  \u001b[0;31m# Return None for ygrad since the config allows that.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/gen_math_ops.py\u001b[0m in \u001b[0;36mless_equal\u001b[0;34m(x, y, name)\u001b[0m\n\u001b[1;32m   4914\u001b[0m     \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   4915\u001b[0m       _result = pywrap_tfe.TFE_Py_FastPathExecute(\n\u001b[0;32m-> 4916\u001b[0;31m         _ctx, \"LessEqual\", name, x, y)\n\u001b[0m\u001b[1;32m   4917\u001b[0m       \u001b[0;32mreturn\u001b[0m \u001b[0m_result\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   4918\u001b[0m     \u001b[0;32mexcept\u001b[0m \u001b[0m_core\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_NotOkStatusException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "training = []\n",
    "testing = []\n",
    "\n",
    "start = time.time()\n",
    "\n",
    "for meta_iter in range(int(ckpt.step), meta_iters):\n",
    "    mini_dataset = train_dataset.get_mini_dataset(\n",
    "        inner_batch_size, inner_iters, train_shots, classes\n",
    "    )\n",
    "\n",
    "    ckpt = checkpoint_manager(meta_iter)\n",
    "\n",
    "    frac_done = meta_iter / meta_iters\n",
    "    cur_meta_step_size = (1 - frac_done) * meta_step_size\n",
    "    \n",
    "    # Get a sample from the full dataset.\n",
    "    if meta_iter > 0:\n",
    "        old_vars = sb_class.get_weights()\n",
    "\n",
    "    j = 0    \n",
    "    for images, labels in mini_dataset:\n",
    "        \n",
    "        with tf.GradientTape() as tape:\n",
    "            x1, x2, x3 = images.shape[0], images.shape[1], images.shape[2]\n",
    "            images = tf.reshape(images, [x1, x2*x3])    \n",
    "            #preds, _, _ = sb_class(images, train=True, activation=\"lwta\")\n",
    "            try:\n",
    "                preds, _, _ = train(images, train=True, activation=\"lwta\")\n",
    "            except (UnboundLocalError, ValueError):\n",
    "                train = tf.function(train_func)\n",
    "                preds, _, _ = train(images, train=True, activation=\"lwta\")\n",
    "                \n",
    "            if (j == 0) and (meta_iter == 0):\n",
    "                old_vars = sb_class.get_weights()\n",
    "            # we optimize the variational lower bound scaled by the number of data\n",
    "            # points (so we can keep our intuitions about hyper-params such as the learning rate)\n",
    "            #kl_loss = sum(sb_class.losses) / (x1 * x2)\n",
    "            ce = keras.losses.sparse_categorical_crossentropy(labels, preds)\n",
    "            #loss = ce + kl_loss\n",
    "            loss = ce\n",
    "           \n",
    "        grads = tape.gradient(loss, sb_class.trainable_weights)\n",
    "        optimizer.apply_gradients(zip(grads, sb_class.trainable_weights))\n",
    "       \n",
    "        j += 1\n",
    "    \n",
    "    new_vars = sb_class.get_weights()\n",
    "\n",
    "    # Perform SGD for the meta step.\n",
    "    for var in range(len(new_vars)):\n",
    "        new_vars[var] = old_vars[var] + (\n",
    "            (new_vars[var] - old_vars[var]) * cur_meta_step_size\n",
    "        )\n",
    "    # After the meta-learning step, reload the newly-trained weights into the model.\n",
    "    sb_class.set_weights(new_vars)\n",
    "   \n",
    "    if meta_iter == 50:\n",
    "        print(\"\\n#### The first 50 iterations took {:.2f} secs ####\".format(time.time()-start))\n",
    " \n",
    "    # Evaluation loop\n",
    "    if meta_iter % eval_interval == 0:\n",
    "        accuracies = []\n",
    "        for dataset in (train_dataset, test_dataset):\n",
    "            # Sample a mini dataset from the full dataset.\n",
    "            train_set, test_images, test_labels = dataset.get_mini_dataset(\n",
    "                eval_batch_size, eval_iters, shots, classes, split=True\n",
    "            )\n",
    "            old_vars = sb_class.get_weights()\n",
    "            # Train on the samples and get the resulting accuracies.\n",
    "            for images, labels in train_set:\n",
    "              \n",
    "                with tf.GradientTape() as tape:\n",
    "                    x1, x2, x3 = images.shape[0], images.shape[1], images.shape[2]\n",
    "                    images = tf.reshape(images, [x1, x2*x3])\n",
    "                    try:\n",
    "                        preds, _, _ = train(images, train=True, activation=\"lwta\")\n",
    "                    except (UnboundLocalError, ValueError):\n",
    "                        train = tf.function(train_func)\n",
    "                        preds, _, _ = train(images, train=True, activation=\"lwta\")\n",
    "                    #kl_loss = sum(sb_class.losses) / (x1 * x2)\n",
    "                    ce = keras.losses.sparse_categorical_crossentropy(labels, preds)\n",
    "                    #loss = ce + kl_loss\n",
    "                    loss = ce\n",
    "\n",
    "                grads = tape.gradient(loss, sb_class.trainable_weights)\n",
    "                optimizer.apply_gradients(zip(grads, sb_class.trainable_weights))\n",
    "            \n",
    "            x1, x2, x3 = test_images.shape[0], test_images.shape[1], test_images.shape[2]\n",
    "\n",
    "            test_images = tf.reshape(test_images, [x1, x2*x3]) \n",
    "#             test_preds, _, _ = sb_class(test_images, train=False, activation=\"lwta\") for bma=True\n",
    "            try:\n",
    "                test_preds, _, _ = train(test_images, train=False, activation=\"lwta\")\n",
    "            except (UnboundLocalError, ValueError):\n",
    "                train = tf.function(train_func)\n",
    "                test_preds, _, _ = train(test_images, train=False, activation=\"lwta\")\n",
    "                \n",
    "            test_preds = tf.argmax(test_preds).numpy()\n",
    "            num_correct = (test_preds == test_labels).sum()\n",
    "            \n",
    "            # Reset the weights after getting the evaluation accuracies.\n",
    "            sb_class.set_weights(old_vars)\n",
    "            accuracies.append(num_correct / classes)\n",
    "       \n",
    "            \n",
    "        training.append(accuracies[0])\n",
    "        testing.append(accuracies[1])\n",
    "        \n",
    "        if meta_iter % report_frequency == 0:\n",
    "            # total num of params\n",
    "            print(tf.reduce_sum([tf.reduce_prod(v.shape) for v in sb_class.get_weights()]))\n",
    "\n",
    "            print(\"Iter = %d => train_acc = %.2f%% / test_acc = %.2f%%\" % (meta_iter,\n",
    "                  100*np.mean(training),100*np.mean(testing)))\n",
    "\n",
    "end = time.time()\n",
    "print(\"The training took {:.2f} secs\".format(end-start))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Final training accuracy = 80.0 %%\n",
      "Final testing accuracy = 66.0 %%\n"
     ]
    }
   ],
   "source": [
    "print(\"Final training accuracy = {:.1f} %%\".format(100*np.mean(training)))\n",
    "print(\"Final testing accuracy = {:.1f} %%\".format(100*np.mean(testing)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "n7B7K-ODlcrT"
   },
   "source": [
    "## Train the model (without LWTA)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "def conv_bn(x):\n",
    "    x = layers.Conv2D(filters=64, kernel_size=3, strides=2, padding=\"same\")(x)\n",
    "    x = layers.BatchNormalization()(x)\n",
    "    return layers.ReLU()(x)\n",
    "\n",
    "inputs = layers.Input(shape=(28, 28, 1))\n",
    "x = conv_bn(inputs)\n",
    "x = conv_bn(x)\n",
    "x = conv_bn(x)\n",
    "x = conv_bn(x)\n",
    "x = layers.Flatten()(x)\n",
    "outputs = layers.Dense(classes, activation=\"softmax\")(x)\n",
    "model = keras.Model(inputs=inputs, outputs=outputs)\n",
    "model.compile()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "def model_func(x):\n",
    "    return model(x)\n",
    "\n",
    "train = tf.function(model_func)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create checkpoint for resuming training \n",
    "ckpt = tf.train.Checkpoint(step=tf.Variable(0), optimizer=optimizer, net=model)\n",
    "ckpt_path = f\"./checkpoints_omniglot_lwta_{shots}_shot_{classes}_way\"\n",
    "manager = tf.train.CheckpointManager(ckpt, ckpt_path, max_to_keep=3)\n",
    "\n",
    "def checkpoint_manager(meta_iter):  \n",
    "    if meta_iter == 0:\n",
    "        if os.path.isdir(ckpt_path):\n",
    "            ckpt.restore(manager.latest_checkpoint)\n",
    "            print(\"\\nRestored from {}\\n\".format(manager.latest_checkpoint))\n",
    "        else:\n",
    "            print(\"\\nNone checkpoints found => Initializing from scratch.\\n\")\n",
    "    elif meta_iter % checkpoint_freq == 0:\n",
    "        save_path = manager.save()\n",
    "        print(\"\\nSaved checkpoint for step {}: {}\\n\".format(meta_iter, save_path))\n",
    "        \n",
    "    ckpt.step.assign_add(1)\n",
    "    return ckpt\n",
    "\n",
    "#save to extenrla file integer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "rm: cannot remove 'checkpoints_omniglot_lwta_1_shot_5_way': No such file or directory\n",
      "rm: cannot remove 'checkpoints_omniglot_lwta_5_shot_5_way': No such file or directory\n"
     ]
    }
   ],
   "source": [
    "!rm -r checkpoints_omniglot_lwta_1_shot_5_way\n",
    "!rm -r checkpoints_omniglot_lwta_5_shot_5_way"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "collapsed": true,
    "id": "3PRTNTwBlcrT",
    "outputId": "b53c2c93-7bb3-4390-f453-081ad15da2a1"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "None checkpoints found => Initializing from scratch.\n",
      "\n",
      "Iter = 0 => train_acc = 60.00% / test_acc = 20.00%\n",
      "\n",
      "#### The first 50 iterations took 29.22 secs ####\n",
      "Iter = 50 => train_acc = 50.20% / test_acc = 46.27%\n",
      "Iter = 100 => train_acc = 58.61% / test_acc = 55.25%\n",
      "Iter = 150 => train_acc = 66.23% / test_acc = 64.77%\n",
      "Iter = 200 => train_acc = 71.14% / test_acc = 68.86%\n",
      "Iter = 250 => train_acc = 74.34% / test_acc = 72.75%\n",
      "Iter = 300 => train_acc = 75.81% / test_acc = 74.82%\n",
      "Iter = 350 => train_acc = 77.21% / test_acc = 76.41%\n",
      "Iter = 400 => train_acc = 77.76% / test_acc = 77.36%\n",
      "Iter = 450 => train_acc = 78.94% / test_acc = 77.83%\n",
      "Iter = 500 => train_acc = 79.84% / test_acc = 78.84%\n",
      "Iter = 550 => train_acc = 80.44% / test_acc = 79.17%\n",
      "Iter = 600 => train_acc = 81.13% / test_acc = 79.83%\n",
      "Iter = 650 => train_acc = 81.41% / test_acc = 79.88%\n",
      "Iter = 700 => train_acc = 81.88% / test_acc = 79.91%\n",
      "Iter = 750 => train_acc = 82.24% / test_acc = 80.37%\n",
      "Iter = 800 => train_acc = 82.25% / test_acc = 80.75%\n",
      "Iter = 850 => train_acc = 82.82% / test_acc = 80.85%\n",
      "Iter = 900 => train_acc = 83.09% / test_acc = 81.15%\n",
      "Iter = 950 => train_acc = 83.24% / test_acc = 81.43%\n",
      "\n",
      "Saved checkpoint for step 1000: ./checkpoints_omniglot_lwta_5_shot_5_way/ckpt-1\n",
      "\n",
      "Iter = 1000 => train_acc = 83.50% / test_acc = 81.76%\n",
      "Iter = 1050 => train_acc = 83.82% / test_acc = 81.98%\n",
      "Iter = 1100 => train_acc = 84.12% / test_acc = 82.00%\n",
      "Iter = 1150 => train_acc = 84.54% / test_acc = 82.24%\n",
      "Iter = 1200 => train_acc = 84.61% / test_acc = 82.23%\n",
      "Iter = 1250 => train_acc = 84.89% / test_acc = 82.08%\n",
      "Iter = 1300 => train_acc = 85.07% / test_acc = 82.17%\n",
      "Iter = 1350 => train_acc = 85.39% / test_acc = 82.34%\n",
      "Iter = 1400 => train_acc = 85.44% / test_acc = 82.46%\n",
      "Iter = 1450 => train_acc = 85.55% / test_acc = 82.63%\n",
      "Iter = 1500 => train_acc = 85.57% / test_acc = 82.70%\n",
      "Iter = 1550 => train_acc = 85.70% / test_acc = 82.89%\n",
      "Iter = 1600 => train_acc = 85.83% / test_acc = 82.94%\n",
      "Iter = 1650 => train_acc = 86.11% / test_acc = 82.98%\n",
      "Iter = 1700 => train_acc = 86.22% / test_acc = 83.07%\n",
      "Iter = 1750 => train_acc = 86.36% / test_acc = 83.20%\n",
      "Iter = 1800 => train_acc = 86.51% / test_acc = 83.30%\n",
      "Iter = 1850 => train_acc = 86.49% / test_acc = 83.39%\n",
      "Iter = 1900 => train_acc = 86.69% / test_acc = 83.49%\n",
      "Iter = 1950 => train_acc = 86.66% / test_acc = 83.45%\n",
      "\n",
      "Saved checkpoint for step 2000: ./checkpoints_omniglot_lwta_5_shot_5_way/ckpt-2\n",
      "\n",
      "Iter = 2000 => train_acc = 86.72% / test_acc = 83.57%\n",
      "Iter = 2050 => train_acc = 86.77% / test_acc = 83.56%\n",
      "Iter = 2100 => train_acc = 86.76% / test_acc = 83.68%\n",
      "Iter = 2150 => train_acc = 86.85% / test_acc = 83.75%\n",
      "Iter = 2200 => train_acc = 86.93% / test_acc = 83.73%\n",
      "Iter = 2250 => train_acc = 86.97% / test_acc = 83.75%\n",
      "Iter = 2300 => train_acc = 86.94% / test_acc = 83.77%\n",
      "Iter = 2350 => train_acc = 87.01% / test_acc = 83.79%\n",
      "Iter = 2400 => train_acc = 87.12% / test_acc = 83.85%\n",
      "Iter = 2450 => train_acc = 87.26% / test_acc = 83.93%\n",
      "Iter = 2500 => train_acc = 87.29% / test_acc = 83.97%\n",
      "Iter = 2550 => train_acc = 87.35% / test_acc = 84.12%\n",
      "Iter = 2600 => train_acc = 87.37% / test_acc = 84.14%\n",
      "Iter = 2650 => train_acc = 87.37% / test_acc = 84.10%\n",
      "Iter = 2700 => train_acc = 87.40% / test_acc = 84.15%\n",
      "Iter = 2750 => train_acc = 87.41% / test_acc = 84.27%\n",
      "Iter = 2800 => train_acc = 87.44% / test_acc = 84.30%\n",
      "Iter = 3050 => train_acc = 87.67% / test_acc = 84.41%\n",
      "Iter = 3100 => train_acc = 87.76% / test_acc = 84.46%\n",
      "Iter = 3150 => train_acc = 87.74% / test_acc = 84.49%\n",
      "Iter = 3200 => train_acc = 87.76% / test_acc = 84.55%\n",
      "Iter = 3250 => train_acc = 87.77% / test_acc = 84.54%\n",
      "Iter = 3300 => train_acc = 87.80% / test_acc = 84.60%\n",
      "Iter = 3350 => train_acc = 87.85% / test_acc = 84.60%\n",
      "Iter = 3400 => train_acc = 87.95% / test_acc = 84.67%\n",
      "Iter = 3450 => train_acc = 87.98% / test_acc = 84.71%\n",
      "Iter = 3500 => train_acc = 88.01% / test_acc = 84.75%\n",
      "Iter = 3550 => train_acc = 88.02% / test_acc = 84.79%\n",
      "Iter = 3600 => train_acc = 88.14% / test_acc = 84.81%\n",
      "Iter = 3650 => train_acc = 88.11% / test_acc = 84.82%\n",
      "Iter = 3700 => train_acc = 88.16% / test_acc = 84.87%\n",
      "Iter = 3750 => train_acc = 88.25% / test_acc = 84.92%\n",
      "Iter = 3800 => train_acc = 88.32% / test_acc = 84.97%\n",
      "Iter = 3850 => train_acc = 88.36% / test_acc = 84.91%\n",
      "Iter = 3900 => train_acc = 88.39% / test_acc = 84.93%\n",
      "Iter = 3950 => train_acc = 88.45% / test_acc = 84.92%\n",
      "\n",
      "Saved checkpoint for step 4000: ./checkpoints_omniglot_lwta_5_shot_5_way/ckpt-4\n",
      "\n",
      "Iter = 4000 => train_acc = 88.44% / test_acc = 84.98%\n",
      "Iter = 4050 => train_acc = 88.49% / test_acc = 85.06%\n",
      "Iter = 4100 => train_acc = 88.55% / test_acc = 85.11%\n",
      "Iter = 4150 => train_acc = 88.61% / test_acc = 85.18%\n",
      "Iter = 4200 => train_acc = 88.66% / test_acc = 85.24%\n",
      "Iter = 4250 => train_acc = 88.69% / test_acc = 85.26%\n",
      "Iter = 4300 => train_acc = 88.72% / test_acc = 85.29%\n",
      "Iter = 4350 => train_acc = 88.77% / test_acc = 85.37%\n",
      "Iter = 4400 => train_acc = 88.80% / test_acc = 85.40%\n",
      "Iter = 4450 => train_acc = 88.87% / test_acc = 85.39%\n",
      "Iter = 4500 => train_acc = 88.85% / test_acc = 85.43%\n",
      "Iter = 4550 => train_acc = 88.91% / test_acc = 85.49%\n",
      "Iter = 4600 => train_acc = 88.95% / test_acc = 85.51%\n",
      "Iter = 4650 => train_acc = 89.00% / test_acc = 85.57%\n",
      "Iter = 4700 => train_acc = 89.02% / test_acc = 85.63%\n",
      "Iter = 4750 => train_acc = 89.06% / test_acc = 85.67%\n",
      "Iter = 4800 => train_acc = 89.07% / test_acc = 85.66%\n",
      "Iter = 4850 => train_acc = 89.10% / test_acc = 85.69%\n",
      "Iter = 4900 => train_acc = 89.15% / test_acc = 85.70%\n",
      "Iter = 4950 => train_acc = 89.17% / test_acc = 85.73%\n",
      "\n",
      "Saved checkpoint for step 5000: ./checkpoints_omniglot_lwta_5_shot_5_way/ckpt-5\n",
      "\n",
      "Iter = 5000 => train_acc = 89.21% / test_acc = 85.72%\n",
      "Iter = 5050 => train_acc = 89.25% / test_acc = 85.73%\n",
      "Iter = 5100 => train_acc = 89.26% / test_acc = 85.81%\n",
      "Iter = 5150 => train_acc = 89.30% / test_acc = 85.86%\n",
      "Iter = 5200 => train_acc = 89.31% / test_acc = 85.90%\n",
      "Iter = 5250 => train_acc = 89.35% / test_acc = 85.94%\n",
      "Iter = 5300 => train_acc = 89.36% / test_acc = 86.00%\n",
      "Iter = 5350 => train_acc = 89.38% / test_acc = 86.01%\n",
      "Iter = 5400 => train_acc = 89.42% / test_acc = 86.06%\n",
      "Iter = 5450 => train_acc = 89.42% / test_acc = 86.08%\n",
      "Iter = 5500 => train_acc = 89.44% / test_acc = 86.11%\n",
      "Iter = 5550 => train_acc = 89.48% / test_acc = 86.12%\n",
      "Iter = 5600 => train_acc = 89.53% / test_acc = 86.18%\n",
      "Iter = 5650 => train_acc = 89.56% / test_acc = 86.20%\n",
      "Iter = 5700 => train_acc = 89.60% / test_acc = 86.24%\n",
      "Iter = 5750 => train_acc = 89.60% / test_acc = 86.30%\n",
      "Iter = 5800 => train_acc = 89.62% / test_acc = 86.29%\n",
      "Iter = 5850 => train_acc = 89.63% / test_acc = 86.34%\n",
      "Iter = 5900 => train_acc = 89.67% / test_acc = 86.37%\n",
      "Iter = 5950 => train_acc = 89.70% / test_acc = 86.43%\n",
      "\n",
      "Saved checkpoint for step 6000: ./checkpoints_omniglot_lwta_5_shot_5_way/ckpt-6\n",
      "\n",
      "Iter = 6000 => train_acc = 89.74% / test_acc = 86.45%\n",
      "Iter = 6050 => train_acc = 89.76% / test_acc = 86.48%\n",
      "Iter = 6100 => train_acc = 89.82% / test_acc = 86.51%\n",
      "Iter = 6150 => train_acc = 89.84% / test_acc = 86.51%\n",
      "Iter = 6200 => train_acc = 89.88% / test_acc = 86.56%\n",
      "Iter = 6250 => train_acc = 89.90% / test_acc = 86.57%\n",
      "Iter = 6300 => train_acc = 89.95% / test_acc = 86.61%\n",
      "Iter = 6350 => train_acc = 89.99% / test_acc = 86.65%\n",
      "Iter = 6400 => train_acc = 90.02% / test_acc = 86.66%\n",
      "Iter = 6450 => train_acc = 90.04% / test_acc = 86.66%\n",
      "Iter = 6500 => train_acc = 90.10% / test_acc = 86.71%\n",
      "Iter = 6550 => train_acc = 90.14% / test_acc = 86.73%\n",
      "Iter = 6600 => train_acc = 90.14% / test_acc = 86.75%\n",
      "Iter = 6650 => train_acc = 90.19% / test_acc = 86.77%\n",
      "Iter = 6700 => train_acc = 90.23% / test_acc = 86.82%\n",
      "Iter = 6750 => train_acc = 90.24% / test_acc = 86.86%\n",
      "Iter = 6800 => train_acc = 90.25% / test_acc = 86.88%\n",
      "Iter = 6850 => train_acc = 90.28% / test_acc = 86.95%\n",
      "Iter = 6900 => train_acc = 90.31% / test_acc = 86.98%\n",
      "Iter = 6950 => train_acc = 90.33% / test_acc = 87.02%\n",
      "\n",
      "Saved checkpoint for step 7000: ./checkpoints_omniglot_lwta_5_shot_5_way/ckpt-7\n",
      "\n",
      "Iter = 7000 => train_acc = 90.35% / test_acc = 87.05%\n",
      "Iter = 7050 => train_acc = 90.38% / test_acc = 87.09%\n",
      "Iter = 7100 => train_acc = 90.40% / test_acc = 87.11%\n",
      "Iter = 7150 => train_acc = 90.44% / test_acc = 87.14%\n",
      "Iter = 7200 => train_acc = 90.48% / test_acc = 87.18%\n",
      "Iter = 7250 => train_acc = 90.51% / test_acc = 87.20%\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iter = 7300 => train_acc = 90.54% / test_acc = 87.19%\n",
      "Iter = 7350 => train_acc = 90.55% / test_acc = 87.21%\n",
      "Iter = 7400 => train_acc = 90.58% / test_acc = 87.25%\n",
      "Iter = 7450 => train_acc = 90.62% / test_acc = 87.26%\n",
      "Iter = 7500 => train_acc = 90.64% / test_acc = 87.27%\n",
      "Iter = 7550 => train_acc = 90.67% / test_acc = 87.29%\n",
      "Iter = 7600 => train_acc = 90.70% / test_acc = 87.30%\n",
      "Iter = 7650 => train_acc = 90.73% / test_acc = 87.28%\n",
      "Iter = 7700 => train_acc = 90.74% / test_acc = 87.27%\n",
      "Iter = 7750 => train_acc = 90.77% / test_acc = 87.26%\n",
      "Iter = 7800 => train_acc = 90.79% / test_acc = 87.29%\n",
      "Iter = 7850 => train_acc = 90.80% / test_acc = 87.30%\n",
      "Iter = 7900 => train_acc = 90.82% / test_acc = 87.33%\n",
      "Iter = 7950 => train_acc = 90.85% / test_acc = 87.34%\n",
      "\n",
      "Saved checkpoint for step 8000: ./checkpoints_omniglot_lwta_5_shot_5_way/ckpt-8\n",
      "\n",
      "Iter = 8000 => train_acc = 90.88% / test_acc = 87.35%\n",
      "Iter = 8050 => train_acc = 90.90% / test_acc = 87.39%\n",
      "Iter = 8100 => train_acc = 90.94% / test_acc = 87.41%\n",
      "Iter = 8150 => train_acc = 90.95% / test_acc = 87.45%\n",
      "Iter = 8200 => train_acc = 90.97% / test_acc = 87.46%\n",
      "Iter = 8250 => train_acc = 90.99% / test_acc = 87.49%\n",
      "Iter = 8300 => train_acc = 91.03% / test_acc = 87.52%\n",
      "Iter = 8350 => train_acc = 91.05% / test_acc = 87.56%\n",
      "Iter = 8400 => train_acc = 91.08% / test_acc = 87.58%\n",
      "Iter = 8450 => train_acc = 91.10% / test_acc = 87.60%\n",
      "Iter = 8500 => train_acc = 91.13% / test_acc = 87.63%\n",
      "Iter = 8550 => train_acc = 91.15% / test_acc = 87.68%\n",
      "Iter = 8600 => train_acc = 91.17% / test_acc = 87.69%\n",
      "Iter = 8650 => train_acc = 91.20% / test_acc = 87.72%\n",
      "Iter = 8700 => train_acc = 91.21% / test_acc = 87.74%\n",
      "Iter = 8750 => train_acc = 91.25% / test_acc = 87.74%\n",
      "Iter = 8800 => train_acc = 91.26% / test_acc = 87.76%\n",
      "Iter = 8850 => train_acc = 91.28% / test_acc = 87.78%\n",
      "Iter = 8900 => train_acc = 91.31% / test_acc = 87.77%\n",
      "Iter = 8950 => train_acc = 91.34% / test_acc = 87.81%\n",
      "\n",
      "Saved checkpoint for step 9000: ./checkpoints_omniglot_lwta_5_shot_5_way/ckpt-9\n",
      "\n",
      "Iter = 9000 => train_acc = 91.37% / test_acc = 87.85%\n",
      "Iter = 9050 => train_acc = 91.39% / test_acc = 87.85%\n",
      "Iter = 9100 => train_acc = 91.40% / test_acc = 87.89%\n",
      "Iter = 9150 => train_acc = 91.43% / test_acc = 87.91%\n",
      "Iter = 9200 => train_acc = 91.46% / test_acc = 87.92%\n",
      "Iter = 9250 => train_acc = 91.49% / test_acc = 87.94%\n",
      "Iter = 9300 => train_acc = 91.51% / test_acc = 87.95%\n",
      "Iter = 9350 => train_acc = 91.53% / test_acc = 87.96%\n",
      "Iter = 9400 => train_acc = 91.53% / test_acc = 87.95%\n",
      "Iter = 9450 => train_acc = 91.55% / test_acc = 87.97%\n",
      "Iter = 9500 => train_acc = 91.57% / test_acc = 88.01%\n",
      "Iter = 9550 => train_acc = 91.59% / test_acc = 88.06%\n",
      "Iter = 9600 => train_acc = 91.62% / test_acc = 88.07%\n",
      "Iter = 9650 => train_acc = 91.64% / test_acc = 88.10%\n",
      "Iter = 9700 => train_acc = 91.65% / test_acc = 88.12%\n",
      "Iter = 9750 => train_acc = 91.68% / test_acc = 88.13%\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-43-35bee3fa0099>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     26\u001b[0m \u001b[0;31m#             preds = model(images)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     27\u001b[0m             \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 28\u001b[0;31m                 \u001b[0mpreds\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimages\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     29\u001b[0m             \u001b[0;32mexcept\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mUnboundLocalError\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     30\u001b[0m                 \u001b[0mtrain\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfunction\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel_func\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args, **kwds)\u001b[0m\n\u001b[1;32m    826\u001b[0m     \u001b[0mtracing_count\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexperimental_get_tracing_count\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    827\u001b[0m     \u001b[0;32mwith\u001b[0m \u001b[0mtrace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTrace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_name\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mtm\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 828\u001b[0;31m       \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwds\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    829\u001b[0m       \u001b[0mcompiler\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m\"xla\"\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_experimental_compile\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;34m\"nonXla\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    830\u001b[0m       \u001b[0mnew_tracing_count\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexperimental_get_tracing_count\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py\u001b[0m in \u001b[0;36m_call\u001b[0;34m(self, *args, **kwds)\u001b[0m\n\u001b[1;32m    860\u001b[0m       \u001b[0;31m# In this case we have not created variables on the first call. So we can\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    861\u001b[0m       \u001b[0;31m# run the first trace but we should fail if variables are created.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 862\u001b[0;31m       \u001b[0mresults\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_stateful_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwds\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    863\u001b[0m       \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_created_variables\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    864\u001b[0m         raise ValueError(\"Creating variables on a non-first call to a function\"\n",
      "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   2941\u001b[0m        filtered_flat_args) = self._maybe_define_function(args, kwargs)\n\u001b[1;32m   2942\u001b[0m     return graph_function._call_flat(\n\u001b[0;32m-> 2943\u001b[0;31m         filtered_flat_args, captured_inputs=graph_function.captured_inputs)  # pylint: disable=protected-access\n\u001b[0m\u001b[1;32m   2944\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2945\u001b[0m   \u001b[0;34m@\u001b[0m\u001b[0mproperty\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py\u001b[0m in \u001b[0;36m_call_flat\u001b[0;34m(self, args, captured_inputs, cancellation_manager)\u001b[0m\n\u001b[1;32m   1925\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mexecuting_eagerly\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1926\u001b[0m       flat_outputs = forward_function.call(\n\u001b[0;32m-> 1927\u001b[0;31m           ctx, args_with_tangents, cancellation_manager=cancellation_manager)\n\u001b[0m\u001b[1;32m   1928\u001b[0m     \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1929\u001b[0m       with default_graph._override_gradient_function(  # pylint: disable=protected-access\n",
      "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py\u001b[0m in \u001b[0;36mcall\u001b[0;34m(self, ctx, args, cancellation_manager)\u001b[0m\n\u001b[1;32m    558\u001b[0m               \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    559\u001b[0m               \u001b[0mattrs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mattrs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 560\u001b[0;31m               ctx=ctx)\n\u001b[0m\u001b[1;32m    561\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    562\u001b[0m           outputs = execute.execute_with_cancellation(\n",
      "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/execute.py\u001b[0m in \u001b[0;36mquick_execute\u001b[0;34m(op_name, num_outputs, inputs, attrs, ctx, name)\u001b[0m\n\u001b[1;32m     58\u001b[0m     \u001b[0mctx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mensure_initialized\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     59\u001b[0m     tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,\n\u001b[0;32m---> 60\u001b[0;31m                                         inputs, attrs, num_outputs)\n\u001b[0m\u001b[1;32m     61\u001b[0m   \u001b[0;32mexcept\u001b[0m \u001b[0mcore\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_NotOkStatusException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     62\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mname\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "training = []\n",
    "testing = []\n",
    "num_params = 0\n",
    "train_acc = []\n",
    "test_acc = []\n",
    "train_loss = []\n",
    "\n",
    "start = time.time()\n",
    "\n",
    "for meta_iter in range(int(ckpt.step), meta_iters):   \n",
    "    mini_dataset = train_dataset.get_mini_dataset(\n",
    "        inner_batch_size, inner_iters, train_shots, classes\n",
    "    )\n",
    "    \n",
    "    checkpoint_manager(meta_iter)\n",
    "        \n",
    "    frac_done = meta_iter / meta_iters\n",
    "    cur_meta_step_size = (1 - frac_done) * meta_step_size\n",
    "    # Temporarily save the weights from the model.\n",
    "    old_vars = model.get_weights()\n",
    "    # Get a sample from the full dataset.\n",
    "    \n",
    "    j = 0\n",
    "    for images, labels in mini_dataset:\n",
    "        with tf.GradientTape() as tape:\n",
    "#             preds = model(images)\n",
    "            try:\n",
    "                preds = train(images)\n",
    "            except (UnboundLocalError, ValueError):\n",
    "                train = tf.function(model_func)\n",
    "                preds = train(images)\n",
    "            loss = keras.losses.sparse_categorical_crossentropy(labels, preds)\n",
    "        grads = tape.gradient(loss, model.trainable_weights)\n",
    "\n",
    "        optimizer.apply_gradients(zip(grads, model.trainable_weights))\n",
    "        j += 1\n",
    "    new_vars = model.get_weights()\n",
    "    # Perform SGD for the meta step.\n",
    "    for var in range(len(new_vars)):\n",
    "        new_vars[var] = old_vars[var] + (\n",
    "            (new_vars[var] - old_vars[var]) * cur_meta_step_size\n",
    "        )\n",
    "    # After the meta-learning step, reload the newly-trained weights into the model.\n",
    "    model.set_weights(new_vars)\n",
    "    if meta_iter == 50:\n",
    "        print(\"\\n#### The first 50 iterations took {:.2f} secs ####\".format(time.time()-start))\n",
    "    \n",
    "    # Evaluation loop\n",
    "    if meta_iter % 1 == 0:\n",
    "        accuracies = []\n",
    "        for dataset in (train_dataset, test_dataset):\n",
    "            # Sample a mini dataset from the full dataset.\n",
    "            train_set, test_images, test_labels = dataset.get_mini_dataset(\n",
    "                eval_batch_size, eval_iters, shots, classes, split=True\n",
    "            )\n",
    "            old_vars = model.get_weights()\n",
    "            # Train on the samples and get the resulting accuracies.\n",
    "            for images, labels in train_set:\n",
    "                with tf.GradientTape() as tape:\n",
    "#                     preds = model(images)\n",
    "                    try:\n",
    "                        preds = train(images)\n",
    "                    except (UnboundLocalError, ValueError):\n",
    "                        train = tf.function(model_func)\n",
    "                        preds = train(images)\n",
    "                    loss = keras.losses.sparse_categorical_crossentropy(labels, preds)                          \n",
    "                grads = tape.gradient(loss, model.trainable_weights)\n",
    "                optimizer.apply_gradients(zip(grads, model.trainable_weights))\n",
    "            \n",
    "            try:\n",
    "                test_preds = train(test_images)\n",
    "            except (UnboundLocalError, ValueError):\n",
    "                train = tf.function(model_func)\n",
    "                test_preds = train(test_images)\n",
    "#             test_preds = model.predict(test_images)\n",
    "            \n",
    "            train_loss_ = keras.losses.sparse_categorical_crossentropy(test_labels, test_preds) \n",
    "            test_preds = tf.argmax(test_preds).numpy()\n",
    "                        \n",
    "            num_correct = (test_preds == test_labels).sum()\n",
    "           \n",
    "            # Reset the weights after getting the evaluation accuracies.\n",
    "            model.set_weights(old_vars)\n",
    "            accuracies.append(num_correct / classes)\n",
    "       \n",
    "        training.append(accuracies[0])\n",
    "        testing.append(accuracies[1])\n",
    "        train_acc.append(100*np.mean(training))\n",
    "        test_acc.append(100*np.mean(testing))\n",
    "        train_loss.append(tf.reduce_mean(train_loss_))\n",
    "\n",
    "        with open('print_metrics_folder/train_acc_reptile_no_lwta.txt', 'a+') as file:\n",
    "            file.write(\"%f\\n\" % (100.0*np.mean(training)))\n",
    "\n",
    "        with open('print_metrics_folder/train_loss_reptile_no_lwta.txt', 'a+') as file:\n",
    "            file.write(\"%f\\n\" % (tf.reduce_mean(train_loss_).numpy()))\n",
    "\n",
    "        with open('print_metrics_folder/test_acc_reptile_no_lwta.txt', 'a+') as file:\n",
    "            file.write(\"%f\\n\" % (100.0*np.mean(testing)))\n",
    "        \n",
    "        if meta_iter % report_frequency == 0:\n",
    "            print(\"Iter = %d => train_acc = %.2f%% / test_acc = %.2f%%\" % (meta_iter,\n",
    "                  100*np.mean(training),100*np.mean(testing)))\n",
    "\n",
    "end = time.time()\n",
    "print(\"The training took {:.2f} secs\".format(end-start))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "collapsed": true,
    "id": "wQ60Xj9ufvWd",
    "outputId": "03a7b218-decf-43b2-e7b7-0a0aab628395"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Final training accuracy = 83.8 %%\n",
      "Final testing accuracy = 68.6 %%\n"
     ]
    }
   ],
   "source": [
    "print(\"Final training accuracy = {:.1f} %%\".format(100*np.mean(training)))\n",
    "print(\"Final testing accuracy = {:.1f} %%\".format(100*np.mean(testing)))"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "reptile_omniglot.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "00eb811349f14c848f0bb2c8e4fab65f": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "04da630f62a040acb30c3ca55b589f5d": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "0622f3b6ac3b4c1a93b8e2e3616e8b00": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "06b95880542941b8a819aa28baa2e0e8": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "086d178b0dbe48dd897ee9b125dcd464": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "0bdba0100ae34522a621c53f28cc3e8d": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "0ca1b2b2206b43e1af3d03da4fa64503": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "0fc193a8863e479fb87055e664de8b9d": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "0fd1b28baa684abd8c8e6ff7e24ab19f": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "danger",
      "description": " 75%",
      "description_tooltip": null,
      "layout": "IPY_MODEL_00eb811349f14c848f0bb2c8e4fab65f",
      "max": 19280,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_d45c15799fca46239f6797f894a852cd",
      "value": 14528
     }
    },
    "12636d6f48804f8d93b21f3ab54bcf6c": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "17e640fbeb5f458287117311c12b2bbd": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_b4540f3e948047b7a18a428860e55652",
       "IPY_MODEL_fbc46482c9eb442a9458f0b619443ef1"
      ],
      "layout": "IPY_MODEL_9e6af9e40f1f489fb6e1514f6e9ad8f6"
     }
    },
    "1a08f76b599747a1933c758ac0bf397e": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "203034d88830486c92067b44ea451890": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_983c38b96f314039aae4e3be8a452f9e",
       "IPY_MODEL_d8167d1417bd415c8d7c66194cf5f6b7"
      ],
      "layout": "IPY_MODEL_45f90aa79c1d46158f58f5d2682450fd"
     }
    },
    "20a7a110d571405da19829ad1727643b": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_5793ada4367f46708da317de1652752f",
       "IPY_MODEL_520eeb9985fb46929f90de6d2f24e999"
      ],
      "layout": "IPY_MODEL_5749fe77e4f948e18033d46beea5718e"
     }
    },
    "2606945c3d5744408320c95aad7eeb88": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_56869a904c514f5e8835c517e9245fbf",
       "IPY_MODEL_6df570404ba44b9daad834e269c59c49"
      ],
      "layout": "IPY_MODEL_12636d6f48804f8d93b21f3ab54bcf6c"
     }
    },
    "28a53bac50b14a1caf72690a9c2a0a93": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "31ef7325a71b47eab92e3a4dcc0b5318": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": "initial"
     }
    },
    "32edaffc9d8d4dfb90b9d99ed640f0e5": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "3789fd81836e435691fcac8a429c900d": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": "initial"
     }
    },
    "3c167a3dde704c79979311b872948e76": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "3dbbde2390f045eda9f428746ebaa795": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "4505d9dfa62c4826ad96da9ab397e352": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "45f90aa79c1d46158f58f5d2682450fd": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "469d3773a7114477bbe21e5086c06aec": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "471de552d73f48f7a244cb2e298f62bc": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": "initial"
     }
    },
    "5118178ff1e14036969c18ad2faf722a": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_b7d109ba579741ff8f74e38ce66cda0a",
      "placeholder": "​",
      "style": "IPY_MODEL_7b87243a89ab4f2fa596a09188ebeadf",
      "value": " 4/4 [00:13&lt;00:00,  3.40s/ file]"
     }
    },
    "51b83ae12ca74a2abde28cdc977b8ed2": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": "initial"
     }
    },
    "520eeb9985fb46929f90de6d2f24e999": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_a4f380d9686946308d483d78a121a3ea",
      "placeholder": "​",
      "style": "IPY_MODEL_469d3773a7114477bbe21e5086c06aec",
      "value": " 19280/0 [00:08&lt;00:00, 2315.00 examples/s]"
     }
    },
    "550f8801d97c417e8d519f1a6b07146e": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_78e64e6e95f6447d859ae1170365f94c",
       "IPY_MODEL_5118178ff1e14036969c18ad2faf722a"
      ],
      "layout": "IPY_MODEL_0bdba0100ae34522a621c53f28cc3e8d"
     }
    },
    "56869a904c514f5e8835c517e9245fbf": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "danger",
      "description": "  0%",
      "description_tooltip": null,
      "layout": "IPY_MODEL_61e50e415fd64b6abb13926f0747aff2",
      "max": 2720,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_5856779e14464d9b867c51eceb5d0dfc",
      "value": 0
     }
    },
    "5749fe77e4f948e18033d46beea5718e": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "5793ada4367f46708da317de1652752f": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "info",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_4505d9dfa62c4826ad96da9ab397e352",
      "max": 1,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_31ef7325a71b47eab92e3a4dcc0b5318",
      "value": 1
     }
    },
    "57bdadfb0c4247b49371f24941acf784": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "5856779e14464d9b867c51eceb5d0dfc": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": "initial"
     }
    },
    "5ad1704befa648a3874ef7174ef3e5ed": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "61e50e415fd64b6abb13926f0747aff2": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "622ddc39ade64c77b160cbd66e6752b2": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "6a4750d06897462eadbdcd673bf21e43": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_71345d17bc68497f8f4f5b3a3ff125f1",
       "IPY_MODEL_e0b0681af49c463aa93f2ec7526a7b90"
      ],
      "layout": "IPY_MODEL_8ccf110725f14ee5b8c58b84f99c39f9"
     }
    },
    "6bf7532e29754f839963c7eb7a58e200": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_0fd1b28baa684abd8c8e6ff7e24ab19f",
       "IPY_MODEL_8793819ee22c431789fe05792b71ef50"
      ],
      "layout": "IPY_MODEL_cd945007c0be4a568e77b96944676717"
     }
    },
    "6df570404ba44b9daad834e269c59c49": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_c61f26881c75493f9d9cdc06628a867d",
      "placeholder": "​",
      "style": "IPY_MODEL_c9e2d236024e4644b3c6607ca127ab21",
      "value": " 0/2720 [00:00&lt;?, ? examples/s]"
     }
    },
    "71345d17bc68497f8f4f5b3a3ff125f1": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "danger",
      "description": "  0%",
      "description_tooltip": null,
      "layout": "IPY_MODEL_1a08f76b599747a1933c758ac0bf397e",
      "max": 3120,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_ebb18b0a97254ad4857872f885f43d80",
      "value": 0
     }
    },
    "756c0aa0fe9e4fd88c8bc6752c7d5bf1": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "78e64e6e95f6447d859ae1170365f94c": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "success",
      "description": "Extraction completed...: 100%",
      "description_tooltip": null,
      "layout": "IPY_MODEL_b171d93947a14520907b59f66697574e",
      "max": 1,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_d5e7d7197f57418e82e19a2bee67f3bd",
      "value": 1
     }
    },
    "79d648274ec14a348109fcd6a66dace2": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "danger",
      "description": "  0%",
      "description_tooltip": null,
      "layout": "IPY_MODEL_8be07adc4c3c4f1796c6c0d2192e625a",
      "max": 13180,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_dd248e0814a64ee8a63090fee78e962d",
      "value": 0
     }
    },
    "7b87243a89ab4f2fa596a09188ebeadf": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "7c7b964330724bb2b222233f13cb1839": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": "initial"
     }
    },
    "86f877a4124c4439a7d4551a488ef4e3": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "8793819ee22c431789fe05792b71ef50": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_d55437c1a0f346b7b0494eaea8988236",
      "placeholder": "​",
      "style": "IPY_MODEL_32edaffc9d8d4dfb90b9d99ed640f0e5",
      "value": " 14528/19280 [00:00&lt;00:00, 145072.26 examples/s]"
     }
    },
    "89e979baca37455496ded4f36586ff03": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "8be07adc4c3c4f1796c6c0d2192e625a": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "8ccf110725f14ee5b8c58b84f99c39f9": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "92dbd1a4c73848c3bbeba68ff0c6f31d": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "info",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_756c0aa0fe9e4fd88c8bc6752c7d5bf1",
      "max": 1,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_3789fd81836e435691fcac8a429c900d",
      "value": 1
     }
    },
    "983c38b96f314039aae4e3be8a452f9e": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "success",
      "description": "Dl Size...: 100%",
      "description_tooltip": null,
      "layout": "IPY_MODEL_89e979baca37455496ded4f36586ff03",
      "max": 1,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_7c7b964330724bb2b222233f13cb1839",
      "value": 1
     }
    },
    "9e1f957b1e344d428ed19e26dc3b2f38": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_ff3f698a52ce4422aa0666a67ec97738",
      "placeholder": "​",
      "style": "IPY_MODEL_0622f3b6ac3b4c1a93b8e2e3616e8b00",
      "value": " 13180/0 [00:05&lt;00:00, 2293.03 examples/s]"
     }
    },
    "9e6af9e40f1f489fb6e1514f6e9ad8f6": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "a2deeecce6d3496d83a0f42c62aa3f3b": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "a4f380d9686946308d483d78a121a3ea": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "a7f926e82ae346928221be0cdc83cb00": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_92dbd1a4c73848c3bbeba68ff0c6f31d",
       "IPY_MODEL_9e1f957b1e344d428ed19e26dc3b2f38"
      ],
      "layout": "IPY_MODEL_86f877a4124c4439a7d4551a488ef4e3"
     }
    },
    "b08991fb355f444f9f8d6be2ea099142": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "b171d93947a14520907b59f66697574e": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "b4540f3e948047b7a18a428860e55652": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "success",
      "description": "Dl Completed...: 100%",
      "description_tooltip": null,
      "layout": "IPY_MODEL_c770b8fe48a44039a1a8486180df589d",
      "max": 1,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_e66c0534282e4c55a21bd149fb32aab8",
      "value": 1
     }
    },
    "b707ee96ad21479fbd0608e8158c5276": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "b7d109ba579741ff8f74e38ce66cda0a": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "bf1f44068c034278b6be5e12393898e1": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_e5479b3849034671a1be5bad552bb299",
       "IPY_MODEL_d62201089e8446a1ae0dd483b9a31686"
      ],
      "layout": "IPY_MODEL_0ca1b2b2206b43e1af3d03da4fa64503"
     }
    },
    "bf9a2d6ff3b042d98749cdb2d8f00948": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "c1e4612e2d234f3299069f2c292e3ed3": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_bf9a2d6ff3b042d98749cdb2d8f00948",
      "placeholder": "​",
      "style": "IPY_MODEL_5ad1704befa648a3874ef7174ef3e5ed",
      "value": " 0/13180 [00:00&lt;?, ? examples/s]"
     }
    },
    "c61f26881c75493f9d9cdc06628a867d": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "c62679f2de2e4e0c954082d6f752160c": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_79d648274ec14a348109fcd6a66dace2",
       "IPY_MODEL_c1e4612e2d234f3299069f2c292e3ed3"
      ],
      "layout": "IPY_MODEL_cfb9c99287c24d5fb3ff5fca1b3c2f37"
     }
    },
    "c770b8fe48a44039a1a8486180df589d": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "c9e2d236024e4644b3c6607ca127ab21": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "cad7f13a33b84e78b76336792fe46cea": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "cd945007c0be4a568e77b96944676717": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "cfb9c99287c24d5fb3ff5fca1b3c2f37": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "d45c15799fca46239f6797f894a852cd": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": "initial"
     }
    },
    "d55437c1a0f346b7b0494eaea8988236": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "d5e7d7197f57418e82e19a2bee67f3bd": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": "initial"
     }
    },
    "d62201089e8446a1ae0dd483b9a31686": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_0fc193a8863e479fb87055e664de8b9d",
      "placeholder": "​",
      "style": "IPY_MODEL_cad7f13a33b84e78b76336792fe46cea",
      "value": " 3120/0 [00:01&lt;00:00, 2339.33 examples/s]"
     }
    },
    "d8167d1417bd415c8d7c66194cf5f6b7": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_3dbbde2390f045eda9f428746ebaa795",
      "placeholder": "​",
      "style": "IPY_MODEL_28a53bac50b14a1caf72690a9c2a0a93",
      "value": " 17/17 [00:13&lt;00:00,  1.25 MiB/s]"
     }
    },
    "dc0c030fd626423f815f22e91e591030": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_de03e722d1a449cb99bab16e69f1e1c3",
       "IPY_MODEL_e932a46875654a3ebfe14802be27418a"
      ],
      "layout": "IPY_MODEL_086d178b0dbe48dd897ee9b125dcd464"
     }
    },
    "dd248e0814a64ee8a63090fee78e962d": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": "initial"
     }
    },
    "de03e722d1a449cb99bab16e69f1e1c3": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "info",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_3c167a3dde704c79979311b872948e76",
      "max": 1,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_51b83ae12ca74a2abde28cdc977b8ed2",
      "value": 1
     }
    },
    "e0b0681af49c463aa93f2ec7526a7b90": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_a2deeecce6d3496d83a0f42c62aa3f3b",
      "placeholder": "​",
      "style": "IPY_MODEL_06b95880542941b8a819aa28baa2e0e8",
      "value": " 0/3120 [00:00&lt;?, ? examples/s]"
     }
    },
    "e5479b3849034671a1be5bad552bb299": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "info",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_57bdadfb0c4247b49371f24941acf784",
      "max": 1,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_471de552d73f48f7a244cb2e298f62bc",
      "value": 1
     }
    },
    "e66c0534282e4c55a21bd149fb32aab8": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": "initial"
     }
    },
    "e932a46875654a3ebfe14802be27418a": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_b08991fb355f444f9f8d6be2ea099142",
      "placeholder": "​",
      "style": "IPY_MODEL_622ddc39ade64c77b160cbd66e6752b2",
      "value": " 2720/0 [00:01&lt;00:00, 2275.61 examples/s]"
     }
    },
    "ebb18b0a97254ad4857872f885f43d80": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": "initial"
     }
    },
    "fbc46482c9eb442a9458f0b619443ef1": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_b707ee96ad21479fbd0608e8158c5276",
      "placeholder": "​",
      "style": "IPY_MODEL_04da630f62a040acb30c3ca55b589f5d",
      "value": " 4/4 [00:13&lt;00:00,  3.42s/ url]"
     }
    },
    "ff3f698a52ce4422aa0666a67ec97738": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    }
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
