{
  "cells": [
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "jOxhrofLejiY"
      },
      "source": [
        "# Imports"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6sBAOOgeeZmA"
      },
      "outputs": [],
      "source": [
        "!pip install colorama"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HfDRnCoPeo9I"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "import math\n",
        "import json\n",
        "import random as rnd\n",
        "import numpy as np\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch.utils.data import DataLoader, Dataset, random_split\n",
        "import matplotlib.pyplot as plt\n",
        "import pandas as  pd\n",
        "import torchvision.utils as vision_utils\n",
        "from PIL import Image\n",
        "import torchvision\n",
        "from colorama import Fore, Back, Style\n",
        "from matplotlib.ticker import NullFormatter\n",
        "import sys\n",
        "sys.path.insert(1, os.path.join(sys.path[0], \"..\"))\n",
        "\n",
        "DEVICE = torch.device('cuda')"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "gAfByEeYek9B"
      },
      "source": [
        "# Build MC-Dominoes dataset"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "def plot_samples(dataset, nrow=13, figsize=(10,7)):\n",
        "  try:\n",
        "    X, Y = dataset.tensors\n",
        "  except:\n",
        "    try:\n",
        "      (X,) = dataset.tensors\n",
        "    except:\n",
        "      X = dataset\n",
        "  fig = plt.figure(figsize=figsize, dpi=130)\n",
        "  grid_img = vision_utils.make_grid(X[:nrow].cpu(), nrow=nrow, normalize=True, padding=1)\n",
        "  _ = plt.imshow(grid_img.permute(1, 2, 0), interpolation='nearest')\n",
        "  _ = plt.tick_params(axis=u'both', which=u'both',length=0)\n",
        "  ax = plt.gca()\n",
        "  _ = ax.xaxis.set_major_formatter(NullFormatter()) \n",
        "  _ = ax.yaxis.set_major_formatter(NullFormatter()) \n",
        "  plt.show()\n",
        "\n",
        "\n",
        "def keep_only_lbls(dataset, lbls):\n",
        "  lbls = {lbl: i for i, lbl in enumerate(lbls)}\n",
        "  final_X, final_Y = [], []\n",
        "  for x, y in dataset:\n",
        "    if y in lbls:\n",
        "      final_X.append(x)\n",
        "      final_Y.append(lbls[y])\n",
        "  X = torch.stack(final_X)\n",
        "  Y = torch.tensor(final_Y).float().view(-1,1)\n",
        "  return X, Y\n",
        "\n",
        "\n",
        "def format_mnist(imgs):\n",
        "  imgs = np.stack([np.pad(imgs[i][0], 2, constant_values=0)[None,:] for i in range(len(imgs))])\n",
        "  imgs = np.repeat(imgs, 3, axis=1)\n",
        "  return torch.tensor(imgs)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000,
          "referenced_widgets": [
            "57663c6514c546f08f23e6962b040a27",
            "c59d0e94b22045deb2146cafc5253757",
            "6d91e5a5ac7b405abb724eac3875bd26",
            "f741d9f146484900be3ac5779b1f76b3",
            "25dcd06a0db041d4bcd2d370709810a5",
            "fe9ff4d95e1a4f6b939510c11c4f58f2",
            "7aef4900d2284e0c9b00326ac5d9c479",
            "edb8b9e215a648439c6c4face3ada113",
            "50c323e671af4e279880557b3e461618",
            "f5f2039571f940cf8a13ac3de59e9426",
            "8bd6858e05c24935b352907ba363f438",
            "3c7baf6fe7c54a66baa6ca1c6b162774",
            "42f7a54652bc4e12818caa0ac51688db",
            "c2522ec6f48f4c35aff0c2f90037175a",
            "78890e24670f4f5ba9982708475d2a4a",
            "11e0abca12ea40c98aecc6f73d825bcf",
            "e3139d6003834907838bccef67f7471a",
            "c91214aeb3ae455a96f61b09e5a4abb0",
            "d5867d70563e45e7bf9e52d5903ff657",
            "fc31d40bf2114c80a4be84eed637cb9e",
            "129f62f0bef9443c9e143935531f0fad",
            "500a9b18a3fa425da61080412bc53115",
            "7645ccf2a7d24715a129f0c4831d404b",
            "43899976ba064c39a360f5f980b0caf4",
            "02f58aca7dee4a3a9f631a6a87c7668b",
            "2d931ca4fc49498fbd793abd1ed7b7f5",
            "3308ab945b3248f4a94f3a50fa0ad00a",
            "0ace65dc15f0472fb47ab899a0551e9a",
            "bbcc01bcba3f4e61b00795d95ae3a5d8",
            "d222845c70134afc8e569a413b8b07af",
            "d4fabb860e8849288979fe79786d6f70",
            "c32a83d1b93d4cd3acd1cac758ab82f9",
            "9e413d8fe45945219c88e333d3a822f6",
            "cb2ad3d2884a42569edef29b9c18423e",
            "37fb9b7807dc49f99fd2a37f4a61ca05",
            "ebf03580c1a8499ab713b07301ca830f",
            "866d4df1f14147ab9b55906ddb8de51a",
            "cdaed5fd140a4f72ab03092786eabe2b",
            "889b16931d6d43afa0d7fde84f27e929",
            "a2fa795380574174a28294214fab29bc",
            "45230628dc164fc5b6a6b736f9568903",
            "f3f84fbeb6f342bd9192ddc6f7863b13",
            "02bf2addd49248bc865b8515d2427739",
            "663ce22902af4f4c94aedbdb9c551558",
            "7bccf661f3904cfe965eb9a40b2d075d",
            "0b9611c5b47d418d80e031710bbc3a97",
            "68dc44b7570e455e940d8cab6e7d4154",
            "a7d3d59a874a4041927ec17107f7425b",
            "6a8fd1f5529249cc9c4bf047e408b8ea",
            "7e9f2a6fb4f14f1fab863a1ea655b639",
            "1ffdca0097e74db18017a7987f46f46c",
            "7a9fe76e5f6f422380dd626086483613",
            "7f071b80e275427abc41b4b43f3a6bfa",
            "1db0d5ee53bd47e7bd2870aa6c8639b1",
            "8f13ee26b3cd4a9aae22ea4480981c7a",
            "08195e51bf244c0b9d5c83e2d2b3b2c2",
            "f3c008785b334c93bb744c0dfa032b3b",
            "2e6efdb8a1bc4818aee537182a04dac3",
            "4e02d6ee3e994276941bded69012d632",
            "7a772f9fc8124493b99e7c203a2fd535",
            "f91e1ae924a9402983ffac448b714fae",
            "a2de69155b7d4adf881bb5aa777b01ef",
            "7baa1e51abc3434b9eebfbf975d50188",
            "fb5229eab1cb452c831e5229e9ed5ac3",
            "061c6ab8faf443dfa6e63f4162328d57",
            "fd85711851c340f787245de139b7bd19"
          ]
        },
        "id": "8MCgoYp5ki-V",
        "outputId": "b1e4a7e7-da4f-4a17-a073-0bc823522958"
      },
      "outputs": [],
      "source": [
        "def build_mc_dataset(mnist_data, cifar_data, randomize_m=False, randomize_c=False):\n",
        "  X_m_train_0, _ = keep_only_lbls(mnist_data, lbls=[0])\n",
        "  X_m_train_1, _ = keep_only_lbls(mnist_data, lbls=[1])\n",
        "  X_m_train_0 = format_mnist(X_m_train_0.view(-1, 1, 28, 28))\n",
        "  X_m_train_1 = format_mnist(X_m_train_1.view(-1, 1, 28, 28))\n",
        "  X_m_train_0 = X_m_train_0[torch.randperm(len(X_m_train_0))]\n",
        "  X_m_train_1 = X_m_train_1[torch.randperm(len(X_m_train_1))]\n",
        "\n",
        "  X_c_train_1, _ = keep_only_lbls(cifar_data, lbls=[1])\n",
        "  X_c_train_9, _ = keep_only_lbls(cifar_data, lbls=[9])\n",
        "  X_c_train_1 = X_c_train_1[torch.randperm(len(X_c_train_1))]\n",
        "  X_c_train_9 = X_c_train_9[torch.randperm(len(X_c_train_9))]\n",
        "\n",
        "  min_01 = min(len(X_m_train_0), len(X_c_train_1))\n",
        "  min_19 = min(len(X_m_train_1), len(X_c_train_9))\n",
        "  X_top = torch.cat((X_m_train_0[:min_01], X_m_train_1[:min_19]),dim=0) \n",
        "  X_bottom = torch.cat((X_c_train_1[:min_01], X_c_train_9[:min_19]),dim=0) \n",
        "  if randomize_m:\n",
        "    shuffle = torch.randperm(len(X_top))\n",
        "    X_top = X_top[shuffle]\n",
        "  if randomize_c:\n",
        "    shuffle = torch.randperm(len(X_bottom))\n",
        "    X_bottom = X_bottom[shuffle]\n",
        "  X_train = torch.cat((X_top, X_bottom), dim=2)\n",
        "  Y_train = torch.cat((torch.zeros((min_01,)), torch.ones((min_19,))))\n",
        "  shuffle = torch.randperm(len(X_train))\n",
        "  X_train = X_train[shuffle]\n",
        "  Y_train = Y_train[shuffle].float().view(-1,1)\n",
        "  data_train = torch.utils.data.TensorDataset(X_train.to(DEVICE), Y_train.to(DEVICE))\n",
        "  return data_train\n",
        "\n",
        "\n",
        "transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])\n",
        "\n",
        "mnist_train = torchvision.datasets.MNIST('./data/mnist/', train=True, download=True, transform=transform)\n",
        "cifar_train = torchvision.datasets.CIFAR10('./data/cifar10/', train=True, download=True, transform=transform)\n",
        "mnist_perturb_base, mnist_train, mnist_valid = random_split(mnist_train, [10000, 45000, 5000], generator=torch.Generator().manual_seed(42))\n",
        "cifar_perturb_base, cifar_train, cifar_valid = random_split(cifar_train, [10000, 35000, 5000], generator=torch.Generator().manual_seed(42))\n",
        "\n",
        "mnist_test = torchvision.datasets.MNIST('./data/mnist/', train=False, download=True, transform=transform)\n",
        "cifar_test = torchvision.datasets.CIFAR10('./data/cifar10/', train=False, download=True, transform=transform)\n",
        "\n",
        "\n",
        "# Training / valid / test datasets\n",
        "data_train = build_mc_dataset(mnist_train, cifar_train)\n",
        "data_valid = build_mc_dataset(mnist_valid, cifar_valid)\n",
        "data_test = build_mc_dataset(mnist_test, cifar_test)\n",
        "\n",
        "train_dl = torch.utils.data.DataLoader(data_train, batch_size=256, shuffle=True)\n",
        "valid_dl = torch.utils.data.DataLoader(data_valid, batch_size=1024, shuffle=False)\n",
        "test_dl = torch.utils.data.DataLoader(data_test, batch_size=1024, shuffle=False)\n",
        "\n",
        "\n",
        "# MNIST randomized test / valid datasets\n",
        "data_test_rm = build_mc_dataset(mnist_test, cifar_test, randomize_m=True, randomize_c=False)\n",
        "data_valid_rm = build_mc_dataset(mnist_valid, cifar_valid, randomize_m=True, randomize_c=False)\n",
        "\n",
        "test_rm_dl = torch.utils.data.DataLoader(data_test_rm, batch_size=1024, shuffle=False)\n",
        "valid_rm_dl = torch.utils.data.DataLoader(data_valid_rm, batch_size=1024, shuffle=False)\n",
        "\n",
        "# CIFAR-10 randomized test / valid datasets\n",
        "data_test_rc = build_mc_dataset(mnist_test, cifar_test, randomize_m=False, randomize_c=True)\n",
        "data_valid_rc = build_mc_dataset(mnist_valid, cifar_valid, randomize_m=False, randomize_c=True)\n",
        "\n",
        "test_rc_dl = torch.utils.data.DataLoader(data_test_rc, batch_size=1024, shuffle=False)\n",
        "valid_rc_dl = torch.utils.data.DataLoader(data_valid_rc, batch_size=1024, shuffle=False)\n",
        "\n",
        "print(f\"Train length: {len(train_dl.dataset)}\")\n",
        "print(f\"Test length: {len(test_dl.dataset)}\")\n",
        "print(f\"Test length randomized mnist: {len(test_rm_dl.dataset)}\")\n",
        "print(f\"Test length randomized cifar10: {len(test_rc_dl.dataset)}\")\n",
        "\n",
        "print(\"Non-randomized train dataset:\")\n",
        "plot_samples(data_train)\n",
        "\n",
        "print(\"Non-randomized test dataset:\")\n",
        "plot_samples(data_test)\n",
        "\n",
        "print(\"MNIST-randomized test dataset:\")\n",
        "plot_samples(data_test_rm)\n",
        "\n",
        "print(\"CIFAR10-randomized test dataset:\")\n",
        "plot_samples(data_test_rc)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "def build_mf_dataset(mnist_data, fashm_data, randomize_m=False, randomize_f=False):\n",
        "  X_m_train_0, _ = keep_only_lbls(mnist_data, lbls=[0])\n",
        "  X_m_train_1, _ = keep_only_lbls(mnist_data, lbls=[1])\n",
        "  X_m_train_0 = X_m_train_0[torch.randperm(len(X_m_train_0))]\n",
        "  X_m_train_1 = X_m_train_1[torch.randperm(len(X_m_train_1))]\n",
        "\n",
        "  X_f_train_4, _ = keep_only_lbls(fashm_data, lbls=[4])\n",
        "  X_f_train_3, _ = keep_only_lbls(fashm_data, lbls=[3])\n",
        "  X_f_train_4 = X_f_train_4[torch.randperm(len(X_f_train_4))]\n",
        "  X_f_train_3 = X_f_train_3[torch.randperm(len(X_f_train_3))]\n",
        "\n",
        "  min_04 = min(len(X_m_train_0), len(X_f_train_4))\n",
        "  min_13 = min(len(X_m_train_1), len(X_f_train_3))\n",
        "  X_top = torch.cat((X_m_train_0[:min_04], X_m_train_1[:min_13]),dim=0) \n",
        "  X_bottom = torch.cat((X_f_train_4[:min_04], X_f_train_3[:min_13]),dim=0) \n",
        "  if randomize_m:\n",
        "    shuffle = torch.randperm(len(X_top))\n",
        "    X_top = X_top[shuffle]\n",
        "  if randomize_f:\n",
        "    shuffle = torch.randperm(len(X_bottom))\n",
        "    X_bottom = X_bottom[shuffle]\n",
        "  X_train = torch.cat((X_top, X_bottom), dim=2)\n",
        "  Y_train = torch.cat((torch.zeros((min_04,)), torch.ones((min_13,))))\n",
        "  shuffle = torch.randperm(len(X_train))\n",
        "  X_train = X_train[shuffle]\n",
        "  Y_train = Y_train[shuffle].float().view(-1,1)\n",
        "  data_train = torch.utils.data.TensorDataset(X_train.to(DEVICE), Y_train.to(DEVICE))\n",
        "  return data_train\n",
        "\n",
        "\n",
        "transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])\n",
        "\n",
        "mnist_train = torchvision.datasets.MNIST('./data/mnist/', train=True, download=True, transform=transform)\n",
        "fashm_train = torchvision.datasets.FashionMNIST('./data/FashionMNIST/', train=True, download=True, transform=transform)\n",
        "mnist_perturb_base, mnist_train, mnist_valid = random_split(mnist_train, [10000, 45000, 5000], generator=torch.Generator().manual_seed(42))\n",
        "fashm_perturb_base, fashm_train, fashm_valid = random_split(fashm_train, [10000, 45000, 5000], generator=torch.Generator().manual_seed(42))\n",
        "\n",
        "mnist_test = torchvision.datasets.MNIST('./data/mnist/', train=False, download=True, transform=transform)\n",
        "fashm_test = torchvision.datasets.FashionMNIST('./data/FashionMNIST/', train=False, download=True, transform=transform)\n",
        "\n",
        "# Training / valid / test datasets\n",
        "data_train = build_mf_dataset(mnist_train, fashm_train)\n",
        "data_valid = build_mf_dataset(mnist_valid, fashm_valid)\n",
        "data_test = build_mf_dataset(mnist_test, fashm_test)\n",
        "\n",
        "train_dl = torch.utils.data.DataLoader(data_train, batch_size=256, shuffle=True)\n",
        "valid_dl = torch.utils.data.DataLoader(data_valid, batch_size=1024, shuffle=False)\n",
        "test_dl = torch.utils.data.DataLoader(data_test, batch_size=1024, shuffle=False)\n",
        "\n",
        "\n",
        "# MNIST randomized test / valid datasets\n",
        "data_test_rm = build_mf_dataset(mnist_test, fashm_test, randomize_m=True, randomize_f=False)\n",
        "data_valid_rm = build_mf_dataset(mnist_valid, fashm_valid, randomize_m=True, randomize_f=False)\n",
        "\n",
        "test_rm_dl = torch.utils.data.DataLoader(data_test_rm, batch_size=1024, shuffle=False)\n",
        "valid_rm_dl = torch.utils.data.DataLoader(data_valid_rm, batch_size=1024, shuffle=False)\n",
        "\n",
        "# F-MNIST randomized test / valid datasets\n",
        "data_test_rf = build_mf_dataset(mnist_test, fashm_test, randomize_m=False, randomize_f=True)\n",
        "data_valid_rf = build_mf_dataset(mnist_valid, fashm_valid, randomize_m=False, randomize_f=True)\n",
        "\n",
        "test_rf_dl = torch.utils.data.DataLoader(data_test_rf, batch_size=1024, shuffle=False)\n",
        "valid_rf_dl = torch.utils.data.DataLoader(data_valid_rf, batch_size=1024, shuffle=False)\n",
        "\n",
        "\n",
        "print(f\"Train length: {len(train_dl.dataset)}\")\n",
        "print(f\"Test length: {len(test_dl.dataset)}\")\n",
        "print(f\"Test length randomized mnist: {len(test_rm_dl.dataset)}\")\n",
        "print(f\"Test length randomized cifar10: {len(test_rf_dl.dataset)}\")\n",
        "\n",
        "print(\"Non-randomized train dataset:\")\n",
        "plot_samples(data_train)\n",
        "\n",
        "print(\"Non-randomized test dataset:\")\n",
        "plot_samples(data_test)\n",
        "\n",
        "print(\"MNIST-randomized test dataset:\")\n",
        "plot_samples(data_test_rm)\n",
        "\n",
        "print(\"CIFAR10-randomized test dataset:\")\n",
        "plot_samples(data_test_rf)"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "j3IAtYwOdowN"
      },
      "source": [
        "# Utils"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Bmj1GdVFdo5G"
      },
      "outputs": [],
      "source": [
        "@torch.no_grad()\n",
        "def get_acc(model, dl, idx=None):\n",
        "  if idx is None:  # Then it's D-BAT\n",
        "    model.eval()\n",
        "    acc = []\n",
        "    for X, y in dl:\n",
        "      acc.append((torch.sigmoid(model(X)) > 0.5) == y)\n",
        "    acc = torch.cat(acc)\n",
        "    acc = torch.sum(acc)/len(acc)\n",
        "    model.train()\n",
        "    return acc.item()\n",
        "  else:  # Then it's DivDis\n",
        "    model.eval()\n",
        "    acc = []\n",
        "    for X, y in dl:\n",
        "      acc.append((torch.sigmoid(model(X))[:, idx][:, None] > 0.5) == y)\n",
        "    acc = torch.cat(acc)\n",
        "    acc = torch.sum(acc)/len(acc)\n",
        "    model.train()\n",
        "    return acc.item()\n",
        "\n",
        "\n",
        "@torch.no_grad()\n",
        "def get_ens_acc(ensemble, dl):\n",
        "  if isinstance(ensemble, nn.Module):  # Then DivDis\n",
        "    model = ensemble\n",
        "    model.eval()\n",
        "    acc = []\n",
        "    for X, y in dl:\n",
        "      probs = torch.sigmoid(model(X)).mean(dim=0)\n",
        "      acc.append((probs > 0.5) == y)\n",
        "    acc = torch.cat(acc)\n",
        "    acc = torch.sum(acc)/len(acc)\n",
        "    model.train()\n",
        "    return acc.item()\n",
        "  else:  # Then D-BAT\n",
        "    for model in ensemble:\n",
        "      model.eval()\n",
        "    acc = []\n",
        "    for X, y in dl:\n",
        "      probs = [torch.sigmoid(model(X)) for model in ensemble]\n",
        "      probs = torch.stack(probs).mean(dim=0)\n",
        "      acc.append((probs > 0.5) == y)\n",
        "    acc = torch.cat(acc)\n",
        "    acc = torch.sum(acc)/len(acc)\n",
        "    for model in ensemble:\n",
        "      model.train()\n",
        "    return acc.item()\n",
        "\n",
        "\n",
        "def dl_to_sampler(dl):\n",
        "  dl_iter = iter(dl)\n",
        "  def sample():\n",
        "    nonlocal dl_iter\n",
        "    try:\n",
        "      return next(dl_iter)\n",
        "    except StopIteration:\n",
        "      dl_iter = iter(dl)\n",
        "      return next(dl_iter)\n",
        "  return sample\n",
        "\n",
        "\n",
        "def print_stats(stats):\n",
        "\n",
        "  fig, (ax1, ax2, ax3, ax4, ax5) = plt.subplots(1,5,figsize=(16,3), dpi=110)\n",
        "  ax1.grid()\n",
        "  ax2.grid()\n",
        "  ax3.grid()\n",
        "  ax4.grid()\n",
        "  ax5.grid()\n",
        "\n",
        "  ax1.set_title(\"ERM loss\")\n",
        "  ax2.set_title(\"Adv Loss\")\n",
        "  ax3.set_title(\"Acc\")\n",
        "  ax4.set_title(\"Randomized MNIST Acc\")\n",
        "  ax5.set_title(\"Randomized CIFAR Acc\")\n",
        "  \n",
        "  ax1.set_xlabel(\"iterations\")\n",
        "  ax2.set_xlabel(\"iterations\")\n",
        "  ax3.set_xlabel(\"iterations\")\n",
        "  ax4.set_xlabel(\"iterations\")\n",
        "  ax5.set_xlabel(\"iterations\")\n",
        "\n",
        "  for m_id, m_stats in stats.items():\n",
        "    if m_id[0] != 'm':\n",
        "      continue\n",
        "    itrs = [x[0] for x in m_stats['loss']]\n",
        "    ax1.plot(itrs, [x[1] for x in m_stats['loss']], label=m_id)\n",
        "    ax2.plot(itrs, [x[1] for x in m_stats['adv-loss']], label=m_id)\n",
        "    ax3.plot(itrs, [x[1] for x in m_stats['acc']], label=m_id)\n",
        "    ax4.plot(itrs, [x[1] for x in m_stats['rm-acc']], label=m_id)\n",
        "    ax5.plot(itrs, [x[1] for x in m_stats['rc-acc']], label=m_id)\n",
        "\n",
        "  ax3.set_ylim(0.45, 1.05)\n",
        "  ax4.set_ylim(0.45, 1.05)\n",
        "  ax5.set_ylim(0.45, 1.05)"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "ExUrHLZkgDFQ"
      },
      "source": [
        "# Model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "CMbmcEciJija"
      },
      "outputs": [],
      "source": [
        "class LeNet(nn.Module):\n",
        "\n",
        "    def __init__(self, num_classes=10, dropout_p=0.0) -> None:\n",
        "        super().__init__()\n",
        "        self.droput_p = dropout_p\n",
        "        self.conv1 = nn.Conv2d(3, 32, kernel_size=5, padding=2)\n",
        "        self.conv2 = nn.Conv2d(32, 56, kernel_size=5)\n",
        "        self.fc1 = nn.Linear(2016, 512)\n",
        "        self.fc2 = nn.Linear(512, 256)\n",
        "        self.fc3 = nn.Linear(256, num_classes)\n",
        "        self.relu = nn.ReLU()\n",
        "        self.avgpool_2 = nn.AvgPool2d(kernel_size=2)\n",
        "        self.avgpool_3 = nn.AvgPool2d(kernel_size=3)\n",
        "\n",
        "    def forward(self, x: torch.Tensor, dropout=True) -> torch.Tensor:\n",
        "        x = self.relu(self.conv1(x))\n",
        "        x = F.dropout(x, p=self.droput_p, training=dropout)\n",
        "        x = self.avgpool_2(x)\n",
        "        x = self.relu(self.conv2(x))\n",
        "        x = F.dropout(x, p=self.droput_p, training=dropout)\n",
        "        x = self.avgpool_3(x)\n",
        "        x = torch.flatten(x, start_dim=1)\n",
        "        x = self.fc1(x)\n",
        "        x = self.relu(x)\n",
        "        x = F.dropout(x, p=self.droput_p, training=dropout)\n",
        "        x = self.fc2(x)\n",
        "        x = self.relu(x)\n",
        "        x = F.dropout(x, p=self.droput_p, training=dropout)\n",
        "        x = self.fc3(x)\n",
        "        return x\n",
        "\n",
        "    \n",
        "def set_train_mode(models):\n",
        "  for m in models:\n",
        "    m.train()\n",
        "\n",
        "\n",
        "def set_eval_mode(models):\n",
        "  for m in models:\n",
        "    m.eval()"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "35E1d3WEgMs7"
      },
      "source": [
        "# Training code"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "QzZ5mq_ZGDF6"
      },
      "outputs": [],
      "source": [
        "def sequential_train_dbat(num_models, train_dl, valid_dl, valid_rm_dl, valid_rc_dl, test_dl, test_rm_dl, \n",
        "                     test_rc_dl, perturb_dl, alpha=10, max_epoch=100, opt='SGD',\n",
        "                     use_diversity_reg=True, reg_model_weights=None, lr_max=0.2, weight_decay=1e-5, use_scheduler=True):\n",
        "  \n",
        "  models = [LeNet(num_classes=1).to(DEVICE) for _ in range(num_models)]\n",
        "  set_train_mode(models)\n",
        "  \n",
        "  stats = {f\"m{i+1}\": {\"acc\": [], \"rm-acc\": [], \"rc-acc\": [], \"loss\": [], \"adv-loss\": []} for i in range(len(models))}\n",
        "\n",
        "  if reg_model_weights is None:\n",
        "    reg_model_weights = [1.0 for _ in range(num_models)]\n",
        "\n",
        "  for m_idx, m in enumerate(models):\n",
        "\n",
        "    if opt == 'SGD':\n",
        "      opt = torch.optim.SGD(m.parameters(), lr=lr_max, momentum=0.9, weight_decay=weight_decay)\n",
        "    else:\n",
        "      opt = torch.optim.Adam(m.parameters(), lr=lr_max, weight_decay=weight_decay)\n",
        "    if use_scheduler:\n",
        "      scheduler = torch.optim.lr_scheduler.CyclicLR(opt, 0, lr_max, step_size_up=(len(train_dl)*max_epoch)//2, \n",
        "                                                    mode='triangular', cycle_momentum=False)\n",
        "    else:\n",
        "      scheduler = None\n",
        "    perturb_sampler = dl_to_sampler(perturb_dl)\n",
        "\n",
        "    for epoch in range(max_epoch):\n",
        "      for itr, (x, y) in enumerate(train_dl):\n",
        "        (x_tilde, _) = perturb_sampler()\n",
        "        erm_loss = F.binary_cross_entropy_with_logits(m(x), y)\n",
        "        \n",
        "        if use_diversity_reg and m_idx != 0:\n",
        "          adv_loss = []\n",
        "          with torch.no_grad():\n",
        "            set_eval_mode(models)\n",
        "            ps = [torch.sigmoid(m_(x_tilde)) for m_ in models[:m_idx]]\n",
        "            set_train_mode(models)\n",
        "          psm = torch.sigmoid(m(x_tilde))\n",
        "          for i in range(len(ps)):\n",
        "            al = - ((1.-ps[i]) * psm + ps[i] * (1.-psm) + 1e-7).log().mean()\n",
        "            adv_loss.append(al*reg_model_weights[i])\n",
        "        else:\n",
        "          adv_loss = [torch.tensor([0]).to(DEVICE)]\n",
        "\n",
        "        adv_loss = sum(adv_loss)/sum(reg_model_weights[:len(adv_loss)])\n",
        "        loss = erm_loss + alpha * adv_loss\n",
        "\n",
        "        opt.zero_grad()\n",
        "        loss.backward()\n",
        "        opt.step()\n",
        "        if scheduler is not None: scheduler.step()\n",
        "\n",
        "        if (itr + epoch * len(train_dl)) % 200 == 0:\n",
        "          set_eval_mode(models)\n",
        "          itr_ = itr + epoch * len(train_dl)\n",
        "          print_str = f\"[m{m_idx+1}] {epoch}/{itr_} [train] loss: {erm_loss.item():.2f} adv-loss: {adv_loss.item():.2f} \"\n",
        "          if itr_ != 0 and scheduler is not None:\n",
        "            print_str += f\"[lr] {scheduler.get_last_lr()[0]:.5f} \"\n",
        "          stats[f\"m{m_idx+1}\"][\"loss\"].append((itr_, erm_loss.item()))\n",
        "          stats[f\"m{m_idx+1}\"][\"adv-loss\"].append((itr_, adv_loss.item()))\n",
        "          acc = get_acc(m, valid_dl)\n",
        "          acc_rm = get_acc(m, valid_rm_dl)\n",
        "          acc_rc = get_acc(m, valid_rc_dl)\n",
        "          stats[f\"m{m_idx+1}\"][\"acc\"].append((itr_, acc))\n",
        "          stats[f\"m{m_idx+1}\"][\"rm-acc\"].append((itr_, acc_rm))\n",
        "          stats[f\"m{m_idx+1}\"][\"rc-acc\"].append((itr_, acc_rc))\n",
        "          print_str += f\" acc: {acc:.2f}, {Fore.BLUE} r0/1-acc: {acc_rm:.2f} {Style.RESET_ALL}\"\n",
        "          set_train_mode(models)\n",
        "          print(print_str)\n",
        "        \n",
        "        itr += 1\n",
        "\n",
        "    test_acc = get_acc(m, test_dl)\n",
        "    test_rm_acc = get_acc(m, test_rm_dl)\n",
        "    test_rc_acc = get_acc(m, test_rc_dl)\n",
        "    acc_on_perturb = get_acc(m, perturb_dl)\n",
        "    ensemble_acc = get_ens_acc(models, test_rm_dl)\n",
        "    stats[f\"m{m_idx+1}\"][\"test-acc\"] = test_acc\n",
        "    stats[f\"m{m_idx+1}\"][\"test-rm-acc\"] = test_rm_acc\n",
        "    stats[f\"m{m_idx+1}\"][\"test-rc-acc\"] = test_rc_acc\n",
        "    stats[f\"m{m_idx+1}\"][\"ens-test-rm-acc\"] = ensemble_acc\n",
        "    print(f\"[m{m_idx+1}] [test] acc: {test_acc:.3f}, r-acc: {test_rm_acc:.3f}, r-acc-ens: {ensemble_acc:.3f}, acc-on-perturb: {acc_on_perturb:.3f}\")\n",
        "\n",
        "  return stats"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "def simultaneous_train_divdis(num_models, train_dl, valid_dl, valid_rm_dl, valid_rc_dl, test_dl, test_rm_dl, \n",
        "                     test_rc_dl, perturb_dl, alpha=10, max_epoch=100, opt='SGD',\n",
        "                     use_diversity_reg=True, reg_model_weights=None, lr_max=0.2, weight_decay=1e-5, use_scheduler=True):\n",
        "  \n",
        "  from divdis import DivDisLoss\n",
        "  loss_fn = DivDisLoss(heads=num_models, mode=\"mi\", reduction=\"mean\")\n",
        "\n",
        "  m = LeNet(num_classes=num_models).to(DEVICE)\n",
        "  m.train()\n",
        "  \n",
        "  stats = {f\"m{i+1}\": {\"acc\": [], \"rm-acc\": [], \"rc-acc\": [], \"loss\": [], \"adv-loss\": []} for i in range(num_models)}\n",
        "\n",
        "  if opt == 'SGD':\n",
        "    opt = torch.optim.SGD(m.parameters(), lr=lr_max, momentum=0.9, weight_decay=weight_decay)\n",
        "  else:\n",
        "    opt = torch.optim.Adam(m.parameters(), lr=lr_max, weight_decay=weight_decay)\n",
        "  if use_scheduler:\n",
        "    scheduler = torch.optim.lr_scheduler.CyclicLR(opt, 0, lr_max, step_size_up=(len(train_dl)*max_epoch)//2, \n",
        "                                                  mode='triangular', cycle_momentum=False)\n",
        "  else:\n",
        "    scheduler = None\n",
        "  perturb_sampler = dl_to_sampler(perturb_dl)\n",
        "\n",
        "  for epoch in range(max_epoch):\n",
        "    for itr, (x, y) in enumerate(train_dl):\n",
        "      (x_tilde,_) = perturb_sampler()\n",
        "      logits = m(x)\n",
        "      logits_chunked = torch.chunk(logits, num_models, dim=-1)\n",
        "      erm_losses = [F.binary_cross_entropy_with_logits(logit, y) for logit in logits_chunked]\n",
        "      erm_loss = sum(erm_losses)\n",
        "      \n",
        "      if use_diversity_reg:\n",
        "        adv_loss = loss_fn(m(x_tilde))\n",
        "        loss = erm_loss + alpha * adv_loss\n",
        "      else:\n",
        "        loss = erm_loss\n",
        "\n",
        "      opt.zero_grad()\n",
        "      loss.backward()\n",
        "      opt.step()\n",
        "      if scheduler is not None: scheduler.step()\n",
        "\n",
        "      if (itr + epoch * len(train_dl)) % 200 == 0:\n",
        "        itr_ = itr + epoch * len(train_dl)\n",
        "        for m_idx in range(num_models):\n",
        "          m.eval()\n",
        "          print_str = f\"[m{m_idx+1}] {epoch}/{itr_} [train] loss: {erm_losses[m_idx].item():.2f} adv-loss: {adv_loss.item():.2f} \"\n",
        "          if itr_ != 0 and scheduler is not None:\n",
        "            print_str += f\"[lr] {scheduler.get_last_lr()[0]:.5f} \"\n",
        "          stats[f\"m{m_idx+1}\"][\"loss\"].append((itr_, erm_loss.item()))\n",
        "          stats[f\"m{m_idx+1}\"][\"adv-loss\"].append((itr_, adv_loss.item()))\n",
        "          acc = get_acc(m, valid_dl, idx=m_idx)\n",
        "          acc_rm = get_acc(m, valid_rm_dl, idx=m_idx)\n",
        "          acc_rc = get_acc(m, valid_rc_dl, idx=m_idx)\n",
        "          stats[f\"m{m_idx+1}\"][\"acc\"].append((itr_, acc))\n",
        "          stats[f\"m{m_idx+1}\"][\"rm-acc\"].append((itr_, acc_rm))\n",
        "          stats[f\"m{m_idx+1}\"][\"rc-acc\"].append((itr_, acc_rc))\n",
        "          print_str += f\" acc: {acc:.2f}, {Fore.BLUE} r0/1-acc: {acc_rm:.2f} {Style.RESET_ALL}\"\n",
        "          m.train()\n",
        "          print(print_str)\n",
        "      \n",
        "      itr += 1\n",
        "  for m_idx in range(num_models):\n",
        "    m.eval()\n",
        "    test_acc = get_acc(m, test_dl, idx=m_idx)\n",
        "    test_rm_acc = get_acc(m, test_rm_dl, idx=m_idx)\n",
        "    test_rc_acc = get_acc(m, test_rc_dl, idx=m_idx)\n",
        "    acc_on_perturb = get_acc(m, perturb_dl, idx=m_idx)\n",
        "    ensemble_acc = get_ens_acc(m, test_rm_dl)\n",
        "    stats[f\"m{m_idx+1}\"][\"test-acc\"] = test_acc\n",
        "    stats[f\"m{m_idx+1}\"][\"test-rm-acc\"] = test_rm_acc\n",
        "    stats[f\"m{m_idx+1}\"][\"test-rc-acc\"] = test_rc_acc\n",
        "    stats[f\"m{m_idx+1}\"][\"ens-test-rm-acc\"] = ensemble_acc\n",
        "    print(f\"[m{m_idx+1}] [test] acc: {test_acc:.3f}, r-acc: {test_rm_acc:.3f}, r-acc-ens: {ensemble_acc:.3f}, acc-on-perturb: {acc_on_perturb:.3f}\")\n",
        "\n",
        "  return stats"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "def sequential_train_divdis(num_models, train_dl, valid_dl, valid_rm_dl, valid_rc_dl, test_dl, test_rm_dl, \n",
        "                     test_rc_dl, perturb_dl, alpha=10, max_epoch=100, opt='SGD',\n",
        "                     use_diversity_reg=True, reg_model_weights=None, lr_max=0.2, weight_decay=1e-5, use_scheduler=True):\n",
        "\n",
        "  from divdis import DivDisLoss\n",
        "  loss_fn = DivDisLoss(heads=num_models, mode=\"mi\", reduction=\"mean\")\n",
        "  \n",
        "  models = [LeNet(num_classes=1).to(DEVICE) for _ in range(num_models)]\n",
        "  set_train_mode(models)\n",
        "  \n",
        "  stats = {f\"m{i+1}\": {\"acc\": [], \"rm-acc\": [], \"rc-acc\": [], \"loss\": [], \"adv-loss\": []} for i in range(len(models))}\n",
        "\n",
        "  for m_idx, m in enumerate(models):\n",
        "\n",
        "    if opt == 'SGD':\n",
        "      opt = torch.optim.SGD(m.parameters(), lr=lr_max, momentum=0.9, weight_decay=weight_decay)\n",
        "    else:\n",
        "      opt = torch.optim.Adam(m.parameters(), lr=lr_max, weight_decay=weight_decay)\n",
        "    if use_scheduler:\n",
        "      scheduler = torch.optim.lr_scheduler.CyclicLR(opt, 0, lr_max, step_size_up=(len(train_dl)*max_epoch)//2, \n",
        "                                                    mode='triangular', cycle_momentum=False)\n",
        "    else:\n",
        "      scheduler = None\n",
        "    perturb_sampler = dl_to_sampler(perturb_dl)\n",
        "\n",
        "    for epoch in range(max_epoch):\n",
        "      for itr, (x, y) in enumerate(train_dl):\n",
        "        (x_tilde,_) = perturb_sampler()\n",
        "        erm_loss = F.binary_cross_entropy_with_logits(m(x), y)\n",
        "        \n",
        "        if use_diversity_reg and m_idx != 0:\n",
        "          with torch.no_grad():\n",
        "            set_eval_mode(models)\n",
        "            prev_logits = [m_(x_tilde) for m_ in models[:m_idx]]\n",
        "            set_train_mode(models)\n",
        "          curr_logits = m(x_tilde)\n",
        "          prev_logits.append(curr_logits)\n",
        "          target_logits = torch.cat(prev_logits, dim=-1)\n",
        "          adv_loss = loss_fn(target_logits)\n",
        "          loss = erm_loss + alpha * adv_loss\n",
        "        else:\n",
        "          adv_loss = torch.tensor([0])\n",
        "          loss = erm_loss\n",
        "        \n",
        "        opt.zero_grad()\n",
        "        loss.backward()\n",
        "        opt.step()\n",
        "        if scheduler is not None: scheduler.step()\n",
        "\n",
        "        if (itr + epoch * len(train_dl)) % 200 == 0:\n",
        "          set_eval_mode(models)\n",
        "          itr_ = itr + epoch * len(train_dl)\n",
        "          print_str = f\"[m{m_idx+1}] {epoch}/{itr_} [train] loss: {erm_loss.item():.2f} adv-loss: {adv_loss.item():.2f} \"\n",
        "          if itr_ != 0 and scheduler is not None:\n",
        "            print_str += f\"[lr] {scheduler.get_last_lr()[0]:.5f} \"\n",
        "          stats[f\"m{m_idx+1}\"][\"loss\"].append((itr_, erm_loss.item()))\n",
        "          stats[f\"m{m_idx+1}\"][\"adv-loss\"].append((itr_, adv_loss.item()))\n",
        "          acc = get_acc(m, valid_dl)\n",
        "          acc_rm = get_acc(m, valid_rm_dl)\n",
        "          acc_rc = get_acc(m, valid_rc_dl)\n",
        "          stats[f\"m{m_idx+1}\"][\"acc\"].append((itr_, acc))\n",
        "          stats[f\"m{m_idx+1}\"][\"rm-acc\"].append((itr_, acc_rm))\n",
        "          stats[f\"m{m_idx+1}\"][\"rc-acc\"].append((itr_, acc_rc))\n",
        "          print_str += f\" acc: {acc:.2f}, {Fore.BLUE} r0/1-acc: {acc_rm:.2f} {Style.RESET_ALL}\"\n",
        "          set_train_mode(models)\n",
        "          print(print_str)\n",
        "        \n",
        "        itr += 1\n",
        "\n",
        "    test_acc = get_acc(m, test_dl)\n",
        "    test_rm_acc = get_acc(m, test_rm_dl)\n",
        "    test_rc_acc = get_acc(m, test_rc_dl)\n",
        "    acc_on_perturb = get_acc(m, perturb_dl)\n",
        "    ensemble_acc = get_ens_acc(models, test_rm_dl)\n",
        "    stats[f\"m{m_idx+1}\"][\"test-acc\"] = test_acc\n",
        "    stats[f\"m{m_idx+1}\"][\"test-rm-acc\"] = test_rm_acc\n",
        "    stats[f\"m{m_idx+1}\"][\"test-rc-acc\"] = test_rc_acc\n",
        "    stats[f\"m{m_idx+1}\"][\"ens-test-rm-acc\"] = ensemble_acc\n",
        "    print(f\"[m{m_idx+1}] [test] acc: {test_acc:.3f}, r-acc: {test_rm_acc:.3f}, r-acc-ens: {ensemble_acc:.3f}, acc-on-perturb: {acc_on_perturb:.3f}\")\n",
        "\n",
        "  return stats"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "zYOM2OZNp25o"
      },
      "source": [
        "# Experiments with $\\mathcal{D}_\\text{ood} = \\mathcal{D}_\\text{test}$"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "def build_mc_perturb_dataset(mnist_test, cifar_test, including_labels=True, degree_of_balance=None):\n",
        "  assert degree_of_balance >= 0 and degree_of_balance <= 1\n",
        "\n",
        "  if degree_of_balance == 1:\n",
        "    X_m_test, Y_m_test = keep_only_lbls(mnist_test, lbls=[0,1])\n",
        "    m_rand_idx = torch.randperm(len(X_m_test))\n",
        "    X_m_test = format_mnist(X_m_test.view(-1, 1, 28, 28))[m_rand_idx]\n",
        "    Y_m_test = Y_m_test[m_rand_idx]\n",
        "\n",
        "    X_c_test, Y_c_test = keep_only_lbls(cifar_test, lbls=[1,9])\n",
        "    c_rand_idx = torch.randperm(len(X_c_test))\n",
        "    X_c_test = X_c_test[c_rand_idx]\n",
        "    Y_c_test = Y_c_test[c_rand_idx]\n",
        "\n",
        "    min_l = min(len(X_m_test), len(X_c_test)) # min(2067, 1996)\n",
        "    X = torch.cat((X_m_test[:min_l], X_c_test[:min_l]),  axis=2)\n",
        "    Y = Y_c_test[:min_l].float().view(-1,1)\n",
        "\n",
        "    if including_labels:\n",
        "      data_perturb = torch.utils.data.TensorDataset(X.to(DEVICE), Y.to(DEVICE))\n",
        "    else:\n",
        "      data_perturb = torch.utils.data.TensorDataset(X.to(DEVICE))\n",
        "\n",
        "    return data_perturb\n",
        "\n",
        "  else:\n",
        "    # Filter the class 0 and 1 in the given MNIST data\n",
        "    X_m_0, _ = keep_only_lbls(mnist_test, lbls=[0])\n",
        "    X_m_1, _ = keep_only_lbls(mnist_test, lbls=[1])\n",
        "    X_m_0 = format_mnist(X_m_0.view(-1, 1, 28, 28))\n",
        "    X_m_1 = format_mnist(X_m_1.view(-1, 1, 28, 28))\n",
        "    # Filter the car and truck for the given CIFAR10 data\n",
        "    X_c_1, _ = keep_only_lbls(cifar_test, lbls=[1])\n",
        "    X_c_9, _ = keep_only_lbls(cifar_test, lbls=[9])\n",
        "\n",
        "    # Shuffle\n",
        "    X_c_1 = X_c_1[torch.randperm(len(X_c_1))]\n",
        "    X_c_9 = X_c_9[torch.randperm(len(X_c_9))]\n",
        "    X_m_0 = X_m_0[torch.randperm(len(X_m_0))]\n",
        "    X_m_1 = X_m_1[torch.randperm(len(X_m_1))]\n",
        "\n",
        "    # Find the shorter one\n",
        "    min_11 = min(len(X_m_1), len(X_c_1))\n",
        "    min_09 = min(len(X_m_0), len(X_c_9))\n",
        "    X_top = torch.cat((X_m_1[:min_11], X_m_0[:min_09]),dim=0)\n",
        "    X_bottom = torch.cat((X_c_1[:min_11], X_c_9[:min_09]),dim=0)\n",
        "\n",
        "    balanced_len = int(degree_of_balance * len(X_top))\n",
        "    partial_shuffle = torch.cat([torch.randperm(balanced_len), torch.arange(balanced_len, len(X_top))])\n",
        "    X_top = X_top[partial_shuffle]\n",
        "\n",
        "    X = torch.cat((X_top, X_bottom), dim=2)\n",
        "    Y = torch.cat((torch.ones((min_11,)), torch.zeros((min_09,))))\n",
        "\n",
        "    shuffle = torch.randperm(len(X))\n",
        "    X = X[shuffle]\n",
        "    Y = Y[shuffle].float().view(-1,1)\n",
        "    if including_labels:\n",
        "      data_perturb = torch.utils.data.TensorDataset(X.to(DEVICE), Y.to(DEVICE))\n",
        "    else:\n",
        "      data_perturb = torch.utils.data.TensorDataset(X.to(DEVICE))\n",
        "    return data_perturb"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 233
        },
        "id": "bkJKGj7dp3CX",
        "outputId": "1113b368-a859-4197-8637-2674fc2725a0"
      },
      "outputs": [],
      "source": [
        "mnist_test = mnist_perturb_base\n",
        "cifar_test = cifar_perturb_base\n",
        "\n",
        "data_perturb = build_mc_perturb_dataset(mnist_test, cifar_test, including_labels=True, degree_of_balance=0.0)\n",
        "\n",
        "perturb_dl = torch.utils.data.DataLoader(data_perturb, batch_size=256, shuffle=True)\n",
        "\n",
        "print(f\"OOD dataset size: {len(perturb_dl.dataset)}\")\n",
        "\n",
        "print(\"OOD dataset:\")\n",
        "plot_samples(data_perturb)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lvny2VG6p3FW"
      },
      "outputs": [],
      "source": [
        "all_stats = []\n",
        "for _ in range(1):\n",
        "  stats = sequential_train_dbat(2, train_dl, valid_dl, valid_rm_dl, valid_rc_dl, test_dl, test_rm_dl, test_rc_dl, \n",
        "                           perturb_dl, alpha=5, max_epoch=70, lr_max=0.001, use_scheduler=True, opt=\"Adam\")\n",
        "  all_stats.append(stats)\n",
        "  print_stats(stats)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "all_stats = []\n",
        "for _ in range(1):\n",
        "  stats = simultaneous_train_divdis(2, train_dl, valid_dl, valid_rm_dl, valid_rc_dl, test_dl, test_rm_dl, test_rc_dl, \n",
        "                           perturb_dl, alpha=50, max_epoch=70, lr_max=0.001, use_scheduler=True, opt=\"Adam\")\n",
        "  all_stats.append(stats)\n",
        "  print_stats(stats)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "all_stats = []\n",
        "for _ in range(1):\n",
        "  stats = sequential_train_divdis(2, train_dl, valid_dl, valid_rm_dl, valid_rc_dl, test_dl, test_rm_dl, test_rc_dl, \n",
        "                           perturb_dl, alpha=50, max_epoch=70, lr_max=0.001, use_scheduler=True, opt=\"Adam\")\n",
        "  all_stats.append(stats)\n",
        "  print_stats(stats)"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.10 (default, Jun 22 2022, 20:18:18) \n[GCC 9.4.0]"
    },
    "vscode": {
      "interpreter": {
        "hash": "916dbcbb3f70747c44a77c7bcd40155683ae19c65e1c03b4aa3499c5328201f1"
      }
    },
    "widgets": {
      "application/vnd.jupyter.widget-state+json": {
        "02bf2addd49248bc865b8515d2427739": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "02f58aca7dee4a3a9f631a6a87c7668b": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "FloatProgressModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_d222845c70134afc8e569a413b8b07af",
            "max": 1648877,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_d4fabb860e8849288979fe79786d6f70",
            "value": 1648877
          }
        },
        "061c6ab8faf443dfa6e63f4162328d57": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "08195e51bf244c0b9d5c83e2d2b3b2c2": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HBoxModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_f3c008785b334c93bb744c0dfa032b3b",
              "IPY_MODEL_2e6efdb8a1bc4818aee537182a04dac3",
              "IPY_MODEL_4e02d6ee3e994276941bded69012d632"
            ],
            "layout": "IPY_MODEL_7a772f9fc8124493b99e7c203a2fd535"
          }
        },
        "0ace65dc15f0472fb47ab899a0551e9a": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "0b9611c5b47d418d80e031710bbc3a97": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_7e9f2a6fb4f14f1fab863a1ea655b639",
            "placeholder": "​",
            "style": "IPY_MODEL_1ffdca0097e74db18017a7987f46f46c",
            "value": "100%"
          }
        },
        "11e0abca12ea40c98aecc6f73d825bcf": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "129f62f0bef9443c9e143935531f0fad": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "1db0d5ee53bd47e7bd2870aa6c8639b1": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "1ffdca0097e74db18017a7987f46f46c": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "25dcd06a0db041d4bcd2d370709810a5": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "2d931ca4fc49498fbd793abd1ed7b7f5": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_c32a83d1b93d4cd3acd1cac758ab82f9",
            "placeholder": "​",
            "style": "IPY_MODEL_9e413d8fe45945219c88e333d3a822f6",
            "value": " 1648877/1648877 [00:00&lt;00:00, 9068243.98it/s]"
          }
        },
        "2e6efdb8a1bc4818aee537182a04dac3": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "FloatProgressModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_7baa1e51abc3434b9eebfbf975d50188",
            "max": 170498071,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_fb5229eab1cb452c831e5229e9ed5ac3",
            "value": 170498071
          }
        },
        "3308ab945b3248f4a94f3a50fa0ad00a": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "37fb9b7807dc49f99fd2a37f4a61ca05": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_889b16931d6d43afa0d7fde84f27e929",
            "placeholder": "​",
            "style": "IPY_MODEL_a2fa795380574174a28294214fab29bc",
            "value": "100%"
          }
        },
        "3c7baf6fe7c54a66baa6ca1c6b162774": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HBoxModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_42f7a54652bc4e12818caa0ac51688db",
              "IPY_MODEL_c2522ec6f48f4c35aff0c2f90037175a",
              "IPY_MODEL_78890e24670f4f5ba9982708475d2a4a"
            ],
            "layout": "IPY_MODEL_11e0abca12ea40c98aecc6f73d825bcf"
          }
        },
        "42f7a54652bc4e12818caa0ac51688db": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_e3139d6003834907838bccef67f7471a",
            "placeholder": "​",
            "style": "IPY_MODEL_c91214aeb3ae455a96f61b09e5a4abb0",
            "value": "100%"
          }
        },
        "43899976ba064c39a360f5f980b0caf4": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_0ace65dc15f0472fb47ab899a0551e9a",
            "placeholder": "​",
            "style": "IPY_MODEL_bbcc01bcba3f4e61b00795d95ae3a5d8",
            "value": "100%"
          }
        },
        "45230628dc164fc5b6a6b736f9568903": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "4e02d6ee3e994276941bded69012d632": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_061c6ab8faf443dfa6e63f4162328d57",
            "placeholder": "​",
            "style": "IPY_MODEL_fd85711851c340f787245de139b7bd19",
            "value": " 170498071/170498071 [00:06&lt;00:00, 30567021.71it/s]"
          }
        },
        "500a9b18a3fa425da61080412bc53115": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "50c323e671af4e279880557b3e461618": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "ProgressStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "57663c6514c546f08f23e6962b040a27": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HBoxModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_c59d0e94b22045deb2146cafc5253757",
              "IPY_MODEL_6d91e5a5ac7b405abb724eac3875bd26",
              "IPY_MODEL_f741d9f146484900be3ac5779b1f76b3"
            ],
            "layout": "IPY_MODEL_25dcd06a0db041d4bcd2d370709810a5"
          }
        },
        "663ce22902af4f4c94aedbdb9c551558": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "68dc44b7570e455e940d8cab6e7d4154": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "FloatProgressModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_7a9fe76e5f6f422380dd626086483613",
            "max": 170498071,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_7f071b80e275427abc41b4b43f3a6bfa",
            "value": 170498071
          }
        },
        "6a8fd1f5529249cc9c4bf047e408b8ea": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "6d91e5a5ac7b405abb724eac3875bd26": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "FloatProgressModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_edb8b9e215a648439c6c4face3ada113",
            "max": 9912422,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_50c323e671af4e279880557b3e461618",
            "value": 9912422
          }
        },
        "7645ccf2a7d24715a129f0c4831d404b": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HBoxModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_43899976ba064c39a360f5f980b0caf4",
              "IPY_MODEL_02f58aca7dee4a3a9f631a6a87c7668b",
              "IPY_MODEL_2d931ca4fc49498fbd793abd1ed7b7f5"
            ],
            "layout": "IPY_MODEL_3308ab945b3248f4a94f3a50fa0ad00a"
          }
        },
        "78890e24670f4f5ba9982708475d2a4a": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_129f62f0bef9443c9e143935531f0fad",
            "placeholder": "​",
            "style": "IPY_MODEL_500a9b18a3fa425da61080412bc53115",
            "value": " 28881/28881 [00:00&lt;00:00, 297392.49it/s]"
          }
        },
        "7a772f9fc8124493b99e7c203a2fd535": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "7a9fe76e5f6f422380dd626086483613": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "7aef4900d2284e0c9b00326ac5d9c479": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "7baa1e51abc3434b9eebfbf975d50188": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "7bccf661f3904cfe965eb9a40b2d075d": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HBoxModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_0b9611c5b47d418d80e031710bbc3a97",
              "IPY_MODEL_68dc44b7570e455e940d8cab6e7d4154",
              "IPY_MODEL_a7d3d59a874a4041927ec17107f7425b"
            ],
            "layout": "IPY_MODEL_6a8fd1f5529249cc9c4bf047e408b8ea"
          }
        },
        "7e9f2a6fb4f14f1fab863a1ea655b639": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "7f071b80e275427abc41b4b43f3a6bfa": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "ProgressStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "866d4df1f14147ab9b55906ddb8de51a": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_02bf2addd49248bc865b8515d2427739",
            "placeholder": "​",
            "style": "IPY_MODEL_663ce22902af4f4c94aedbdb9c551558",
            "value": " 4542/4542 [00:00&lt;00:00, 66545.09it/s]"
          }
        },
        "889b16931d6d43afa0d7fde84f27e929": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "8bd6858e05c24935b352907ba363f438": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "8f13ee26b3cd4a9aae22ea4480981c7a": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "9e413d8fe45945219c88e333d3a822f6": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "a2de69155b7d4adf881bb5aa777b01ef": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "a2fa795380574174a28294214fab29bc": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "a7d3d59a874a4041927ec17107f7425b": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_1db0d5ee53bd47e7bd2870aa6c8639b1",
            "placeholder": "​",
            "style": "IPY_MODEL_8f13ee26b3cd4a9aae22ea4480981c7a",
            "value": " 170498071/170498071 [00:05&lt;00:00, 33386245.23it/s]"
          }
        },
        "bbcc01bcba3f4e61b00795d95ae3a5d8": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "c2522ec6f48f4c35aff0c2f90037175a": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "FloatProgressModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_d5867d70563e45e7bf9e52d5903ff657",
            "max": 28881,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_fc31d40bf2114c80a4be84eed637cb9e",
            "value": 28881
          }
        },
        "c32a83d1b93d4cd3acd1cac758ab82f9": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "c59d0e94b22045deb2146cafc5253757": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_fe9ff4d95e1a4f6b939510c11c4f58f2",
            "placeholder": "​",
            "style": "IPY_MODEL_7aef4900d2284e0c9b00326ac5d9c479",
            "value": "100%"
          }
        },
        "c91214aeb3ae455a96f61b09e5a4abb0": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "cb2ad3d2884a42569edef29b9c18423e": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HBoxModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_37fb9b7807dc49f99fd2a37f4a61ca05",
              "IPY_MODEL_ebf03580c1a8499ab713b07301ca830f",
              "IPY_MODEL_866d4df1f14147ab9b55906ddb8de51a"
            ],
            "layout": "IPY_MODEL_cdaed5fd140a4f72ab03092786eabe2b"
          }
        },
        "cdaed5fd140a4f72ab03092786eabe2b": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "d222845c70134afc8e569a413b8b07af": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "d4fabb860e8849288979fe79786d6f70": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "ProgressStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "d5867d70563e45e7bf9e52d5903ff657": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "e3139d6003834907838bccef67f7471a": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "ebf03580c1a8499ab713b07301ca830f": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "FloatProgressModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_45230628dc164fc5b6a6b736f9568903",
            "max": 4542,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_f3f84fbeb6f342bd9192ddc6f7863b13",
            "value": 4542
          }
        },
        "edb8b9e215a648439c6c4face3ada113": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "f3c008785b334c93bb744c0dfa032b3b": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_f91e1ae924a9402983ffac448b714fae",
            "placeholder": "​",
            "style": "IPY_MODEL_a2de69155b7d4adf881bb5aa777b01ef",
            "value": "100%"
          }
        },
        "f3f84fbeb6f342bd9192ddc6f7863b13": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "ProgressStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "f5f2039571f940cf8a13ac3de59e9426": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "f741d9f146484900be3ac5779b1f76b3": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_f5f2039571f940cf8a13ac3de59e9426",
            "placeholder": "​",
            "style": "IPY_MODEL_8bd6858e05c24935b352907ba363f438",
            "value": " 9912422/9912422 [00:00&lt;00:00, 32882696.06it/s]"
          }
        },
        "f91e1ae924a9402983ffac448b714fae": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "fb5229eab1cb452c831e5229e9ed5ac3": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "ProgressStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "fc31d40bf2114c80a4be84eed637cb9e": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "ProgressStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "fd85711851c340f787245de139b7bd19": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "fe9ff4d95e1a4f6b939510c11c4f58f2": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        }
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
