{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "# Hyperparameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "hidden": true,
    "id": "GMQvFN5oZVoL"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "8\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "# Select cuda device\n",
    "print(torch.cuda.device_count())\n",
    "device = 'cuda:1' if torch.cuda.is_available() else 'cpu'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "7F4maoBuXtmf"
   },
   "source": [
    "# Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "W0ifSsnGX58I",
    "outputId": "fc104731-6897-4806-d963-39c55ea07f07"
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "id": "cXcRcXgaXfpR"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-03-07 11:22:45.119591: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA\n",
      "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
      "2023-03-07 11:22:45.804003: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory\n",
      "2023-03-07 11:22:45.804066: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory\n",
      "2023-03-07 11:22:45.804072: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n"
     ]
    }
   ],
   "source": [
    "import gc\n",
    "import math\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import os\n",
    "from os.path import join\n",
    "import sys\n",
    "from tqdm import tqdm\n",
    "\n",
    "# Taken from https://github.com/Bjarten/early-stopping-pytorch/blob/master/pytorchtools.py \n",
    "from pytorchtools import EarlyStopping\n",
    "\n",
    "import torch.autograd as autograd\n",
    "from torch.autograd import Variable\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data import TensorDataset\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "from torchmetrics.image.fid import FrechetInceptionDistance\n",
    "\n",
    "from torchvision.datasets import MNIST\n",
    "import torchvision.transforms as transforms\n",
    "from torchvision.utils import make_grid, save_image\n",
    "\n",
    "\n",
    "def make_dir(folder):\n",
    "    os.makedirs(folder, exist_ok=True)\n",
    "    return folder\n",
    "\n",
    "\n",
    "ROOT = make_dir( #### make your directory #####)\n",
    "SAMPLES_DIR  = make_dir(join(ROOT, \"samples\"))\n",
    "DATASETS_DIR = make_dir(join(ROOT, \"datasets\"))\n",
    "MODELS_DIR   = make_dir(join(ROOT, \"models\"))\n",
    "METRICS_DIR  = make_dir(join(ROOT, \"metrics\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "heading_collapsed": true,
    "id": "GGnQvXLKZ2wx"
   },
   "source": [
    "# Generator + Discriminator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "hidden": true,
    "id": "5_skbBEve4PW"
   },
   "outputs": [],
   "source": [
    "class Generator(nn.Module):\n",
    "    def __init__(\n",
    "        self,\n",
    "        img_shape:   tuple = (1,28,28),\n",
    "        num_classes: int   = 10,\n",
    "        latent_dim:  int   = 100\n",
    "        ):\n",
    "      \n",
    "        super(Generator, self).__init__()\n",
    "\n",
    "        self.img_shape = img_shape\n",
    "\n",
    "        def block(in_feat, out_feat, normalize=True):\n",
    "            layers = [nn.Linear(in_feat, out_feat)]\n",
    "            if normalize:\n",
    "                layers.append(nn.BatchNorm1d(out_feat, 0.8))\n",
    "            layers.append(nn.LeakyReLU(0.2, inplace=True))\n",
    "            return layers\n",
    "\n",
    "        self.model = nn.Sequential(\n",
    "            *block(latent_dim + num_classes, 128, normalize=False),\n",
    "            *block(128, 256),\n",
    "            *block(256, 512),\n",
    "            *block(512, 1024),\n",
    "            nn.Linear(1024, int(np.prod(self.img_shape))),\n",
    "            nn.Tanh()\n",
    "        )\n",
    "\n",
    "    def forward(self, z):\n",
    "        img = self.model(z)\n",
    "        img = img.view(img.shape[0], *self.img_shape)\n",
    "        return img\n",
    "\n",
    "\n",
    "class Discriminator(nn.Module):\n",
    "    def __init__(\n",
    "        self,\n",
    "        img_shape: tuple = (1, 28, 28),\n",
    "        num_classes: int = 10\n",
    "        ):\n",
    "      \n",
    "        super(Discriminator, self).__init__()\n",
    "\n",
    "        self.img_shape = img_shape\n",
    "\n",
    "        self.model = nn.Sequential(\n",
    "            nn.Linear(int(np.prod(self.img_shape) + num_classes), 512),\n",
    "            nn.LeakyReLU(0.2, inplace=True),\n",
    "            nn.Linear(512, 256),\n",
    "            nn.LeakyReLU(0.2, inplace=True),\n",
    "            nn.Linear(256, 1),\n",
    "        )\n",
    "\n",
    "    def forward(self, img):\n",
    "        img_flat = img.view(img.shape[0], -1)\n",
    "        validity = self.model(img_flat)\n",
    "        return validity"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "rgG4-A2mZzsT"
   },
   "source": [
    "# WGAN_GP Class"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "id": "PKir0bl_Z7Zm"
   },
   "outputs": [],
   "source": [
    "def preprocess(dataset: TensorDataset):\n",
    "    return TensorDataset(dataset.tensors[0] / 255 * 2 - 1, dataset.tensors[1])\n",
    "\n",
    "\n",
    "class WGAN_GP(nn.Module):\n",
    "    def __init__(\n",
    "        self,\n",
    "        device: torch.device,\n",
    "        generation: int,\n",
    "        hparams: dict = {\n",
    "            'num_classes': 10,\n",
    "            'channels'   : 1,\n",
    "            'width'      : 28,\n",
    "            'height'     : 28,\n",
    "            'latent_dim' : 100,\n",
    "            'batch_size' : 64,\n",
    "            'n_critic'   : 5,\n",
    "            'lr'         : 0.0002,\n",
    "            'b1'         : 0.5,\n",
    "            'b2'         : 0.999,\n",
    "            'lambda_gp'  : 10\n",
    "        }\n",
    "    ):\n",
    "\n",
    "        super(WGAN_GP, self).__init__()\n",
    "        self.device = device\n",
    "        self.generation = generation\n",
    "        self.hparams = hparams\n",
    "\n",
    "        # Modules\n",
    "        self.generator = Generator().to(self.device)\n",
    "        self.discriminator = Discriminator().to(self.device)\n",
    "\n",
    "        # Optimizers\n",
    "        self.opt_G = torch.optim.Adam(self.generator.parameters(),\n",
    "                                      lr=self.hparams['lr'],\n",
    "                                      betas=(self.hparams['b1'], self.hparams['b2']))\n",
    "        self.opt_D = torch.optim.Adam(self.discriminator.parameters(),\n",
    "                                      lr=self.hparams['lr'],\n",
    "                                      betas=(self.hparams['b1'], self.hparams['b2']))\n",
    "        \n",
    "        # Samples\n",
    "        self.validation_z = self.sample_z(self.hparams['num_classes'])\n",
    "        self.validation_labels = torch.arange(self.hparams['num_classes'], device=self.device)\n",
    "\n",
    "\n",
    "\n",
    "    def compute_gradient_penalty(self, real_samples, fake_samples):\n",
    "        \"\"\"Calculates the gradient penalty loss for WGAN GP\"\"\"\n",
    "        # Random weight term for interpolation between real and fake samples\n",
    "        alpha = torch.rand(real_samples.size(0), 1, device=self.device) # Only one 1 required b/c each sample is already flattened\n",
    "        # Get random interpolation between real and fake samples\n",
    "        interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)\n",
    "        d_interpolates = self.discriminator(interpolates)\n",
    "        fake = Variable(torch.ones(real_samples.shape[0], 1, device=self.device), requires_grad=True)\n",
    "        # Get gradient w.r.t. interpolates\n",
    "        gradients = torch.autograd.grad(\n",
    "          outputs=d_interpolates,\n",
    "          inputs=interpolates,\n",
    "          grad_outputs=fake,\n",
    "          create_graph=True,\n",
    "          retain_graph=True,\n",
    "          only_inputs=True,\n",
    "        )[0]\n",
    "        gradients = gradients.view(gradients.size(0), -1) #.to(self.device)\n",
    "        gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()\n",
    "        return gradient_penalty\n",
    "        \n",
    "\n",
    "    def one_hot(self, labels):\n",
    "        return F.one_hot(labels.long(), self.hparams['num_classes'])\n",
    "\n",
    "    def cat(self, imgs, labels):\n",
    "        one_hot_labels = self.one_hot(labels) if labels.dim() == 1 else labels\n",
    "        return torch.cat((imgs, one_hot_labels), dim=1) \n",
    "      \n",
    "    def sample_z(self, length):\n",
    "        return Variable(torch.randn(length, self.hparams['latent_dim'], device=self.device))\n",
    "\n",
    "    def forward(self, z, labels):\n",
    "        return self.generator(self.cat(z, labels))\n",
    "\n",
    "    def flatten(self, imgs):\n",
    "        return imgs.reshape(-1, self.hparams['channels'] * self.hparams['width'] * self.hparams['height'])\n",
    "    \n",
    "    \n",
    "    def training_step(self, batch, batch_idx):\n",
    "\n",
    "        imgs = batch[0].to(self.device)\n",
    "        labels = batch[1].to(self.device)\n",
    "\n",
    "        self.opt_D.zero_grad()\n",
    "\n",
    "        # Sample latents\n",
    "        z = self.sample_z(len(labels))\n",
    "\n",
    "\n",
    "        ########################################################\n",
    "        # Train Discriminator\n",
    "        ########################################################\n",
    "\n",
    "        # Validity of fake images\n",
    "        fake_imgs = self.cat(self.flatten(self(z, labels)), labels)\n",
    "        fake_validity = self.discriminator(fake_imgs)\n",
    "\n",
    "        # Validity of real images\n",
    "        real_imgs = self.cat(self.flatten(Variable(imgs)), labels)\n",
    "        real_validity = self.discriminator(real_imgs)\n",
    "\n",
    "        # Calculate Discriminator Loss\n",
    "        gradient_penalty = self.compute_gradient_penalty(real_imgs, fake_imgs)\n",
    "        d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + self.hparams['lambda_gp'] * gradient_penalty\n",
    "\n",
    "        d_loss.backward()\n",
    "        self.opt_D.step()\n",
    "        self.opt_G.zero_grad()\n",
    "\n",
    "\n",
    "        ########################################################\n",
    "        # Train Generator\n",
    "        ########################################################\n",
    "        if batch_idx % self.hparams['n_critic'] == 0:\n",
    "\n",
    "            # Validity of fake images\n",
    "            fake_imgs = self.cat(self.flatten(self(z, labels)), labels)\n",
    "            fake_validity = self.discriminator(fake_imgs)\n",
    "\n",
    "            g_loss = -torch.mean(fake_validity)\n",
    "\n",
    "            g_loss.backward()\n",
    "            self.opt_G.step()\n",
    "\n",
    "            return float(d_loss.item()), float(self.hparams['lambda_gp'] * gradient_penalty.item()), float(g_loss.item())\n",
    "        \n",
    "        else:\n",
    "            return float(d_loss.item()), float(self.hparams['lambda_gp'] * gradient_penalty.item()), None\n",
    "        \n",
    "        \n",
    "    def postprocess(self, tensor: torch.Tensor):\n",
    "        return ((tensor / 2 + 0.5).clamp(0, 1) * 255).round().to(torch.uint8)\n",
    "        \n",
    "        \n",
    "    @torch.no_grad()\n",
    "    def make_new_dataset(self, previous_dataset, train=False, save=False): #r=1):\n",
    "        # Generate and save a new dataset.\n",
    "        # Save the current model.\n",
    "        labels = previous_dataset.tensors[1].to(self.device)\n",
    "        imgs = self(self.sample_z(len(labels)), labels).detach().cpu()\n",
    "        new_dataset = TensorDataset(self.postprocess(imgs), labels.cpu())\n",
    "        \n",
    "        if save:\n",
    "            if train:\n",
    "                torch.save(new_dataset, join(DATASETS_DIR, f'train_{g:03}.pt'))\n",
    "            else:\n",
    "                torch.save(new_dataset, join(DATASETS_DIR, f'test_{g:03}.pt'))\n",
    "        del imgs\n",
    "        \n",
    "        # Convert new dataset to [-1, 1]\n",
    "        return preprocess(new_dataset)\n",
    "    \n",
    "    @torch.no_grad()\n",
    "    def calculate_fid(self, dataset: TensorDataset):\n",
    "        # Calculate FID:\n",
    "        \n",
    "        real_imgs = dataset.tensors[0].to(self.device)\n",
    "        labels    = dataset.tensors[1].to(self.device)\n",
    "        \n",
    "        fake_imgs = self(self.sample_z(len(labels)), labels).detach()\n",
    "        fake_imgs = self.postprocess(fake_imgs)\n",
    "        real_imgs = self.postprocess(real_imgs)\n",
    "\n",
    "        # New shape will be be (N_fid, 3, H, W)\n",
    "        fake_imgs = torch.cat([fake_imgs] * 3, axis=1)\n",
    "        real_imgs = torch.cat([real_imgs] * 3, axis=1)\n",
    "\n",
    "        # Normalize=True -> images should be in [0, 1]\n",
    "        # Normalize=False -> images should be in [0, 255]\n",
    "        fid = FrechetInceptionDistance(normalize=False).to(self.device)\n",
    "\n",
    "        batch = 50\n",
    "        for i in range(len(labels) // batch):\n",
    "            fid.update(real_imgs[int(i*batch):int(i+1)*batch], real=True)\n",
    "            fid.update(fake_imgs[int(i*batch):int(i+1)*batch], real=False)\n",
    "\n",
    "        val = float(fid.compute().item())\n",
    "        del fake_imgs\n",
    "        return val"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "id": "KlgfzBSQ6x0k"
   },
   "outputs": [],
   "source": [
    "real_train = MNIST(os.getcwd(), train=True,  download=True)\n",
    "real_test  = MNIST(os.getcwd(), train=False, download=True)\n",
    "# Normalize to [-1, 1] and put into a TensorDataset\n",
    "real_train = preprocess(TensorDataset(real_train.data.unsqueeze(1), real_train.targets))\n",
    "real_test  = preprocess(TensorDataset(real_test.data.unsqueeze(1),  real_test.targets))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 104,
     "referenced_widgets": [
      "366041823e42417789867cea8c6e8385",
      "270bceca43b546d2b5764ede89c4fe71",
      "642f33504ee34af0b9118de0cd0ee605",
      "9e39d565171e451baf47a0dbb74aec62",
      "6e7abe6f49854c7c9c2be6b3a9c64bab",
      "5f7d5dbff7974362bf0b005ae4dfc1ac",
      "3fbcff2f4ec642a489d30c5a60002e17",
      "eb3dcc16cc4a460b95a78376124d59d7",
      "b1c15292898e4554b3fccd4823c288fd",
      "f69a0eed8a6b44a1a38683e20185d3b1",
      "a477ec05c246488491c18606b131d550"
     ]
    },
    "id": "d_WLbrXzex53",
    "outputId": "a08de719-8062-4a2b-9e5a-9c31d3cb923f",
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 14%|██████████▉                                                                   | 140/1001 [18:53<1:56:08,  8.09s/it]\n",
      "  7%|█████▌                                                                         | 71/1001 [09:31<1:24:47,  5.47s/it]"
     ]
    }
   ],
   "source": [
    "generations = 50\n",
    "epochs = 1000 + 1 #400 + 1\n",
    "sample_interval = 20 #10\n",
    "\n",
    "# Initialize training dataset to real dataset if no other datasets have been generated\n",
    "if len(os.listdir(DATASETS_DIR)) == 0:\n",
    "    train_dataset = real_train\n",
    "    test_dataset  = real_test\n",
    "    init_gen = 0\n",
    "# Initialize training dataset to synthetic dataset if one already exists\n",
    "else:\n",
    "    # Find out which was the last generation\n",
    "    prev_gen = max([int(f.split('_')[1].split('.')[0]) for f in os.listdir(DATASETS_DIR)])\n",
    "    # Load the corresponding 60k and 10k datasets\n",
    "    train_dataset = torch.load(join(DATASETS_DIR, f'train_{prev_gen:03}.pt'))\n",
    "    test_dataset  = torch.load(join(DATASETS_DIR, f'test_{prev_gen:03}.pt'))\n",
    "    # Normalize to [-1, 1]\n",
    "    train_dataset = preprocess(train_dataset)\n",
    "    test_dataset  = preprocess(test_dataset)\n",
    "    init_gen = prev_gen + 1\n",
    "    del prev_gen\n",
    "    \n",
    "    \n",
    "for g in np.arange(generations) + init_gen:\n",
    "    wgan = WGAN_GP(device, g)\n",
    "    early_stopping = EarlyStopping(patience=2, path=join(MODELS_DIR, 'checkpoint.pt'),\n",
    "                                   trace_func=lambda x: None)\n",
    "    train_dataloader = DataLoader(train_dataset, batch_size = wgan.hparams['batch_size'],\n",
    "                                  shuffle=True, pin_memory=True)\n",
    "    fid_madc = []\n",
    "    \n",
    "    \n",
    "    for epoch in tqdm(range(epochs)):\n",
    "        for i, batch in enumerate(train_dataloader):\n",
    "            out = wgan.training_step(batch, i)\n",
    "        \n",
    "        if epoch % sample_interval == 0:\n",
    "            newest_fid = wgan.calculate_fid(test_dataset)\n",
    "            fid_madc.append(newest_fid)\n",
    "            early_stopping(newest_fid, wgan)\n",
    "            \n",
    "        if early_stopping.early_stop:\n",
    "            break\n",
    "            \n",
    "    # Load the best model\n",
    "    wgan.load_state_dict(torch.load(join(MODELS_DIR, 'checkpoint.pt')))\n",
    "                \n",
    "    # Rename the best model ('checkpoint.pt') to 'gan_{g}.pt'\n",
    "    os.rename(join(MODELS_DIR, 'checkpoint.pt'),\n",
    "              join(MODELS_DIR, f'gan_{g:03}.pt'))\n",
    "            \n",
    "    # Calculate FID wrt real dataset. Save FID scores as .npz\n",
    "    np.savez(join(METRICS_DIR, f'fid_{g:03}.npz'),\n",
    "             fid_madc=fid_madc,\n",
    "             fid_real=wgan.calculate_fid(real_test))\n",
    "    \n",
    "    # Generate and save a new dataset. Save the current model.\n",
    "    train_dataset = wgan.make_new_dataset(train_dataset, train=True, save=True)\n",
    "    test_dataset  = wgan.make_new_dataset(test_dataset, train=False, save=True)\n",
    "    # Free up CUDA memory.\n",
    "    del wgan, train_dataloader\n",
    "    #torch.cuda.empty_cache()\n",
    "    gc.collect()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Check to make sure that the necessary generations / datasets are correctly normalized"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Uncomment to check CUDA memory:\n",
    "#print(torch.cuda.memory_summary(device=device, abbreviated=False))"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [
    "7F4maoBuXtmf",
    "XRP4f-r_ZPml",
    "GGnQvXLKZ2wx",
    "jkRHkvoHrqjL"
   ],
   "provenance": []
  },
  "gpuClass": "standard",
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "270bceca43b546d2b5764ede89c4fe71": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_5f7d5dbff7974362bf0b005ae4dfc1ac",
      "placeholder": "​",
      "style": "IPY_MODEL_3fbcff2f4ec642a489d30c5a60002e17",
      "value": "100%"
     }
    },
    "366041823e42417789867cea8c6e8385": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_270bceca43b546d2b5764ede89c4fe71",
       "IPY_MODEL_642f33504ee34af0b9118de0cd0ee605",
       "IPY_MODEL_9e39d565171e451baf47a0dbb74aec62"
      ],
      "layout": "IPY_MODEL_6e7abe6f49854c7c9c2be6b3a9c64bab"
     }
    },
    "3fbcff2f4ec642a489d30c5a60002e17": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "5f7d5dbff7974362bf0b005ae4dfc1ac": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "642f33504ee34af0b9118de0cd0ee605": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "success",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_eb3dcc16cc4a460b95a78376124d59d7",
      "max": 95628359,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_b1c15292898e4554b3fccd4823c288fd",
      "value": 95628359
     }
    },
    "6e7abe6f49854c7c9c2be6b3a9c64bab": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "9e39d565171e451baf47a0dbb74aec62": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_f69a0eed8a6b44a1a38683e20185d3b1",
      "placeholder": "​",
      "style": "IPY_MODEL_a477ec05c246488491c18606b131d550",
      "value": " 91.2M/91.2M [00:03&lt;00:00, 24.7MB/s]"
     }
    },
    "a477ec05c246488491c18606b131d550": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "b1c15292898e4554b3fccd4823c288fd": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": ""
     }
    },
    "eb3dcc16cc4a460b95a78376124d59d7": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "f69a0eed8a6b44a1a38683e20185d3b1": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    }
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
