{"nbformat_minor": 5, "nbformat": 4, "cells": [{"cell_type": "code", "source": ["from google.colab import drive"], "outputs": [], "execution_count": 2, "id": "q4wgKhonGDJn", "metadata": {"id": "q4wgKhonGDJn", "executionInfo": {"status": "ok", "timestamp": 1643887401188, "user": {"userId": "05569193017561439961", "displayName": "ai np", "photoUrl": "https://lh3.googleusercontent.com/a/default-user=s64"}, "user_tz": -330, "elapsed": 468}}}, {"source": [""], "cell_type": "markdown", "id": "koLSDU3KHn5e", "metadata": {"id": "koLSDU3KHn5e"}}, {"cell_type": "code", "source": ["drive.mount('/content/drive', force_remount= True)"], "outputs": [{"output_type": "stream", "name": "stdout", "text": ["Mounted at /content/drive\n"]}], "execution_count": 3, "id": "QGzq7HyiGDDv", "metadata": {"outputId": "79e73b22-757a-4053-f68b-d9987b2fe9cf", "id": "QGzq7HyiGDDv", "colab": {"base_uri": "https://localhost:8080/"}, "executionInfo": {"status": "ok", "timestamp": 1643887431163, "user": {"userId": "05569193017561439961", "displayName": "ai np", "photoUrl": "https://lh3.googleusercontent.com/a/default-user=s64"}, "user_tz": -330, "elapsed": 27956}}}, {"cell_type": "code", "source": ["!ls"], "outputs": [{"output_type": "stream", "name": "stdout", "text": ["drive  sample_data\n"]}], "execution_count": null, "id": "ePwFUQvrGq1A", "metadata": {"outputId": "06a8f9ce-2acc-4da6-f6af-3a9993bfbd88", "id": "ePwFUQvrGq1A", "colab": {"base_uri": "https://localhost:8080/"}, "executionInfo": {"status": "ok", "timestamp": 1643829772255, "user": {"userId": "05569193017561439961", "displayName": "ai np", "photoUrl": "https://lh3.googleusercontent.com/a/default-user=s64"}, "user_tz": -330, "elapsed": 431}}}, {"cell_type": "code", "source": ["cd /content/drive/MyDrive/Human-Path-Prediction-master/ynet"], "outputs": [{"output_type": "stream", "name": "stdout", "text": ["/content/drive/MyDrive/Human-Path-Prediction-master/ynet\n"]}], "execution_count": 31, "id": "YPOC6rPIGC9Y", "metadata": {"outputId": "8af76835-b357-475b-adb4-95195219401c", "id": "YPOC6rPIGC9Y", "colab": {"base_uri": "https://localhost:8080/"}, "executionInfo": {"status": "ok", "timestamp": 1643896284188, "user": {"userId": "05569193017561439961", "displayName": "ai np", "photoUrl": "https://lh3.googleusercontent.com/a/default-user=s64"}, "user_tz": -330, "elapsed": 2164}}}, {"cell_type": "code", "source": ["!pip install wandb"], "outputs": [{"output_type": "stream", "name": "stdout", "text": ["Collecting wandb\n", "  Downloading wandb-0.12.10-py2.py3-none-any.whl (1.7 MB)\n", "\u001b[K     |\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 1.7 MB 5.1 MB/s \n", "\u001b[?25hRequirement already satisfied: Click!=8.0.0,>=7.0 in /usr/local/lib/python3.7/dist-packages (from wandb) (7.1.2)\n", "Requirement already satisfied: python-dateutil>=2.6.1 in /usr/local/lib/python3.7/dist-packages (from wandb) (2.8.2)\n", "Requirement already satisfied: PyYAML in /usr/local/lib/python3.7/dist-packages (from wandb) (3.13)\n", "Collecting GitPython>=1.0.0\n", "  Downloading GitPython-3.1.26-py3-none-any.whl (180 kB)\n", "\u001b[K     |\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 180 kB 46.1 MB/s \n", "\u001b[?25hRequirement already satisfied: protobuf>=3.12.0 in /usr/local/lib/python3.7/dist-packages (from wandb) (3.17.3)\n", "Collecting sentry-sdk>=1.0.0\n", "  Downloading sentry_sdk-1.5.4-py2.py3-none-any.whl (143 kB)\n", "\u001b[K     |\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 143 kB 53.8 MB/s \n", "\u001b[?25hRequirement already satisfied: promise<3,>=2.0 in /usr/local/lib/python3.7/dist-packages (from wandb) (2.3)\n", "Collecting yaspin>=1.0.0\n", "  Downloading yaspin-2.1.0-py3-none-any.whl (18 kB)\n", "Requirement already satisfied: psutil>=5.0.0 in /usr/local/lib/python3.7/dist-packages (from wandb) (5.4.8)\n", "Requirement already satisfied: six>=1.13.0 in /usr/local/lib/python3.7/dist-packages (from wandb) (1.15.0)\n", "Collecting shortuuid>=0.5.0\n", "  Downloading shortuuid-1.0.8-py3-none-any.whl (9.5 kB)\n", "Collecting docker-pycreds>=0.4.0\n", "  Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB)\n", "Collecting pathtools\n", "  Downloading pathtools-0.1.2.tar.gz (11 kB)\n", "Requirement already satisfied: requests<3,>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from wandb) (2.23.0)\n", "Collecting gitdb<5,>=4.0.1\n", "  Downloading gitdb-4.0.9-py3-none-any.whl (63 kB)\n", "\u001b[K     |\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 63 kB 1.9 MB/s \n", "\u001b[?25hRequirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.7/dist-packages (from GitPython>=1.0.0->wandb) (3.10.0.2)\n", "Collecting smmap<6,>=3.0.1\n", "  Downloading smmap-5.0.0-py3-none-any.whl (24 kB)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.0.0->wandb) (2021.10.8)\n", "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.0.0->wandb) (2.10)\n", "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.0.0->wandb) (3.0.4)\n", "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.0.0->wandb) (1.24.3)\n", "Requirement already satisfied: termcolor<2.0.0,>=1.1.0 in /usr/local/lib/python3.7/dist-packages (from yaspin>=1.0.0->wandb) (1.1.0)\n", "Building wheels for collected packages: pathtools\n", "  Building wheel for pathtools (setup.py) ... \u001b[?25l\u001b[?25hdone\n", "  Created wheel for pathtools: filename=pathtools-0.1.2-py3-none-any.whl size=8806 sha256=4c35635acf73f56a09b5094fa63b57526f8f54d86f909414731962855b83d9ce\n", "  Stored in directory: /root/.cache/pip/wheels/3e/31/09/fa59cef12cdcfecc627b3d24273699f390e71828921b2cbba2\n", "Successfully built pathtools\n", "Installing collected packages: smmap, gitdb, yaspin, shortuuid, sentry-sdk, pathtools, GitPython, docker-pycreds, wandb\n", "Successfully installed GitPython-3.1.26 docker-pycreds-0.4.0 gitdb-4.0.9 pathtools-0.1.2 sentry-sdk-1.5.4 shortuuid-1.0.8 smmap-5.0.0 wandb-0.12.10 yaspin-2.1.0\n"]}], "execution_count": 5, "id": "p2IwKmpUGCyL", "metadata": {"outputId": "5519a511-a649-4d77-abf4-2199d4cbefdf", "id": "p2IwKmpUGCyL", "colab": {"base_uri": "https://localhost:8080/"}, "executionInfo": {"status": "ok", "timestamp": 1643887448708, "user": {"userId": "05569193017561439961", "displayName": "ai np", "photoUrl": "https://lh3.googleusercontent.com/a/default-user=s64"}, "user_tz": -330, "elapsed": 7604}}}, {"cell_type": "code", "source": ["import pandas as pd\n", "import yaml\n", "import argparse\n", "import torch\n", "from model import YNet"], "outputs": [], "execution_count": 32, "id": "determined-township", "metadata": {"id": "determined-township", "executionInfo": {"status": "ok", "timestamp": 1643896291509, "user": {"userId": "05569193017561439961", "displayName": "ai np", "photoUrl": "https://lh3.googleusercontent.com/a/default-user=s64"}, "user_tz": -330, "elapsed": 2548}}}, {"cell_type": "code", "source": ["%load_ext autoreload\n", "%autoreload 2"], "outputs": [{"output_type": "stream", "name": "stdout", "text": ["The autoreload extension is already loaded. To reload it, use:\n", "  %reload_ext autoreload\n"]}], "execution_count": 22, "id": "finished-potential", "metadata": {"outputId": "e784f9d4-fd28-43b5-8d1b-55ea916476be", "id": "finished-potential", "colab": {"base_uri": "https://localhost:8080/"}, "executionInfo": {"status": "ok", "timestamp": 1643891794574, "user": {"userId": "05569193017561439961", "displayName": "ai np", "photoUrl": "https://lh3.googleusercontent.com/a/default-user=s64"}, "user_tz": -330, "elapsed": 2426}}}, {"source": ["#### Some hyperparameters and settings"], "cell_type": "markdown", "id": "e567fbe0", "metadata": {"id": "e567fbe0"}}, {"cell_type": "code", "source": ["CONFIG_FILE_PATH = 'config/inD_longterm.yaml'  # yaml config file containing all the hyperparameters\n", "EXPERIMENT_NAME = 'inD_longterm'  # arbitrary name for this experiment\n", "DATASET_NAME = 'inD'\n", "\n", "TRAIN_DATA_PATH = 'data/inD/train.pkl'\n", "TRAIN_IMAGE_PATH = 'data/inD/train'\n", "VAL_DATA_PATH = 'data/inD/test.pkl'\n", "VAL_IMAGE_PATH = 'data/inD/test'\n", "OBS_LEN = 5  # in timesteps\n", "PRED_LEN = 30  # in timesteps\n", "NUM_GOALS = 20  # K_e\n", "NUM_TRAJ = 1  # K_a\n", "\n", "BATCH_SIZE = 4"], "outputs": [], "execution_count": 36, "id": "arabic-thickness", "metadata": {"id": "arabic-thickness", "executionInfo": {"status": "ok", "timestamp": 1643896417509, "user": {"userId": "05569193017561439961", "displayName": "ai np", "photoUrl": "https://lh3.googleusercontent.com/a/default-user=s64"}, "user_tz": -330, "elapsed": 2260}}}, {"cell_type": "code", "source": ["pip install -r requirements.txt"], "outputs": [{"output_type": "stream", "name": "stdout", "text": ["Collecting tqdm==4.48.0\n", "  Downloading tqdm-4.48.0-py2.py3-none-any.whl (67 kB)\n", "\u001b[K     |\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 67 kB 2.9 MB/s \n", "\u001b[?25hCollecting pyyaml==5.3.1\n", "  Downloading PyYAML-5.3.1.tar.gz (269 kB)\n", "\u001b[K     |\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 269 kB 10.6 MB/s \n", "\u001b[?25hRequirement already satisfied: matplotlib==3.2.2 in /usr/local/lib/python3.7/dist-packages (from -r requirements.txt (line 3)) (3.2.2)\n", "Collecting torch==1.5.1\n", "  Downloading torch-1.5.1-cp37-cp37m-manylinux1_x86_64.whl (753.2 MB)\n", "\u001b[K     |\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 753.2 MB 13 kB/s \n", "\u001b[?25hCollecting pandas==1.1.5\n", "  Downloading pandas-1.1.5-cp37-cp37m-manylinux1_x86_64.whl (9.5 MB)\n", "\u001b[K     |\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 9.5 MB 49.8 MB/s \n", "\u001b[?25hCollecting opencv-python==4.4.0.42\n", "  Downloading opencv_python-4.4.0.42-cp37-cp37m-manylinux2014_x86_64.whl (49.4 MB)\n", "\u001b[K     |\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 49.4 MB 274 kB/s \n", "\u001b[?25hCollecting scipy==1.5.0\n", "  Downloading scipy-1.5.0-cp37-cp37m-manylinux1_x86_64.whl (25.9 MB)\n", "\u001b[K     |\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 25.9 MB 1.2 MB/s \n", "\u001b[?25hCollecting segmentation_models_pytorch==0.1.0\n", "  Downloading segmentation_models_pytorch-0.1.0-py3-none-any.whl (42 kB)\n", "\u001b[K     |\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 42 kB 1.2 MB/s \n", "\u001b[?25hRequirement already satisfied: numpy>=1.11 in /usr/local/lib/python3.7/dist-packages (from matplotlib==3.2.2->-r requirements.txt (line 3)) (1.19.5)\n", "Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib==3.2.2->-r requirements.txt (line 3)) (3.0.7)\n", "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib==3.2.2->-r requirements.txt (line 3)) (0.11.0)\n", "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib==3.2.2->-r requirements.txt (line 3)) (1.3.2)\n", "Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib==3.2.2->-r requirements.txt (line 3)) (2.8.2)\n", "Requirement already satisfied: future in /usr/local/lib/python3.7/dist-packages (from torch==1.5.1->-r requirements.txt (line 4)) (0.16.0)\n", "Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.7/dist-packages (from pandas==1.1.5->-r requirements.txt (line 5)) (2018.9)\n", "Collecting pretrainedmodels==0.7.4\n", "  Downloading pretrainedmodels-0.7.4.tar.gz (58 kB)\n", "\u001b[K     |\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 58 kB 6.2 MB/s \n", "\u001b[?25hCollecting efficientnet-pytorch>=0.5.1\n", "  Downloading efficientnet_pytorch-0.7.1.tar.gz (21 kB)\n", "Requirement already satisfied: torchvision>=0.3.0 in /usr/local/lib/python3.7/dist-packages (from segmentation_models_pytorch==0.1.0->-r requirements.txt (line 8)) (0.11.1+cu111)\n", "Collecting munch\n", "  Downloading munch-2.5.0-py2.py3-none-any.whl (10 kB)\n", "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil>=2.1->matplotlib==3.2.2->-r requirements.txt (line 3)) (1.15.0)\n", "Requirement already satisfied: pillow!=8.3.0,>=5.3.0 in /usr/local/lib/python3.7/dist-packages (from torchvision>=0.3.0->segmentation_models_pytorch==0.1.0->-r requirements.txt (line 8)) (7.1.2)\n", "Collecting torchvision>=0.3.0\n", "  Downloading torchvision-0.11.3-cp37-cp37m-manylinux1_x86_64.whl (23.2 MB)\n", "\u001b[K     |\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 23.2 MB 142 kB/s \n", "\u001b[?25h  Downloading torchvision-0.11.2-cp37-cp37m-manylinux1_x86_64.whl (23.3 MB)\n", "\u001b[K     |\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 23.3 MB 1.2 MB/s \n", "\u001b[?25h  Downloading torchvision-0.11.1-cp37-cp37m-manylinux1_x86_64.whl (23.3 MB)\n", "\u001b[K     |\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 23.3 MB 441 kB/s \n", "\u001b[?25h  Downloading torchvision-0.10.1-cp37-cp37m-manylinux1_x86_64.whl (22.1 MB)\n", "\u001b[K     |\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 22.1 MB 317 kB/s \n", "\u001b[?25h  Downloading torchvision-0.10.0-cp37-cp37m-manylinux1_x86_64.whl (22.1 MB)\n", "\u001b[K     |\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 22.1 MB 369 kB/s \n", "\u001b[?25h  Downloading torchvision-0.9.1-cp37-cp37m-manylinux1_x86_64.whl (17.4 MB)\n", "\u001b[K     |\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 17.4 MB 78 kB/s \n", "\u001b[?25h  Downloading torchvision-0.9.0-cp37-cp37m-manylinux1_x86_64.whl (17.3 MB)\n", "\u001b[K     |\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 17.3 MB 27.5 MB/s \n", "\u001b[?25h  Downloading torchvision-0.8.2-cp37-cp37m-manylinux1_x86_64.whl (12.8 MB)\n", "\u001b[K     |\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 12.8 MB 24.2 MB/s \n", "\u001b[?25h  Downloading torchvision-0.8.1-cp37-cp37m-manylinux1_x86_64.whl (12.7 MB)\n", "\u001b[K     |\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 12.7 MB 33.6 MB/s \n", "\u001b[?25h  Downloading torchvision-0.8.0-cp37-cp37m-manylinux1_x86_64.whl (11.8 MB)\n", "\u001b[K     |\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 11.8 MB 54.5 MB/s \n", "\u001b[?25h  Downloading torchvision-0.7.0-cp37-cp37m-manylinux1_x86_64.whl (5.9 MB)\n", "\u001b[K     |\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 5.9 MB 30.5 MB/s \n", "\u001b[?25h  Downloading torchvision-0.6.1-cp37-cp37m-manylinux1_x86_64.whl (6.6 MB)\n", "\u001b[K     |\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 6.6 MB 53.3 MB/s \n", "\u001b[?25hBuilding wheels for collected packages: pyyaml, pretrainedmodels, efficientnet-pytorch\n", "  Building wheel for pyyaml (setup.py) ... \u001b[?25l\u001b[?25hdone\n", "  Created wheel for pyyaml: filename=PyYAML-5.3.1-cp37-cp37m-linux_x86_64.whl size=44636 sha256=3383aff5877165f93fb212abb28ccd905186324b0fcf71cc8d665502099dedda\n", "  Stored in directory: /root/.cache/pip/wheels/5e/03/1e/e1e954795d6f35dfc7b637fe2277bff021303bd9570ecea653\n", "  Building wheel for pretrainedmodels (setup.py) ... \u001b[?25l\u001b[?25hdone\n", "  Created wheel for pretrainedmodels: filename=pretrainedmodels-0.7.4-py3-none-any.whl size=60965 sha256=78dd3e11bcf475bdf65edc107ac7e39b05631bfcde92474022758293c5deea52\n", "  Stored in directory: /root/.cache/pip/wheels/ed/27/e8/9543d42de2740d3544db96aefef63bda3f2c1761b3334f4873\n", "  Building wheel for efficientnet-pytorch (setup.py) ... \u001b[?25l\u001b[?25hdone\n", "  Created wheel for efficientnet-pytorch: filename=efficientnet_pytorch-0.7.1-py3-none-any.whl size=16446 sha256=b2c7c32d2c7f7b1fbe50dd6a982141cadb376f2418a8e9e2a3ed3952e86287e1\n", "  Stored in directory: /root/.cache/pip/wheels/0e/cc/b2/49e74588263573ff778da58cc99b9c6349b496636a7e165be6\n", "Successfully built pyyaml pretrainedmodels efficientnet-pytorch\n", "Installing collected packages: torch, tqdm, torchvision, munch, pretrainedmodels, efficientnet-pytorch, segmentation-models-pytorch, scipy, pyyaml, pandas, opencv-python\n", "  Attempting uninstall: torch\n", "    Found existing installation: torch 1.10.0+cu111\n", "    Uninstalling torch-1.10.0+cu111:\n", "      Successfully uninstalled torch-1.10.0+cu111\n", "  Attempting uninstall: tqdm\n", "    Found existing installation: tqdm 4.62.3\n", "    Uninstalling tqdm-4.62.3:\n", "      Successfully uninstalled tqdm-4.62.3\n", "  Attempting uninstall: torchvision\n", "    Found existing installation: torchvision 0.11.1+cu111\n", "    Uninstalling torchvision-0.11.1+cu111:\n", "      Successfully uninstalled torchvision-0.11.1+cu111\n", "  Attempting uninstall: scipy\n", "    Found existing installation: scipy 1.4.1\n", "    Uninstalling scipy-1.4.1:\n", "      Successfully uninstalled scipy-1.4.1\n", "  Attempting uninstall: pyyaml\n", "    Found existing installation: PyYAML 3.13\n", "    Uninstalling PyYAML-3.13:\n", "      Successfully uninstalled PyYAML-3.13\n", "  Attempting uninstall: pandas\n", "    Found existing installation: pandas 1.3.5\n", "    Uninstalling pandas-1.3.5:\n", "      Successfully uninstalled pandas-1.3.5\n", "  Attempting uninstall: opencv-python\n", "    Found existing installation: opencv-python 4.1.2.30\n", "    Uninstalling opencv-python-4.1.2.30:\n", "      Successfully uninstalled opencv-python-4.1.2.30\n", "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", "torchtext 0.11.0 requires torch==1.10.0, but you have torch 1.5.1 which is incompatible.\n", "torchaudio 0.10.0+cu111 requires torch==1.10.0, but you have torch 1.5.1 which is incompatible.\n", "albumentations 0.1.12 requires imgaug<0.2.7,>=0.2.5, but you have imgaug 0.2.9 which is incompatible.\u001b[0m\n", "Successfully installed efficientnet-pytorch-0.7.1 munch-2.5.0 opencv-python-4.4.0.42 pandas-1.1.5 pretrainedmodels-0.7.4 pyyaml-5.3.1 scipy-1.5.0 segmentation-models-pytorch-0.1.0 torch-1.5.1 torchvision-0.6.1 tqdm-4.48.0\n"]}, {"output_type": "display_data", "data": {"application/vnd.colab-display-data+json": {"pip_warning": {"packages": ["cv2", "pandas", "torch", "tqdm", "yaml"]}}}, "metadata": {}}], "execution_count": 9, "id": "6BhBP_J5FhgB", "metadata": {"outputId": "f446badc-380b-496b-c001-d5e6d5e8b763", "id": "6BhBP_J5FhgB", "colab": {"base_uri": "https://localhost:8080/", "height": 1000}, "executionInfo": {"status": "ok", "timestamp": 1643887651066, "user": {"userId": "05569193017561439961", "displayName": "ai np", "photoUrl": "https://lh3.googleusercontent.com/a/default-user=s64"}, "user_tz": -330, "elapsed": 144818}}}, {"source": ["#### Load config file and print hyperparameters"], "cell_type": "markdown", "id": "2f729e8f", "metadata": {"id": "2f729e8f"}}, {"cell_type": "code", "source": ["with open(CONFIG_FILE_PATH) as file:\n", "    params = yaml.load(file, Loader=yaml.FullLoader)\n", "experiment_name = CONFIG_FILE_PATH.split('.yaml')[0].split('config/')[1]\n", "params"], "outputs": [{"output_type": "execute_result", "data": {"text/plain": ["{'CWS_params': {'ratio': 2, 'rot': True, 'sigma_factor': 6},\n", " 'batch_size': 8,\n", " 'decoder_channels': [64, 64, 64, 32, 32],\n", " 'encoder_channels': [32, 32, 64, 64, 64],\n", " 'kernlen': 31,\n", " 'learning_rate': 0.0001,\n", " 'loss_scale': 1000,\n", " 'nsig': 4,\n", " 'num_epochs': 300,\n", " 'rel_threshold': 0.002,\n", " 'resize': 0.33,\n", " 'segmentation_model_fp': 'segmentation_models/inD_segmentation.pth',\n", " 'semantic_classes': 6,\n", " 'temperature': 1.8,\n", " 'unfreeze': 100,\n", " 'use_CWS': True,\n", " 'use_TTST': False,\n", " 'use_features_only': False,\n", " 'viz_epoch': 10,\n", " 'waypoints': [14, 29]}"]}, "execution_count": 37, "metadata": {}}], "execution_count": 37, "id": "dangerous-cutting", "metadata": {"outputId": "dad55781-773d-4bf6-a6ee-104b8e0c9b24", "id": "dangerous-cutting", "colab": {"base_uri": "https://localhost:8080/"}, "executionInfo": {"status": "ok", "timestamp": 1643896425546, "user": {"userId": "05569193017561439961", "displayName": "ai np", "photoUrl": "https://lh3.googleusercontent.com/a/default-user=s64"}, "user_tz": -330, "elapsed": 2487}}}, {"source": ["#### Wandb INIT"], "cell_type": "markdown", "id": "699a7543", "metadata": {"id": "699a7543"}}, {"cell_type": "code", "source": [""], "outputs": [], "execution_count": null, "id": "65f560e7", "metadata": {"id": "65f560e7"}}, {"source": ["#### Load preprocessed Data"], "cell_type": "markdown", "id": "amber-pressure", "metadata": {"id": "amber-pressure"}}, {"cell_type": "code", "source": ["#df_train = pd.read_pickle(TRAIN_DATA_PATH)\n", "#df_val = pd.read_pickle(VAL_DATA_PATH)\n", "!pip3 install pickle5\n", "#df_train = pd.read_pickle(TRAIN_DATA_PATH)\n", "#df_val = pd.read_pickle(VAL_DATA_PATH)\n", "\n", "import pickle5 as pickle \n", "with open(TRAIN_DATA_PATH, \"rb\") as fh:\n", "    df_train = pickle.load(fh)\n", "with open(VAL_DATA_PATH, \"rb\") as fh1:\n", "    df_val = pickle.load(fh1)"], "outputs": [{"output_type": "stream", "name": "stdout", "text": ["Requirement already satisfied: pickle5 in /usr/local/lib/python3.7/dist-packages (0.0.12)\n"]}], "execution_count": 38, "id": "german-feature", "metadata": {"outputId": "4f0a44f5-5f83-4639-e230-82ee376befff", "id": "german-feature", "colab": {"base_uri": "https://localhost:8080/"}, "executionInfo": {"status": "ok", "timestamp": 1643896434089, "user": {"userId": "05569193017561439961", "displayName": "ai np", "photoUrl": "https://lh3.googleusercontent.com/a/default-user=s64"}, "user_tz": -330, "elapsed": 5989}}}, {"cell_type": "code", "source": ["df_train.head()"], "outputs": [{"output_type": "execute_result", "data": {"text/plain": ["   trackId  frame         x         y sceneId  metaId\n", "0       31   2217  25.07654   6.78323      07       0\n", "1       31   2242  26.11484   7.72170      07       0\n", "2       31   2267  27.05390   8.94723      07       0\n", "3       31   2292  28.08326  10.18219      07       0\n", "4       31   2317  29.08530  11.39276      07       0"], "text/html": ["\n", "  <div id=\"df-aac27958-db4d-45e7-8064-53ae75142648\">\n", "    <div class=\"colab-df-container\">\n", "      <div>\n", "<style scoped>\n", "    .dataframe tbody tr th:only-of-type {\n", "        vertical-align: middle;\n", "    }\n", "\n", "    .dataframe tbody tr th {\n", "        vertical-align: top;\n", "    }\n", "\n", "    .dataframe thead th {\n", "        text-align: right;\n", "    }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", "  <thead>\n", "    <tr style=\"text-align: right;\">\n", "      <th></th>\n", "      <th>trackId</th>\n", "      <th>frame</th>\n", "      <th>x</th>\n", "      <th>y</th>\n", "      <th>sceneId</th>\n", "      <th>metaId</th>\n", "    </tr>\n", "  </thead>\n", "  <tbody>\n", "    <tr>\n", "      <th>0</th>\n", "      <td>31</td>\n", "      <td>2217</td>\n", "      <td>25.07654</td>\n", "      <td>6.78323</td>\n", "      <td>07</td>\n", "      <td>0</td>\n", "    </tr>\n", "    <tr>\n", "      <th>1</th>\n", "      <td>31</td>\n", "      <td>2242</td>\n", "      <td>26.11484</td>\n", "      <td>7.72170</td>\n", "      <td>07</td>\n", "      <td>0</td>\n", "    </tr>\n", "    <tr>\n", "      <th>2</th>\n", "      <td>31</td>\n", "      <td>2267</td>\n", "      <td>27.05390</td>\n", "      <td>8.94723</td>\n", "      <td>07</td>\n", "      <td>0</td>\n", "    </tr>\n", "    <tr>\n", "      <th>3</th>\n", "      <td>31</td>\n", "      <td>2292</td>\n", "      <td>28.08326</td>\n", "      <td>10.18219</td>\n", "      <td>07</td>\n", "      <td>0</td>\n", "    </tr>\n", "    <tr>\n", "      <th>4</th>\n", "      <td>31</td>\n", "      <td>2317</td>\n", "      <td>29.08530</td>\n", "      <td>11.39276</td>\n", "      <td>07</td>\n", "      <td>0</td>\n", "    </tr>\n", "  </tbody>\n", "</table>\n", "</div>\n", "      <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-aac27958-db4d-45e7-8064-53ae75142648')\"\n", "              title=\"Convert this dataframe to an interactive table.\"\n", "              style=\"display:none;\">\n", "        \n", "  <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n", "       width=\"24px\">\n", "    <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n", "    <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n", "  </svg>\n", "      </button>\n", "      \n", "  <style>\n", "    .colab-df-container {\n", "      display:flex;\n", "      flex-wrap:wrap;\n", "      gap: 12px;\n", "    }\n", "\n", "    .colab-df-convert {\n", "      background-color: #E8F0FE;\n", "      border: none;\n", "      border-radius: 50%;\n", "      cursor: pointer;\n", "      display: none;\n", "      fill: #1967D2;\n", "      height: 32px;\n", "      padding: 0 0 0 0;\n", "      width: 32px;\n", "    }\n", "\n", "    .colab-df-convert:hover {\n", "      background-color: #E2EBFA;\n", "      box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n", "      fill: #174EA6;\n", "    }\n", "\n", "    [theme=dark] .colab-df-convert {\n", "      background-color: #3B4455;\n", "      fill: #D2E3FC;\n", "    }\n", "\n", "    [theme=dark] .colab-df-convert:hover {\n", "      background-color: #434B5C;\n", "      box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n", "      filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n", "      fill: #FFFFFF;\n", "    }\n", "  </style>\n", "\n", "      <script>\n", "        const buttonEl =\n", "          document.querySelector('#df-aac27958-db4d-45e7-8064-53ae75142648 button.colab-df-convert');\n", "        buttonEl.style.display =\n", "          google.colab.kernel.accessAllowed ? 'block' : 'none';\n", "\n", "        async function convertToInteractive(key) {\n", "          const element = document.querySelector('#df-aac27958-db4d-45e7-8064-53ae75142648');\n", "          const dataTable =\n", "            await google.colab.kernel.invokeFunction('convertToInteractive',\n", "                                                     [key], {});\n", "          if (!dataTable) return;\n", "\n", "          const docLinkHtml = 'Like what you see? Visit the ' +\n", "            '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n", "            + ' to learn more about interactive tables.';\n", "          element.innerHTML = '';\n", "          dataTable['output_type'] = 'display_data';\n", "          await google.colab.output.renderOutput(dataTable, element);\n", "          const docLink = document.createElement('div');\n", "          docLink.innerHTML = docLinkHtml;\n", "          element.appendChild(docLink);\n", "        }\n", "      </script>\n", "    </div>\n", "  </div>\n", "  "]}, "execution_count": 39, "metadata": {}}], "execution_count": 39, "id": "corporate-pharmacy", "metadata": {"outputId": "7b5b791a-e4be-4c2e-e096-8637dcf4881b", "id": "corporate-pharmacy", "colab": {"base_uri": "https://localhost:8080/", "height": 206}, "executionInfo": {"status": "ok", "timestamp": 1643896439680, "user": {"userId": "05569193017561439961", "displayName": "ai np", "photoUrl": "https://lh3.googleusercontent.com/a/default-user=s64"}, "user_tz": -330, "elapsed": 2142}}}, {"source": ["#### Initiate model"], "cell_type": "markdown", "id": "24dc5d7c", "metadata": {"id": "24dc5d7c"}}, {"cell_type": "code", "source": ["model = YNet(obs_len=OBS_LEN, pred_len=PRED_LEN, params=params)"], "outputs": [{"output_type": "stream", "name": "stderr", "text": ["/usr/local/lib/python3.7/dist-packages/torch/serialization.py:657: SourceChangeWarning: source code of class 'segmentation_models_pytorch.encoders.resnet.ResNetEncoder' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n", "  warnings.warn(msg, SourceChangeWarning)\n", "/usr/local/lib/python3.7/dist-packages/torch/serialization.py:657: SourceChangeWarning: source code of class 'segmentation_models_pytorch.base.modules.Conv2dReLU' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n", "  warnings.warn(msg, SourceChangeWarning)\n", "/usr/local/lib/python3.7/dist-packages/torch/serialization.py:657: SourceChangeWarning: source code of class 'segmentation_models_pytorch.base.modules.Activation' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n", "  warnings.warn(msg, SourceChangeWarning)\n"]}], "execution_count": 40, "id": "harmful-colleague", "metadata": {"outputId": "29590f8e-9506-464d-f586-5b9cc2b9adfe", "id": "harmful-colleague", "colab": {"base_uri": "https://localhost:8080/"}, "executionInfo": {"status": "ok", "timestamp": 1643896448007, "user": {"userId": "05569193017561439961", "displayName": "ai np", "photoUrl": "https://lh3.googleusercontent.com/a/default-user=s64"}, "user_tz": -330, "elapsed": 4727}}}, {"source": ["#### Start training\n", "Note, the Val ADE and FDE are without TTST and CWS to save time. Therefore, the numbers will be worse than the final values."], "cell_type": "markdown", "id": "45e099fe", "metadata": {"id": "45e099fe"}}, {"cell_type": "code", "source": ["import weights_and_biases as wandb\n", "wandb.init_wandb(params.copy(), model.model)"], "outputs": [{"output_type": "display_data", "data": {"text/plain": ["<IPython.core.display.HTML object>"], "text/html": ["Finishing last run (ID:2jz9cve3) before initializing another..."]}, "metadata": {}}], "execution_count": null, "id": "suiRZodUeLO4", "metadata": {"outputId": "73d84102-a70a-476a-dc6b-a7200f6bd687", "id": "suiRZodUeLO4", "colab": {"base_uri": "https://localhost:8080/"}}}, {"cell_type": "code", "source": ["model.load('/content/drive/MyDrive/Human-Path-Prediction-master/ynet/pretrained_models/fg/sdd_longterm_weights.pt')"], "outputs": [{"output_type": "stream", "name": "stdout", "text": ["<All keys matched successfully>\n"]}], "execution_count": 29, "id": "uA_h9s9eZwE4", "metadata": {"outputId": "91cca0c7-c091-48f3-a803-6b0940217f27", "id": "uA_h9s9eZwE4", "colab": {"base_uri": "https://localhost:8080/"}, "executionInfo": {"status": "ok", "timestamp": 1643891850533, "user": {"userId": "05569193017561439961", "displayName": "ai np", "photoUrl": "https://lh3.googleusercontent.com/a/default-user=s64"}, "user_tz": -330, "elapsed": 2312}}}, {"cell_type": "code", "source": ["model.evaluate(df_val, params, image_path='data/SDD/test',\n", "               batch_size=BATCH_SIZE, rounds=3, \n", "               num_goals=NUM_GOALS, num_traj=NUM_TRAJ, device=None, dataset_name=DATASET_NAME)"], "outputs": [{"output_type": "stream", "name": "stdout", "text": ["Preprocess data\n"]}, {"output_type": "stream", "name": "stderr", "text": ["Prepare Dataset: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 11/11 [00:00<00:00, 798.10it/s]\n", "/content/drive/MyDrive/Human-Path-Prediction-master/ynet/utils/dataloader.py:38: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray\n", "  return np.array(trajectories), meta, scene_list\n", "Round:   0%|          | 0/3 [00:00<?, ?it/s]"]}, {"output_type": "stream", "name": "stdout", "text": ["(288, 512, 3)\n", "(288, 512, 3)\n", "(512, 352, 3)\n", "(512, 384, 3)\n", "(512, 384, 3)\n", "(192, 384, 3)\n", "(512, 384, 3)\n", "(512, 352, 3)\n", "(512, 352, 3)\n", "(512, 384, 3)\n", "(512, 352, 3)\n", "Start testing\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\rRound:  33%|\u2588\u2588\u2588\u258e      | 1/3 [05:59<11:59, 359.73s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Round 0: \n", "Test ADE: 46.681884765625 \n", "Test FDE: 62.60225296020508\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\rRound:  67%|\u2588\u2588\u2588\u2588\u2588\u2588\u258b   | 2/3 [11:55<05:58, 358.45s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Round 1: \n", "Test ADE: 46.647457122802734 \n", "Test FDE: 64.13143920898438\n"]}, {"output_type": "stream", "name": "stderr", "text": ["Round: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 3/3 [17:56<00:00, 358.97s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Round 2: \n", "Test ADE: 46.869285583496094 \n", "Test FDE: 61.698974609375\n", "\n", "\n", "Average performance over 3 rounds: \n", "Test ADE: 46.73287582397461 \n", "Test FDE: 62.81088892618815\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\n"]}], "execution_count": 30, "id": "wng62LWfZ0Ap", "metadata": {"outputId": "e30e8811-f9b1-434c-c4bd-3066aade868e", "id": "wng62LWfZ0Ap", "colab": {"base_uri": "https://localhost:8080/"}, "executionInfo": {"status": "ok", "timestamp": 1643892932197, "user": {"userId": "05569193017561439961", "displayName": "ai np", "photoUrl": "https://lh3.googleusercontent.com/a/default-user=s64"}, "user_tz": -330, "elapsed": 1079825}}}, {"cell_type": "code", "source": ["model.train(df_train, df_val, params, train_image_path=TRAIN_IMAGE_PATH, val_image_path=VAL_IMAGE_PATH, \n", "            experiment_name=EXPERIMENT_NAME, batch_size=BATCH_SIZE, num_goals=NUM_GOALS, num_traj=NUM_TRAJ, \n", "            device=None, dataset_name= 'sdd')"], "outputs": [{"output_type": "stream", "text": ["Preprocess data\n"], "name": "stdout", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["Prepare Dataset: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 184/184 [00:00<00:00, 770.99it/s]\n", "/content/drive/MyDrive/Human-Path-Prediction-master (1)/ynet/utils/dataloader.py:38: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray\n", "  return np.array(trajectories), meta, scene_list\n", "Prepare Dataset: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 11/11 [00:00<00:00, 682.28it/s]\n", "Epoch:   0%|          | 0/300 [00:00<?, ?it/s]"], "name": "stderr", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["Start training\n", "184\n", "Epoch 0: \n", "Val ADE: 301.7619323730469 \n", "Val FDE: 255.6228485107422\n", "Best Epoch 0: \n", "Val ADE: 301.7619323730469 \n", "Val FDE: 255.6228485107422\n"], "name": "stdout", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m Saving files without folders. If you want to preserve sub directories pass base_path to wandb.save, i.e. wandb.save(\"/mnt/folder/file.h5\", base_path=\"/mnt\")\n", "Epoch:   0%|          | 1/300 [06:43<33:28:45, 403.09s/it]"], "name": "stderr", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["Saved model to: saved_models/Ynet.pt\n", "Epoch 1: \n", "Val ADE: 201.84591674804688 \n", "Val FDE: 151.03102111816406\n", "Best Epoch 1: \n", "Val ADE: 201.84591674804688 \n", "Val FDE: 151.03102111816406\n"], "name": "stdout", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m Saving files without folders. If you want to preserve sub directories pass base_path to wandb.save, i.e. wandb.save(\"/mnt/folder/file.h5\", base_path=\"/mnt\")\n", "Epoch:   1%|          | 2/300 [13:26<33:22:16, 403.14s/it]"], "name": "stderr", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["Saved model to: saved_models/Ynet.pt\n", "Epoch 2: \n", "Val ADE: 193.38035583496094 \n", "Val FDE: 148.112060546875\n", "Best Epoch 2: \n", "Val ADE: 193.38035583496094 \n", "Val FDE: 148.112060546875\n"], "name": "stdout", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m Saving files without folders. If you want to preserve sub directories pass base_path to wandb.save, i.e. wandb.save(\"/mnt/folder/file.h5\", base_path=\"/mnt\")\n", "Epoch:   1%|          | 3/300 [20:09<33:15:30, 403.13s/it]"], "name": "stderr", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["Saved model to: saved_models/Ynet.pt\n", "Epoch 3: \n", "Val ADE: 181.08123779296875 \n", "Val FDE: 138.97171020507812\n", "Best Epoch 3: \n", "Val ADE: 181.08123779296875 \n", "Val FDE: 138.97171020507812\n"], "name": "stdout", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m Saving files without folders. If you want to preserve sub directories pass base_path to wandb.save, i.e. wandb.save(\"/mnt/folder/file.h5\", base_path=\"/mnt\")\n", "Epoch:   1%|\u258f         | 4/300 [26:52<33:08:43, 403.12s/it]"], "name": "stderr", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["Saved model to: saved_models/Ynet.pt\n"], "name": "stdout", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["\rEpoch:   2%|\u258f         | 5/300 [33:34<33:00:42, 402.86s/it]"], "name": "stderr", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["Epoch 4: \n", "Val ADE: 182.64474487304688 \n", "Val FDE: 137.41041564941406\n", "Epoch 5: \n", "Val ADE: 173.60455322265625 \n", "Val FDE: 131.7052001953125\n", "Best Epoch 5: \n", "Val ADE: 173.60455322265625 \n", "Val FDE: 131.7052001953125\n"], "name": "stdout", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m Saving files without folders. If you want to preserve sub directories pass base_path to wandb.save, i.e. wandb.save(\"/mnt/folder/file.h5\", base_path=\"/mnt\")\n", "Epoch:   2%|\u258f         | 6/300 [40:17<32:54:04, 402.87s/it]"], "name": "stderr", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["Saved model to: saved_models/Ynet.pt\n", "Epoch 6: \n", "Val ADE: 122.37327575683594 \n", "Val FDE: 137.73631286621094\n", "Best Epoch 6: \n", "Val ADE: 122.37327575683594 \n", "Val FDE: 137.73631286621094\n"], "name": "stdout", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m Saving files without folders. If you want to preserve sub directories pass base_path to wandb.save, i.e. wandb.save(\"/mnt/folder/file.h5\", base_path=\"/mnt\")\n", "Epoch:   2%|\u258f         | 7/300 [47:00<32:47:14, 402.85s/it]"], "name": "stderr", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["Saved model to: saved_models/Ynet.pt\n", "Epoch 7: \n", "Val ADE: 89.133056640625 \n", "Val FDE: 92.61986541748047\n", "Best Epoch 7: \n", "Val ADE: 89.133056640625 \n", "Val FDE: 92.61986541748047\n"], "name": "stdout", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m Saving files without folders. If you want to preserve sub directories pass base_path to wandb.save, i.e. wandb.save(\"/mnt/folder/file.h5\", base_path=\"/mnt\")\n", "Epoch:   3%|\u258e         | 8/300 [53:43<32:40:30, 402.85s/it]"], "name": "stderr", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["Saved model to: saved_models/Ynet.pt\n", "Epoch 8: \n", "Val ADE: 86.85653686523438 \n", "Val FDE: 85.07341766357422\n", "Best Epoch 8: \n", "Val ADE: 86.85653686523438 \n", "Val FDE: 85.07341766357422\n"], "name": "stdout", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m Saving files without folders. If you want to preserve sub directories pass base_path to wandb.save, i.e. wandb.save(\"/mnt/folder/file.h5\", base_path=\"/mnt\")\n", "Epoch:   3%|\u258e         | 9/300 [1:00:26<32:33:57, 402.88s/it]"], "name": "stderr", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["Saved model to: saved_models/Ynet.pt\n"], "name": "stdout", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["\rEpoch:   3%|\u258e         | 10/300 [1:07:08<32:26:10, 402.66s/it]"], "name": "stderr", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["Epoch 9: \n", "Val ADE: 88.74711608886719 \n", "Val FDE: 90.28411102294922\n", "Epoch 10: \n", "Val ADE: 85.20935821533203 \n", "Val FDE: 90.6670150756836\n", "Best Epoch 10: \n", "Val ADE: 85.20935821533203 \n", "Val FDE: 90.6670150756836\n"], "name": "stdout", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m Saving files without folders. If you want to preserve sub directories pass base_path to wandb.save, i.e. wandb.save(\"/mnt/folder/file.h5\", base_path=\"/mnt\")\n", "Epoch:   4%|\u258e         | 11/300 [1:13:51<32:19:37, 402.69s/it]"], "name": "stderr", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["Saved model to: saved_models/Ynet.pt\n"], "name": "stdout", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["\rEpoch:   4%|\u258d         | 12/300 [1:20:33<32:12:03, 402.51s/it]"], "name": "stderr", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["Epoch 11: \n", "Val ADE: 86.12983703613281 \n", "Val FDE: 85.69186401367188\n", "Epoch 12: \n", "Val ADE: 82.30138397216797 \n", "Val FDE: 93.57181549072266\n", "Best Epoch 12: \n", "Val ADE: 82.30138397216797 \n", "Val FDE: 93.57181549072266\n"], "name": "stdout", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m Saving files without folders. If you want to preserve sub directories pass base_path to wandb.save, i.e. wandb.save(\"/mnt/folder/file.h5\", base_path=\"/mnt\")\n", "Epoch:   4%|\u258d         | 13/300 [1:27:16<32:05:50, 402.61s/it]"], "name": "stderr", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["Saved model to: saved_models/Ynet.pt\n", "Epoch 13: \n", "Val ADE: 80.12541961669922 \n", "Val FDE: 96.78948974609375\n", "Best Epoch 13: \n", "Val ADE: 80.12541961669922 \n", "Val FDE: 96.78948974609375\n"], "name": "stdout", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m Saving files without folders. If you want to preserve sub directories pass base_path to wandb.save, i.e. wandb.save(\"/mnt/folder/file.h5\", base_path=\"/mnt\")\n", "Epoch:   5%|\u258d         | 14/300 [1:33:58<31:59:25, 402.68s/it]"], "name": "stderr", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["Saved model to: saved_models/Ynet.pt\n", "Epoch 14: \n", "Val ADE: 76.92031860351562 \n", "Val FDE: 90.86177062988281\n", "Best Epoch 14: \n", "Val ADE: 76.92031860351562 \n", "Val FDE: 90.86177062988281\n"], "name": "stdout", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m Saving files without folders. If you want to preserve sub directories pass base_path to wandb.save, i.e. wandb.save(\"/mnt/folder/file.h5\", base_path=\"/mnt\")\n", "Epoch:   5%|\u258c         | 15/300 [1:40:41<31:53:07, 402.76s/it]"], "name": "stderr", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["Saved model to: saved_models/Ynet.pt\n"], "name": "stdout", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["\rEpoch:   5%|\u258c         | 16/300 [1:47:24<31:45:28, 402.57s/it]"], "name": "stderr", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["Epoch 15: \n", "Val ADE: 79.80490112304688 \n", "Val FDE: 94.581787109375\n"], "name": "stdout", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["\rEpoch:   6%|\u258c         | 17/300 [1:54:06<31:38:09, 402.44s/it]"], "name": "stderr", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["Epoch 16: \n", "Val ADE: 82.22650909423828 \n", "Val FDE: 98.23641967773438\n", "Epoch 17: \n", "Val ADE: 74.87488555908203 \n", "Val FDE: 92.14183044433594\n", "Best Epoch 17: \n", "Val ADE: 74.87488555908203 \n", "Val FDE: 92.14183044433594\n"], "name": "stdout", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m Saving files without folders. If you want to preserve sub directories pass base_path to wandb.save, i.e. wandb.save(\"/mnt/folder/file.h5\", base_path=\"/mnt\")\n", "Epoch:   6%|\u258c         | 18/300 [2:00:48<31:31:46, 402.51s/it]"], "name": "stderr", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["Saved model to: saved_models/Ynet.pt\n"], "name": "stdout", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["\rEpoch:   6%|\u258b         | 19/300 [2:07:30<31:24:18, 402.34s/it]"], "name": "stderr", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["Epoch 18: \n", "Val ADE: 78.0372085571289 \n", "Val FDE: 91.91835021972656\n", "Epoch 19: \n", "Val ADE: 72.20713806152344 \n", "Val FDE: 85.36895751953125\n", "Best Epoch 19: \n", "Val ADE: 72.20713806152344 \n", "Val FDE: 85.36895751953125\n"], "name": "stdout", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m Saving files without folders. If you want to preserve sub directories pass base_path to wandb.save, i.e. wandb.save(\"/mnt/folder/file.h5\", base_path=\"/mnt\")\n", "Epoch:   7%|\u258b         | 20/300 [2:14:13<31:18:05, 402.45s/it]"], "name": "stderr", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["Saved model to: saved_models/Ynet.pt\n", "Epoch 20: \n", "Val ADE: 71.57848358154297 \n", "Val FDE: 90.71351623535156\n", "Best Epoch 20: \n", "Val ADE: 71.57848358154297 \n", "Val FDE: 90.71351623535156\n"], "name": "stdout", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m Saving files without folders. If you want to preserve sub directories pass base_path to wandb.save, i.e. wandb.save(\"/mnt/folder/file.h5\", base_path=\"/mnt\")\n", "Epoch:   7%|\u258b         | 21/300 [2:20:56<31:11:43, 402.52s/it]"], "name": "stderr", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["Saved model to: saved_models/Ynet.pt\n", "Epoch 21: \n", "Val ADE: 68.15691375732422 \n", "Val FDE: 88.47090148925781\n", "Best Epoch 21: \n", "Val ADE: 68.15691375732422 \n", "Val FDE: 88.47090148925781\n"], "name": "stdout", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m Saving files without folders. If you want to preserve sub directories pass base_path to wandb.save, i.e. wandb.save(\"/mnt/folder/file.h5\", base_path=\"/mnt\")\n", "Epoch:   7%|\u258b         | 22/300 [2:27:38<31:05:13, 402.57s/it]"], "name": "stderr", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["Saved model to: saved_models/Ynet.pt\n"], "name": "stdout", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["\rEpoch:   8%|\u258a         | 23/300 [2:34:20<30:57:46, 402.40s/it]"], "name": "stderr", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["Epoch 22: \n", "Val ADE: 75.18169403076172 \n", "Val FDE: 89.41007232666016\n"], "name": "stdout", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["\rEpoch:   8%|\u258a         | 24/300 [2:41:02<30:50:31, 402.29s/it]"], "name": "stderr", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["Epoch 23: \n", "Val ADE: 76.58006286621094 \n", "Val FDE: 94.20979309082031\n"], "name": "stdout", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["\rEpoch:   8%|\u258a         | 25/300 [2:47:45<30:43:37, 402.24s/it]"], "name": "stderr", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["Epoch 24: \n", "Val ADE: 72.92988586425781 \n", "Val FDE: 87.30663299560547\n"], "name": "stdout", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["\rEpoch:   9%|\u258a         | 26/300 [2:54:27<30:36:55, 402.25s/it]"], "name": "stderr", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["Epoch 25: \n", "Val ADE: 78.54552459716797 \n", "Val FDE: 91.18190002441406\n"], "name": "stdout", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["\rEpoch:   9%|\u2589         | 27/300 [3:01:09<30:30:28, 402.30s/it]"], "name": "stderr", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["Epoch 26: \n", "Val ADE: 75.70587921142578 \n", "Val FDE: 91.01994323730469\n"], "name": "stdout", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["\rEpoch:   9%|\u2589         | 28/300 [3:07:52<30:23:52, 402.32s/it]"], "name": "stderr", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["Epoch 27: \n", "Val ADE: 76.4176254272461 \n", "Val FDE: 88.73723602294922\n"], "name": "stdout", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["\rEpoch:  10%|\u2589         | 29/300 [3:14:34<30:17:21, 402.37s/it]"], "name": "stderr", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["Epoch 28: \n", "Val ADE: 71.9253921508789 \n", "Val FDE: 87.2762222290039\n"], "name": "stdout", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["\rEpoch:  10%|\u2588         | 30/300 [3:21:16<30:10:35, 402.35s/it]"], "name": "stderr", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["Epoch 29: \n", "Val ADE: 78.49124908447266 \n", "Val FDE: 92.36381530761719\n"], "name": "stdout", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["\rEpoch:  10%|\u2588         | 31/300 [3:27:58<30:03:20, 402.23s/it]"], "name": "stderr", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["Epoch 30: \n", "Val ADE: 71.29315185546875 \n", "Val FDE: 89.93594360351562\n"], "name": "stdout", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["\rEpoch:  11%|\u2588         | 32/300 [3:34:40<29:56:13, 402.14s/it]"], "name": "stderr", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["Epoch 31: \n", "Val ADE: 72.03828430175781 \n", "Val FDE: 95.21200561523438\n"], "name": "stdout", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["\rEpoch:  11%|\u2588         | 33/300 [3:41:22<29:49:23, 402.11s/it]"], "name": "stderr", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["Epoch 32: \n", "Val ADE: 75.97977447509766 \n", "Val FDE: 85.79773712158203\n"], "name": "stdout", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["\rEpoch:  11%|\u2588\u258f        | 34/300 [3:48:04<29:42:33, 402.08s/it]"], "name": "stderr", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["Epoch 33: \n", "Val ADE: 81.12572479248047 \n", "Val FDE: 99.3475570678711\n", "Epoch 34: \n", "Val ADE: 66.18566131591797 \n", "Val FDE: 83.22686004638672\n", "Best Epoch 34: \n", "Val ADE: 66.18566131591797 \n", "Val FDE: 83.22686004638672\n"], "name": "stdout", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m Saving files without folders. If you want to preserve sub directories pass base_path to wandb.save, i.e. wandb.save(\"/mnt/folder/file.h5\", base_path=\"/mnt\")\n", "Epoch:  12%|\u2588\u258f        | 35/300 [3:54:47<29:36:39, 402.26s/it]"], "name": "stderr", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["Saved model to: saved_models/Ynet.pt\n"], "name": "stdout", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["\rEpoch:  12%|\u2588\u258f        | 36/300 [4:01:29<29:29:37, 402.19s/it]"], "name": "stderr", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["Epoch 35: \n", "Val ADE: 68.09703063964844 \n", "Val FDE: 81.17948913574219\n"], "name": "stdout", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["\rEpoch:  12%|\u2588\u258f        | 37/300 [4:08:11<29:22:33, 402.11s/it]"], "name": "stderr", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["Epoch 36: \n", "Val ADE: 70.6156234741211 \n", "Val FDE: 85.33181762695312\n"], "name": "stdout", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["\rEpoch:  13%|\u2588\u258e        | 38/300 [4:14:53<29:15:35, 402.04s/it]"], "name": "stderr", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["Epoch 37: \n", "Val ADE: 71.63475799560547 \n", "Val FDE: 92.97126007080078\n"], "name": "stdout", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["\rEpoch:  13%|\u2588\u258e        | 39/300 [4:21:35<29:08:56, 402.06s/it]"], "name": "stderr", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["Epoch 38: \n", "Val ADE: 74.99298858642578 \n", "Val FDE: 93.91376495361328\n"], "name": "stdout", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["\rEpoch:  13%|\u2588\u258e        | 40/300 [4:28:17<29:02:14, 402.05s/it]"], "name": "stderr", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["Epoch 39: \n", "Val ADE: 71.88893127441406 \n", "Val FDE: 82.456787109375\n"], "name": "stdout", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["\rEpoch:  14%|\u2588\u258e        | 41/300 [4:34:59<28:55:24, 402.02s/it]"], "name": "stderr", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["Epoch 40: \n", "Val ADE: 67.46001434326172 \n", "Val FDE: 81.40061950683594\n"], "name": "stdout", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["\rEpoch:  14%|\u2588\u258d        | 42/300 [4:41:41<28:48:45, 402.04s/it]"], "name": "stderr", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["Epoch 41: \n", "Val ADE: 69.84492492675781 \n", "Val FDE: 86.85179901123047\n"], "name": "stdout", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["\rEpoch:  14%|\u2588\u258d        | 43/300 [4:48:23<28:41:47, 401.98s/it]"], "name": "stderr", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["Epoch 42: \n", "Val ADE: 72.41036987304688 \n", "Val FDE: 87.54203796386719\n"], "name": "stdout", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["\rEpoch:  15%|\u2588\u258d        | 44/300 [4:55:05<28:35:00, 401.96s/it]"], "name": "stderr", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["Epoch 43: \n", "Val ADE: 70.14878845214844 \n", "Val FDE: 90.79217529296875\n"], "name": "stdout", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["\rEpoch:  15%|\u2588\u258c        | 45/300 [5:01:47<28:28:16, 401.95s/it]"], "name": "stderr", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["Epoch 44: \n", "Val ADE: 74.57258605957031 \n", "Val FDE: 90.74398803710938\n"], "name": "stdout", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["\rEpoch:  15%|\u2588\u258c        | 46/300 [5:08:29<28:21:45, 401.99s/it]"], "name": "stderr", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["Epoch 45: \n", "Val ADE: 71.58636474609375 \n", "Val FDE: 86.26039123535156\n"], "name": "stdout", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["\rEpoch:  16%|\u2588\u258c        | 47/300 [5:15:11<28:15:29, 402.09s/it]"], "name": "stderr", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["Epoch 46: \n", "Val ADE: 72.58191680908203 \n", "Val FDE: 89.1121597290039\n"], "name": "stdout", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["\rEpoch:  16%|\u2588\u258c        | 48/300 [5:21:53<28:08:51, 402.11s/it]"], "name": "stderr", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["Epoch 47: \n", "Val ADE: 72.31369018554688 \n", "Val FDE: 90.41600799560547\n"], "name": "stdout", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["\rEpoch:  16%|\u2588\u258b        | 49/300 [5:28:36<28:02:28, 402.18s/it]"], "name": "stderr", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["Epoch 48: \n", "Val ADE: 69.5880126953125 \n", "Val FDE: 82.61445617675781\n"], "name": "stdout", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["\rEpoch:  17%|\u2588\u258b        | 50/300 [5:35:18<27:55:43, 402.17s/it]"], "name": "stderr", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["Epoch 49: \n", "Val ADE: 71.56086730957031 \n", "Val FDE: 84.25231170654297\n"], "name": "stdout", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["\rEpoch:  17%|\u2588\u258b        | 51/300 [5:42:00<27:49:08, 402.20s/it]"], "name": "stderr", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["Epoch 50: \n", "Val ADE: 74.74462890625 \n", "Val FDE: 93.85311889648438\n"], "name": "stdout", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["\rEpoch:  17%|\u2588\u258b        | 52/300 [5:48:42<27:42:35, 402.24s/it]"], "name": "stderr", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["Epoch 51: \n", "Val ADE: 71.78186798095703 \n", "Val FDE: 83.74314880371094\n"], "name": "stdout", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["\rEpoch:  18%|\u2588\u258a        | 53/300 [5:55:25<27:35:52, 402.24s/it]"], "name": "stderr", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["Epoch 52: \n", "Val ADE: 74.55719757080078 \n", "Val FDE: 90.2347412109375\n"], "name": "stdout", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["\rEpoch:  18%|\u2588\u258a        | 54/300 [6:02:06<27:28:46, 402.14s/it]"], "name": "stderr", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["Epoch 53: \n", "Val ADE: 70.1415786743164 \n", "Val FDE: 89.1491928100586\n"], "name": "stdout", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["\rEpoch:  18%|\u2588\u258a        | 55/300 [6:08:48<27:21:53, 402.09s/it]"], "name": "stderr", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["Epoch 54: \n", "Val ADE: 69.09719848632812 \n", "Val FDE: 85.7782974243164\n"], "name": "stdout", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["\rEpoch:  19%|\u2588\u258a        | 56/300 [6:15:31<27:15:07, 402.08s/it]"], "name": "stderr", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["Epoch 55: \n", "Val ADE: 69.2908935546875 \n", "Val FDE: 88.9847640991211\n"], "name": "stdout", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["\rEpoch:  19%|\u2588\u2589        | 57/300 [6:22:12<27:08:16, 402.04s/it]"], "name": "stderr", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["Epoch 56: \n", "Val ADE: 71.8873519897461 \n", "Val FDE: 89.47559356689453\n"], "name": "stdout", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["\rEpoch:  19%|\u2588\u2589        | 58/300 [6:28:55<27:01:38, 402.06s/it]"], "name": "stderr", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["Epoch 57: \n", "Val ADE: 72.58536529541016 \n", "Val FDE: 87.47309875488281\n"], "name": "stdout", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["\rEpoch:  20%|\u2588\u2589        | 59/300 [6:35:36<26:54:41, 402.00s/it]"], "name": "stderr", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["Epoch 58: \n", "Val ADE: 70.79429626464844 \n", "Val FDE: 90.87911224365234\n"], "name": "stdout", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["\rEpoch:  20%|\u2588\u2588        | 60/300 [6:42:18<26:47:50, 401.96s/it]"], "name": "stderr", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["Epoch 59: \n", "Val ADE: 73.35765075683594 \n", "Val FDE: 92.65557861328125\n"], "name": "stdout", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["\rEpoch:  20%|\u2588\u2588        | 61/300 [6:49:00<26:41:09, 401.96s/it]"], "name": "stderr", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["Epoch 60: \n", "Val ADE: 73.51614379882812 \n", "Val FDE: 93.17198944091797\n"], "name": "stdout", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["\rEpoch:  21%|\u2588\u2588        | 62/300 [6:55:42<26:34:28, 401.97s/it]"], "name": "stderr", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["Epoch 61: \n", "Val ADE: 72.31851196289062 \n", "Val FDE: 91.13289642333984\n"], "name": "stdout", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["\rEpoch:  21%|\u2588\u2588        | 63/300 [7:02:24<26:27:48, 401.98s/it]"], "name": "stderr", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["Epoch 62: \n", "Val ADE: 72.23517608642578 \n", "Val FDE: 98.02088165283203\n"], "name": "stdout", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["\rEpoch:  21%|\u2588\u2588\u258f       | 64/300 [7:09:06<26:21:09, 401.99s/it]"], "name": "stderr", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["Epoch 63: \n", "Val ADE: 74.33807373046875 \n", "Val FDE: 95.19650268554688\n"], "name": "stdout", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["\rEpoch:  22%|\u2588\u2588\u258f       | 65/300 [7:15:48<26:14:43, 402.06s/it]"], "name": "stderr", "metadata": {"tags": null}}, {"output_type": "stream", "text": ["Epoch 64: \n", "Val ADE: 72.4472427368164 \n", "Val FDE: 95.61434936523438\n"], "name": "stdout", "metadata": {"tags": null}}, {"output_type": "stream", "name": "stderr", "text": ["\rEpoch:  22%|\u2588\u2588\u258f       | 66/300 [7:22:31<26:08:12, 402.10s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Epoch 65: \n", "Val ADE: 74.37706756591797 \n", "Val FDE: 99.89218139648438\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\rEpoch:  22%|\u2588\u2588\u258f       | 67/300 [7:29:13<26:01:26, 402.09s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Epoch 66: \n", "Val ADE: 72.69349670410156 \n", "Val FDE: 91.93955993652344\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\rEpoch:  23%|\u2588\u2588\u258e       | 68/300 [7:35:55<25:54:38, 402.06s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Epoch 67: \n", "Val ADE: 74.20513153076172 \n", "Val FDE: 97.18782043457031\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\rEpoch:  23%|\u2588\u2588\u258e       | 69/300 [7:42:37<25:48:03, 402.09s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Epoch 68: \n", "Val ADE: 75.07601165771484 \n", "Val FDE: 94.36964416503906\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\rEpoch:  23%|\u2588\u2588\u258e       | 70/300 [7:49:19<25:41:17, 402.08s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Epoch 69: \n", "Val ADE: 85.26095581054688 \n", "Val FDE: 100.542236328125\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\rEpoch:  24%|\u2588\u2588\u258e       | 71/300 [7:56:01<25:34:59, 402.18s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Epoch 70: \n", "Val ADE: 74.46408081054688 \n", "Val FDE: 93.24795532226562\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\rEpoch:  24%|\u2588\u2588\u258d       | 72/300 [8:02:43<25:28:07, 402.14s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Epoch 71: \n", "Val ADE: 74.33038330078125 \n", "Val FDE: 97.97826385498047\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\rEpoch:  24%|\u2588\u2588\u258d       | 73/300 [8:09:26<25:21:30, 402.16s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Epoch 72: \n", "Val ADE: 73.51270294189453 \n", "Val FDE: 87.97372436523438\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\rEpoch:  25%|\u2588\u2588\u258d       | 74/300 [8:16:08<25:14:56, 402.20s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Epoch 73: \n", "Val ADE: 74.89344787597656 \n", "Val FDE: 99.4470443725586\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\rEpoch:  25%|\u2588\u2588\u258c       | 75/300 [8:22:50<25:08:29, 402.27s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Epoch 74: \n", "Val ADE: 75.96234130859375 \n", "Val FDE: 96.75408172607422\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\rEpoch:  25%|\u2588\u2588\u258c       | 76/300 [8:29:33<25:01:58, 402.32s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Epoch 75: \n", "Val ADE: 75.01667022705078 \n", "Val FDE: 92.12223815917969\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\rEpoch:  26%|\u2588\u2588\u258c       | 77/300 [8:36:15<24:55:00, 402.25s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Epoch 76: \n", "Val ADE: 75.16649627685547 \n", "Val FDE: 97.27088928222656\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\rEpoch:  26%|\u2588\u2588\u258c       | 78/300 [8:42:57<24:47:51, 402.12s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Epoch 77: \n", "Val ADE: 72.44387817382812 \n", "Val FDE: 87.00078582763672\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\rEpoch:  26%|\u2588\u2588\u258b       | 79/300 [8:49:39<24:40:50, 402.04s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Epoch 78: \n", "Val ADE: 76.05685424804688 \n", "Val FDE: 103.36338806152344\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\rEpoch:  27%|\u2588\u2588\u258b       | 80/300 [8:56:20<24:33:54, 401.98s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Epoch 79: \n", "Val ADE: 74.80158996582031 \n", "Val FDE: 93.39835357666016\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\rEpoch:  27%|\u2588\u2588\u258b       | 81/300 [9:03:02<24:27:06, 401.95s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Epoch 80: \n", "Val ADE: 78.92801666259766 \n", "Val FDE: 101.9036865234375\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\rEpoch:  27%|\u2588\u2588\u258b       | 82/300 [9:09:44<24:20:16, 401.91s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Epoch 81: \n", "Val ADE: 77.72897338867188 \n", "Val FDE: 106.47919464111328\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\rEpoch:  28%|\u2588\u2588\u258a       | 83/300 [9:16:26<24:13:35, 401.91s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Epoch 82: \n", "Val ADE: 76.36753845214844 \n", "Val FDE: 95.8289566040039\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\rEpoch:  28%|\u2588\u2588\u258a       | 84/300 [9:23:08<24:07:19, 402.03s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Epoch 83: \n", "Val ADE: 71.68247985839844 \n", "Val FDE: 91.23521423339844\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\rEpoch:  28%|\u2588\u2588\u258a       | 85/300 [9:29:50<24:00:46, 402.08s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Epoch 84: \n", "Val ADE: 76.81002044677734 \n", "Val FDE: 100.6324691772461\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\rEpoch:  29%|\u2588\u2588\u258a       | 86/300 [9:36:33<23:54:06, 402.09s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Epoch 85: \n", "Val ADE: 76.82654571533203 \n", "Val FDE: 102.10348510742188\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\rEpoch:  29%|\u2588\u2588\u2589       | 87/300 [9:43:15<23:47:18, 402.06s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Epoch 86: \n", "Val ADE: 78.30652618408203 \n", "Val FDE: 103.19623565673828\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\rEpoch:  29%|\u2588\u2588\u2589       | 88/300 [9:49:56<23:40:26, 402.01s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Epoch 87: \n", "Val ADE: 74.9014892578125 \n", "Val FDE: 100.20053100585938\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\rEpoch:  30%|\u2588\u2588\u2589       | 89/300 [9:56:39<23:33:48, 402.03s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Epoch 88: \n", "Val ADE: 77.68144989013672 \n", "Val FDE: 94.29315185546875\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\rEpoch:  30%|\u2588\u2588\u2588       | 90/300 [10:03:21<23:27:39, 402.19s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Epoch 89: \n", "Val ADE: 75.15917205810547 \n", "Val FDE: 101.39463806152344\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\rEpoch:  30%|\u2588\u2588\u2588       | 91/300 [10:10:04<23:21:17, 402.29s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Epoch 90: \n", "Val ADE: 78.68143463134766 \n", "Val FDE: 102.17784881591797\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\rEpoch:  31%|\u2588\u2588\u2588       | 92/300 [10:16:46<23:14:44, 402.33s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Epoch 91: \n", "Val ADE: 76.76518249511719 \n", "Val FDE: 108.03560638427734\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\rEpoch:  31%|\u2588\u2588\u2588       | 93/300 [10:23:29<23:08:17, 402.40s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Epoch 92: \n", "Val ADE: 74.84069061279297 \n", "Val FDE: 103.66519927978516\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\rEpoch:  31%|\u2588\u2588\u2588\u258f      | 94/300 [10:30:11<23:01:27, 402.37s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Epoch 93: \n", "Val ADE: 77.72318267822266 \n", "Val FDE: 99.5692367553711\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\rEpoch:  32%|\u2588\u2588\u2588\u258f      | 95/300 [10:36:53<22:54:31, 402.30s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Epoch 94: \n", "Val ADE: 77.06853485107422 \n", "Val FDE: 106.30477142333984\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\rEpoch:  32%|\u2588\u2588\u2588\u258f      | 96/300 [10:43:36<22:48:06, 402.38s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Epoch 95: \n", "Val ADE: 81.12006378173828 \n", "Val FDE: 102.43267822265625\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\rEpoch:  32%|\u2588\u2588\u2588\u258f      | 97/300 [10:50:18<22:41:32, 402.43s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Epoch 96: \n", "Val ADE: 76.07836151123047 \n", "Val FDE: 107.00360870361328\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\rEpoch:  33%|\u2588\u2588\u2588\u258e      | 98/300 [10:57:01<22:35:04, 402.50s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Epoch 97: \n", "Val ADE: 81.31273651123047 \n", "Val FDE: 104.62193298339844\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\rEpoch:  33%|\u2588\u2588\u2588\u258e      | 99/300 [11:03:44<22:28:39, 402.59s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Epoch 98: \n", "Val ADE: 76.32565307617188 \n", "Val FDE: 99.85575866699219\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\rEpoch:  33%|\u2588\u2588\u2588\u258e      | 100/300 [11:10:26<22:22:06, 402.63s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Epoch 99: \n", "Val ADE: 78.27252197265625 \n", "Val FDE: 102.0928955078125\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\rEpoch:  34%|\u2588\u2588\u2588\u258e      | 101/300 [11:18:23<23:29:03, 424.84s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Epoch 100: \n", "Val ADE: 71.67366027832031 \n", "Val FDE: 93.67620849609375\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\rEpoch:  34%|\u2588\u2588\u2588\u258d      | 102/300 [11:26:20<24:13:25, 440.43s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Epoch 101: \n", "Val ADE: 72.38790893554688 \n", "Val FDE: 97.3553695678711\n"]}, {"output_type": "stream", "name": "stderr", "text": ["\rEpoch:  34%|\u2588\u2588\u2588\u258d      | 103/300 [11:34:17<24:41:55, 451.35s/it]"]}, {"output_type": "stream", "name": "stdout", "text": ["Epoch 102: \n", "Val ADE: 70.12110900878906 \n", "Val FDE: 97.36795806884766\n"]}], "execution_count": null, "id": "optional-colleague", "metadata": {"outputId": "da69270e-f010-48f3-ad7f-4cb308592979", "id": "optional-colleague", "colab": {"base_uri": "https://localhost:8080/"}}}, {"cell_type": "code", "source": [""], "outputs": [], "execution_count": null, "id": "16fe307e", "metadata": {"id": "16fe307e"}}], "metadata": {"kernelspec": {"display_name": "Python 3", "name": "python3"}, "language_info": {"name": "python"}, "colab": {"provenance": [], "machine_shape": "hm", "name": "train_inD_longterm.ipynb"}, "accelerator": "GPU"}}