{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "q4wgKhonGDJn"
      },
      "outputs": [],
      "source": [
        "from google.colab import drive"
      ],
      "id": "q4wgKhonGDJn"
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "koLSDU3KHn5e"
      },
      "source": [
        ""
      ],
      "id": "koLSDU3KHn5e"
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "QGzq7HyiGDDv",
        "outputId": "381a2fa5-6507-4735-d9df-1be4b62bcf4e"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Mounted at /content/drive\n"
          ]
        }
      ],
      "source": [
        "drive.mount('/content/drive', force_remount= True)"
      ],
      "id": "QGzq7HyiGDDv"
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "ePwFUQvrGq1A",
        "outputId": "1ad2c94e-120b-44de-9bae-90f6816946be"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "drive  sample_data\n"
          ]
        }
      ],
      "source": [
        "!ls"
      ],
      "id": "ePwFUQvrGq1A"
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "YPOC6rPIGC9Y",
        "outputId": "05102601-cc1b-4169-cfb5-26781ddac609"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "/content/drive/MyDrive/Human-Path-Prediction-master (1)/ynet\n"
          ]
        }
      ],
      "source": [
        "cd /content/drive/MyDrive/Human-Path-Prediction-master (1)/ynet"
      ],
      "id": "YPOC6rPIGC9Y"
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "p2IwKmpUGCyL",
        "outputId": "d68d9228-49f5-4bfe-e605-93e26c9982ff"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Collecting wandb\n",
            "  Downloading wandb-0.12.9-py2.py3-none-any.whl (1.7 MB)\n",
            "\u001b[K     |████████████████████████████████| 1.7 MB 9.1 MB/s \n",
            "\u001b[?25hRequirement already satisfied: promise<3,>=2.0 in /usr/local/lib/python3.7/dist-packages (from wandb) (2.3)\n",
            "Requirement already satisfied: PyYAML in /usr/local/lib/python3.7/dist-packages (from wandb) (3.13)\n",
            "Collecting GitPython>=1.0.0\n",
            "  Downloading GitPython-3.1.26-py3-none-any.whl (180 kB)\n",
            "\u001b[K     |████████████████████████████████| 180 kB 43.9 MB/s \n",
            "\u001b[?25hRequirement already satisfied: protobuf>=3.12.0 in /usr/local/lib/python3.7/dist-packages (from wandb) (3.17.3)\n",
            "Requirement already satisfied: psutil>=5.0.0 in /usr/local/lib/python3.7/dist-packages (from wandb) (5.4.8)\n",
            "Collecting shortuuid>=0.5.0\n",
            "  Downloading shortuuid-1.0.8-py3-none-any.whl (9.5 kB)\n",
            "Requirement already satisfied: Click!=8.0.0,>=7.0 in /usr/local/lib/python3.7/dist-packages (from wandb) (7.1.2)\n",
            "Requirement already satisfied: python-dateutil>=2.6.1 in /usr/local/lib/python3.7/dist-packages (from wandb) (2.8.2)\n",
            "Collecting sentry-sdk>=1.0.0\n",
            "  Downloading sentry_sdk-1.5.4-py2.py3-none-any.whl (143 kB)\n",
            "\u001b[K     |████████████████████████████████| 143 kB 56.0 MB/s \n",
            "\u001b[?25hCollecting docker-pycreds>=0.4.0\n",
            "  Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB)\n",
            "Requirement already satisfied: six>=1.13.0 in /usr/local/lib/python3.7/dist-packages (from wandb) (1.15.0)\n",
            "Collecting subprocess32>=3.5.3\n",
            "  Downloading subprocess32-3.5.4.tar.gz (97 kB)\n",
            "\u001b[K     |████████████████████████████████| 97 kB 7.0 MB/s \n",
            "\u001b[?25hCollecting configparser>=3.8.1\n",
            "  Downloading configparser-5.2.0-py3-none-any.whl (19 kB)\n",
            "Requirement already satisfied: requests<3,>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from wandb) (2.23.0)\n",
            "Collecting pathtools\n",
            "  Downloading pathtools-0.1.2.tar.gz (11 kB)\n",
            "Collecting yaspin>=1.0.0\n",
            "  Downloading yaspin-2.1.0-py3-none-any.whl (18 kB)\n",
            "Collecting gitdb<5,>=4.0.1\n",
            "  Downloading gitdb-4.0.9-py3-none-any.whl (63 kB)\n",
            "\u001b[K     |████████████████████████████████| 63 kB 1.8 MB/s \n",
            "\u001b[?25hRequirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.7/dist-packages (from GitPython>=1.0.0->wandb) (3.10.0.2)\n",
            "Collecting smmap<6,>=3.0.1\n",
            "  Downloading smmap-5.0.0-py3-none-any.whl (24 kB)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.0.0->wandb) (2021.10.8)\n",
            "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.0.0->wandb) (1.24.3)\n",
            "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.0.0->wandb) (2.10)\n",
            "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.0.0->wandb) (3.0.4)\n",
            "Requirement already satisfied: termcolor<2.0.0,>=1.1.0 in /usr/local/lib/python3.7/dist-packages (from yaspin>=1.0.0->wandb) (1.1.0)\n",
            "Building wheels for collected packages: subprocess32, pathtools\n",
            "  Building wheel for subprocess32 (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Created wheel for subprocess32: filename=subprocess32-3.5.4-py3-none-any.whl size=6502 sha256=247792d87b550c9f62e05db144861e1698f60f13eded8ee63f498702016a1683\n",
            "  Stored in directory: /root/.cache/pip/wheels/50/ca/fa/8fca8d246e64f19488d07567547ddec8eb084e8c0d7a59226a\n",
            "  Building wheel for pathtools (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Created wheel for pathtools: filename=pathtools-0.1.2-py3-none-any.whl size=8806 sha256=0d2381307eea6d5cdd41910ae4b41ae8fc9d6fcc9b5b58ade7958fe2aeb92c16\n",
            "  Stored in directory: /root/.cache/pip/wheels/3e/31/09/fa59cef12cdcfecc627b3d24273699f390e71828921b2cbba2\n",
            "Successfully built subprocess32 pathtools\n",
            "Installing collected packages: smmap, gitdb, yaspin, subprocess32, shortuuid, sentry-sdk, pathtools, GitPython, docker-pycreds, configparser, wandb\n",
            "Successfully installed GitPython-3.1.26 configparser-5.2.0 docker-pycreds-0.4.0 gitdb-4.0.9 pathtools-0.1.2 sentry-sdk-1.5.4 shortuuid-1.0.8 smmap-5.0.0 subprocess32-3.5.4 wandb-0.12.9 yaspin-2.1.0\n"
          ]
        }
      ],
      "source": [
        "!pip install wandb"
      ],
      "id": "p2IwKmpUGCyL"
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "determined-township"
      },
      "outputs": [],
      "source": [
        "import pandas as pd\n",
        "import yaml\n",
        "import argparse\n",
        "import torch\n",
        "from model import YNet"
      ],
      "id": "determined-township"
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "finished-potential"
      },
      "outputs": [],
      "source": [
        "%load_ext autoreload\n",
        "%autoreload 2"
      ],
      "id": "finished-potential"
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "e567fbe0"
      },
      "source": [
        "#### Some hyperparameters and settings"
      ],
      "id": "e567fbe0"
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "arabic-thickness"
      },
      "outputs": [],
      "source": [
        "CONFIG_FILE_PATH = 'config/sdd_trajnet.yaml'  # yaml config file containing all the hyperparameters\n",
        "EXPERIMENT_NAME = 'sdd_trajnet'  # arbitrary name for this experiment\n",
        "DATASET_NAME = 'sdd'\n",
        "\n",
        "TRAIN_DATA_PATH = 'data/SDD/train_trajnet.pkl'\n",
        "TRAIN_IMAGE_PATH = 'data/SDD/train'\n",
        "VAL_DATA_PATH = 'data/SDD/test_trajnet.pkl'\n",
        "VAL_IMAGE_PATH = 'data/SDD/test'\n",
        "OBS_LEN = 8  # in timesteps\n",
        "PRED_LEN = 12  # in timesteps\n",
        "NUM_GOALS = 20  # K_e\n",
        "NUM_TRAJ = 1  # K_a\n",
        "\n",
        "BATCH_SIZE = 4"
      ],
      "id": "arabic-thickness"
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "id": "6BhBP_J5FhgB",
        "outputId": "b79001b8-7077-4a7c-f543-747e397b26f7"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Collecting tqdm==4.48.0\n",
            "  Downloading tqdm-4.48.0-py2.py3-none-any.whl (67 kB)\n",
            "\u001b[?25l\r\u001b[K     |████▉                           | 10 kB 28.5 MB/s eta 0:00:01\r\u001b[K     |█████████▋                      | 20 kB 27.4 MB/s eta 0:00:01\r\u001b[K     |██████████████▌                 | 30 kB 18.0 MB/s eta 0:00:01\r\u001b[K     |███████████████████▎            | 40 kB 15.0 MB/s eta 0:00:01\r\u001b[K     |████████████████████████▏       | 51 kB 10.2 MB/s eta 0:00:01\r\u001b[K     |█████████████████████████████   | 61 kB 10.4 MB/s eta 0:00:01\r\u001b[K     |████████████████████████████████| 67 kB 4.3 MB/s \n",
            "\u001b[?25hCollecting pyyaml==5.3.1\n",
            "  Downloading PyYAML-5.3.1.tar.gz (269 kB)\n",
            "\u001b[K     |████████████████████████████████| 269 kB 19.0 MB/s \n",
            "\u001b[?25hRequirement already satisfied: matplotlib==3.2.2 in /usr/local/lib/python3.7/dist-packages (from -r requirements.txt (line 3)) (3.2.2)\n",
            "Collecting torch==1.5.1\n",
            "  Downloading torch-1.5.1-cp37-cp37m-manylinux1_x86_64.whl (753.2 MB)\n",
            "\u001b[K     |████████████████████████████████| 753.2 MB 15 kB/s \n",
            "\u001b[?25hRequirement already satisfied: pandas==1.1.5 in /usr/local/lib/python3.7/dist-packages (from -r requirements.txt (line 5)) (1.1.5)\n",
            "Collecting opencv-python==4.4.0.42\n",
            "  Downloading opencv_python-4.4.0.42-cp37-cp37m-manylinux2014_x86_64.whl (49.4 MB)\n",
            "\u001b[K     |████████████████████████████████| 49.4 MB 49.1 MB/s \n",
            "\u001b[?25hCollecting scipy==1.5.0\n",
            "  Downloading scipy-1.5.0-cp37-cp37m-manylinux1_x86_64.whl (25.9 MB)\n",
            "\u001b[K     |████████████████████████████████| 25.9 MB 1.3 MB/s \n",
            "\u001b[?25hCollecting segmentation_models_pytorch==0.1.0\n",
            "  Downloading segmentation_models_pytorch-0.1.0-py3-none-any.whl (42 kB)\n",
            "\u001b[K     |████████████████████████████████| 42 kB 1.3 MB/s \n",
            "\u001b[?25hRequirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib==3.2.2->-r requirements.txt (line 3)) (2.8.2)\n",
            "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib==3.2.2->-r requirements.txt (line 3)) (0.11.0)\n",
            "Requirement already satisfied: numpy>=1.11 in /usr/local/lib/python3.7/dist-packages (from matplotlib==3.2.2->-r requirements.txt (line 3)) (1.19.5)\n",
            "Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib==3.2.2->-r requirements.txt (line 3)) (3.0.7)\n",
            "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib==3.2.2->-r requirements.txt (line 3)) (1.3.2)\n",
            "Requirement already satisfied: future in /usr/local/lib/python3.7/dist-packages (from torch==1.5.1->-r requirements.txt (line 4)) (0.16.0)\n",
            "Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.7/dist-packages (from pandas==1.1.5->-r requirements.txt (line 5)) (2018.9)\n",
            "Collecting efficientnet-pytorch>=0.5.1\n",
            "  Downloading efficientnet_pytorch-0.7.1.tar.gz (21 kB)\n",
            "Collecting pretrainedmodels==0.7.4\n",
            "  Downloading pretrainedmodels-0.7.4.tar.gz (58 kB)\n",
            "\u001b[K     |████████████████████████████████| 58 kB 6.8 MB/s \n",
            "\u001b[?25hRequirement already satisfied: torchvision>=0.3.0 in /usr/local/lib/python3.7/dist-packages (from segmentation_models_pytorch==0.1.0->-r requirements.txt (line 8)) (0.11.1+cu111)\n",
            "Collecting munch\n",
            "  Downloading munch-2.5.0-py2.py3-none-any.whl (10 kB)\n",
            "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil>=2.1->matplotlib==3.2.2->-r requirements.txt (line 3)) (1.15.0)\n",
            "Requirement already satisfied: pillow!=8.3.0,>=5.3.0 in /usr/local/lib/python3.7/dist-packages (from torchvision>=0.3.0->segmentation_models_pytorch==0.1.0->-r requirements.txt (line 8)) (7.1.2)\n",
            "Collecting torchvision>=0.3.0\n",
            "  Downloading torchvision-0.11.3-cp37-cp37m-manylinux1_x86_64.whl (23.2 MB)\n",
            "\u001b[K     |████████████████████████████████| 23.2 MB 1.2 MB/s \n",
            "\u001b[?25h  Downloading torchvision-0.11.2-cp37-cp37m-manylinux1_x86_64.whl (23.3 MB)\n",
            "\u001b[K     |████████████████████████████████| 23.3 MB 1.3 MB/s \n",
            "\u001b[?25h  Downloading torchvision-0.11.1-cp37-cp37m-manylinux1_x86_64.whl (23.3 MB)\n",
            "\u001b[K     |████████████████████████████████| 23.3 MB 457 kB/s \n",
            "\u001b[?25h  Downloading torchvision-0.10.1-cp37-cp37m-manylinux1_x86_64.whl (22.1 MB)\n",
            "\u001b[K     |████████████████████████████████| 22.1 MB 344 kB/s \n",
            "\u001b[?25h  Downloading torchvision-0.10.0-cp37-cp37m-manylinux1_x86_64.whl (22.1 MB)\n",
            "\u001b[K     |████████████████████████████████| 22.1 MB 1.1 MB/s \n",
            "\u001b[?25h  Downloading torchvision-0.9.1-cp37-cp37m-manylinux1_x86_64.whl (17.4 MB)\n",
            "\u001b[K     |████████████████████████████████| 17.4 MB 84 kB/s \n",
            "\u001b[?25h  Downloading torchvision-0.9.0-cp37-cp37m-manylinux1_x86_64.whl (17.3 MB)\n",
            "\u001b[K     |████████████████████████████████| 17.3 MB 45.5 MB/s \n",
            "\u001b[?25h  Downloading torchvision-0.8.2-cp37-cp37m-manylinux1_x86_64.whl (12.8 MB)\n",
            "\u001b[K     |████████████████████████████████| 12.8 MB 33.9 MB/s \n",
            "\u001b[?25h  Downloading torchvision-0.8.1-cp37-cp37m-manylinux1_x86_64.whl (12.7 MB)\n",
            "\u001b[K     |████████████████████████████████| 12.7 MB 40.0 MB/s \n",
            "\u001b[?25h  Downloading torchvision-0.8.0-cp37-cp37m-manylinux1_x86_64.whl (11.8 MB)\n",
            "\u001b[K     |████████████████████████████████| 11.8 MB 49.3 MB/s \n",
            "\u001b[?25h  Downloading torchvision-0.7.0-cp37-cp37m-manylinux1_x86_64.whl (5.9 MB)\n",
            "\u001b[K     |████████████████████████████████| 5.9 MB 50.7 MB/s \n",
            "\u001b[?25h  Downloading torchvision-0.6.1-cp37-cp37m-manylinux1_x86_64.whl (6.6 MB)\n",
            "\u001b[K     |████████████████████████████████| 6.6 MB 40.1 MB/s \n",
            "\u001b[?25hBuilding wheels for collected packages: pyyaml, pretrainedmodels, efficientnet-pytorch\n",
            "  Building wheel for pyyaml (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Created wheel for pyyaml: filename=PyYAML-5.3.1-cp37-cp37m-linux_x86_64.whl size=44636 sha256=cd88e76549df6e3876a852c73665e14a4ca1f8b9b5bf6d0703948bdcb9a90d54\n",
            "  Stored in directory: /root/.cache/pip/wheels/5e/03/1e/e1e954795d6f35dfc7b637fe2277bff021303bd9570ecea653\n",
            "  Building wheel for pretrainedmodels (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Created wheel for pretrainedmodels: filename=pretrainedmodels-0.7.4-py3-none-any.whl size=60965 sha256=489d386f4a884e10ab175363ca49502049c3d28557508f070f8221dfe135f779\n",
            "  Stored in directory: /root/.cache/pip/wheels/ed/27/e8/9543d42de2740d3544db96aefef63bda3f2c1761b3334f4873\n",
            "  Building wheel for efficientnet-pytorch (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Created wheel for efficientnet-pytorch: filename=efficientnet_pytorch-0.7.1-py3-none-any.whl size=16446 sha256=cda86caec871085550d055e8db38d2ac98593fd5bd961bb42c2d30ba27ee182b\n",
            "  Stored in directory: /root/.cache/pip/wheels/0e/cc/b2/49e74588263573ff778da58cc99b9c6349b496636a7e165be6\n",
            "Successfully built pyyaml pretrainedmodels efficientnet-pytorch\n",
            "Installing collected packages: torch, tqdm, torchvision, munch, pretrainedmodels, efficientnet-pytorch, segmentation-models-pytorch, scipy, pyyaml, opencv-python\n",
            "  Attempting uninstall: torch\n",
            "    Found existing installation: torch 1.10.0+cu111\n",
            "    Uninstalling torch-1.10.0+cu111:\n",
            "      Successfully uninstalled torch-1.10.0+cu111\n",
            "  Attempting uninstall: tqdm\n",
            "    Found existing installation: tqdm 4.62.3\n",
            "    Uninstalling tqdm-4.62.3:\n",
            "      Successfully uninstalled tqdm-4.62.3\n",
            "  Attempting uninstall: torchvision\n",
            "    Found existing installation: torchvision 0.11.1+cu111\n",
            "    Uninstalling torchvision-0.11.1+cu111:\n",
            "      Successfully uninstalled torchvision-0.11.1+cu111\n",
            "  Attempting uninstall: scipy\n",
            "    Found existing installation: scipy 1.4.1\n",
            "    Uninstalling scipy-1.4.1:\n",
            "      Successfully uninstalled scipy-1.4.1\n",
            "  Attempting uninstall: pyyaml\n",
            "    Found existing installation: PyYAML 3.13\n",
            "    Uninstalling PyYAML-3.13:\n",
            "      Successfully uninstalled PyYAML-3.13\n",
            "  Attempting uninstall: opencv-python\n",
            "    Found existing installation: opencv-python 4.1.2.30\n",
            "    Uninstalling opencv-python-4.1.2.30:\n",
            "      Successfully uninstalled opencv-python-4.1.2.30\n",
            "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
            "torchtext 0.11.0 requires torch==1.10.0, but you have torch 1.5.1 which is incompatible.\n",
            "torchaudio 0.10.0+cu111 requires torch==1.10.0, but you have torch 1.5.1 which is incompatible.\n",
            "albumentations 0.1.12 requires imgaug<0.2.7,>=0.2.5, but you have imgaug 0.2.9 which is incompatible.\u001b[0m\n",
            "Successfully installed efficientnet-pytorch-0.7.1 munch-2.5.0 opencv-python-4.4.0.42 pretrainedmodels-0.7.4 pyyaml-5.3.1 scipy-1.5.0 segmentation-models-pytorch-0.1.0 torch-1.5.1 torchvision-0.6.1 tqdm-4.48.0\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "application/vnd.colab-display-data+json": {
              "pip_warning": {
                "packages": [
                  "cv2",
                  "torch",
                  "tqdm",
                  "yaml"
                ]
              }
            }
          },
          "metadata": {}
        }
      ],
      "source": [
        "pip install -r requirements.txt"
      ],
      "id": "6BhBP_J5FhgB"
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2f729e8f"
      },
      "source": [
        "#### Load config file and print hyperparameters"
      ],
      "id": "2f729e8f"
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "dangerous-cutting",
        "outputId": "6226400a-9e7b-4b0f-a4e2-3e745c49b1e2"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "{'CWS_params': {'ratio': 2, 'rot': True, 'sigma_factor': 6},\n",
              " 'batch_size': 8,\n",
              " 'decoder_channels': [64, 64, 64, 32, 32],\n",
              " 'encoder_channels': [32, 32, 64, 64, 64],\n",
              " 'kernlen': 31,\n",
              " 'learning_rate': 0.0001,\n",
              " 'loss_scale': 1000,\n",
              " 'nsig': 4,\n",
              " 'num_epochs': 300,\n",
              " 'rel_threshold': 0.002,\n",
              " 'resize': 0.25,\n",
              " 'segmentation_model_fp': 'segmentation_models/SDD_segmentation.pth',\n",
              " 'semantic_classes': 6,\n",
              " 'temperature': 1.8,\n",
              " 'unfreeze': 100,\n",
              " 'use_CWS': True,\n",
              " 'use_TTST': True,\n",
              " 'use_features_only': False,\n",
              " 'viz_epoch': 10,\n",
              " 'waypoints': [14, 29]}"
            ]
          },
          "metadata": {},
          "execution_count": 6
        }
      ],
      "source": [
        "with open(CONFIG_FILE_PATH) as file:\n",
        "    params = yaml.load(file, Loader=yaml.FullLoader)\n",
        "experiment_name = CONFIG_FILE_PATH.split('.yaml')[0].split('config/')[1]\n",
        "params"
      ],
      "id": "dangerous-cutting"
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "699a7543"
      },
      "source": [
        "#### Wandb INIT"
      ],
      "id": "699a7543"
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "65f560e7"
      },
      "outputs": [],
      "source": [
        ""
      ],
      "id": "65f560e7"
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "amber-pressure"
      },
      "source": [
        "#### Load preprocessed Data"
      ],
      "id": "amber-pressure"
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "german-feature",
        "outputId": "3dd2f167-0169-4413-e7e8-00eea3bb851b"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Collecting pickle5\n",
            "  Downloading pickle5-0.0.12-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl (256 kB)\n",
            "\u001b[?25l\r\u001b[K     |█▎                              | 10 kB 26.2 MB/s eta 0:00:01\r\u001b[K     |██▋                             | 20 kB 31.2 MB/s eta 0:00:01\r\u001b[K     |███▉                            | 30 kB 20.2 MB/s eta 0:00:01\r\u001b[K     |█████▏                          | 40 kB 17.1 MB/s eta 0:00:01\r\u001b[K     |██████▍                         | 51 kB 10.3 MB/s eta 0:00:01\r\u001b[K     |███████▊                        | 61 kB 11.9 MB/s eta 0:00:01\r\u001b[K     |█████████                       | 71 kB 9.7 MB/s eta 0:00:01\r\u001b[K     |██████████▎                     | 81 kB 10.7 MB/s eta 0:00:01\r\u001b[K     |███████████▌                    | 92 kB 11.7 MB/s eta 0:00:01\r\u001b[K     |████████████▉                   | 102 kB 9.7 MB/s eta 0:00:01\r\u001b[K     |██████████████                  | 112 kB 9.7 MB/s eta 0:00:01\r\u001b[K     |███████████████▍                | 122 kB 9.7 MB/s eta 0:00:01\r\u001b[K     |████████████████▋               | 133 kB 9.7 MB/s eta 0:00:01\r\u001b[K     |██████████████████              | 143 kB 9.7 MB/s eta 0:00:01\r\u001b[K     |███████████████████▏            | 153 kB 9.7 MB/s eta 0:00:01\r\u001b[K     |████████████████████▌           | 163 kB 9.7 MB/s eta 0:00:01\r\u001b[K     |█████████████████████▊          | 174 kB 9.7 MB/s eta 0:00:01\r\u001b[K     |███████████████████████         | 184 kB 9.7 MB/s eta 0:00:01\r\u001b[K     |████████████████████████▎       | 194 kB 9.7 MB/s eta 0:00:01\r\u001b[K     |█████████████████████████▋      | 204 kB 9.7 MB/s eta 0:00:01\r\u001b[K     |██████████████████████████▉     | 215 kB 9.7 MB/s eta 0:00:01\r\u001b[K     |████████████████████████████▏   | 225 kB 9.7 MB/s eta 0:00:01\r\u001b[K     |█████████████████████████████▍  | 235 kB 9.7 MB/s eta 0:00:01\r\u001b[K     |██████████████████████████████▊ | 245 kB 9.7 MB/s eta 0:00:01\r\u001b[K     |████████████████████████████████| 256 kB 9.7 MB/s eta 0:00:01\r\u001b[K     |████████████████████████████████| 256 kB 9.7 MB/s \n",
            "\u001b[?25hInstalling collected packages: pickle5\n",
            "Successfully installed pickle5-0.0.12\n"
          ]
        }
      ],
      "source": [
        "#df_train = pd.read_pickle(TRAIN_DATA_PATH)\n",
        "#df_val = pd.read_pickle(VAL_DATA_PATH)\n",
        "!pip3 install pickle5\n",
        "#df_train = pd.read_pickle(TRAIN_DATA_PATH)\n",
        "#df_val = pd.read_pickle(VAL_DATA_PATH)\n",
        "\n",
        "import pickle5 as pickle \n",
        "with open(TRAIN_DATA_PATH, \"rb\") as fh:\n",
        "    df_train = pickle.load(fh)\n",
        "with open(VAL_DATA_PATH, \"rb\") as fh1:\n",
        "    df_val = pickle.load(fh1)"
      ],
      "id": "german-feature"
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 206
        },
        "id": "corporate-pharmacy",
        "outputId": "ecb47b1c-3343-4f17-9496-70274640e79a"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/html": [
              "\n",
              "  <div id=\"df-ec47adc5-d504-4e63-9077-ba69f101d52d\">\n",
              "    <div class=\"colab-df-container\">\n",
              "      <div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>trackId</th>\n",
              "      <th>frame</th>\n",
              "      <th>x</th>\n",
              "      <th>y</th>\n",
              "      <th>sceneId</th>\n",
              "      <th>metaId</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>0</th>\n",
              "      <td>2</td>\n",
              "      <td>6881</td>\n",
              "      <td>17.0</td>\n",
              "      <td>893.5</td>\n",
              "      <td>bookstore_0</td>\n",
              "      <td>0</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>2</td>\n",
              "      <td>6911</td>\n",
              "      <td>31.0</td>\n",
              "      <td>904.0</td>\n",
              "      <td>bookstore_0</td>\n",
              "      <td>0</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2</th>\n",
              "      <td>2</td>\n",
              "      <td>6941</td>\n",
              "      <td>63.0</td>\n",
              "      <td>910.5</td>\n",
              "      <td>bookstore_0</td>\n",
              "      <td>0</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>3</th>\n",
              "      <td>2</td>\n",
              "      <td>6971</td>\n",
              "      <td>98.5</td>\n",
              "      <td>917.5</td>\n",
              "      <td>bookstore_0</td>\n",
              "      <td>0</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>4</th>\n",
              "      <td>2</td>\n",
              "      <td>7001</td>\n",
              "      <td>134.0</td>\n",
              "      <td>919.5</td>\n",
              "      <td>bookstore_0</td>\n",
              "      <td>0</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>\n",
              "      <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-ec47adc5-d504-4e63-9077-ba69f101d52d')\"\n",
              "              title=\"Convert this dataframe to an interactive table.\"\n",
              "              style=\"display:none;\">\n",
              "        \n",
              "  <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
              "       width=\"24px\">\n",
              "    <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
              "    <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
              "  </svg>\n",
              "      </button>\n",
              "      \n",
              "  <style>\n",
              "    .colab-df-container {\n",
              "      display:flex;\n",
              "      flex-wrap:wrap;\n",
              "      gap: 12px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert {\n",
              "      background-color: #E8F0FE;\n",
              "      border: none;\n",
              "      border-radius: 50%;\n",
              "      cursor: pointer;\n",
              "      display: none;\n",
              "      fill: #1967D2;\n",
              "      height: 32px;\n",
              "      padding: 0 0 0 0;\n",
              "      width: 32px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert:hover {\n",
              "      background-color: #E2EBFA;\n",
              "      box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
              "      fill: #174EA6;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert {\n",
              "      background-color: #3B4455;\n",
              "      fill: #D2E3FC;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert:hover {\n",
              "      background-color: #434B5C;\n",
              "      box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
              "      filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
              "      fill: #FFFFFF;\n",
              "    }\n",
              "  </style>\n",
              "\n",
              "      <script>\n",
              "        const buttonEl =\n",
              "          document.querySelector('#df-ec47adc5-d504-4e63-9077-ba69f101d52d button.colab-df-convert');\n",
              "        buttonEl.style.display =\n",
              "          google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
              "\n",
              "        async function convertToInteractive(key) {\n",
              "          const element = document.querySelector('#df-ec47adc5-d504-4e63-9077-ba69f101d52d');\n",
              "          const dataTable =\n",
              "            await google.colab.kernel.invokeFunction('convertToInteractive',\n",
              "                                                     [key], {});\n",
              "          if (!dataTable) return;\n",
              "\n",
              "          const docLinkHtml = 'Like what you see? Visit the ' +\n",
              "            '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
              "            + ' to learn more about interactive tables.';\n",
              "          element.innerHTML = '';\n",
              "          dataTable['output_type'] = 'display_data';\n",
              "          await google.colab.output.renderOutput(dataTable, element);\n",
              "          const docLink = document.createElement('div');\n",
              "          docLink.innerHTML = docLinkHtml;\n",
              "          element.appendChild(docLink);\n",
              "        }\n",
              "      </script>\n",
              "    </div>\n",
              "  </div>\n",
              "  "
            ],
            "text/plain": [
              "   trackId  frame      x      y      sceneId  metaId\n",
              "0        2   6881   17.0  893.5  bookstore_0       0\n",
              "1        2   6911   31.0  904.0  bookstore_0       0\n",
              "2        2   6941   63.0  910.5  bookstore_0       0\n",
              "3        2   6971   98.5  917.5  bookstore_0       0\n",
              "4        2   7001  134.0  919.5  bookstore_0       0"
            ]
          },
          "metadata": {},
          "execution_count": 8
        }
      ],
      "source": [
        "df_train.head()"
      ],
      "id": "corporate-pharmacy"
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "24dc5d7c"
      },
      "source": [
        "#### Initiate model"
      ],
      "id": "24dc5d7c"
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "harmful-colleague",
        "outputId": "6d6fbce4-9763-47f3-a71d-df7c2fcac253"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.7/dist-packages/torch/serialization.py:657: SourceChangeWarning: source code of class 'segmentation_models_pytorch.encoders.resnet.ResNetEncoder' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n",
            "  warnings.warn(msg, SourceChangeWarning)\n",
            "/usr/local/lib/python3.7/dist-packages/torch/serialization.py:657: SourceChangeWarning: source code of class 'segmentation_models_pytorch.base.modules.Conv2dReLU' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n",
            "  warnings.warn(msg, SourceChangeWarning)\n",
            "/usr/local/lib/python3.7/dist-packages/torch/serialization.py:657: SourceChangeWarning: source code of class 'segmentation_models_pytorch.base.modules.Activation' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n",
            "  warnings.warn(msg, SourceChangeWarning)\n"
          ]
        }
      ],
      "source": [
        "model = YNet(obs_len=OBS_LEN, pred_len=PRED_LEN, params=params)"
      ],
      "id": "harmful-colleague"
    },
    {
      "cell_type": "code",
      "source": [
        "model.load(f'pretrained_models/{experiment_name}_weights.pt')"
      ],
      "metadata": {
        "id": "6eFG2LJ-h_1_"
      },
      "id": "6eFG2LJ-h_1_",
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "45e099fe"
      },
      "source": [
        "#### Start training\n",
        "Note, the Val ADE and FDE are without TTST and CWS to save time. Therefore, the numbers will be worse than the final values."
      ],
      "id": "45e099fe"
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 240,
          "referenced_widgets": [
            "3586a897e9444afc8a0356824801220a",
            "cf603b405dae45789925083f966e04b0",
            "961e3587280247148346b9ebef196714",
            "b3fc2ad3a06440c4ad841ebf22ee7a46",
            "6b56f05f1a374e96ab1f9c5a93c8e5d0",
            "e123e221a5b042419f86c33856729662",
            "f71a887754d3466dbfbfd99c05ea41ee",
            "73234e617a2345fd9d13339011b39fcd"
          ]
        },
        "id": "suiRZodUeLO4",
        "outputId": "0b6517e7-0cf7-4efb-db90-2408111c34ae"
      },
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/html": [
              "Finishing last run (ID:2fd2enb8) before initializing another..."
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/html": [
              "<br/>Waiting for W&B process to finish, PID 599... <strong style=\"color:green\">(success).</strong>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "3586a897e9444afc8a0356824801220a",
              "version_minor": 0,
              "version_major": 2
            },
            "text/plain": [
              "VBox(children=(Label(value=' 0.11MB of 0.11MB uploaded (0.00MB deduped)\\r'), FloatProgress(value=1.0, max=1.0)…"
            ]
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/html": [
              "<style>\n",
              "    table.wandb td:nth-child(1) { padding: 0 10px; text-align: right }\n",
              "    .wandb-row { display: flex; flex-direction: row; flex-wrap: wrap; width: 100% }\n",
              "    .wandb-col { display: flex; flex-direction: column; flex-basis: 100%; flex: 1; padding: 10px; }\n",
              "    </style>\n",
              "<div class=\"wandb-row\"><div class=\"wandb-col\">\n",
              "</div><div class=\"wandb-col\">\n",
              "</div></div>\n",
              "Synced 4 W&B file(s), 0 media file(s), 3 artifact file(s) and 1 other file(s)\n",
              "<br/>Synced <strong style=\"color:#cdcd00\">silvery-universe-20</strong>: <a href=\"https://wandb.ai/agv/ynet/runs/2fd2enb8\" target=\"_blank\">https://wandb.ai/agv/ynet/runs/2fd2enb8</a><br/>\n",
              "Find logs at: <code>w_andb/wandb/run-20220131_114917-2fd2enb8/logs</code><br/>\n"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/html": [
              "Successfully finished last run (ID:2fd2enb8). Initializing new run:<br/>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/html": [
              "\n",
              "                    Syncing run <strong><a href=\"https://wandb.ai/agv/ynet/runs/dh4np6an\" target=\"_blank\">gallant-frog-21</a></strong> to <a href=\"https://wandb.ai/agv/ynet\" target=\"_blank\">Weights & Biases</a> (<a href=\"https://docs.wandb.com/integrations/jupyter.html\" target=\"_blank\">docs</a>).<br/>\n",
              "\n",
              "                "
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {}
        }
      ],
      "source": [
        "import weights_and_biases as wandb\n",
        "wandb.init_wandb(params.copy(), model.model)"
      ],
      "id": "suiRZodUeLO4"
    },
    {
      "cell_type": "code",
      "source": [
        "model.evaluate(df_test, params, image_path=TEST_IMAGE_PATH,\n",
        "               batch_size=BATCH_SIZE, rounds=ROUNDS, \n",
        "               num_goals=NUM_GOALS, num_traj=NUM_TRAJ, device=None, dataset_name=DATASET_NAME)"
      ],
      "metadata": {
        "id": "ojV3AXC3dFHv"
      },
      "id": "ojV3AXC3dFHv",
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "16fe307e"
      },
      "outputs": [],
      "source": [
        ""
      ],
      "id": "16fe307e"
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "name": "eVal_sdd_short_term_experiment_.ipynb",
      "provenance": [],
      "machine_shape": "hm"
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    },
    "widgets": {
      "application/vnd.jupyter.widget-state+json": {
        "3586a897e9444afc8a0356824801220a": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "VBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_view_name": "VBoxView",
            "_dom_classes": [],
            "_model_name": "VBoxModel",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "box_style": "",
            "layout": "IPY_MODEL_cf603b405dae45789925083f966e04b0",
            "_model_module": "@jupyter-widgets/controls",
            "children": [
              "IPY_MODEL_961e3587280247148346b9ebef196714",
              "IPY_MODEL_b3fc2ad3a06440c4ad841ebf22ee7a46"
            ]
          }
        },
        "cf603b405dae45789925083f966e04b0": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "961e3587280247148346b9ebef196714": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "LabelModel",
          "model_module_version": "1.5.0",
          "state": {
            "_view_name": "LabelView",
            "style": "IPY_MODEL_6b56f05f1a374e96ab1f9c5a93c8e5d0",
            "_dom_classes": [],
            "description": "",
            "_model_name": "LabelModel",
            "placeholder": "​",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": " 0.12MB of 0.12MB uploaded (0.00MB deduped)\r",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_e123e221a5b042419f86c33856729662"
          }
        },
        "b3fc2ad3a06440c4ad841ebf22ee7a46": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_view_name": "ProgressView",
            "style": "IPY_MODEL_f71a887754d3466dbfbfd99c05ea41ee",
            "_dom_classes": [],
            "description": "",
            "_model_name": "FloatProgressModel",
            "bar_style": "",
            "max": 1,
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": 1,
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "orientation": "horizontal",
            "min": 0,
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_73234e617a2345fd9d13339011b39fcd"
          }
        },
        "6b56f05f1a374e96ab1f9c5a93c8e5d0": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "DescriptionStyleModel",
            "description_width": "",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "e123e221a5b042419f86c33856729662": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "f71a887754d3466dbfbfd99c05ea41ee": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "ProgressStyleModel",
            "description_width": "",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "bar_color": null,
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "73234e617a2345fd9d13339011b39fcd": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        }
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}