{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "id": "TCP7Y2QklcrQ"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "from skimage.io import imread_collection\n",
    "import pathlib\n",
    "import random\n",
    "import time\n",
    "import PIL\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import random\n",
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "from tensorflow.keras import layers\n",
    "import tensorflow_datasets as tfds\n",
    "from tensorflow.keras.regularizers import l1,l2\n",
    "from LWTA.base import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "id": "v8x00gbSlcrR"
   },
   "outputs": [],
   "source": [
    "learning_rate = 0.003\n",
    "optimizer = keras.optimizers.SGD(learning_rate=learning_rate)  \n",
    "\n",
    "meta_step_size = 0.25\n",
    "inner_batch_size = 25\n",
    "eval_batch_size = 25\n",
    "\n",
    "meta_iters = 200000\n",
    "eval_iters = 5\n",
    "inner_iters = 4\n",
    "\n",
    "eval_interval = 50\n",
    "report_frequency = 50\n",
    "checkpoint_freq = 1000\n",
    "train_shots = 20\n",
    "shots = 5 # the k for => k-shot N-way\n",
    "classes = 5 # the N for => k-shot N-way\n",
    "LWTA = True\n",
    "BMA = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(38400, 12000)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# read images\n",
    "train_images_names = []\n",
    "\n",
    "for root, dirs, files in os.walk(\"mini-Imagenet/train\"):\n",
    "    for filename in files:\n",
    "        train_images_names.append(filename)\n",
    "\n",
    "test_images_names = []\n",
    "\n",
    "for root, dirs, files in os.walk(\"mini-Imagenet/test\"):\n",
    "    for filename in files:\n",
    "        test_images_names.append(filename)  \n",
    "\n",
    "len(train_images_names), len(test_images_names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "([615, 295, 153, 41], [938, 660, 1271, 1044])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# create labels\n",
    "train_labels = []\n",
    "\n",
    "for image in train_images_names:\n",
    "    train_labels.append(int(image[9:-4]))\n",
    "\n",
    "test_labels = []\n",
    "\n",
    "for image in test_images_names:\n",
    "    test_labels.append(int(image[9:-4]))\n",
    "\n",
    "train_labels[0:4], test_labels[0:4]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# train_filenames = tf.constant(train_images_names)\n",
    "# train_labels = tf.constant(train_labels)\n",
    "# test_filenames = tf.constant(test_images_names)\n",
    "# test_labels = tf.constant(test_labels)\n",
    "\n",
    "# train_dataset = tf.data.Dataset.from_tensor_slices((train_filenames, train_labels))\n",
    "# test_dataset = tf.data.Dataset.from_tensor_slices((test_filenames, test_labels))\n",
    "\n",
    "# # def im_file_to_tensor(file, label):\n",
    "# #     def _im_file_to_tensor(file, label):\n",
    "# #         path = f\"/mini-Imagenet/train/n01532829/{file.numpy().decode()}\"\n",
    "# #         im = tf.image.decode_jpeg(tf.io.read_file(path), channels=3)\n",
    "# #         im = tf.cast(image_decoded, tf.float32) / 255.0\n",
    "# #         return im, label\n",
    "# #     return tf.py_function(_im_file_to_tensor, \n",
    "# #                           inp=(file, label), \n",
    "# #                           Tout=(tf.float32, tf.uint8))\n",
    "\n",
    "# # train_dataset = train_dataset.map(im_file_to_tensor)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# train_ds_paths = tf.data.Dataset.list_files(str('mini-Imagenet/train/'+'*/*'))\n",
    "# test_ds_paths = tf.data.Dataset.list_files(str('mini-Imagenet/test/'+'*/*'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# def process_path(file_path):\n",
    "#     label = tf.strings.split(file_path, os.sep)[-2]\n",
    "#     return tf.io.read_file(file_path), label\n",
    "\n",
    "# train_ds = train_ds_paths.map(process_path)\n",
    "# train_ds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# train_data = []\n",
    "\n",
    "# for x in train_ds_paths:\n",
    "#     an_image = PIL.Image.open(x.numpy().decode(\"utf-8\") )\n",
    "\n",
    "#     image_sequence = an_image.getdata()\n",
    "#     image_array = np.array(image_sequence)\n",
    "#     dim_1 = dim_2 = int(np.sqrt(image_array.shape[0]))\n",
    "#     image_array = image_array.reshape((dim_1, dim_2, channels))\n",
    "#     train_data.append(image_array)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# test_data = []\n",
    "\n",
    "# for x in test_ds_paths:\n",
    "#     an_image = PIL.Image.open(x.numpy().decode(\"utf-8\") )\n",
    "\n",
    "#     image_sequence = an_image.getdata()\n",
    "#     image_array = np.array(image_sequence)\n",
    "#     dim_1 = dim_2 = int(np.sqrt(image_array.shape[0]))\n",
    "#     image_array = image_array.reshape((dim_1, dim_2, channels))\n",
    "#     test_data.append(image_array)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# with open('train_data.npy', 'wb') as f:\n",
    "#     np.save(f, train_data)\n",
    "    \n",
    "# with open('test_data.npy', 'wb') as f:\n",
    "#     np.save(f, test_data)\n",
    "\n",
    "with open('train_data.npy', 'rb') as f:\n",
    "    train_data = np.load(f)\n",
    "    \n",
    "with open('test_data.npy', 'rb') as f:\n",
    "    test_data = np.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data = train_data.astype(np.uint8)\n",
    "test_data = test_data.astype(np.uint8)\n",
    "\n",
    "train_dataset_demo = tf.data.Dataset.from_tensor_slices((train_data, train_labels))\n",
    "test_dataset_demo = tf.data.Dataset.from_tensor_slices((test_data, test_labels))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 478,
     "referenced_widgets": [
      "17e640fbeb5f458287117311c12b2bbd",
      "9e6af9e40f1f489fb6e1514f6e9ad8f6",
      "b4540f3e948047b7a18a428860e55652",
      "fbc46482c9eb442a9458f0b619443ef1",
      "e66c0534282e4c55a21bd149fb32aab8",
      "c770b8fe48a44039a1a8486180df589d",
      "04da630f62a040acb30c3ca55b589f5d",
      "b707ee96ad21479fbd0608e8158c5276",
      "203034d88830486c92067b44ea451890",
      "45f90aa79c1d46158f58f5d2682450fd",
      "983c38b96f314039aae4e3be8a452f9e",
      "d8167d1417bd415c8d7c66194cf5f6b7",
      "7c7b964330724bb2b222233f13cb1839",
      "89e979baca37455496ded4f36586ff03",
      "28a53bac50b14a1caf72690a9c2a0a93",
      "3dbbde2390f045eda9f428746ebaa795",
      "550f8801d97c417e8d519f1a6b07146e",
      "0bdba0100ae34522a621c53f28cc3e8d",
      "78e64e6e95f6447d859ae1170365f94c",
      "5118178ff1e14036969c18ad2faf722a",
      "d5e7d7197f57418e82e19a2bee67f3bd",
      "b171d93947a14520907b59f66697574e",
      "7b87243a89ab4f2fa596a09188ebeadf",
      "b7d109ba579741ff8f74e38ce66cda0a",
      "20a7a110d571405da19829ad1727643b",
      "5749fe77e4f948e18033d46beea5718e",
      "5793ada4367f46708da317de1652752f",
      "520eeb9985fb46929f90de6d2f24e999",
      "31ef7325a71b47eab92e3a4dcc0b5318",
      "4505d9dfa62c4826ad96da9ab397e352",
      "469d3773a7114477bbe21e5086c06aec",
      "a4f380d9686946308d483d78a121a3ea",
      "6bf7532e29754f839963c7eb7a58e200",
      "cd945007c0be4a568e77b96944676717",
      "0fd1b28baa684abd8c8e6ff7e24ab19f",
      "8793819ee22c431789fe05792b71ef50",
      "d45c15799fca46239f6797f894a852cd",
      "00eb811349f14c848f0bb2c8e4fab65f",
      "32edaffc9d8d4dfb90b9d99ed640f0e5",
      "d55437c1a0f346b7b0494eaea8988236",
      "a7f926e82ae346928221be0cdc83cb00",
      "86f877a4124c4439a7d4551a488ef4e3",
      "92dbd1a4c73848c3bbeba68ff0c6f31d",
      "9e1f957b1e344d428ed19e26dc3b2f38",
      "3789fd81836e435691fcac8a429c900d",
      "756c0aa0fe9e4fd88c8bc6752c7d5bf1",
      "0622f3b6ac3b4c1a93b8e2e3616e8b00",
      "ff3f698a52ce4422aa0666a67ec97738",
      "c62679f2de2e4e0c954082d6f752160c",
      "cfb9c99287c24d5fb3ff5fca1b3c2f37",
      "79d648274ec14a348109fcd6a66dace2",
      "c1e4612e2d234f3299069f2c292e3ed3",
      "dd248e0814a64ee8a63090fee78e962d",
      "8be07adc4c3c4f1796c6c0d2192e625a",
      "5ad1704befa648a3874ef7174ef3e5ed",
      "bf9a2d6ff3b042d98749cdb2d8f00948",
      "dc0c030fd626423f815f22e91e591030",
      "086d178b0dbe48dd897ee9b125dcd464",
      "de03e722d1a449cb99bab16e69f1e1c3",
      "e932a46875654a3ebfe14802be27418a",
      "51b83ae12ca74a2abde28cdc977b8ed2",
      "3c167a3dde704c79979311b872948e76",
      "622ddc39ade64c77b160cbd66e6752b2",
      "b08991fb355f444f9f8d6be2ea099142",
      "2606945c3d5744408320c95aad7eeb88",
      "12636d6f48804f8d93b21f3ab54bcf6c",
      "56869a904c514f5e8835c517e9245fbf",
      "6df570404ba44b9daad834e269c59c49",
      "5856779e14464d9b867c51eceb5d0dfc",
      "61e50e415fd64b6abb13926f0747aff2",
      "c9e2d236024e4644b3c6607ca127ab21",
      "c61f26881c75493f9d9cdc06628a867d",
      "bf1f44068c034278b6be5e12393898e1",
      "0ca1b2b2206b43e1af3d03da4fa64503",
      "e5479b3849034671a1be5bad552bb299",
      "d62201089e8446a1ae0dd483b9a31686",
      "471de552d73f48f7a244cb2e298f62bc",
      "57bdadfb0c4247b49371f24941acf784",
      "cad7f13a33b84e78b76336792fe46cea",
      "0fc193a8863e479fb87055e664de8b9d",
      "6a4750d06897462eadbdcd673bf21e43",
      "8ccf110725f14ee5b8c58b84f99c39f9",
      "71345d17bc68497f8f4f5b3a3ff125f1",
      "e0b0681af49c463aa93f2ec7526a7b90",
      "ebb18b0a97254ad4857872f885f43d80",
      "1a08f76b599747a1933c758ac0bf397e",
      "06b95880542941b8a819aa28baa2e0e8",
      "a2deeecce6d3496d83a0f42c62aa3f3b"
     ]
    },
    "id": "_cLhxSASlcrR",
    "outputId": "909150ee-074a-448f-943e-ebb7d9900616"
   },
   "outputs": [],
   "source": [
    "class Dataset:\n",
    "    # This class will facilitate the creation of a few-shot dataset\n",
    "    # from the Mini-Imagenet dataset that can be sampled from quickly while also\n",
    "    # allowing to create new labels at the same time.\n",
    "    def __init__(self, training):\n",
    "        # Download the tfrecord files containing the omniglot data and convert to a\n",
    "        # dataset.\n",
    "        if training:\n",
    "            ds = train_dataset_demo\n",
    "        else:\n",
    "            ds = test_dataset_demo\n",
    "        \n",
    "        # Iterate over the dataset to get each individual image and its class,\n",
    "        # and put that data into a dictionary.\n",
    "        self.data = {}\n",
    "        \n",
    "        def extraction(image, label):           \n",
    "            # This function will shrink the Omniglot images to the desired size,\n",
    "            # scale pixel values and convert the RGB image to grayscale\n",
    "            image = tf.image.convert_image_dtype(image, tf.float32)\n",
    "            image = tf.image.resize(image, [84, 84])\n",
    "            return image, label\n",
    "\n",
    "        for image, label in ds.map(extraction):\n",
    "            image = image.numpy()\n",
    "            label = str(label.numpy())\n",
    "            if label not in self.data:\n",
    "                self.data[label] = []\n",
    "            self.data[label].append(image)\n",
    "            self.labels = list(self.data.keys())\n",
    "       \n",
    "    def get_mini_dataset(\n",
    "        self, batch_size, repetitions, shots, num_classes, split=False\n",
    "    ):\n",
    "        temp_labels = np.zeros(shape=(num_classes * shots))\n",
    "        temp_images = np.zeros(shape=(num_classes * shots, 84, 84, 3))\n",
    "        if split:\n",
    "            test_labels = np.zeros(shape=(num_classes))\n",
    "            test_images = np.zeros(shape=(num_classes, 84, 84, 3))\n",
    "\n",
    "        # Get a random subset of labels from the entire label set.\n",
    "        label_subset = random.choices(self.labels, k=num_classes)\n",
    "        for class_idx, class_obj in enumerate(label_subset):\n",
    "            # Use enumerated index value as a temporary label for mini-batch in\n",
    "            # few shot learning.\n",
    "            temp_labels[class_idx * shots : (class_idx + 1) * shots] = class_idx\n",
    "            # If creating a split dataset for testing, select an extra sample from each\n",
    "            # label to create the test dataset.\n",
    "            if split:\n",
    "                test_labels[class_idx] = class_idx\n",
    "                images_to_split = random.choices(\n",
    "                    self.data[label_subset[class_idx]], k=shots + 1\n",
    "                )\n",
    "                test_images[class_idx] = images_to_split[-1]\n",
    "                temp_images[\n",
    "                    class_idx * shots : (class_idx + 1) * shots\n",
    "                ] = images_to_split[:-1]\n",
    "            else:\n",
    "                # For each index in the randomly selected label_subset, sample the\n",
    "                # necessary number of images.\n",
    "                temp_images[\n",
    "                    class_idx * shots : (class_idx + 1) * shots\n",
    "                ] = random.choices(self.data[label_subset[class_idx]], k=shots)\n",
    "\n",
    "        dataset = tf.data.Dataset.from_tensor_slices(\n",
    "            (temp_images.astype(np.float32), temp_labels.astype(np.int32))\n",
    "        )\n",
    "        dataset = dataset.shuffle(100).batch(batch_size).repeat(repetitions)\n",
    "        if split:\n",
    "            return dataset, test_images, test_labels\n",
    "        return dataset\n",
    "\n",
    "import urllib3\n",
    "\n",
    "urllib3.disable_warnings()  # Disable SSL warnings that may happen during download.\n",
    "train_dataset = Dataset(training=True)\n",
    "test_dataset = Dataset(training=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ma-HzYXYlcrS"
   },
   "source": [
    "## Train the model (with dense_LWTA)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "sb_class = LwtaClassifier(original_dim = [classes,1], tau = 5e-2, bma=BMA,dataset=\"Imagenet\",\n",
    "                          deterministic=False) \n",
    "\n",
    "def train_func(x, train=True, activation=\"lwta\"):\n",
    "    return sb_class(x, train=train, activation=activation)\n",
    "\n",
    "train = tf.function(train_func)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "id": "UTGH7yOilcrT"
   },
   "outputs": [],
   "source": [
    "# Create checkpoint for resuming training \n",
    "ckpt = tf.train.Checkpoint(step=tf.Variable(0), optimizer=optimizer, net=sb_class)\n",
    "ckpt_path = f\"./checkpoints_miniImagenet_lwta_{shots}_shot_{classes}_way\"\n",
    "manager = tf.train.CheckpointManager(ckpt, ckpt_path, max_to_keep=3)\n",
    "\n",
    "def checkpoint_manager(meta_iter):  \n",
    "    if meta_iter == 0:\n",
    "        if os.path.isdir(ckpt_path):\n",
    "            ckpt.restore(manager.latest_checkpoint)\n",
    "            print(\"\\nRestored from {}\\n\".format(manager.latest_checkpoint))\n",
    "        else:\n",
    "            print(\"\\nNone checkpoints found => Initializing from scratch.\\n\")\n",
    "    elif meta_iter % checkpoint_freq == 0:\n",
    "        save_path = manager.save()\n",
    "        print(\"\\nSaved checkpoint for step {}: {}\\n\".format(meta_iter, save_path))\n",
    "        \n",
    "    ckpt.step.assign_add(1)\n",
    "    return ckpt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "rm: cannot remove 'checkpoints_miniImagenet_1_shot_5_way': No such file or directory\n",
      "rm: cannot remove 'checkpoints_miniImagenet_5_shot_5_way': No such file or directory\n"
     ]
    }
   ],
   "source": [
    "!rm -r checkpoints_miniImagenet_1_shot_5_way\n",
    "!rm -r checkpoints_miniImagenet_5_shot_5_way"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "collapsed": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tf.Tensor(382261, shape=(), dtype=int32)\n"
     ]
    },
    {
     "ename": "AssertionError",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mAssertionError\u001b[0m                            Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-22-bf7333e08628>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     57\u001b[0m     \u001b[0msb_class\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_weights\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnew_vars\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     58\u001b[0m     \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreduce_sum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreduce_prod\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mv\u001b[0m \u001b[0;32min\u001b[0m \u001b[0msb_class\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainable_weights\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 59\u001b[0;31m     \u001b[0;32massert\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     60\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mmeta_iter\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m50\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     61\u001b[0m         \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"\\n#### The first 50 iterations took {:.2f} secs ####\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0mstart\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mAssertionError\u001b[0m: "
     ]
    }
   ],
   "source": [
    "training = []\n",
    "testing = []\n",
    "\n",
    "start = time.time()\n",
    "\n",
    "for meta_iter in range(int(ckpt.step), meta_iters):\n",
    "\n",
    "    mini_dataset = train_dataset.get_mini_dataset(\n",
    "        inner_batch_size, inner_iters, train_shots, classes\n",
    "    )\n",
    "   \n",
    "    ckpt = checkpoint_manager(meta_iter)\n",
    "\n",
    "    frac_done = meta_iter / meta_iters\n",
    "    cur_meta_step_size = (1 - frac_done) * meta_step_size\n",
    "    \n",
    "    # Get a sample from the full dataset.\n",
    "    if meta_iter > 0:\n",
    "        old_vars = sb_class.get_weights()\n",
    "\n",
    "    j = 0    \n",
    "    for images, labels in mini_dataset:\n",
    "\n",
    "        with tf.GradientTape() as tape:\n",
    "\n",
    "            x1, x2, x3 = images.shape[0], images.shape[1], images.shape[2]\n",
    "            images = tf.reshape(images, [x1, x2*x3*3])    \n",
    "#             preds, _, _ = sb_class(images, train=True, activation=\"lwta\")\n",
    "            try:\n",
    "                preds, _, _ = train(images, train=True, activation=\"lwta\")\n",
    "            except (UnboundLocalError, ValueError):\n",
    "                train = tf.function(train_func)\n",
    "                preds, _, _ = train(images, train=True, activation=\"lwta\")\n",
    "                \n",
    "            if (j == 0) and (meta_iter == 0):\n",
    "                old_vars = sb_class.get_weights()\n",
    "            # we optimize the variational lower bound scaled by the number of data\n",
    "            # points (so we can keep our intuitions about hyper-params such as the learning rate)\n",
    "            #kl_loss = sum(sb_class.losses) / (x1 * x2)\n",
    "            ce = keras.losses.sparse_categorical_crossentropy(labels, preds)\n",
    "            #loss = ce + kl_loss\n",
    "            loss = ce\n",
    "           \n",
    "        grads = tape.gradient(loss, sb_class.trainable_weights)\n",
    "        optimizer.apply_gradients(zip(grads, sb_class.trainable_weights))\n",
    "       \n",
    "        j += 1\n",
    "\n",
    "    new_vars = sb_class.get_weights()\n",
    "\n",
    "    # Perform SGD for the meta step.\n",
    "    for var in range(len(new_vars)):\n",
    "        new_vars[var] = old_vars[var] + (\n",
    "            (new_vars[var] - old_vars[var]) * cur_meta_step_size\n",
    "        )\n",
    "    # After the meta-learning step, reload the newly-trained weights into the model.\n",
    "    sb_class.set_weights(new_vars)\n",
    "\n",
    "    if meta_iter == 50:\n",
    "        print(\"\\n#### The first 50 iterations took {:.2f} secs ####\".format(time.time()-start))\n",
    "    \n",
    "    # Evaluation loop\n",
    "    if meta_iter % eval_interval == 0:\n",
    "        accuracies = []\n",
    "        for dataset in (train_dataset, test_dataset):\n",
    "            # Sample a mini dataset from the full dataset.\n",
    "            train_set, test_images, test_labels = dataset.get_mini_dataset(\n",
    "                eval_batch_size, eval_iters, shots, classes, split=True\n",
    "            )\n",
    "            old_vars = sb_class.get_weights()\n",
    "            # Train on the samples and get the resulting accuracies.\n",
    "            for images, labels in train_set:\n",
    "              \n",
    "                with tf.GradientTape() as tape:\n",
    "                    x1, x2, x3 = images.shape[0], images.shape[1], images.shape[2]\n",
    "                    images = tf.reshape(images, [x1, x2*x3*3])\n",
    "                    try:\n",
    "                        preds, _, _ = train(images, train=True, activation=\"lwta\")\n",
    "                    except (UnboundLocalError, ValueError):\n",
    "                        train = tf.function(train_func)\n",
    "                        preds, _, _ = train(images, train=True, activation=\"lwta\")\n",
    "                    #kl_loss = sum(sb_class.losses) / (x1 * x2)\n",
    "                    ce = keras.losses.sparse_categorical_crossentropy(labels, preds)\n",
    "                    #loss = ce + kl_loss\n",
    "                    loss = ce\n",
    "\n",
    "                grads = tape.gradient(loss, sb_class.trainable_weights)\n",
    "                optimizer.apply_gradients(zip(grads, sb_class.trainable_weights))\n",
    "                \n",
    "            x1, x2, x3 = test_images.shape[0], test_images.shape[1], test_images.shape[2]\n",
    "\n",
    "            test_images = tf.reshape(test_images, [x1, x2*x3*3]) \n",
    "#             test_preds, _, _ = sb_class(test_images, train=False, activation=\"lwta\") for bma=True\n",
    "            try:\n",
    "                test_preds, _, _ = train(test_images, train=False, activation=\"lwta\")\n",
    "            except (UnboundLocalError, ValueError):\n",
    "                train = tf.function(train_func)\n",
    "                test_preds, _, _ = train(test_images, train=False, activation=\"lwta\")\n",
    "                \n",
    "            test_preds = tf.argmax(test_preds).numpy()\n",
    "            num_correct = (test_preds == test_labels).sum()\n",
    "            \n",
    "            # Reset the weights after getting the evaluation accuracies.\n",
    "            sb_class.set_weights(old_vars)\n",
    "            accuracies.append(num_correct / classes)\n",
    "      \n",
    "        training.append(accuracies[0])\n",
    "        testing.append(accuracies[1])\n",
    "        \n",
    "        if meta_iter % report_frequency == 0:\n",
    "            print(\"Iter = %d => train_acc = %.2f%% / test_acc = %.2f%%\" % (meta_iter,\n",
    "                  100*np.mean(training),100*np.mean(testing)))\n",
    "\n",
    "end = time.time()\n",
    "print(\"The training took {:.2f} secs\".format(end-start))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Final training accuracy = 31.9 %%\n",
      "Final testing accuracy = 53.6 %%\n"
     ]
    }
   ],
   "source": [
    "print(\"Final training accuracy = {:.1f} %%\".format(100*np.mean(training)))\n",
    "print(\"Final testing accuracy = {:.1f} %%\".format(100*np.mean(testing)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "n7B7K-ODlcrT"
   },
   "source": [
    "## Train the model (without LWTA)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def conv_bn(x):\n",
    "    x = layers.Conv2D(filters=64, kernel_size=3,  padding='same',\n",
    "                      kernel_regularizer=l2(0.01),activity_regularizer=l2(0.01))(x)\n",
    "    x = layers.BatchNormalization()(x)\n",
    "    x = layers.ReLU()(x)\n",
    "    x = layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2), padding='same')(x)    \n",
    "    return x\n",
    "\n",
    "inputs = layers.Input(shape=(84, 84, 3))\n",
    "\n",
    "x = conv_bn(inputs)\n",
    "x = conv_bn(x)\n",
    "x = conv_bn(x)\n",
    "x = conv_bn(x)\n",
    "x = layers.Flatten()(x)\n",
    "outputs = layers.Dense(classes, activation=\"softmax\")(x)\n",
    "\n",
    "model = keras.Model(inputs=inputs, outputs=outputs)\n",
    "model.compile()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def model_func(x):\n",
    "    return model(x)\n",
    "\n",
    "train = tf.function(model_func)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create checkpoint for resuming training \n",
    "ckpt = tf.train.Checkpoint(step=tf.Variable(0), optimizer=optimizer, net=model)\n",
    "manager = tf.train.CheckpointManager(ckpt, './checkpoints_miniImagenet', max_to_keep=3)\n",
    "\n",
    "def checkpoint_manager(meta_iter):    \n",
    "    if meta_iter == 0:\n",
    "        if os.path.isdir(\"checkpoints_miniImagenet\"):\n",
    "            \n",
    "            ckpt.restore(manager.latest_checkpoint)\n",
    "            print(\"\\nRestored from {}\\n\".format(manager.latest_checkpoint))\n",
    "        else:\n",
    "            print(\"\\nNone checkpoints found => Initializing from scratch.\\n\")\n",
    "    elif meta_iter % checkpoint_freq == 0:\n",
    "        save_path = manager.save()\n",
    "        print(\"\\nSaved checkpoint for step {}: {}\\n\".format(meta_iter, save_path))\n",
    "    \n",
    "    ckpt.step.assign_add(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "rm: cannot remove 'checkpoints_miniImagenet_1_shot_5_way': No such file or directory\n",
      "rm: cannot remove 'checkpoints_miniImagenet_5_shot_5_way': No such file or directory\n"
     ]
    }
   ],
   "source": [
    "!rm -r checkpoints_miniImagenet_1_shot_5_way\n",
    "!rm -r checkpoints_miniImagenet_5_shot_5_way"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "collapsed": true,
    "id": "3PRTNTwBlcrT",
    "outputId": "b53c2c93-7bb3-4390-f453-081ad15da2a1"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "None checkpoints found => Initializing from scratch.\n",
      "\n",
      "Iter = 0 => train_acc = 0.00% / test_acc = 80.00%\n",
      "\n",
      "#### The first 50 iterations took 52.51 secs ####\n",
      "Iter = 50 => train_acc = 27.06% / test_acc = 43.14%\n",
      "Iter = 100 => train_acc = 28.12% / test_acc = 48.51%\n",
      "Iter = 150 => train_acc = 28.87% / test_acc = 52.32%\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-12-458f0f08bd15>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     22\u001b[0m \u001b[0;31m#             preds = model(images)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     23\u001b[0m             \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 24\u001b[0;31m                 \u001b[0mpreds\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimages\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     25\u001b[0m             \u001b[0;32mexcept\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mUnboundLocalError\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     26\u001b[0m                 \u001b[0mtrain\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfunction\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel_func\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args, **kwds)\u001b[0m\n\u001b[1;32m    826\u001b[0m     \u001b[0mtracing_count\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexperimental_get_tracing_count\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    827\u001b[0m     \u001b[0;32mwith\u001b[0m \u001b[0mtrace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTrace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_name\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mtm\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 828\u001b[0;31m       \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwds\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    829\u001b[0m       \u001b[0mcompiler\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m\"xla\"\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_experimental_compile\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;34m\"nonXla\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    830\u001b[0m       \u001b[0mnew_tracing_count\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexperimental_get_tracing_count\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py\u001b[0m in \u001b[0;36m_call\u001b[0;34m(self, *args, **kwds)\u001b[0m\n\u001b[1;32m    860\u001b[0m       \u001b[0;31m# In this case we have not created variables on the first call. So we can\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    861\u001b[0m       \u001b[0;31m# run the first trace but we should fail if variables are created.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 862\u001b[0;31m       \u001b[0mresults\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_stateful_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwds\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    863\u001b[0m       \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_created_variables\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    864\u001b[0m         raise ValueError(\"Creating variables on a non-first call to a function\"\n",
      "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   2941\u001b[0m        filtered_flat_args) = self._maybe_define_function(args, kwargs)\n\u001b[1;32m   2942\u001b[0m     return graph_function._call_flat(\n\u001b[0;32m-> 2943\u001b[0;31m         filtered_flat_args, captured_inputs=graph_function.captured_inputs)  # pylint: disable=protected-access\n\u001b[0m\u001b[1;32m   2944\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2945\u001b[0m   \u001b[0;34m@\u001b[0m\u001b[0mproperty\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py\u001b[0m in \u001b[0;36m_call_flat\u001b[0;34m(self, args, captured_inputs, cancellation_manager)\u001b[0m\n\u001b[1;32m   1931\u001b[0m            \"StatefulPartitionedCall\": self._get_gradient_function()}):\n\u001b[1;32m   1932\u001b[0m         \u001b[0mflat_outputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mforward_function\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcall\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mctx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs_with_tangents\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1933\u001b[0;31m     \u001b[0mforward_backward\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrecord\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mflat_outputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1934\u001b[0m     \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_build_call_outputs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mflat_outputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1935\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py\u001b[0m in \u001b[0;36mrecord\u001b[0;34m(self, flat_outputs)\u001b[0m\n\u001b[1;32m   1457\u001b[0m       \u001b[0;31m# tape is watching.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1458\u001b[0m       self._functions.record(\n\u001b[0;32m-> 1459\u001b[0;31m           flat_outputs, self._inference_args, self._input_tangents)\n\u001b[0m\u001b[1;32m   1460\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1461\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py\u001b[0m in \u001b[0;36mrecord\u001b[0;34m(self, flat_outputs, inference_args, input_tangents)\u001b[0m\n\u001b[1;32m   1301\u001b[0m     \"\"\"\n\u001b[1;32m   1302\u001b[0m     backward_function, to_record = self._wrap_backward_function(\n\u001b[0;32m-> 1303\u001b[0;31m         self._forward_graph, self._backward, flat_outputs)\n\u001b[0m\u001b[1;32m   1304\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forwardprop_output_indices\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1305\u001b[0m       tape.record_operation_backprop_only(\n",
      "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py\u001b[0m in \u001b[0;36m_wrap_backward_function\u001b[0;34m(self, forward_graph, backward, outputs)\u001b[0m\n\u001b[1;32m   1243\u001b[0m       \u001b[0;32mif\u001b[0m \u001b[0mtrainable_recorded_outputs\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0mbackward_function_inputs\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1244\u001b[0m         \u001b[0mrecorded_outputs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1245\u001b[0;31m       \u001b[0;32mif\u001b[0m \u001b[0mbackprop_util\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mIsTrainable\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1246\u001b[0m         \u001b[0mtrainable_recorded_outputs\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1247\u001b[0m       \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/backprop_util.py\u001b[0m in \u001b[0;36mIsTrainable\u001b[0;34m(tensor_or_dtype)\u001b[0m\n\u001b[1;32m     28\u001b[0m   \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     29\u001b[0m     \u001b[0mdtype\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtensor_or_dtype\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 30\u001b[0;31m   \u001b[0mdtype\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdtypes\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mas_dtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     31\u001b[0m   return dtype.base_dtype in (dtypes.float16, dtypes.float32, dtypes.float64,\n\u001b[1;32m     32\u001b[0m                               \u001b[0mdtypes\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcomplex64\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtypes\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcomplex128\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/dtypes.py\u001b[0m in \u001b[0;36mas_dtype\u001b[0;34m(type_value)\u001b[0m\n\u001b[1;32m    624\u001b[0m   \"\"\"\n\u001b[1;32m    625\u001b[0m   \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtype_value\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mDType\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 626\u001b[0;31m     \u001b[0;32mreturn\u001b[0m \u001b[0m_INTERN_TABLE\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtype_value\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mas_datatype_enum\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    627\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    628\u001b[0m   \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtype_value\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "training = []\n",
    "testing = []\n",
    "\n",
    "start = time.time()\n",
    "\n",
    "for meta_iter in range(int(ckpt.step), meta_iters):   \n",
    "    mini_dataset = train_dataset.get_mini_dataset(\n",
    "        inner_batch_size, inner_iters, train_shots, classes\n",
    "    )\n",
    "    \n",
    "    checkpoint_manager(meta_iter)\n",
    "        \n",
    "    frac_done = meta_iter / meta_iters\n",
    "    cur_meta_step_size = (1 - frac_done) * meta_step_size\n",
    "    # Temporarily save the weights from the model.\n",
    "    old_vars = model.get_weights()\n",
    "    # Get a sample from the full dataset.\n",
    "    \n",
    "    j = 0\n",
    "    for images, labels in mini_dataset:\n",
    "        with tf.GradientTape() as tape:\n",
    "#             preds = model(images)\n",
    "            try:\n",
    "                preds = train(images)\n",
    "            except (UnboundLocalError, ValueError):\n",
    "                train = tf.function(model_func)\n",
    "                preds = train(images)\n",
    "            loss = keras.losses.sparse_categorical_crossentropy(labels, preds)\n",
    "        grads = tape.gradient(loss, model.trainable_weights)\n",
    "\n",
    "        optimizer.apply_gradients(zip(grads, model.trainable_weights))\n",
    "        j += 1\n",
    "    new_vars = model.get_weights()\n",
    "    # Perform SGD for the meta step.\n",
    "    for var in range(len(new_vars)):\n",
    "        new_vars[var] = old_vars[var] + (\n",
    "            (new_vars[var] - old_vars[var]) * cur_meta_step_size\n",
    "        )\n",
    "    # After the meta-learning step, reload the newly-trained weights into the model.\n",
    "    model.set_weights(new_vars)\n",
    "    if meta_iter == 50:\n",
    "        print(\"\\n#### The first 50 iterations took {:.2f} secs ####\".format(time.time()-start))\n",
    "        \n",
    "    # Evaluation loop\n",
    "    if meta_iter % eval_interval == 0:\n",
    "        accuracies = []\n",
    "        for dataset in (train_dataset, test_dataset):\n",
    "            # Sample a mini dataset from the full dataset.\n",
    "            train_set, test_images, test_labels = dataset.get_mini_dataset(\n",
    "                eval_batch_size, eval_iters, shots, classes, split=True\n",
    "            )\n",
    "            old_vars = model.get_weights()\n",
    "            # Train on the samples and get the resulting accuracies.\n",
    "            for images, labels in train_set:\n",
    "                with tf.GradientTape() as tape:\n",
    "#                     preds = model(images)\n",
    "                    try:\n",
    "                        preds = train(images)\n",
    "                    except (UnboundLocalError, ValueError):\n",
    "                        train = tf.function(model_func)\n",
    "                        preds = train(images)\n",
    "                    loss = keras.losses.sparse_categorical_crossentropy(labels, preds)                          \n",
    "                grads = tape.gradient(loss, model.trainable_weights)\n",
    "                optimizer.apply_gradients(zip(grads, model.trainable_weights))\n",
    "            \n",
    "            try:\n",
    "                test_preds = train(test_images)\n",
    "            except (UnboundLocalError, ValueError):\n",
    "                train = tf.function(model_func)\n",
    "                test_preds = train(test_images)\n",
    "#             test_preds = model.predict(test_images)\n",
    "            test_preds = tf.argmax(test_preds).numpy()\n",
    "                        \n",
    "            num_correct = (test_preds == test_labels).sum()\n",
    "           \n",
    "            # Reset the weights after getting the evaluation accuracies.\n",
    "            model.set_weights(old_vars)\n",
    "            accuracies.append(num_correct / classes)\n",
    "            \n",
    "        training.append(accuracies[0])\n",
    "        testing.append(accuracies[1])\n",
    "        \n",
    "        if meta_iter % report_frequency == 0:\n",
    "            print(\"Iter = %d => train_acc = %.2f%% / test_acc = %.2f%%\" % (meta_iter,\n",
    "                  100*np.mean(training),100*np.mean(testing)))\n",
    "\n",
    "end = time.time()\n",
    "print(\"The training took {:.2f} secs\".format(end-start))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "wQ60Xj9ufvWd",
    "outputId": "03a7b218-decf-43b2-e7b7-0a0aab628395"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Final training accuracy = 23.4 %%\n",
      "Final testing accuracy = 31.4 %%\n"
     ]
    }
   ],
   "source": [
    "print(\"Final training accuracy = {:.1f} %%\".format(100*np.mean(training)))\n",
    "print(\"Final testing accuracy = {:.1f} %%\".format(100*np.mean(testing)))"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "reptile_omniglot.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "00eb811349f14c848f0bb2c8e4fab65f": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "04da630f62a040acb30c3ca55b589f5d": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "0622f3b6ac3b4c1a93b8e2e3616e8b00": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "06b95880542941b8a819aa28baa2e0e8": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "086d178b0dbe48dd897ee9b125dcd464": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "0bdba0100ae34522a621c53f28cc3e8d": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "0ca1b2b2206b43e1af3d03da4fa64503": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "0fc193a8863e479fb87055e664de8b9d": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "0fd1b28baa684abd8c8e6ff7e24ab19f": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "danger",
      "description": " 75%",
      "description_tooltip": null,
      "layout": "IPY_MODEL_00eb811349f14c848f0bb2c8e4fab65f",
      "max": 19280,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_d45c15799fca46239f6797f894a852cd",
      "value": 14528
     }
    },
    "12636d6f48804f8d93b21f3ab54bcf6c": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "17e640fbeb5f458287117311c12b2bbd": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_b4540f3e948047b7a18a428860e55652",
       "IPY_MODEL_fbc46482c9eb442a9458f0b619443ef1"
      ],
      "layout": "IPY_MODEL_9e6af9e40f1f489fb6e1514f6e9ad8f6"
     }
    },
    "1a08f76b599747a1933c758ac0bf397e": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "203034d88830486c92067b44ea451890": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_983c38b96f314039aae4e3be8a452f9e",
       "IPY_MODEL_d8167d1417bd415c8d7c66194cf5f6b7"
      ],
      "layout": "IPY_MODEL_45f90aa79c1d46158f58f5d2682450fd"
     }
    },
    "20a7a110d571405da19829ad1727643b": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_5793ada4367f46708da317de1652752f",
       "IPY_MODEL_520eeb9985fb46929f90de6d2f24e999"
      ],
      "layout": "IPY_MODEL_5749fe77e4f948e18033d46beea5718e"
     }
    },
    "2606945c3d5744408320c95aad7eeb88": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_56869a904c514f5e8835c517e9245fbf",
       "IPY_MODEL_6df570404ba44b9daad834e269c59c49"
      ],
      "layout": "IPY_MODEL_12636d6f48804f8d93b21f3ab54bcf6c"
     }
    },
    "28a53bac50b14a1caf72690a9c2a0a93": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "31ef7325a71b47eab92e3a4dcc0b5318": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": "initial"
     }
    },
    "32edaffc9d8d4dfb90b9d99ed640f0e5": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "3789fd81836e435691fcac8a429c900d": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": "initial"
     }
    },
    "3c167a3dde704c79979311b872948e76": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "3dbbde2390f045eda9f428746ebaa795": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "4505d9dfa62c4826ad96da9ab397e352": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "45f90aa79c1d46158f58f5d2682450fd": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "469d3773a7114477bbe21e5086c06aec": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "471de552d73f48f7a244cb2e298f62bc": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": "initial"
     }
    },
    "5118178ff1e14036969c18ad2faf722a": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_b7d109ba579741ff8f74e38ce66cda0a",
      "placeholder": "​",
      "style": "IPY_MODEL_7b87243a89ab4f2fa596a09188ebeadf",
      "value": " 4/4 [00:13&lt;00:00,  3.40s/ file]"
     }
    },
    "51b83ae12ca74a2abde28cdc977b8ed2": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": "initial"
     }
    },
    "520eeb9985fb46929f90de6d2f24e999": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_a4f380d9686946308d483d78a121a3ea",
      "placeholder": "​",
      "style": "IPY_MODEL_469d3773a7114477bbe21e5086c06aec",
      "value": " 19280/0 [00:08&lt;00:00, 2315.00 examples/s]"
     }
    },
    "550f8801d97c417e8d519f1a6b07146e": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_78e64e6e95f6447d859ae1170365f94c",
       "IPY_MODEL_5118178ff1e14036969c18ad2faf722a"
      ],
      "layout": "IPY_MODEL_0bdba0100ae34522a621c53f28cc3e8d"
     }
    },
    "56869a904c514f5e8835c517e9245fbf": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "danger",
      "description": "  0%",
      "description_tooltip": null,
      "layout": "IPY_MODEL_61e50e415fd64b6abb13926f0747aff2",
      "max": 2720,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_5856779e14464d9b867c51eceb5d0dfc",
      "value": 0
     }
    },
    "5749fe77e4f948e18033d46beea5718e": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "5793ada4367f46708da317de1652752f": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "info",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_4505d9dfa62c4826ad96da9ab397e352",
      "max": 1,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_31ef7325a71b47eab92e3a4dcc0b5318",
      "value": 1
     }
    },
    "57bdadfb0c4247b49371f24941acf784": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "5856779e14464d9b867c51eceb5d0dfc": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": "initial"
     }
    },
    "5ad1704befa648a3874ef7174ef3e5ed": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "61e50e415fd64b6abb13926f0747aff2": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "622ddc39ade64c77b160cbd66e6752b2": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "6a4750d06897462eadbdcd673bf21e43": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_71345d17bc68497f8f4f5b3a3ff125f1",
       "IPY_MODEL_e0b0681af49c463aa93f2ec7526a7b90"
      ],
      "layout": "IPY_MODEL_8ccf110725f14ee5b8c58b84f99c39f9"
     }
    },
    "6bf7532e29754f839963c7eb7a58e200": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_0fd1b28baa684abd8c8e6ff7e24ab19f",
       "IPY_MODEL_8793819ee22c431789fe05792b71ef50"
      ],
      "layout": "IPY_MODEL_cd945007c0be4a568e77b96944676717"
     }
    },
    "6df570404ba44b9daad834e269c59c49": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_c61f26881c75493f9d9cdc06628a867d",
      "placeholder": "​",
      "style": "IPY_MODEL_c9e2d236024e4644b3c6607ca127ab21",
      "value": " 0/2720 [00:00&lt;?, ? examples/s]"
     }
    },
    "71345d17bc68497f8f4f5b3a3ff125f1": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "danger",
      "description": "  0%",
      "description_tooltip": null,
      "layout": "IPY_MODEL_1a08f76b599747a1933c758ac0bf397e",
      "max": 3120,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_ebb18b0a97254ad4857872f885f43d80",
      "value": 0
     }
    },
    "756c0aa0fe9e4fd88c8bc6752c7d5bf1": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "78e64e6e95f6447d859ae1170365f94c": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "success",
      "description": "Extraction completed...: 100%",
      "description_tooltip": null,
      "layout": "IPY_MODEL_b171d93947a14520907b59f66697574e",
      "max": 1,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_d5e7d7197f57418e82e19a2bee67f3bd",
      "value": 1
     }
    },
    "79d648274ec14a348109fcd6a66dace2": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "danger",
      "description": "  0%",
      "description_tooltip": null,
      "layout": "IPY_MODEL_8be07adc4c3c4f1796c6c0d2192e625a",
      "max": 13180,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_dd248e0814a64ee8a63090fee78e962d",
      "value": 0
     }
    },
    "7b87243a89ab4f2fa596a09188ebeadf": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "7c7b964330724bb2b222233f13cb1839": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": "initial"
     }
    },
    "86f877a4124c4439a7d4551a488ef4e3": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "8793819ee22c431789fe05792b71ef50": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_d55437c1a0f346b7b0494eaea8988236",
      "placeholder": "​",
      "style": "IPY_MODEL_32edaffc9d8d4dfb90b9d99ed640f0e5",
      "value": " 14528/19280 [00:00&lt;00:00, 145072.26 examples/s]"
     }
    },
    "89e979baca37455496ded4f36586ff03": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "8be07adc4c3c4f1796c6c0d2192e625a": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "8ccf110725f14ee5b8c58b84f99c39f9": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "92dbd1a4c73848c3bbeba68ff0c6f31d": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "info",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_756c0aa0fe9e4fd88c8bc6752c7d5bf1",
      "max": 1,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_3789fd81836e435691fcac8a429c900d",
      "value": 1
     }
    },
    "983c38b96f314039aae4e3be8a452f9e": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "success",
      "description": "Dl Size...: 100%",
      "description_tooltip": null,
      "layout": "IPY_MODEL_89e979baca37455496ded4f36586ff03",
      "max": 1,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_7c7b964330724bb2b222233f13cb1839",
      "value": 1
     }
    },
    "9e1f957b1e344d428ed19e26dc3b2f38": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_ff3f698a52ce4422aa0666a67ec97738",
      "placeholder": "​",
      "style": "IPY_MODEL_0622f3b6ac3b4c1a93b8e2e3616e8b00",
      "value": " 13180/0 [00:05&lt;00:00, 2293.03 examples/s]"
     }
    },
    "9e6af9e40f1f489fb6e1514f6e9ad8f6": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "a2deeecce6d3496d83a0f42c62aa3f3b": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "a4f380d9686946308d483d78a121a3ea": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "a7f926e82ae346928221be0cdc83cb00": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_92dbd1a4c73848c3bbeba68ff0c6f31d",
       "IPY_MODEL_9e1f957b1e344d428ed19e26dc3b2f38"
      ],
      "layout": "IPY_MODEL_86f877a4124c4439a7d4551a488ef4e3"
     }
    },
    "b08991fb355f444f9f8d6be2ea099142": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "b171d93947a14520907b59f66697574e": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "b4540f3e948047b7a18a428860e55652": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "success",
      "description": "Dl Completed...: 100%",
      "description_tooltip": null,
      "layout": "IPY_MODEL_c770b8fe48a44039a1a8486180df589d",
      "max": 1,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_e66c0534282e4c55a21bd149fb32aab8",
      "value": 1
     }
    },
    "b707ee96ad21479fbd0608e8158c5276": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "b7d109ba579741ff8f74e38ce66cda0a": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "bf1f44068c034278b6be5e12393898e1": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_e5479b3849034671a1be5bad552bb299",
       "IPY_MODEL_d62201089e8446a1ae0dd483b9a31686"
      ],
      "layout": "IPY_MODEL_0ca1b2b2206b43e1af3d03da4fa64503"
     }
    },
    "bf9a2d6ff3b042d98749cdb2d8f00948": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "c1e4612e2d234f3299069f2c292e3ed3": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_bf9a2d6ff3b042d98749cdb2d8f00948",
      "placeholder": "​",
      "style": "IPY_MODEL_5ad1704befa648a3874ef7174ef3e5ed",
      "value": " 0/13180 [00:00&lt;?, ? examples/s]"
     }
    },
    "c61f26881c75493f9d9cdc06628a867d": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "c62679f2de2e4e0c954082d6f752160c": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_79d648274ec14a348109fcd6a66dace2",
       "IPY_MODEL_c1e4612e2d234f3299069f2c292e3ed3"
      ],
      "layout": "IPY_MODEL_cfb9c99287c24d5fb3ff5fca1b3c2f37"
     }
    },
    "c770b8fe48a44039a1a8486180df589d": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "c9e2d236024e4644b3c6607ca127ab21": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "cad7f13a33b84e78b76336792fe46cea": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "cd945007c0be4a568e77b96944676717": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "cfb9c99287c24d5fb3ff5fca1b3c2f37": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "d45c15799fca46239f6797f894a852cd": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": "initial"
     }
    },
    "d55437c1a0f346b7b0494eaea8988236": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "d5e7d7197f57418e82e19a2bee67f3bd": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": "initial"
     }
    },
    "d62201089e8446a1ae0dd483b9a31686": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_0fc193a8863e479fb87055e664de8b9d",
      "placeholder": "​",
      "style": "IPY_MODEL_cad7f13a33b84e78b76336792fe46cea",
      "value": " 3120/0 [00:01&lt;00:00, 2339.33 examples/s]"
     }
    },
    "d8167d1417bd415c8d7c66194cf5f6b7": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_3dbbde2390f045eda9f428746ebaa795",
      "placeholder": "​",
      "style": "IPY_MODEL_28a53bac50b14a1caf72690a9c2a0a93",
      "value": " 17/17 [00:13&lt;00:00,  1.25 MiB/s]"
     }
    },
    "dc0c030fd626423f815f22e91e591030": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_de03e722d1a449cb99bab16e69f1e1c3",
       "IPY_MODEL_e932a46875654a3ebfe14802be27418a"
      ],
      "layout": "IPY_MODEL_086d178b0dbe48dd897ee9b125dcd464"
     }
    },
    "dd248e0814a64ee8a63090fee78e962d": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": "initial"
     }
    },
    "de03e722d1a449cb99bab16e69f1e1c3": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "info",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_3c167a3dde704c79979311b872948e76",
      "max": 1,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_51b83ae12ca74a2abde28cdc977b8ed2",
      "value": 1
     }
    },
    "e0b0681af49c463aa93f2ec7526a7b90": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_a2deeecce6d3496d83a0f42c62aa3f3b",
      "placeholder": "​",
      "style": "IPY_MODEL_06b95880542941b8a819aa28baa2e0e8",
      "value": " 0/3120 [00:00&lt;?, ? examples/s]"
     }
    },
    "e5479b3849034671a1be5bad552bb299": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "info",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_57bdadfb0c4247b49371f24941acf784",
      "max": 1,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_471de552d73f48f7a244cb2e298f62bc",
      "value": 1
     }
    },
    "e66c0534282e4c55a21bd149fb32aab8": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": "initial"
     }
    },
    "e932a46875654a3ebfe14802be27418a": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_b08991fb355f444f9f8d6be2ea099142",
      "placeholder": "​",
      "style": "IPY_MODEL_622ddc39ade64c77b160cbd66e6752b2",
      "value": " 2720/0 [00:01&lt;00:00, 2275.61 examples/s]"
     }
    },
    "ebb18b0a97254ad4857872f885f43d80": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": "initial"
     }
    },
    "fbc46482c9eb442a9458f0b619443ef1": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_b707ee96ad21479fbd0608e8158c5276",
      "placeholder": "​",
      "style": "IPY_MODEL_04da630f62a040acb30c3ca55b589f5d",
      "value": " 4/4 [00:13&lt;00:00,  3.42s/ url]"
     }
    },
    "ff3f698a52ce4422aa0666a67ec97738": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    }
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
