{"nbformat_minor": 5, "nbformat": 4, "cells": [{"cell_type": "code", "source": ["from google.colab import drive"], "outputs": [], "execution_count": 1, "id": "q4wgKhonGDJn", "metadata": {"id": "q4wgKhonGDJn", "executionInfo": {"status": "ok", "timestamp": 1643617140060, "user": {"userId": "05569193017561439961", "displayName": "ai np", "photoUrl": "https://lh3.googleusercontent.com/a/default-user=s64"}, "user_tz": -330, "elapsed": 1486}}}, {"source": [""], "cell_type": "markdown", "id": "koLSDU3KHn5e", "metadata": {"id": "koLSDU3KHn5e"}}, {"cell_type": "code", "source": ["drive.mount('/content/drive', force_remount= True)"], "outputs": [{"output_type": "stream", "name": "stdout", "text": ["Mounted at /content/drive\n"]}], "execution_count": 2, "id": "QGzq7HyiGDDv", "metadata": {"outputId": "f1eb914e-1006-4de6-8d36-92e3b0f3997a", "id": "QGzq7HyiGDDv", "colab": {"base_uri": "https://localhost:8080/"}, "executionInfo": {"status": "ok", "timestamp": 1643617167438, "user": {"userId": "05569193017561439961", "displayName": "ai np", "photoUrl": "https://lh3.googleusercontent.com/a/default-user=s64"}, "user_tz": -330, "elapsed": 26915}}}, {"cell_type": "code", "source": ["!ls"], "outputs": [{"output_type": "stream", "name": "stdout", "text": ["drive  sample_data\n"]}], "execution_count": 1, "id": "ePwFUQvrGq1A", "metadata": {"outputId": "1ad2c94e-120b-44de-9bae-90f6816946be", "id": "ePwFUQvrGq1A", "colab": {"base_uri": "https://localhost:8080/"}, "executionInfo": {"status": "ok", "timestamp": 1643617588236, "user": {"userId": "05569193017561439961", "displayName": "ai np", "photoUrl": "https://lh3.googleusercontent.com/a/default-user=s64"}, "user_tz": -330, "elapsed": 601}}}, {"cell_type": "code", "source": ["cd /content/drive/MyDrive/Human-Path-Prediction-master (1)/ynet"], "outputs": [{"output_type": "stream", "name": "stdout", "text": ["/content/drive/MyDrive/Human-Path-Prediction-master (1)/ynet\n"]}], "execution_count": 2, "id": "YPOC6rPIGC9Y", "metadata": {"outputId": "feea5194-2fc3-4569-82dd-d432f636e873", "id": "YPOC6rPIGC9Y", "colab": {"base_uri": "https://localhost:8080/"}, "executionInfo": {"status": "ok", "timestamp": 1643617588849, "user": {"userId": "05569193017561439961", "displayName": "ai np", "photoUrl": "https://lh3.googleusercontent.com/a/default-user=s64"}, "user_tz": -330, "elapsed": 1}}}, {"cell_type": "code", "source": ["!pip install wandb"], "outputs": [{"output_type": "stream", "name": "stdout", "text": ["Requirement already satisfied: wandb in /usr/local/lib/python3.7/dist-packages (0.12.9)\n", "Requirement already satisfied: yaspin>=1.0.0 in /usr/local/lib/python3.7/dist-packages (from wandb) (2.1.0)\n", "Requirement already satisfied: subprocess32>=3.5.3 in /usr/local/lib/python3.7/dist-packages (from wandb) (3.5.4)\n", "Requirement already satisfied: protobuf>=3.12.0 in /usr/local/lib/python3.7/dist-packages (from wandb) (3.17.3)\n", "Requirement already satisfied: pathtools in /usr/local/lib/python3.7/dist-packages (from wandb) (0.1.2)\n", "Requirement already satisfied: Click!=8.0.0,>=7.0 in /usr/local/lib/python3.7/dist-packages (from wandb) (7.1.2)\n", "Requirement already satisfied: psutil>=5.0.0 in /usr/local/lib/python3.7/dist-packages (from wandb) (5.4.8)\n", "Requirement already satisfied: shortuuid>=0.5.0 in /usr/local/lib/python3.7/dist-packages (from wandb) (1.0.8)\n", "Requirement already satisfied: sentry-sdk>=1.0.0 in /usr/local/lib/python3.7/dist-packages (from wandb) (1.5.4)\n", "Requirement already satisfied: configparser>=3.8.1 in /usr/local/lib/python3.7/dist-packages (from wandb) (5.2.0)\n", "Requirement already satisfied: requests<3,>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from wandb) (2.23.0)\n", "Requirement already satisfied: GitPython>=1.0.0 in /usr/local/lib/python3.7/dist-packages (from wandb) (3.1.26)\n", "Requirement already satisfied: six>=1.13.0 in /usr/local/lib/python3.7/dist-packages (from wandb) (1.15.0)\n", "Requirement already satisfied: promise<3,>=2.0 in /usr/local/lib/python3.7/dist-packages (from wandb) (2.3)\n", "Requirement already satisfied: PyYAML in /usr/local/lib/python3.7/dist-packages (from wandb) (5.3.1)\n", "Requirement already satisfied: python-dateutil>=2.6.1 in /usr/local/lib/python3.7/dist-packages (from wandb) (2.8.2)\n", "Requirement already satisfied: docker-pycreds>=0.4.0 in /usr/local/lib/python3.7/dist-packages (from wandb) (0.4.0)\n", "Requirement already satisfied: gitdb<5,>=4.0.1 in /usr/local/lib/python3.7/dist-packages (from GitPython>=1.0.0->wandb) (4.0.9)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.7/dist-packages (from GitPython>=1.0.0->wandb) (3.10.0.2)\n", "Requirement already satisfied: smmap<6,>=3.0.1 in /usr/local/lib/python3.7/dist-packages (from gitdb<5,>=4.0.1->GitPython>=1.0.0->wandb) (5.0.0)\n", "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.0.0->wandb) (1.24.3)\n", "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.0.0->wandb) (3.0.4)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.0.0->wandb) (2021.10.8)\n", "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.0.0->wandb) (2.10)\n", "Requirement already satisfied: termcolor<2.0.0,>=1.1.0 in /usr/local/lib/python3.7/dist-packages (from yaspin>=1.0.0->wandb) (1.1.0)\n"]}], "execution_count": 3, "id": "p2IwKmpUGCyL", "metadata": {"outputId": "d631c237-73e4-43f2-8db3-d89372876651", "id": "p2IwKmpUGCyL", "colab": {"base_uri": "https://localhost:8080/"}, "executionInfo": {"status": "ok", "timestamp": 1643617592673, "user": {"userId": "05569193017561439961", "displayName": "ai np", "photoUrl": "https://lh3.googleusercontent.com/a/default-user=s64"}, "user_tz": -330, "elapsed": 2371}}}, {"cell_type": "code", "source": ["import pandas as pd\n", "import yaml\n", "import argparse\n", "import torch\n", "from model import YNet"], "outputs": [], "execution_count": 4, "id": "determined-township", "metadata": {"id": "determined-township", "executionInfo": {"status": "ok", "timestamp": 1643617595567, "user": {"userId": "05569193017561439961", "displayName": "ai np", "photoUrl": "https://lh3.googleusercontent.com/a/default-user=s64"}, "user_tz": -330, "elapsed": 1225}}}, {"cell_type": "code", "source": ["%load_ext autoreload\n", "%autoreload 2"], "outputs": [], "execution_count": 5, "id": "finished-potential", "metadata": {"id": "finished-potential", "executionInfo": {"status": "ok", "timestamp": 1643617596134, "user": {"userId": "05569193017561439961", "displayName": "ai np", "photoUrl": "https://lh3.googleusercontent.com/a/default-user=s64"}, "user_tz": -330, "elapsed": 1}}}, {"source": ["#### Some hyperparameters and settings"], "cell_type": "markdown", "id": "e567fbe0", "metadata": {"id": "e567fbe0"}}, {"cell_type": "code", "source": ["CONFIG_FILE_PATH = 'config/inD_longterm.yaml'  # yaml config file containing all the hyperparameters\n", "EXPERIMENT_NAME = 'ind_longterm'  # arbitrary name for this experiment\n", "DATASET_NAME = 'ind'\n", "\n", "TRAIN_DATA_PATH = 'data/inD/train.pkl'\n", "TRAIN_IMAGE_PATH = 'data/inD/train'\n", "VAL_DATA_PATH = 'data/inD/test.pkl'\n", "VAL_IMAGE_PATH = 'data/inD/test'\n", "OBS_LEN = 5 # in timesteps\n", "PRED_LEN = 30 # in timesteps\n", "NUM_GOALS = 20 # K_e\n", "NUM_TRAJ = 10 # K_a\n", "\n", "BATCH_SIZE = 8"], "outputs": [], "execution_count": 7, "id": "arabic-thickness", "metadata": {"id": "arabic-thickness", "executionInfo": {"status": "ok", "timestamp": 1643617601476, "user": {"userId": "05569193017561439961", "displayName": "ai np", "photoUrl": "https://lh3.googleusercontent.com/a/default-user=s64"}, "user_tz": -330, "elapsed": 399}}}, {"cell_type": "code", "source": ["pip install -r requirements.txt"], "outputs": [{"output_type": "stream", "name": "stdout", "text": ["Collecting tqdm==4.48.0\n", "  Downloading tqdm-4.48.0-py2.py3-none-any.whl (67 kB)\n", "\u001b[?25l\r\u001b[K     |\u2588\u2588\u2588\u2588\u2589                           | 10 kB 33.8 MB/s eta 0:00:01\r\u001b[K     |\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u258b                      | 20 kB 29.7 MB/s eta 0:00:01\r\u001b[K     |\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u258c                 | 30 kB 20.1 MB/s eta 0:00:01\r\u001b[K     |\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u258e            | 40 kB 16.8 MB/s eta 0:00:01\r\u001b[K     |\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u258f       | 51 kB 8.5 MB/s eta 0:00:01\r\u001b[K     |\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588   | 61 kB 9.2 MB/s eta 0:00:01\r\u001b[K     |\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 67 kB 4.4 MB/s \n", "\u001b[?25hCollecting pyyaml==5.3.1\n", "  Downloading PyYAML-5.3.1.tar.gz (269 kB)\n", "\u001b[K     |\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 269 kB 18.0 MB/s \n", "\u001b[?25hRequirement already satisfied: matplotlib==3.2.2 in /usr/local/lib/python3.7/dist-packages (from -r requirements.txt (line 3)) (3.2.2)\n", "Collecting torch==1.5.1\n", "  Downloading torch-1.5.1-cp37-cp37m-manylinux1_x86_64.whl (753.2 MB)\n", "\u001b[K     |\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 753.2 MB 12 kB/s \n", "\u001b[?25hRequirement already satisfied: pandas==1.1.5 in /usr/local/lib/python3.7/dist-packages (from -r requirements.txt (line 5)) (1.1.5)\n", "Collecting opencv-python==4.4.0.42\n", "  Downloading opencv_python-4.4.0.42-cp37-cp37m-manylinux2014_x86_64.whl (49.4 MB)\n", "\u001b[K     |\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 49.4 MB 1.2 MB/s \n", "\u001b[?25hCollecting scipy==1.5.0\n", "  Downloading scipy-1.5.0-cp37-cp37m-manylinux1_x86_64.whl (25.9 MB)\n", "\u001b[K     |\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 25.9 MB 99.9 MB/s \n", "\u001b[?25hCollecting segmentation_models_pytorch==0.1.0\n", "  Downloading segmentation_models_pytorch-0.1.0-py3-none-any.whl (42 kB)\n", "\u001b[K     |\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 42 kB 1.6 MB/s \n", "\u001b[?25hRequirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib==3.2.2->-r requirements.txt (line 3)) (0.11.0)\n", "Requirement already satisfied: numpy>=1.11 in /usr/local/lib/python3.7/dist-packages (from matplotlib==3.2.2->-r requirements.txt (line 3)) (1.19.5)\n", "Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib==3.2.2->-r requirements.txt (line 3)) (3.0.7)\n", "Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib==3.2.2->-r requirements.txt (line 3)) (2.8.2)\n", "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib==3.2.2->-r requirements.txt (line 3)) (1.3.2)\n", "Requirement already satisfied: future in /usr/local/lib/python3.7/dist-packages (from torch==1.5.1->-r requirements.txt (line 4)) (0.16.0)\n", "Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.7/dist-packages (from pandas==1.1.5->-r requirements.txt (line 5)) (2018.9)\n", "Collecting pretrainedmodels==0.7.4\n", "  Downloading pretrainedmodels-0.7.4.tar.gz (58 kB)\n", "\u001b[K     |\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 58 kB 8.2 MB/s \n", "\u001b[?25hRequirement already satisfied: torchvision>=0.3.0 in /usr/local/lib/python3.7/dist-packages (from segmentation_models_pytorch==0.1.0->-r requirements.txt (line 8)) (0.11.1+cu111)\n", "Collecting efficientnet-pytorch>=0.5.1\n", "  Downloading efficientnet_pytorch-0.7.1.tar.gz (21 kB)\n", "Collecting munch\n", "  Downloading munch-2.5.0-py2.py3-none-any.whl (10 kB)\n", "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil>=2.1->matplotlib==3.2.2->-r requirements.txt (line 3)) (1.15.0)\n", "Requirement already satisfied: pillow!=8.3.0,>=5.3.0 in /usr/local/lib/python3.7/dist-packages (from torchvision>=0.3.0->segmentation_models_pytorch==0.1.0->-r requirements.txt (line 8)) (7.1.2)\n", "Collecting torchvision>=0.3.0\n", "  Downloading torchvision-0.11.3-cp37-cp37m-manylinux1_x86_64.whl (23.2 MB)\n", "\u001b[K     |\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 23.2 MB 1.1 MB/s \n", "\u001b[?25h  Downloading torchvision-0.11.2-cp37-cp37m-manylinux1_x86_64.whl (23.3 MB)\n", "\u001b[K     |\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 23.3 MB 1.1 MB/s \n", "\u001b[?25h  Downloading torchvision-0.11.1-cp37-cp37m-manylinux1_x86_64.whl (23.3 MB)\n", "\u001b[K     |\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 23.3 MB 1.2 MB/s \n", "\u001b[?25h  Downloading torchvision-0.10.1-cp37-cp37m-manylinux1_x86_64.whl (22.1 MB)\n", "\u001b[K     |\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 22.1 MB 1.2 MB/s \n", "\u001b[?25h  Downloading torchvision-0.10.0-cp37-cp37m-manylinux1_x86_64.whl (22.1 MB)\n", "\u001b[K     |\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 22.1 MB 56 kB/s \n", "\u001b[?25h  Downloading torchvision-0.9.1-cp37-cp37m-manylinux1_x86_64.whl (17.4 MB)\n", "\u001b[K     |\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 17.4 MB 65.7 MB/s \n", "\u001b[?25h  Downloading torchvision-0.9.0-cp37-cp37m-manylinux1_x86_64.whl (17.3 MB)\n", "\u001b[K     |\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 17.3 MB 89.6 MB/s \n", "\u001b[?25h  Downloading torchvision-0.8.2-cp37-cp37m-manylinux1_x86_64.whl (12.8 MB)\n", "\u001b[K     |\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 12.8 MB 75.5 MB/s \n", "\u001b[?25h  Downloading torchvision-0.8.1-cp37-cp37m-manylinux1_x86_64.whl (12.7 MB)\n", "\u001b[K     |\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 12.7 MB 68.1 MB/s \n", "\u001b[?25h  Downloading torchvision-0.8.0-cp37-cp37m-manylinux1_x86_64.whl (11.8 MB)\n", "\u001b[K     |\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 11.8 MB 46.9 MB/s \n", "\u001b[?25h  Downloading torchvision-0.7.0-cp37-cp37m-manylinux1_x86_64.whl (5.9 MB)\n", "\u001b[K     |\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 5.9 MB 72.2 MB/s \n", "\u001b[?25h  Downloading torchvision-0.6.1-cp37-cp37m-manylinux1_x86_64.whl (6.6 MB)\n", "\u001b[K     |\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 6.6 MB 48.6 MB/s \n", "\u001b[?25hBuilding wheels for collected packages: pyyaml, pretrainedmodels, efficientnet-pytorch\n", "  Building wheel for pyyaml (setup.py) ... \u001b[?25l\u001b[?25hdone\n", "  Created wheel for pyyaml: filename=PyYAML-5.3.1-cp37-cp37m-linux_x86_64.whl size=44636 sha256=28d47031ec264246114265250c3e99d3677a3532a37b4459a4ed23f0f4b73069\n", "  Stored in directory: /root/.cache/pip/wheels/5e/03/1e/e1e954795d6f35dfc7b637fe2277bff021303bd9570ecea653\n", "  Building wheel for pretrainedmodels (setup.py) ... \u001b[?25l\u001b[?25hdone\n", "  Created wheel for pretrainedmodels: filename=pretrainedmodels-0.7.4-py3-none-any.whl size=60965 sha256=01a940f0046eb2f1293f0faa5b234ab43342b86ba78d08c0e7f96109cd3302cd\n", "  Stored in directory: /root/.cache/pip/wheels/ed/27/e8/9543d42de2740d3544db96aefef63bda3f2c1761b3334f4873\n", "  Building wheel for efficientnet-pytorch (setup.py) ... \u001b[?25l\u001b[?25hdone\n", "  Created wheel for efficientnet-pytorch: filename=efficientnet_pytorch-0.7.1-py3-none-any.whl size=16446 sha256=20028317f38a8d940287b8cbe049c3f906a93790a3d9b5bbf5147a12427298e2\n", "  Stored in directory: /root/.cache/pip/wheels/0e/cc/b2/49e74588263573ff778da58cc99b9c6349b496636a7e165be6\n", "Successfully built pyyaml pretrainedmodels efficientnet-pytorch\n", "Installing collected packages: torch, tqdm, torchvision, munch, pretrainedmodels, efficientnet-pytorch, segmentation-models-pytorch, scipy, pyyaml, opencv-python\n", "  Attempting uninstall: torch\n", "    Found existing installation: torch 1.10.0+cu111\n", "    Uninstalling torch-1.10.0+cu111:\n", "      Successfully uninstalled torch-1.10.0+cu111\n", "  Attempting uninstall: tqdm\n", "    Found existing installation: tqdm 4.62.3\n", "    Uninstalling tqdm-4.62.3:\n", "      Successfully uninstalled tqdm-4.62.3\n", "  Attempting uninstall: torchvision\n", "    Found existing installation: torchvision 0.11.1+cu111\n", "    Uninstalling torchvision-0.11.1+cu111:\n", "      Successfully uninstalled torchvision-0.11.1+cu111\n", "  Attempting uninstall: scipy\n", "    Found existing installation: scipy 1.4.1\n", "    Uninstalling scipy-1.4.1:\n", "      Successfully uninstalled scipy-1.4.1\n", "  Attempting uninstall: pyyaml\n", "    Found existing installation: PyYAML 3.13\n", "    Uninstalling PyYAML-3.13:\n", "      Successfully uninstalled PyYAML-3.13\n", "  Attempting uninstall: opencv-python\n", "    Found existing installation: opencv-python 4.1.2.30\n", "    Uninstalling opencv-python-4.1.2.30:\n", "      Successfully uninstalled opencv-python-4.1.2.30\n", "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", "torchtext 0.11.0 requires torch==1.10.0, but you have torch 1.5.1 which is incompatible.\n", "torchaudio 0.10.0+cu111 requires torch==1.10.0, but you have torch 1.5.1 which is incompatible.\n", "albumentations 0.1.12 requires imgaug<0.2.7,>=0.2.5, but you have imgaug 0.2.9 which is incompatible.\u001b[0m\n", "Successfully installed efficientnet-pytorch-0.7.1 munch-2.5.0 opencv-python-4.4.0.42 pretrainedmodels-0.7.4 pyyaml-5.3.1 scipy-1.5.0 segmentation-models-pytorch-0.1.0 torch-1.5.1 torchvision-0.6.1 tqdm-4.48.0\n"]}, {"output_type": "display_data", "data": {"application/vnd.colab-display-data+json": {"pip_warning": {"packages": ["cv2", "torch", "tqdm", "yaml"]}}}, "metadata": {}}], "execution_count": 10, "id": "6BhBP_J5FhgB", "metadata": {"outputId": "e1c01031-6a3d-4e9f-8ae7-06357167859e", "id": "6BhBP_J5FhgB", "colab": {"base_uri": "https://localhost:8080/", "height": 1000}, "executionInfo": {"status": "ok", "timestamp": 1643617531259, "user": {"userId": "05569193017561439961", "displayName": "ai np", "photoUrl": "https://lh3.googleusercontent.com/a/default-user=s64"}, "user_tz": -330, "elapsed": 134521}}}, {"source": ["#### Load config file and print hyperparameters"], "cell_type": "markdown", "id": "2f729e8f", "metadata": {"id": "2f729e8f"}}, {"cell_type": "code", "source": ["with open(CONFIG_FILE_PATH) as file:\n", "    params = yaml.load(file, Loader=yaml.FullLoader)\n", "experiment_name = CONFIG_FILE_PATH.split('.yaml')[0].split('config/')[1]\n", "params"], "outputs": [{"output_type": "execute_result", "data": {"text/plain": ["{'CWS_params': {'ratio': 2, 'rot': True, 'sigma_factor': 6},\n", " 'batch_size': 8,\n", " 'decoder_channels': [64, 64, 64, 32, 32],\n", " 'encoder_channels': [32, 32, 64, 64, 64],\n", " 'kernlen': 31,\n", " 'learning_rate': 0.0001,\n", " 'loss_scale': 1000,\n", " 'nsig': 4,\n", " 'num_epochs': 300,\n", " 'rel_threshold': 0.002,\n", " 'resize': 0.33,\n", " 'segmentation_model_fp': 'segmentation_models/inD_segmentation.pth',\n", " 'semantic_classes': 6,\n", " 'temperature': 1.8,\n", " 'unfreeze': 100,\n", " 'use_CWS': True,\n", " 'use_TTST': True,\n", " 'use_features_only': False,\n", " 'viz_epoch': 10,\n", " 'waypoints': [14, 29]}"]}, "execution_count": 9, "metadata": {}}], "execution_count": 9, "id": "dangerous-cutting", "metadata": {"outputId": "6520b672-dc1d-4814-db14-bd6b4a01b8b1", "id": "dangerous-cutting", "colab": {"base_uri": "https://localhost:8080/"}, "executionInfo": {"status": "ok", "timestamp": 1643617608479, "user": {"userId": "05569193017561439961", "displayName": "ai np", "photoUrl": "https://lh3.googleusercontent.com/a/default-user=s64"}, "user_tz": -330, "elapsed": 385}}}, {"source": ["#### Wandb INIT"], "cell_type": "markdown", "id": "699a7543", "metadata": {"id": "699a7543"}}, {"cell_type": "code", "source": [""], "outputs": [], "execution_count": null, "id": "65f560e7", "metadata": {"id": "65f560e7"}}, {"source": ["#### Load preprocessed Data"], "cell_type": "markdown", "id": "amber-pressure", "metadata": {"id": "amber-pressure"}}, {"cell_type": "code", "source": ["#df_train = pd.read_pickle(TRAIN_DATA_PATH)\n", "#df_val = pd.read_pickle(VAL_DATA_PATH)\n", "!pip3 install pickle5\n", "#df_train = pd.read_pickle(TRAIN_DATA_PATH)\n", "#df_val = pd.read_pickle(VAL_DATA_PATH)\n", "\n", "import pickle5 as pickle \n", "with open(TRAIN_DATA_PATH, \"rb\") as fh:\n", "    df_train = pickle.load(fh)\n", "with open(VAL_DATA_PATH, \"rb\") as fh1:\n", "    df_val = pickle.load(fh1)"], "outputs": [{"output_type": "stream", "name": "stdout", "text": ["Collecting pickle5\n", "  Downloading pickle5-0.0.12-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl (256 kB)\n", "\u001b[?25l\r\u001b[K     |\u2588\u258e                              | 10 kB 23.3 MB/s eta 0:00:01\r\u001b[K     |\u2588\u2588\u258b                             | 20 kB 30.3 MB/s eta 0:00:01\r\u001b[K     |\u2588\u2588\u2588\u2589                            | 30 kB 19.1 MB/s eta 0:00:01\r\u001b[K     |\u2588\u2588\u2588\u2588\u2588\u258f                          | 40 kB 15.6 MB/s eta 0:00:01\r\u001b[K     |\u2588\u2588\u2588\u2588\u2588\u2588\u258d                         | 51 kB 7.6 MB/s eta 0:00:01\r\u001b[K     |\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u258a                        | 61 kB 7.7 MB/s eta 0:00:01\r\u001b[K     |\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588                       | 71 kB 8.1 MB/s eta 0:00:01\r\u001b[K     |\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u258e                     | 81 kB 9.1 MB/s eta 0:00:01\r\u001b[K     |\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u258c                    | 92 kB 9.3 MB/s eta 0:00:01\r\u001b[K     |\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2589                   | 102 kB 7.3 MB/s eta 0:00:01\r\u001b[K     |\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588                  | 112 kB 7.3 MB/s eta 0:00:01\r\u001b[K     |\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u258d                | 122 kB 7.3 MB/s eta 0:00:01\r\u001b[K     |\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u258b               | 133 kB 7.3 MB/s eta 0:00:01\r\u001b[K     |\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588              | 143 kB 7.3 MB/s eta 0:00:01\r\u001b[K     |\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u258f            | 153 kB 7.3 MB/s eta 0:00:01\r\u001b[K     |\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u258c           | 163 kB 7.3 MB/s eta 0:00:01\r\u001b[K     |\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u258a          | 174 kB 7.3 MB/s eta 0:00:01\r\u001b[K     |\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588         | 184 kB 7.3 MB/s eta 0:00:01\r\u001b[K     |\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u258e       | 194 kB 7.3 MB/s eta 0:00:01\r\u001b[K     |\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u258b      | 204 kB 7.3 MB/s eta 0:00:01\r\u001b[K     |\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2589     | 215 kB 7.3 MB/s eta 0:00:01\r\u001b[K     |\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u258f   | 225 kB 7.3 MB/s eta 0:00:01\r\u001b[K     |\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u258d  | 235 kB 7.3 MB/s eta 0:00:01\r\u001b[K     |\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u258a | 245 kB 7.3 MB/s eta 0:00:01\r\u001b[K     |\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 256 kB 7.3 MB/s eta 0:00:01\r\u001b[K     |\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 256 kB 7.3 MB/s \n", "\u001b[?25hInstalling collected packages: pickle5\n", "Successfully installed pickle5-0.0.12\n"]}], "execution_count": 10, "id": "german-feature", "metadata": {"outputId": "afba0d8a-9ab6-44ae-9f6c-1c52b9075939", "id": "german-feature", "colab": {"base_uri": "https://localhost:8080/"}, "executionInfo": {"status": "ok", "timestamp": 1643617615097, "user": {"userId": "05569193017561439961", "displayName": "ai np", "photoUrl": "https://lh3.googleusercontent.com/a/default-user=s64"}, "user_tz": -330, "elapsed": 4073}}}, {"cell_type": "code", "source": ["df_train.head()"], "outputs": [{"output_type": "execute_result", "data": {"text/plain": ["   trackId  frame         x         y sceneId  metaId\n", "0       31   2217  25.07654   6.78323      07       0\n", "1       31   2242  26.11484   7.72170      07       0\n", "2       31   2267  27.05390   8.94723      07       0\n", "3       31   2292  28.08326  10.18219      07       0\n", "4       31   2317  29.08530  11.39276      07       0"], "text/html": ["\n", "  <div id=\"df-c0a0db93-6fe4-405a-a281-026bf6e79687\">\n", "    <div class=\"colab-df-container\">\n", "      <div>\n", "<style scoped>\n", "    .dataframe tbody tr th:only-of-type {\n", "        vertical-align: middle;\n", "    }\n", "\n", "    .dataframe tbody tr th {\n", "        vertical-align: top;\n", "    }\n", "\n", "    .dataframe thead th {\n", "        text-align: right;\n", "    }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", "  <thead>\n", "    <tr style=\"text-align: right;\">\n", "      <th></th>\n", "      <th>trackId</th>\n", "      <th>frame</th>\n", "      <th>x</th>\n", "      <th>y</th>\n", "      <th>sceneId</th>\n", "      <th>metaId</th>\n", "    </tr>\n", "  </thead>\n", "  <tbody>\n", "    <tr>\n", "      <th>0</th>\n", "      <td>31</td>\n", "      <td>2217</td>\n", "      <td>25.07654</td>\n", "      <td>6.78323</td>\n", "      <td>07</td>\n", "      <td>0</td>\n", "    </tr>\n", "    <tr>\n", "      <th>1</th>\n", "      <td>31</td>\n", "      <td>2242</td>\n", "      <td>26.11484</td>\n", "      <td>7.72170</td>\n", "      <td>07</td>\n", "      <td>0</td>\n", "    </tr>\n", "    <tr>\n", "      <th>2</th>\n", "      <td>31</td>\n", "      <td>2267</td>\n", "      <td>27.05390</td>\n", "      <td>8.94723</td>\n", "      <td>07</td>\n", "      <td>0</td>\n", "    </tr>\n", "    <tr>\n", "      <th>3</th>\n", "      <td>31</td>\n", "      <td>2292</td>\n", "      <td>28.08326</td>\n", "      <td>10.18219</td>\n", "      <td>07</td>\n", "      <td>0</td>\n", "    </tr>\n", "    <tr>\n", "      <th>4</th>\n", "      <td>31</td>\n", "      <td>2317</td>\n", "      <td>29.08530</td>\n", "      <td>11.39276</td>\n", "      <td>07</td>\n", "      <td>0</td>\n", "    </tr>\n", "  </tbody>\n", "</table>\n", "</div>\n", "      <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-c0a0db93-6fe4-405a-a281-026bf6e79687')\"\n", "              title=\"Convert this dataframe to an interactive table.\"\n", "              style=\"display:none;\">\n", "        \n", "  <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n", "       width=\"24px\">\n", "    <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n", "    <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n", "  </svg>\n", "      </button>\n", "      \n", "  <style>\n", "    .colab-df-container {\n", "      display:flex;\n", "      flex-wrap:wrap;\n", "      gap: 12px;\n", "    }\n", "\n", "    .colab-df-convert {\n", "      background-color: #E8F0FE;\n", "      border: none;\n", "      border-radius: 50%;\n", "      cursor: pointer;\n", "      display: none;\n", "      fill: #1967D2;\n", "      height: 32px;\n", "      padding: 0 0 0 0;\n", "      width: 32px;\n", "    }\n", "\n", "    .colab-df-convert:hover {\n", "      background-color: #E2EBFA;\n", "      box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n", "      fill: #174EA6;\n", "    }\n", "\n", "    [theme=dark] .colab-df-convert {\n", "      background-color: #3B4455;\n", "      fill: #D2E3FC;\n", "    }\n", "\n", "    [theme=dark] .colab-df-convert:hover {\n", "      background-color: #434B5C;\n", "      box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n", "      filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n", "      fill: #FFFFFF;\n", "    }\n", "  </style>\n", "\n", "      <script>\n", "        const buttonEl =\n", "          document.querySelector('#df-c0a0db93-6fe4-405a-a281-026bf6e79687 button.colab-df-convert');\n", "        buttonEl.style.display =\n", "          google.colab.kernel.accessAllowed ? 'block' : 'none';\n", "\n", "        async function convertToInteractive(key) {\n", "          const element = document.querySelector('#df-c0a0db93-6fe4-405a-a281-026bf6e79687');\n", "          const dataTable =\n", "            await google.colab.kernel.invokeFunction('convertToInteractive',\n", "                                                     [key], {});\n", "          if (!dataTable) return;\n", "\n", "          const docLinkHtml = 'Like what you see? Visit the ' +\n", "            '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n", "            + ' to learn more about interactive tables.';\n", "          element.innerHTML = '';\n", "          dataTable['output_type'] = 'display_data';\n", "          await google.colab.output.renderOutput(dataTable, element);\n", "          const docLink = document.createElement('div');\n", "          docLink.innerHTML = docLinkHtml;\n", "          element.appendChild(docLink);\n", "        }\n", "      </script>\n", "    </div>\n", "  </div>\n", "  "]}, "execution_count": 11, "metadata": {}}], "execution_count": 11, "id": "corporate-pharmacy", "metadata": {"outputId": "96ada800-356e-466c-bf2a-914d449d2d31", "id": "corporate-pharmacy", "colab": {"base_uri": "https://localhost:8080/", "height": 206}, "executionInfo": {"status": "ok", "timestamp": 1643617615099, "user": {"userId": "05569193017561439961", "displayName": "ai np", "photoUrl": "https://lh3.googleusercontent.com/a/default-user=s64"}, "user_tz": -330, "elapsed": 9}}}, {"source": ["#### Initiate model"], "cell_type": "markdown", "id": "24dc5d7c", "metadata": {"id": "24dc5d7c"}}, {"cell_type": "code", "source": ["model = YNet(obs_len=OBS_LEN, pred_len=PRED_LEN, params=params)"], "outputs": [{"output_type": "stream", "name": "stderr", "text": ["/usr/local/lib/python3.7/dist-packages/torch/serialization.py:657: SourceChangeWarning: source code of class 'segmentation_models_pytorch.encoders.resnet.ResNetEncoder' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n", "  warnings.warn(msg, SourceChangeWarning)\n", "/usr/local/lib/python3.7/dist-packages/torch/serialization.py:657: SourceChangeWarning: source code of class 'segmentation_models_pytorch.base.modules.Conv2dReLU' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n", "  warnings.warn(msg, SourceChangeWarning)\n", "/usr/local/lib/python3.7/dist-packages/torch/serialization.py:657: SourceChangeWarning: source code of class 'segmentation_models_pytorch.base.modules.Activation' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n", "  warnings.warn(msg, SourceChangeWarning)\n"]}], "execution_count": 12, "id": "harmful-colleague", "metadata": {"outputId": "3c7f79ec-fb07-4363-f742-369dc9c2923e", "id": "harmful-colleague", "colab": {"base_uri": "https://localhost:8080/"}, "executionInfo": {"status": "ok", "timestamp": 1643617619720, "user": {"userId": "05569193017561439961", "displayName": "ai np", "photoUrl": "https://lh3.googleusercontent.com/a/default-user=s64"}, "user_tz": -330, "elapsed": 4627}}}, {"source": ["#### Start training\n", "Note, the Val ADE and FDE are without TTST and CWS to save time. Therefore, the numbers will be worse than the final values."], "cell_type": "markdown", "id": "45e099fe", "metadata": {"id": "45e099fe"}}, {"cell_type": "code", "source": ["import weights_and_biases as wandb\n", "wandb.init_wandb(params.copy(), model.model)"], "outputs": [{"output_type": "display_data", "data": {"application/javascript": ["\n", "        window._wandbApiKey = new Promise((resolve, reject) => {\n", "            function loadScript(url) {\n", "            return new Promise(function(resolve, reject) {\n", "                let newScript = document.createElement(\"script\");\n", "                newScript.onerror = reject;\n", "                newScript.onload = resolve;\n", "                document.body.appendChild(newScript);\n", "                newScript.src = url;\n", "            });\n", "            }\n", "            loadScript(\"https://cdn.jsdelivr.net/npm/postmate/build/postmate.min.js\").then(() => {\n", "            const iframe = document.createElement('iframe')\n", "            iframe.style.cssText = \"width:0;height:0;border:none\"\n", "            document.body.appendChild(iframe)\n", "            const handshake = new Postmate({\n", "                container: iframe,\n", "                url: 'https://wandb.ai/authorize'\n", "            });\n", "            const timeout = setTimeout(() => reject(\"Couldn't auto authenticate\"), 5000)\n", "            handshake.then(function(child) {\n", "                child.on('authorize', data => {\n", "                    clearTimeout(timeout)\n", "                    resolve(data)\n", "                });\n", "            });\n", "            })\n", "        });\n", "    "], "text/plain": ["<IPython.core.display.Javascript object>"]}, "metadata": {}}, {"output_type": "stream", "name": "stderr", "text": ["\u001b[34m\u001b[1mwandb\u001b[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc\n"]}, {"output_type": "display_data", "data": {"text/plain": ["<IPython.core.display.HTML object>"], "text/html": ["\n", "                    Syncing run <strong><a href=\"https://wandb.ai/agv/ynet/runs/3321ywf0\" target=\"_blank\">morning-dragon-19</a></strong> to <a href=\"https://wandb.ai/agv/ynet\" target=\"_blank\">Weights & Biases</a> (<a href=\"https://docs.wandb.com/integrations/jupyter.html\" target=\"_blank\">docs</a>).<br/>\n", "\n", "                "]}, "metadata": {}}], "execution_count": null, "id": "suiRZodUeLO4", "metadata": {"outputId": "1d0b1664-7504-4305-c8f1-2c860bd5b0ec", "id": "suiRZodUeLO4", "colab": {"base_uri": "https://localhost:8080/", "height": 69}}}, {"cell_type": "code", "source": ["model.train(df_train, df_val, params, train_image_path=TRAIN_IMAGE_PATH, val_image_path=VAL_IMAGE_PATH, \n", "            experiment_name=EXPERIMENT_NAME, batch_size=BATCH_SIZE, num_goals=NUM_GOALS, num_traj=NUM_TRAJ, \n", "            device=None, dataset_name= 'ind')"], "outputs": [{"output_type": "stream", "text": ["Preprocess data\n"], "name": "stdout", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["Prepare Dataset: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 208/208 [00:00<00:00, 1054.16it/s]\n", "/content/drive/MyDrive/Human-Path-Prediction-master (1)/ynet/utils/dataloader.py:38: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray\n", "  return np.array(trajectories), meta, scene_list\n", "Prepare Dataset: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 7/7 [00:00<00:00, 822.69it/s]\n", "Epoch:   0%|          | 0/300 [00:00<?, ?it/s]"], "name": "stderr", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["Start training\n", "208\n", "Epoch 0: \n", "Val ADE: 24.52324104309082 \n", "Val FDE: 15.545023918151855\n", "Best Epoch 0: \n", "Val ADE: 24.52324104309082 \n", "Val FDE: 15.545023918151855\n"], "name": "stdout", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m Saving files without folders. If you want to preserve sub directories pass base_path to wandb.save, i.e. wandb.save(\"/mnt/folder/file.h5\", base_path=\"/mnt\")\n", "Epoch:   0%|          | 1/300 [05:17<26:19:58, 317.05s/it]"], "name": "stderr", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["Saved model to: saved_models/Ynet.pt\n"], "name": "stdout", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["\rEpoch:   1%|          | 2/300 [10:31<26:11:20, 316.38s/it]"], "name": "stderr", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["Epoch 1: \n", "Val ADE: 31.442262649536133 \n", "Val FDE: 18.835987091064453\n"], "name": "stdout", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["\rEpoch:   1%|          | 3/300 [15:46<26:03:37, 315.88s/it]"], "name": "stderr", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["Epoch 2: \n", "Val ADE: 27.511089324951172 \n", "Val FDE: 13.868932723999023\n", "Epoch 3: \n", "Val ADE: 21.271846771240234 \n", "Val FDE: 11.658187866210938\n", "Best Epoch 3: \n", "Val ADE: 21.271846771240234 \n", "Val FDE: 11.658187866210938\n"], "name": "stdout", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m Saving files without folders. If you want to preserve sub directories pass base_path to wandb.save, i.e. wandb.save(\"/mnt/folder/file.h5\", base_path=\"/mnt\")\n", "Epoch:   1%|\u258f         | 4/300 [21:02<25:57:44, 315.76s/it]"], "name": "stderr", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["Saved model to: saved_models/Ynet.pt\n"], "name": "stdout", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["\rEpoch:   2%|\u258f         | 5/300 [26:16<25:51:01, 315.46s/it]"], "name": "stderr", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["Epoch 4: \n", "Val ADE: 23.82901382446289 \n", "Val FDE: 13.070942878723145\n"], "name": "stdout", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["\rEpoch:   2%|\u258f         | 6/300 [31:31<25:44:40, 315.24s/it]"], "name": "stderr", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["Epoch 5: \n", "Val ADE: 33.927364349365234 \n", "Val FDE: 15.191147804260254\n"], "name": "stdout", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["\rEpoch:   2%|\u258f         | 7/300 [36:46<25:38:41, 315.09s/it]"], "name": "stderr", "metadata": {"tags": null}}, {"output_type": "stream", "name": "stdout", "text": ["Epoch 6: \n", "Val ADE: 25.80826187133789 \n", "Val FDE: 13.541108131408691\n", "Epoch 7: \n", "Val ADE: 21.215837478637695 \n", "Val FDE: 17.555953979492188\n", "Best Epoch 7: \n", "Val ADE: 21.215837478637695 \n", "Val FDE: 17.555953979492188\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m Saving files without folders. If you want to preserve sub directories pass base_path to wandb.save, i.e. wandb.save(\"/mnt/folder/file.h5\", base_path=\"/mnt\")\n", "Epoch:   3%|\u258e         | 8/300 [42:01<25:33:46, 315.16s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Saved model to: saved_models/Ynet.pt\n", "Epoch 8: \n", "Val ADE: 19.34832191467285 \n", "Val FDE: 13.197525024414062\n", "Best Epoch 8: \n", "Val ADE: 19.34832191467285 \n", "Val FDE: 13.197525024414062\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m Saving files without folders. If you want to preserve sub directories pass base_path to wandb.save, i.e. wandb.save(\"/mnt/folder/file.h5\", base_path=\"/mnt\")\n", "Epoch:   3%|\u258e         | 9/300 [47:16<25:28:48, 315.22s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Saved model to: saved_models/Ynet.pt\n", "Epoch 9: \n", "Val ADE: 16.951705932617188 \n", "Val FDE: 17.781843185424805\n", "Best Epoch 9: \n", "Val ADE: 16.951705932617188 \n", "Val FDE: 17.781843185424805\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m Saving files without folders. If you want to preserve sub directories pass base_path to wandb.save, i.e. wandb.save(\"/mnt/folder/file.h5\", base_path=\"/mnt\")\n", "Epoch:   3%|\u258e         | 10/300 [52:32<25:23:47, 315.27s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Saved model to: saved_models/Ynet.pt\n", "Epoch 10: \n", "Val ADE: 9.43283748626709 \n", "Val FDE: 15.147685050964355\n", "Best Epoch 10: \n", "Val ADE: 9.43283748626709 \n", "Val FDE: 15.147685050964355\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m Saving files without folders. If you want to preserve sub directories pass base_path to wandb.save, i.e. wandb.save(\"/mnt/folder/file.h5\", base_path=\"/mnt\")\n", "Epoch:   4%|\u258e         | 11/300 [57:47<25:18:36, 315.28s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Saved model to: saved_models/Ynet.pt\n", "Epoch 11: \n", "Val ADE: 9.431289672851562 \n", "Val FDE: 16.326217651367188\n", "Best Epoch 11: \n", "Val ADE: 9.431289672851562 \n", "Val FDE: 16.326217651367188\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m Saving files without folders. If you want to preserve sub directories pass base_path to wandb.save, i.e. wandb.save(\"/mnt/folder/file.h5\", base_path=\"/mnt\")\n", "Epoch:   4%|\u258d         | 12/300 [1:03:03<25:13:29, 315.31s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Saved model to: saved_models/Ynet.pt\n", "Epoch 12: \n", "Val ADE: 8.449688911437988 \n", "Val FDE: 16.260046005249023\n", "Best Epoch 12: \n", "Val ADE: 8.449688911437988 \n", "Val FDE: 16.260046005249023\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m Saving files without folders. If you want to preserve sub directories pass base_path to wandb.save, i.e. wandb.save(\"/mnt/folder/file.h5\", base_path=\"/mnt\")\n", "Epoch:   4%|\u258d         | 13/300 [1:08:18<25:08:25, 315.35s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Saved model to: saved_models/Ynet.pt\n", "Epoch 13: \n", "Val ADE: 6.904079914093018 \n", "Val FDE: 16.549118041992188\n", "Best Epoch 13: \n", "Val ADE: 6.904079914093018 \n", "Val FDE: 16.549118041992188\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m Saving files without folders. If you want to preserve sub directories pass base_path to wandb.save, i.e. wandb.save(\"/mnt/folder/file.h5\", base_path=\"/mnt\")\n", "Epoch:   5%|\u258d         | 14/300 [1:13:33<25:03:09, 315.35s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Saved model to: saved_models/Ynet.pt\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\rEpoch:   5%|\u258c         | 15/300 [1:18:48<24:56:56, 315.15s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Epoch 14: \n", "Val ADE: 8.350976943969727 \n", "Val FDE: 17.06983757019043\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\rEpoch:   5%|\u258c         | 16/300 [1:24:03<24:51:00, 315.00s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Epoch 15: \n", "Val ADE: 7.119980335235596 \n", "Val FDE: 16.23307228088379\n", "Epoch 16: \n", "Val ADE: 6.2168965339660645 \n", "Val FDE: 16.745540618896484\n", "Best Epoch 16: \n", "Val ADE: 6.2168965339660645 \n", "Val FDE: 16.745540618896484\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m Saving files without folders. If you want to preserve sub directories pass base_path to wandb.save, i.e. wandb.save(\"/mnt/folder/file.h5\", base_path=\"/mnt\")\n", "Epoch:   6%|\u258c         | 17/300 [1:29:18<24:46:15, 315.11s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Saved model to: saved_models/Ynet.pt\n", "Epoch 17: \n", "Val ADE: 5.149584770202637 \n", "Val FDE: 13.510305404663086\n", "Best Epoch 17: \n", "Val ADE: 5.149584770202637 \n", "Val FDE: 13.510305404663086\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m Saving files without folders. If you want to preserve sub directories pass base_path to wandb.save, i.e. wandb.save(\"/mnt/folder/file.h5\", base_path=\"/mnt\")\n", "Epoch:   6%|\u258c         | 18/300 [1:34:33<24:41:13, 315.16s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Saved model to: saved_models/Ynet.pt\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\rEpoch:   6%|\u258b         | 19/300 [1:39:48<24:35:20, 315.02s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Epoch 18: \n", "Val ADE: 6.612066268920898 \n", "Val FDE: 17.32286834716797\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\rEpoch:   7%|\u258b         | 20/300 [1:45:03<24:29:39, 314.93s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Epoch 19: \n", "Val ADE: 6.247178554534912 \n", "Val FDE: 15.675963401794434\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\rEpoch:   7%|\u258b         | 21/300 [1:50:17<24:24:02, 314.85s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Epoch 20: \n", "Val ADE: 6.097778797149658 \n", "Val FDE: 14.187601089477539\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\rEpoch:   7%|\u258b         | 22/300 [1:55:32<24:18:34, 314.80s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Epoch 21: \n", "Val ADE: 6.787241458892822 \n", "Val FDE: 13.701334953308105\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\rEpoch:   8%|\u258a         | 23/300 [2:00:47<24:13:13, 314.78s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Epoch 22: \n", "Val ADE: 7.450379848480225 \n", "Val FDE: 13.574320793151855\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\rEpoch:   8%|\u258a         | 24/300 [2:06:01<24:07:53, 314.76s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Epoch 23: \n", "Val ADE: 6.7749481201171875 \n", "Val FDE: 12.174635887145996\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\rEpoch:   8%|\u258a         | 25/300 [2:11:16<24:02:34, 314.74s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Epoch 24: \n", "Val ADE: 8.040609359741211 \n", "Val FDE: 15.903623580932617\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\rEpoch:   9%|\u258a         | 26/300 [2:16:31<23:57:19, 314.74s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Epoch 25: \n", "Val ADE: 7.765658855438232 \n", "Val FDE: 14.131928443908691\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\rEpoch:   9%|\u2589         | 27/300 [2:21:46<23:51:57, 314.72s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Epoch 26: \n", "Val ADE: 7.066863059997559 \n", "Val FDE: 11.959335327148438\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\rEpoch:   9%|\u2589         | 28/300 [2:27:00<23:46:38, 314.70s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Epoch 27: \n", "Val ADE: 8.887083053588867 \n", "Val FDE: 15.644576072692871\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\rEpoch:  10%|\u2589         | 29/300 [2:32:15<23:41:24, 314.70s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Epoch 28: \n", "Val ADE: 8.01439380645752 \n", "Val FDE: 13.944046974182129\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\rEpoch:  10%|\u2588         | 30/300 [2:37:30<23:36:06, 314.69s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Epoch 29: \n", "Val ADE: 8.040870666503906 \n", "Val FDE: 13.148353576660156\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\rEpoch:  10%|\u2588         | 31/300 [2:42:44<23:30:52, 314.69s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Epoch 30: \n", "Val ADE: 8.784407615661621 \n", "Val FDE: 16.92945671081543\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\rEpoch:  11%|\u2588         | 32/300 [2:47:59<23:25:36, 314.69s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Epoch 31: \n", "Val ADE: 8.764410972595215 \n", "Val FDE: 16.009187698364258\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\rEpoch:  11%|\u2588         | 33/300 [2:53:14<23:20:24, 314.70s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Epoch 32: \n", "Val ADE: 10.523109436035156 \n", "Val FDE: 19.550888061523438\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\rEpoch:  11%|\u2588\u258f        | 34/300 [2:58:28<23:15:07, 314.69s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Epoch 33: \n", "Val ADE: 9.594310760498047 \n", "Val FDE: 17.422550201416016\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\rEpoch:  12%|\u2588\u258f        | 35/300 [3:03:43<23:09:53, 314.69s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Epoch 34: \n", "Val ADE: 8.17928695678711 \n", "Val FDE: 14.617682456970215\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\rEpoch:  12%|\u2588\u258f        | 36/300 [3:08:58<23:04:43, 314.71s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Epoch 35: \n", "Val ADE: 9.265252113342285 \n", "Val FDE: 15.799777030944824\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\rEpoch:  12%|\u2588\u258f        | 37/300 [3:14:13<22:59:27, 314.70s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Epoch 36: \n", "Val ADE: 10.083176612854004 \n", "Val FDE: 18.598773956298828\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\rEpoch:  13%|\u2588\u258e        | 38/300 [3:19:27<22:54:08, 314.69s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Epoch 37: \n", "Val ADE: 9.642986297607422 \n", "Val FDE: 17.64706802368164\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\rEpoch:  13%|\u2588\u258e        | 39/300 [3:24:42<22:48:57, 314.70s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Epoch 38: \n", "Val ADE: 9.195450782775879 \n", "Val FDE: 16.452146530151367\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\rEpoch:  13%|\u2588\u258e        | 40/300 [3:29:57<22:43:41, 314.70s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Epoch 39: \n", "Val ADE: 8.474781036376953 \n", "Val FDE: 15.312761306762695\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\rEpoch:  14%|\u2588\u258e        | 41/300 [3:35:11<22:38:26, 314.70s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Epoch 40: \n", "Val ADE: 9.311485290527344 \n", "Val FDE: 16.22816276550293\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\rEpoch:  14%|\u2588\u258d        | 42/300 [3:40:26<22:33:14, 314.71s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Epoch 41: \n", "Val ADE: 9.702507972717285 \n", "Val FDE: 17.192020416259766\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\rEpoch:  14%|\u2588\u258d        | 43/300 [3:45:41<22:28:00, 314.71s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Epoch 42: \n", "Val ADE: 8.051592826843262 \n", "Val FDE: 13.59714126586914\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\rEpoch:  15%|\u2588\u258d        | 44/300 [3:50:55<22:22:48, 314.72s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Epoch 43: \n", "Val ADE: 8.020970344543457 \n", "Val FDE: 14.879944801330566\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\rEpoch:  15%|\u2588\u258c        | 45/300 [3:56:10<22:17:37, 314.73s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Epoch 44: \n", "Val ADE: 8.095160484313965 \n", "Val FDE: 14.337087631225586\n"]}], "execution_count": null, "id": "optional-colleague", "metadata": {"outputId": "5ad117fb-39eb-48a0-fe4d-88888d3ecc5b", "id": "optional-colleague", "colab": {"base_uri": "https://localhost:8080/"}}}, {"cell_type": "code", "source": [""], "outputs": [], "execution_count": null, "id": "16fe307e", "metadata": {"id": "16fe307e"}}], "metadata": {"kernelspec": {"display_name": "Python 3", "name": "python3"}, "language_info": {"name": "python"}, "colab": {"provenance": [], "machine_shape": "hm", "name": "train_inD_longterm.ipynb"}, "accelerator": "GPU"}}