{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "q4wgKhonGDJn"
      },
      "outputs": [],
      "source": [
        "from google.colab import drive"
      ],
      "id": "q4wgKhonGDJn"
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "koLSDU3KHn5e"
      },
      "source": [
        ""
      ],
      "id": "koLSDU3KHn5e"
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "QGzq7HyiGDDv",
        "outputId": "ed0488a9-c3ba-448f-d851-67c671bde8c0"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Mounted at /content/drive\n"
          ]
        }
      ],
      "source": [
        "drive.mount('/content/drive', force_remount= True)"
      ],
      "id": "QGzq7HyiGDDv"
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "ePwFUQvrGq1A",
        "outputId": "685c2c6e-508e-40d2-a2a0-8dea4c93feab"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "drive  sample_data\n"
          ]
        }
      ],
      "source": [
        "!ls"
      ],
      "id": "ePwFUQvrGq1A"
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "YPOC6rPIGC9Y",
        "outputId": "34bda1c7-43a2-4639-8a75-114a94265e25"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "/content/drive/MyDrive/Human-Path-Prediction-master (1)/ynet\n"
          ]
        }
      ],
      "source": [
        "cd /content/drive/MyDrive/Human-Path-Prediction-master (1)/ynet"
      ],
      "id": "YPOC6rPIGC9Y"
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "p2IwKmpUGCyL",
        "outputId": "6ce3e57a-9847-476f-dcd9-76a48cc3fe49"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Collecting wandb\n",
            "  Downloading wandb-0.12.9-py2.py3-none-any.whl (1.7 MB)\n",
            "\u001b[?25l\r\u001b[K     |▏                               | 10 kB 30.8 MB/s eta 0:00:01\r\u001b[K     |▍                               | 20 kB 17.5 MB/s eta 0:00:01\r\u001b[K     |▋                               | 30 kB 14.5 MB/s eta 0:00:01\r\u001b[K     |▊                               | 40 kB 13.3 MB/s eta 0:00:01\r\u001b[K     |█                               | 51 kB 7.2 MB/s eta 0:00:01\r\u001b[K     |█▏                              | 61 kB 7.8 MB/s eta 0:00:01\r\u001b[K     |█▍                              | 71 kB 8.1 MB/s eta 0:00:01\r\u001b[K     |█▌                              | 81 kB 9.1 MB/s eta 0:00:01\r\u001b[K     |█▊                              | 92 kB 8.9 MB/s eta 0:00:01\r\u001b[K     |██                              | 102 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |██                              | 112 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |██▎                             | 122 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |██▌                             | 133 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |██▊                             | 143 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |██▉                             | 153 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |███                             | 163 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |███▎                            | 174 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |███▍                            | 184 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |███▋                            | 194 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |███▉                            | 204 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |████                            | 215 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |████▏                           | 225 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |████▍                           | 235 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |████▋                           | 245 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |████▉                           | 256 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |█████                           | 266 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |█████▏                          | 276 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |█████▍                          | 286 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |█████▌                          | 296 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |█████▊                          | 307 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |██████                          | 317 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |██████▏                         | 327 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |██████▎                         | 337 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |██████▌                         | 348 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |██████▊                         | 358 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |██████▉                         | 368 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |███████                         | 378 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |███████▎                        | 389 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |███████▌                        | 399 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |███████▋                        | 409 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |███████▉                        | 419 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |████████                        | 430 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |████████▎                       | 440 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |████████▍                       | 450 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |████████▋                       | 460 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |████████▉                       | 471 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |█████████                       | 481 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |█████████▏                      | 491 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |█████████▍                      | 501 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |█████████▋                      | 512 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |█████████▊                      | 522 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |██████████                      | 532 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |██████████▏                     | 542 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |██████████▎                     | 552 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |██████████▌                     | 563 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |██████████▊                     | 573 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |███████████                     | 583 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |███████████                     | 593 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |███████████▎                    | 604 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |███████████▌                    | 614 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |███████████▋                    | 624 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |███████████▉                    | 634 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |████████████                    | 645 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |████████████▎                   | 655 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |████████████▍                   | 665 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |████████████▋                   | 675 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |████████████▉                   | 686 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |█████████████                   | 696 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |█████████████▏                  | 706 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |█████████████▍                  | 716 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |█████████████▋                  | 727 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |█████████████▊                  | 737 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |██████████████                  | 747 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |██████████████▏                 | 757 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |██████████████▍                 | 768 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |██████████████▌                 | 778 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |██████████████▊                 | 788 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |███████████████                 | 798 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |███████████████                 | 808 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |███████████████▎                | 819 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |███████████████▌                | 829 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |███████████████▊                | 839 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |███████████████▉                | 849 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |████████████████                | 860 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |████████████████▎               | 870 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |████████████████▌               | 880 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |████████████████▋               | 890 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |████████████████▉               | 901 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |█████████████████               | 911 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |█████████████████▏              | 921 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |█████████████████▍              | 931 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |█████████████████▋              | 942 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |█████████████████▉              | 952 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |██████████████████              | 962 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |██████████████████▏             | 972 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |██████████████████▍             | 983 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |██████████████████▌             | 993 kB 7.1 MB/s eta 0:00:01\r\u001b[K     |██████████████████▊             | 1.0 MB 7.1 MB/s eta 0:00:01\r\u001b[K     |███████████████████             | 1.0 MB 7.1 MB/s eta 0:00:01\r\u001b[K     |███████████████████▏            | 1.0 MB 7.1 MB/s eta 0:00:01\r\u001b[K     |███████████████████▎            | 1.0 MB 7.1 MB/s eta 0:00:01\r\u001b[K     |███████████████████▌            | 1.0 MB 7.1 MB/s eta 0:00:01\r\u001b[K     |███████████████████▊            | 1.1 MB 7.1 MB/s eta 0:00:01\r\u001b[K     |███████████████████▉            | 1.1 MB 7.1 MB/s eta 0:00:01\r\u001b[K     |████████████████████            | 1.1 MB 7.1 MB/s eta 0:00:01\r\u001b[K     |████████████████████▎           | 1.1 MB 7.1 MB/s eta 0:00:01\r\u001b[K     |████████████████████▌           | 1.1 MB 7.1 MB/s eta 0:00:01\r\u001b[K     |████████████████████▋           | 1.1 MB 7.1 MB/s eta 0:00:01\r\u001b[K     |████████████████████▉           | 1.1 MB 7.1 MB/s eta 0:00:01\r\u001b[K     |█████████████████████           | 1.1 MB 7.1 MB/s eta 0:00:01\r\u001b[K     |█████████████████████▎          | 1.1 MB 7.1 MB/s eta 0:00:01\r\u001b[K     |█████████████████████▍          | 1.1 MB 7.1 MB/s eta 0:00:01\r\u001b[K     |█████████████████████▋          | 1.2 MB 7.1 MB/s eta 0:00:01\r\u001b[K     |█████████████████████▉          | 1.2 MB 7.1 MB/s eta 0:00:01\r\u001b[K     |██████████████████████          | 1.2 MB 7.1 MB/s eta 0:00:01\r\u001b[K     |██████████████████████▏         | 1.2 MB 7.1 MB/s eta 0:00:01\r\u001b[K     |██████████████████████▍         | 1.2 MB 7.1 MB/s eta 0:00:01\r\u001b[K     |██████████████████████▋         | 1.2 MB 7.1 MB/s eta 0:00:01\r\u001b[K     |██████████████████████▊         | 1.2 MB 7.1 MB/s eta 0:00:01\r\u001b[K     |███████████████████████         | 1.2 MB 7.1 MB/s eta 0:00:01\r\u001b[K     |███████████████████████▏        | 1.2 MB 7.1 MB/s eta 0:00:01\r\u001b[K     |███████████████████████▎        | 1.2 MB 7.1 MB/s eta 0:00:01\r\u001b[K     |███████████████████████▌        | 1.3 MB 7.1 MB/s eta 0:00:01\r\u001b[K     |███████████████████████▊        | 1.3 MB 7.1 MB/s eta 0:00:01\r\u001b[K     |████████████████████████        | 1.3 MB 7.1 MB/s eta 0:00:01\r\u001b[K     |████████████████████████        | 1.3 MB 7.1 MB/s eta 0:00:01\r\u001b[K     |████████████████████████▎       | 1.3 MB 7.1 MB/s eta 0:00:01\r\u001b[K     |████████████████████████▌       | 1.3 MB 7.1 MB/s eta 0:00:01\r\u001b[K     |████████████████████████▊       | 1.3 MB 7.1 MB/s eta 0:00:01\r\u001b[K     |████████████████████████▉       | 1.3 MB 7.1 MB/s eta 0:00:01\r\u001b[K     |█████████████████████████       | 1.3 MB 7.1 MB/s eta 0:00:01\r\u001b[K     |█████████████████████████▎      | 1.4 MB 7.1 MB/s eta 0:00:01\r\u001b[K     |█████████████████████████▍      | 1.4 MB 7.1 MB/s eta 0:00:01\r\u001b[K     |█████████████████████████▋      | 1.4 MB 7.1 MB/s eta 0:00:01\r\u001b[K     |█████████████████████████▉      | 1.4 MB 7.1 MB/s eta 0:00:01\r\u001b[K     |██████████████████████████      | 1.4 MB 7.1 MB/s eta 0:00:01\r\u001b[K     |██████████████████████████▏     | 1.4 MB 7.1 MB/s eta 0:00:01\r\u001b[K     |██████████████████████████▍     | 1.4 MB 7.1 MB/s eta 0:00:01\r\u001b[K     |██████████████████████████▋     | 1.4 MB 7.1 MB/s eta 0:00:01\r\u001b[K     |██████████████████████████▊     | 1.4 MB 7.1 MB/s eta 0:00:01\r\u001b[K     |███████████████████████████     | 1.4 MB 7.1 MB/s eta 0:00:01\r\u001b[K     |███████████████████████████▏    | 1.5 MB 7.1 MB/s eta 0:00:01\r\u001b[K     |███████████████████████████▍    | 1.5 MB 7.1 MB/s eta 0:00:01\r\u001b[K     |███████████████████████████▌    | 1.5 MB 7.1 MB/s eta 0:00:01\r\u001b[K     |███████████████████████████▊    | 1.5 MB 7.1 MB/s eta 0:00:01\r\u001b[K     |████████████████████████████    | 1.5 MB 7.1 MB/s eta 0:00:01\r\u001b[K     |████████████████████████████▏   | 1.5 MB 7.1 MB/s eta 0:00:01\r\u001b[K     |████████████████████████████▎   | 1.5 MB 7.1 MB/s eta 0:00:01\r\u001b[K     |████████████████████████████▌   | 1.5 MB 7.1 MB/s eta 0:00:01\r\u001b[K     |████████████████████████████▊   | 1.5 MB 7.1 MB/s eta 0:00:01\r\u001b[K     |████████████████████████████▉   | 1.5 MB 7.1 MB/s eta 0:00:01\r\u001b[K     |█████████████████████████████   | 1.6 MB 7.1 MB/s eta 0:00:01\r\u001b[K     |█████████████████████████████▎  | 1.6 MB 7.1 MB/s eta 0:00:01\r\u001b[K     |█████████████████████████████▌  | 1.6 MB 7.1 MB/s eta 0:00:01\r\u001b[K     |█████████████████████████████▋  | 1.6 MB 7.1 MB/s eta 0:00:01\r\u001b[K     |█████████████████████████████▉  | 1.6 MB 7.1 MB/s eta 0:00:01\r\u001b[K     |██████████████████████████████  | 1.6 MB 7.1 MB/s eta 0:00:01\r\u001b[K     |██████████████████████████████▏ | 1.6 MB 7.1 MB/s eta 0:00:01\r\u001b[K     |██████████████████████████████▍ | 1.6 MB 7.1 MB/s eta 0:00:01\r\u001b[K     |██████████████████████████████▋ | 1.6 MB 7.1 MB/s eta 0:00:01\r\u001b[K     |██████████████████████████████▉ | 1.6 MB 7.1 MB/s eta 0:00:01\r\u001b[K     |███████████████████████████████ | 1.7 MB 7.1 MB/s eta 0:00:01\r\u001b[K     |███████████████████████████████▏| 1.7 MB 7.1 MB/s eta 0:00:01\r\u001b[K     |███████████████████████████████▍| 1.7 MB 7.1 MB/s eta 0:00:01\r\u001b[K     |███████████████████████████████▌| 1.7 MB 7.1 MB/s eta 0:00:01\r\u001b[K     |███████████████████████████████▊| 1.7 MB 7.1 MB/s eta 0:00:01\r\u001b[K     |████████████████████████████████| 1.7 MB 7.1 MB/s eta 0:00:01\r\u001b[K     |████████████████████████████████| 1.7 MB 7.1 MB/s \n",
            "\u001b[?25hRequirement already satisfied: psutil>=5.0.0 in /usr/local/lib/python3.7/dist-packages (from wandb) (5.4.8)\n",
            "Collecting yaspin>=1.0.0\n",
            "  Downloading yaspin-2.1.0-py3-none-any.whl (18 kB)\n",
            "Requirement already satisfied: PyYAML in /usr/local/lib/python3.7/dist-packages (from wandb) (3.13)\n",
            "Collecting pathtools\n",
            "  Downloading pathtools-0.1.2.tar.gz (11 kB)\n",
            "Requirement already satisfied: python-dateutil>=2.6.1 in /usr/local/lib/python3.7/dist-packages (from wandb) (2.8.2)\n",
            "Requirement already satisfied: protobuf>=3.12.0 in /usr/local/lib/python3.7/dist-packages (from wandb) (3.17.3)\n",
            "Collecting subprocess32>=3.5.3\n",
            "  Downloading subprocess32-3.5.4.tar.gz (97 kB)\n",
            "\u001b[K     |████████████████████████████████| 97 kB 8.1 MB/s \n",
            "\u001b[?25hCollecting docker-pycreds>=0.4.0\n",
            "  Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB)\n",
            "Collecting configparser>=3.8.1\n",
            "  Downloading configparser-5.2.0-py3-none-any.whl (19 kB)\n",
            "Requirement already satisfied: requests<3,>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from wandb) (2.23.0)\n",
            "Collecting GitPython>=1.0.0\n",
            "  Downloading GitPython-3.1.26-py3-none-any.whl (180 kB)\n",
            "\u001b[K     |████████████████████████████████| 180 kB 81.1 MB/s \n",
            "\u001b[?25hRequirement already satisfied: promise<3,>=2.0 in /usr/local/lib/python3.7/dist-packages (from wandb) (2.3)\n",
            "Requirement already satisfied: six>=1.13.0 in /usr/local/lib/python3.7/dist-packages (from wandb) (1.15.0)\n",
            "Collecting sentry-sdk>=1.0.0\n",
            "  Downloading sentry_sdk-1.5.4-py2.py3-none-any.whl (143 kB)\n",
            "\u001b[K     |████████████████████████████████| 143 kB 98.4 MB/s \n",
            "\u001b[?25hCollecting shortuuid>=0.5.0\n",
            "  Downloading shortuuid-1.0.8-py3-none-any.whl (9.5 kB)\n",
            "Requirement already satisfied: Click!=8.0.0,>=7.0 in /usr/local/lib/python3.7/dist-packages (from wandb) (7.1.2)\n",
            "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.7/dist-packages (from GitPython>=1.0.0->wandb) (3.10.0.2)\n",
            "Collecting gitdb<5,>=4.0.1\n",
            "  Downloading gitdb-4.0.9-py3-none-any.whl (63 kB)\n",
            "\u001b[K     |████████████████████████████████| 63 kB 2.2 MB/s \n",
            "\u001b[?25hCollecting smmap<6,>=3.0.1\n",
            "  Downloading smmap-5.0.0-py3-none-any.whl (24 kB)\n",
            "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.0.0->wandb) (2.10)\n",
            "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.0.0->wandb) (1.24.3)\n",
            "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.0.0->wandb) (3.0.4)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.0.0->wandb) (2021.10.8)\n",
            "Requirement already satisfied: termcolor<2.0.0,>=1.1.0 in /usr/local/lib/python3.7/dist-packages (from yaspin>=1.0.0->wandb) (1.1.0)\n",
            "Building wheels for collected packages: subprocess32, pathtools\n",
            "  Building wheel for subprocess32 (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Created wheel for subprocess32: filename=subprocess32-3.5.4-py3-none-any.whl size=6502 sha256=5c33088bd8006cb7090bcc5bfa0b85646b7695d6c0a91cd1427ac109af26714e\n",
            "  Stored in directory: /root/.cache/pip/wheels/50/ca/fa/8fca8d246e64f19488d07567547ddec8eb084e8c0d7a59226a\n",
            "  Building wheel for pathtools (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Created wheel for pathtools: filename=pathtools-0.1.2-py3-none-any.whl size=8806 sha256=f24ed1a1c487f69e490ce22c58d65552635e2ae41c7608e652ab40dd28c4f300\n",
            "  Stored in directory: /root/.cache/pip/wheels/3e/31/09/fa59cef12cdcfecc627b3d24273699f390e71828921b2cbba2\n",
            "Successfully built subprocess32 pathtools\n",
            "Installing collected packages: smmap, gitdb, yaspin, subprocess32, shortuuid, sentry-sdk, pathtools, GitPython, docker-pycreds, configparser, wandb\n",
            "Successfully installed GitPython-3.1.26 configparser-5.2.0 docker-pycreds-0.4.0 gitdb-4.0.9 pathtools-0.1.2 sentry-sdk-1.5.4 shortuuid-1.0.8 smmap-5.0.0 subprocess32-3.5.4 wandb-0.12.9 yaspin-2.1.0\n"
          ]
        }
      ],
      "source": [
        "!pip install wandb"
      ],
      "id": "p2IwKmpUGCyL"
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "determined-township"
      },
      "outputs": [],
      "source": [
        "import pandas as pd\n",
        "import yaml\n",
        "import argparse\n",
        "import torch\n",
        "from model import YNet"
      ],
      "id": "determined-township"
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "finished-potential"
      },
      "outputs": [],
      "source": [
        "%load_ext autoreload\n",
        "%autoreload 2"
      ],
      "id": "finished-potential"
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "e567fbe0"
      },
      "source": [
        "#### Some hyperparameters and settings"
      ],
      "id": "e567fbe0"
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "arabic-thickness"
      },
      "outputs": [],
      "source": [
        "CONFIG_FILE_PATH = 'config/inD_longterm.yaml'  # yaml config file containing all the hyperparameters\n",
        "EXPERIMENT_NAME = 'ind_longterm'  # arbitrary name for this experiment\n",
        "DATASET_NAME = 'ind'\n",
        "\n",
        "TRAIN_DATA_PATH = 'data/inD/train.pkl'\n",
        "TRAIN_IMAGE_PATH = 'data/inD/train'\n",
        "VAL_DATA_PATH = 'data/inD/test.pkl'\n",
        "VAL_IMAGE_PATH = 'data/inD/test'\n",
        "OBS_LEN = 5 # in timesteps\n",
        "PRED_LEN = 30 # in timesteps\n",
        "NUM_GOALS = 20 # K_e\n",
        "NUM_TRAJ = 1 # K_a\n",
        "\n",
        "BATCH_SIZE = 8"
      ],
      "id": "arabic-thickness"
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "id": "6BhBP_J5FhgB",
        "outputId": "897b2f71-9d15-4d19-8273-c2c1ce97a9f4"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Collecting tqdm==4.48.0\n",
            "  Downloading tqdm-4.48.0-py2.py3-none-any.whl (67 kB)\n",
            "\u001b[?25l\r\u001b[K     |████▉                           | 10 kB 32.6 MB/s eta 0:00:01\r\u001b[K     |█████████▋                      | 20 kB 32.1 MB/s eta 0:00:01\r\u001b[K     |██████████████▌                 | 30 kB 19.3 MB/s eta 0:00:01\r\u001b[K     |███████████████████▎            | 40 kB 16.4 MB/s eta 0:00:01\r\u001b[K     |████████████████████████▏       | 51 kB 11.1 MB/s eta 0:00:01\r\u001b[K     |█████████████████████████████   | 61 kB 12.0 MB/s eta 0:00:01\r\u001b[K     |████████████████████████████████| 67 kB 4.7 MB/s \n",
            "\u001b[?25hCollecting pyyaml==5.3.1\n",
            "  Downloading PyYAML-5.3.1.tar.gz (269 kB)\n",
            "\u001b[K     |████████████████████████████████| 269 kB 20.8 MB/s \n",
            "\u001b[?25hRequirement already satisfied: matplotlib==3.2.2 in /usr/local/lib/python3.7/dist-packages (from -r requirements.txt (line 3)) (3.2.2)\n",
            "Collecting torch==1.5.1\n",
            "  Downloading torch-1.5.1-cp37-cp37m-manylinux1_x86_64.whl (753.2 MB)\n",
            "\u001b[K     |████████████████████████████████| 753.2 MB 13 kB/s \n",
            "\u001b[?25hRequirement already satisfied: pandas==1.1.5 in /usr/local/lib/python3.7/dist-packages (from -r requirements.txt (line 5)) (1.1.5)\n",
            "Collecting opencv-python==4.4.0.42\n",
            "  Downloading opencv_python-4.4.0.42-cp37-cp37m-manylinux2014_x86_64.whl (49.4 MB)\n",
            "\u001b[K     |████████████████████████████████| 49.4 MB 1.2 MB/s \n",
            "\u001b[?25hCollecting scipy==1.5.0\n",
            "  Downloading scipy-1.5.0-cp37-cp37m-manylinux1_x86_64.whl (25.9 MB)\n",
            "\u001b[K     |████████████████████████████████| 25.9 MB 1.2 MB/s \n",
            "\u001b[?25hCollecting segmentation_models_pytorch==0.1.0\n",
            "  Downloading segmentation_models_pytorch-0.1.0-py3-none-any.whl (42 kB)\n",
            "\u001b[K     |████████████████████████████████| 42 kB 1.8 MB/s \n",
            "\u001b[?25hRequirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib==3.2.2->-r requirements.txt (line 3)) (0.11.0)\n",
            "Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib==3.2.2->-r requirements.txt (line 3)) (2.8.2)\n",
            "Requirement already satisfied: numpy>=1.11 in /usr/local/lib/python3.7/dist-packages (from matplotlib==3.2.2->-r requirements.txt (line 3)) (1.19.5)\n",
            "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib==3.2.2->-r requirements.txt (line 3)) (1.3.2)\n",
            "Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib==3.2.2->-r requirements.txt (line 3)) (3.0.7)\n",
            "Requirement already satisfied: future in /usr/local/lib/python3.7/dist-packages (from torch==1.5.1->-r requirements.txt (line 4)) (0.16.0)\n",
            "Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.7/dist-packages (from pandas==1.1.5->-r requirements.txt (line 5)) (2018.9)\n",
            "Requirement already satisfied: torchvision>=0.3.0 in /usr/local/lib/python3.7/dist-packages (from segmentation_models_pytorch==0.1.0->-r requirements.txt (line 8)) (0.11.1+cu111)\n",
            "Collecting efficientnet-pytorch>=0.5.1\n",
            "  Downloading efficientnet_pytorch-0.7.1.tar.gz (21 kB)\n",
            "Collecting pretrainedmodels==0.7.4\n",
            "  Downloading pretrainedmodels-0.7.4.tar.gz (58 kB)\n",
            "\u001b[K     |████████████████████████████████| 58 kB 8.3 MB/s \n",
            "\u001b[?25hCollecting munch\n",
            "  Downloading munch-2.5.0-py2.py3-none-any.whl (10 kB)\n",
            "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil>=2.1->matplotlib==3.2.2->-r requirements.txt (line 3)) (1.15.0)\n",
            "Requirement already satisfied: pillow!=8.3.0,>=5.3.0 in /usr/local/lib/python3.7/dist-packages (from torchvision>=0.3.0->segmentation_models_pytorch==0.1.0->-r requirements.txt (line 8)) (7.1.2)\n",
            "Collecting torchvision>=0.3.0\n",
            "  Downloading torchvision-0.11.3-cp37-cp37m-manylinux1_x86_64.whl (23.2 MB)\n",
            "\u001b[K     |████████████████████████████████| 23.2 MB 1.2 MB/s \n",
            "\u001b[?25h  Downloading torchvision-0.11.2-cp37-cp37m-manylinux1_x86_64.whl (23.3 MB)\n",
            "\u001b[K     |████████████████████████████████| 23.3 MB 1.2 MB/s \n",
            "\u001b[?25h  Downloading torchvision-0.11.1-cp37-cp37m-manylinux1_x86_64.whl (23.3 MB)\n",
            "\u001b[K     |████████████████████████████████| 23.3 MB 540 kB/s \n",
            "\u001b[?25h  Downloading torchvision-0.10.1-cp37-cp37m-manylinux1_x86_64.whl (22.1 MB)\n",
            "\u001b[K     |████████████████████████████████| 22.1 MB 75.0 MB/s \n",
            "\u001b[?25h  Downloading torchvision-0.10.0-cp37-cp37m-manylinux1_x86_64.whl (22.1 MB)\n",
            "\u001b[K     |████████████████████████████████| 22.1 MB 57.8 MB/s \n",
            "\u001b[?25h  Downloading torchvision-0.9.1-cp37-cp37m-manylinux1_x86_64.whl (17.4 MB)\n",
            "\u001b[K     |████████████████████████████████| 17.4 MB 70.9 MB/s \n",
            "\u001b[?25h  Downloading torchvision-0.9.0-cp37-cp37m-manylinux1_x86_64.whl (17.3 MB)\n",
            "\u001b[K     |████████████████████████████████| 17.3 MB 47.7 MB/s \n",
            "\u001b[?25h  Downloading torchvision-0.8.2-cp37-cp37m-manylinux1_x86_64.whl (12.8 MB)\n",
            "\u001b[K     |████████████████████████████████| 12.8 MB 68.9 MB/s \n",
            "\u001b[?25h  Downloading torchvision-0.8.1-cp37-cp37m-manylinux1_x86_64.whl (12.7 MB)\n",
            "\u001b[K     |████████████████████████████████| 12.7 MB 74 kB/s \n",
            "\u001b[?25h  Downloading torchvision-0.8.0-cp37-cp37m-manylinux1_x86_64.whl (11.8 MB)\n",
            "\u001b[K     |████████████████████████████████| 11.8 MB 52.9 MB/s \n",
            "\u001b[?25h  Downloading torchvision-0.7.0-cp37-cp37m-manylinux1_x86_64.whl (5.9 MB)\n",
            "\u001b[K     |████████████████████████████████| 5.9 MB 93.3 MB/s \n",
            "\u001b[?25h  Downloading torchvision-0.6.1-cp37-cp37m-manylinux1_x86_64.whl (6.6 MB)\n",
            "\u001b[K     |████████████████████████████████| 6.6 MB 39.4 MB/s \n",
            "\u001b[?25hBuilding wheels for collected packages: pyyaml, pretrainedmodels, efficientnet-pytorch\n",
            "  Building wheel for pyyaml (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Created wheel for pyyaml: filename=PyYAML-5.3.1-cp37-cp37m-linux_x86_64.whl size=44636 sha256=b5a6baa2b5013e0038cd42c9dba6edcd8ed5b329585decf0306f7654a7cb46b6\n",
            "  Stored in directory: /root/.cache/pip/wheels/5e/03/1e/e1e954795d6f35dfc7b637fe2277bff021303bd9570ecea653\n",
            "  Building wheel for pretrainedmodels (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Created wheel for pretrainedmodels: filename=pretrainedmodels-0.7.4-py3-none-any.whl size=60965 sha256=6911155523e8a34e993f0366ce2945fcdc11e80f7d14f8ba4f1c22a8d777282b\n",
            "  Stored in directory: /root/.cache/pip/wheels/ed/27/e8/9543d42de2740d3544db96aefef63bda3f2c1761b3334f4873\n",
            "  Building wheel for efficientnet-pytorch (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Created wheel for efficientnet-pytorch: filename=efficientnet_pytorch-0.7.1-py3-none-any.whl size=16446 sha256=92db62c2667db69e4d8173aff20bf4363e8de16ecbf3f143d081b059580c0ad0\n",
            "  Stored in directory: /root/.cache/pip/wheels/0e/cc/b2/49e74588263573ff778da58cc99b9c6349b496636a7e165be6\n",
            "Successfully built pyyaml pretrainedmodels efficientnet-pytorch\n",
            "Installing collected packages: torch, tqdm, torchvision, munch, pretrainedmodels, efficientnet-pytorch, segmentation-models-pytorch, scipy, pyyaml, opencv-python\n",
            "  Attempting uninstall: torch\n",
            "    Found existing installation: torch 1.10.0+cu111\n",
            "    Uninstalling torch-1.10.0+cu111:\n",
            "      Successfully uninstalled torch-1.10.0+cu111\n",
            "  Attempting uninstall: tqdm\n",
            "    Found existing installation: tqdm 4.62.3\n",
            "    Uninstalling tqdm-4.62.3:\n",
            "      Successfully uninstalled tqdm-4.62.3\n",
            "  Attempting uninstall: torchvision\n",
            "    Found existing installation: torchvision 0.11.1+cu111\n",
            "    Uninstalling torchvision-0.11.1+cu111:\n",
            "      Successfully uninstalled torchvision-0.11.1+cu111\n",
            "  Attempting uninstall: scipy\n",
            "    Found existing installation: scipy 1.4.1\n",
            "    Uninstalling scipy-1.4.1:\n",
            "      Successfully uninstalled scipy-1.4.1\n",
            "  Attempting uninstall: pyyaml\n",
            "    Found existing installation: PyYAML 3.13\n",
            "    Uninstalling PyYAML-3.13:\n",
            "      Successfully uninstalled PyYAML-3.13\n",
            "  Attempting uninstall: opencv-python\n",
            "    Found existing installation: opencv-python 4.1.2.30\n",
            "    Uninstalling opencv-python-4.1.2.30:\n",
            "      Successfully uninstalled opencv-python-4.1.2.30\n",
            "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
            "torchtext 0.11.0 requires torch==1.10.0, but you have torch 1.5.1 which is incompatible.\n",
            "torchaudio 0.10.0+cu111 requires torch==1.10.0, but you have torch 1.5.1 which is incompatible.\n",
            "albumentations 0.1.12 requires imgaug<0.2.7,>=0.2.5, but you have imgaug 0.2.9 which is incompatible.\u001b[0m\n",
            "Successfully installed efficientnet-pytorch-0.7.1 munch-2.5.0 opencv-python-4.4.0.42 pretrainedmodels-0.7.4 pyyaml-5.3.1 scipy-1.5.0 segmentation-models-pytorch-0.1.0 torch-1.5.1 torchvision-0.6.1 tqdm-4.48.0\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "application/vnd.colab-display-data+json": {
              "pip_warning": {
                "packages": [
                  "cv2",
                  "torch",
                  "tqdm",
                  "yaml"
                ]
              }
            }
          },
          "metadata": {}
        }
      ],
      "source": [
        "pip install -r requirements.txt"
      ],
      "id": "6BhBP_J5FhgB"
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2f729e8f"
      },
      "source": [
        "#### Load config file and print hyperparameters"
      ],
      "id": "2f729e8f"
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "dangerous-cutting",
        "outputId": "ccae9e8a-ae20-4ac3-8532-b3d4b7047de7"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "{'CWS_params': {'ratio': 2, 'rot': True, 'sigma_factor': 6},\n",
              " 'batch_size': 8,\n",
              " 'decoder_channels': [64, 64, 64, 32, 32],\n",
              " 'encoder_channels': [32, 32, 64, 64, 64],\n",
              " 'kernlen': 31,\n",
              " 'learning_rate': 0.0001,\n",
              " 'loss_scale': 1000,\n",
              " 'nsig': 4,\n",
              " 'num_epochs': 300,\n",
              " 'rel_threshold': 0.002,\n",
              " 'resize': 0.33,\n",
              " 'segmentation_model_fp': 'segmentation_models/inD_segmentation.pth',\n",
              " 'semantic_classes': 6,\n",
              " 'temperature': 1.8,\n",
              " 'unfreeze': 100,\n",
              " 'use_CWS': True,\n",
              " 'use_TTST': True,\n",
              " 'use_features_only': False,\n",
              " 'viz_epoch': 10,\n",
              " 'waypoints': [14, 29]}"
            ]
          },
          "metadata": {},
          "execution_count": 7
        }
      ],
      "source": [
        "with open(CONFIG_FILE_PATH) as file:\n",
        "    params = yaml.load(file, Loader=yaml.FullLoader)\n",
        "experiment_name = CONFIG_FILE_PATH.split('.yaml')[0].split('config/')[1]\n",
        "params"
      ],
      "id": "dangerous-cutting"
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "699a7543"
      },
      "source": [
        "#### Wandb INIT"
      ],
      "id": "699a7543"
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "65f560e7"
      },
      "outputs": [],
      "source": [
        ""
      ],
      "id": "65f560e7"
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "amber-pressure"
      },
      "source": [
        "#### Load preprocessed Data"
      ],
      "id": "amber-pressure"
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "german-feature",
        "outputId": "810d53f1-57ae-47df-e25d-03532a831a6e"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Requirement already satisfied: pickle5 in /usr/local/lib/python3.7/dist-packages (0.0.12)\n"
          ]
        }
      ],
      "source": [
        "#df_train = pd.read_pickle(TRAIN_DATA_PATH)\n",
        "#df_val = pd.read_pickle(VAL_DATA_PATH)\n",
        "!pip3 install pickle5\n",
        "#df_train = pd.read_pickle(TRAIN_DATA_PATH)\n",
        "#df_val = pd.read_pickle(VAL_DATA_PATH)\n",
        "\n",
        "import pickle5 as pickle \n",
        "with open(TRAIN_DATA_PATH, \"rb\") as fh:\n",
        "    df_train = pickle.load(fh)\n",
        "with open(VAL_DATA_PATH, \"rb\") as fh1:\n",
        "    df_val = pickle.load(fh1)"
      ],
      "id": "german-feature"
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 206
        },
        "id": "corporate-pharmacy",
        "outputId": "0d212542-4bc4-4100-b1c7-52696234b9da"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/html": [
              "\n",
              "  <div id=\"df-3d81353e-04f2-4cf4-b22c-2997dcb238d6\">\n",
              "    <div class=\"colab-df-container\">\n",
              "      <div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>trackId</th>\n",
              "      <th>frame</th>\n",
              "      <th>x</th>\n",
              "      <th>y</th>\n",
              "      <th>sceneId</th>\n",
              "      <th>metaId</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>0</th>\n",
              "      <td>31</td>\n",
              "      <td>2217</td>\n",
              "      <td>25.07654</td>\n",
              "      <td>6.78323</td>\n",
              "      <td>07</td>\n",
              "      <td>0</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>31</td>\n",
              "      <td>2242</td>\n",
              "      <td>26.11484</td>\n",
              "      <td>7.72170</td>\n",
              "      <td>07</td>\n",
              "      <td>0</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2</th>\n",
              "      <td>31</td>\n",
              "      <td>2267</td>\n",
              "      <td>27.05390</td>\n",
              "      <td>8.94723</td>\n",
              "      <td>07</td>\n",
              "      <td>0</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>3</th>\n",
              "      <td>31</td>\n",
              "      <td>2292</td>\n",
              "      <td>28.08326</td>\n",
              "      <td>10.18219</td>\n",
              "      <td>07</td>\n",
              "      <td>0</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>4</th>\n",
              "      <td>31</td>\n",
              "      <td>2317</td>\n",
              "      <td>29.08530</td>\n",
              "      <td>11.39276</td>\n",
              "      <td>07</td>\n",
              "      <td>0</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>\n",
              "      <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-3d81353e-04f2-4cf4-b22c-2997dcb238d6')\"\n",
              "              title=\"Convert this dataframe to an interactive table.\"\n",
              "              style=\"display:none;\">\n",
              "        \n",
              "  <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
              "       width=\"24px\">\n",
              "    <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
              "    <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
              "  </svg>\n",
              "      </button>\n",
              "      \n",
              "  <style>\n",
              "    .colab-df-container {\n",
              "      display:flex;\n",
              "      flex-wrap:wrap;\n",
              "      gap: 12px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert {\n",
              "      background-color: #E8F0FE;\n",
              "      border: none;\n",
              "      border-radius: 50%;\n",
              "      cursor: pointer;\n",
              "      display: none;\n",
              "      fill: #1967D2;\n",
              "      height: 32px;\n",
              "      padding: 0 0 0 0;\n",
              "      width: 32px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert:hover {\n",
              "      background-color: #E2EBFA;\n",
              "      box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
              "      fill: #174EA6;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert {\n",
              "      background-color: #3B4455;\n",
              "      fill: #D2E3FC;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert:hover {\n",
              "      background-color: #434B5C;\n",
              "      box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
              "      filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
              "      fill: #FFFFFF;\n",
              "    }\n",
              "  </style>\n",
              "\n",
              "      <script>\n",
              "        const buttonEl =\n",
              "          document.querySelector('#df-3d81353e-04f2-4cf4-b22c-2997dcb238d6 button.colab-df-convert');\n",
              "        buttonEl.style.display =\n",
              "          google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
              "\n",
              "        async function convertToInteractive(key) {\n",
              "          const element = document.querySelector('#df-3d81353e-04f2-4cf4-b22c-2997dcb238d6');\n",
              "          const dataTable =\n",
              "            await google.colab.kernel.invokeFunction('convertToInteractive',\n",
              "                                                     [key], {});\n",
              "          if (!dataTable) return;\n",
              "\n",
              "          const docLinkHtml = 'Like what you see? Visit the ' +\n",
              "            '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
              "            + ' to learn more about interactive tables.';\n",
              "          element.innerHTML = '';\n",
              "          dataTable['output_type'] = 'display_data';\n",
              "          await google.colab.output.renderOutput(dataTable, element);\n",
              "          const docLink = document.createElement('div');\n",
              "          docLink.innerHTML = docLinkHtml;\n",
              "          element.appendChild(docLink);\n",
              "        }\n",
              "      </script>\n",
              "    </div>\n",
              "  </div>\n",
              "  "
            ],
            "text/plain": [
              "   trackId  frame         x         y sceneId  metaId\n",
              "0       31   2217  25.07654   6.78323      07       0\n",
              "1       31   2242  26.11484   7.72170      07       0\n",
              "2       31   2267  27.05390   8.94723      07       0\n",
              "3       31   2292  28.08326  10.18219      07       0\n",
              "4       31   2317  29.08530  11.39276      07       0"
            ]
          },
          "metadata": {},
          "execution_count": 9
        }
      ],
      "source": [
        "df_train.head()"
      ],
      "id": "corporate-pharmacy"
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "24dc5d7c"
      },
      "source": [
        "#### Initiate model"
      ],
      "id": "24dc5d7c"
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "harmful-colleague",
        "outputId": "d96f9832-066d-4c7f-9a56-0a24ecda9042"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.7/dist-packages/torch/serialization.py:657: SourceChangeWarning: source code of class 'segmentation_models_pytorch.encoders.resnet.ResNetEncoder' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n",
            "  warnings.warn(msg, SourceChangeWarning)\n",
            "/usr/local/lib/python3.7/dist-packages/torch/serialization.py:657: SourceChangeWarning: source code of class 'segmentation_models_pytorch.base.modules.Conv2dReLU' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n",
            "  warnings.warn(msg, SourceChangeWarning)\n",
            "/usr/local/lib/python3.7/dist-packages/torch/serialization.py:657: SourceChangeWarning: source code of class 'segmentation_models_pytorch.base.modules.Activation' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n",
            "  warnings.warn(msg, SourceChangeWarning)\n"
          ]
        }
      ],
      "source": [
        "model = YNet(obs_len=OBS_LEN, pred_len=PRED_LEN, params=params)"
      ],
      "id": "harmful-colleague"
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "45e099fe"
      },
      "source": [
        "#### Start training\n",
        "Note, the Val ADE and FDE are without TTST and CWS to save time. Therefore, the numbers will be worse than the final values."
      ],
      "id": "45e099fe"
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 52
        },
        "id": "suiRZodUeLO4",
        "outputId": "057c3099-5309-4ac9-e029-b20a71c61fa7"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33magv\u001b[0m (use `wandb login --relogin` to force relogin)\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/html": [
              "\n",
              "                    Syncing run <strong><a href=\"https://wandb.ai/agv/ynet/runs/cmnwtqw0\" target=\"_blank\">kind-grass-18</a></strong> to <a href=\"https://wandb.ai/agv/ynet\" target=\"_blank\">Weights & Biases</a> (<a href=\"https://docs.wandb.com/integrations/jupyter.html\" target=\"_blank\">docs</a>).<br/>\n",
              "\n",
              "                "
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {}
        }
      ],
      "source": [
        "import weights_and_biases as wandb\n",
        "wandb.init_wandb(params.copy(), model.model)"
      ],
      "id": "suiRZodUeLO4"
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "16fe307e"
      },
      "outputs": [],
      "source": [
        "model.train(df_train, df_val, params, train_image_path=TRAIN_IMAGE_PATH, val_image_path=VAL_IMAGE_PATH, \n",
        "            experiment_name=EXPERIMENT_NAME, batch_size=BATCH_SIZE, num_goals=NUM_GOALS, num_traj=NUM_TRAJ, \n",
        "            device=None, dataset_name= 'ind')"
      ],
      "id": "16fe307e"
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "name": "train_inD_longterm_experiment.ipynb",
      "provenance": [],
      "machine_shape": "hm"
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}