{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "SB92TGw0TwQL",
        "outputId": "a9be44cf-a58a-41c3-f5ad-a29b263300a4"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "/content/drive/MyDrive/ynet\n"
          ]
        }
      ],
      "source": [
        "#For correct folder \n",
        "%cd /content/drive/MyDrive/ynet"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "m5AEj6Ujihwn",
        "outputId": "992031c3-e5f2-472a-a9da-01b4b20ef830"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            " \u001b[0m\u001b[01;34mconfig\u001b[0m/                           requirements.txt\n",
            " \u001b[01;34mdata\u001b[0m/                             \u001b[01;34msegmentation_models\u001b[0m/\n",
            "'evaluate_inD_longterm(1).ipynb'   test.py\n",
            " evaluate_inD_longterm.ipynb       train_inD_longterm.ipynb\n",
            " evaluate_SDD_longterm.ipynb       train.py\n",
            " evaluate_SDD_trajnet.ipynb        train_SDD_longterm.ipynb\n",
            " \u001b[01;34mimages\u001b[0m/                           train_SDD_trajnet.ipynb\n",
            " model.py                          \u001b[01;34mutils\u001b[0m/\n",
            " \u001b[01;34mpretrained_models\u001b[0m/                \u001b[01;34mw_andb\u001b[0m/\n",
            " \u001b[01;34m__pycache__\u001b[0m/                      weights_and_biases.py\n",
            " README.md\n"
          ]
        }
      ],
      "source": [
        "%ls"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "X4sajokVU91C"
      },
      "outputs": [],
      "source": [
        "%load_ext autoreload\n",
        "%autoreload 2"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "id": "YHQNPU5bU-9J",
        "outputId": "690ae6da-e82b-49ed-fdaa-bf5474500547"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Requirement already satisfied: tqdm==4.48.0 in /usr/local/lib/python3.7/dist-packages (from -r requirements.txt (line 1)) (4.48.0)\n",
            "Collecting pyyaml==5.3.1\n",
            "  Using cached PyYAML-5.3.1-cp37-cp37m-linux_x86_64.whl\n",
            "Requirement already satisfied: matplotlib==3.2.2 in /usr/local/lib/python3.7/dist-packages (from -r requirements.txt (line 3)) (3.2.2)\n",
            "Requirement already satisfied: opencv-python==4.4.0.42 in /usr/local/lib/python3.7/dist-packages (from -r requirements.txt (line 6)) (4.4.0.42)\n",
            "Requirement already satisfied: scipy==1.5.0 in /usr/local/lib/python3.7/dist-packages (from -r requirements.txt (line 7)) (1.5.0)\n",
            "Requirement already satisfied: segmentation_models_pytorch==0.1.0 in /usr/local/lib/python3.7/dist-packages (from -r requirements.txt (line 8)) (0.1.0)\n",
            "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib==3.2.2->-r requirements.txt (line 3)) (1.3.2)\n",
            "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib==3.2.2->-r requirements.txt (line 3)) (0.11.0)\n",
            "Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib==3.2.2->-r requirements.txt (line 3)) (2.8.2)\n",
            "Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib==3.2.2->-r requirements.txt (line 3)) (3.0.7)\n",
            "Requirement already satisfied: numpy>=1.11 in /usr/local/lib/python3.7/dist-packages (from matplotlib==3.2.2->-r requirements.txt (line 3)) (1.19.5)\n",
            "Requirement already satisfied: pretrainedmodels==0.7.4 in /usr/local/lib/python3.7/dist-packages (from segmentation_models_pytorch==0.1.0->-r requirements.txt (line 8)) (0.7.4)\n",
            "Requirement already satisfied: efficientnet-pytorch>=0.5.1 in /usr/local/lib/python3.7/dist-packages (from segmentation_models_pytorch==0.1.0->-r requirements.txt (line 8)) (0.7.1)\n",
            "Requirement already satisfied: torchvision>=0.3.0 in /usr/local/lib/python3.7/dist-packages (from segmentation_models_pytorch==0.1.0->-r requirements.txt (line 8)) (0.11.1+cu111)\n",
            "Requirement already satisfied: munch in /usr/local/lib/python3.7/dist-packages (from pretrainedmodels==0.7.4->segmentation_models_pytorch==0.1.0->-r requirements.txt (line 8)) (2.5.0)\n",
            "Requirement already satisfied: torch in /usr/local/lib/python3.7/dist-packages (from pretrainedmodels==0.7.4->segmentation_models_pytorch==0.1.0->-r requirements.txt (line 8)) (1.10.0+cu111)\n",
            "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil>=2.1->matplotlib==3.2.2->-r requirements.txt (line 3)) (1.15.0)\n",
            "Requirement already satisfied: pillow!=8.3.0,>=5.3.0 in /usr/local/lib/python3.7/dist-packages (from torchvision>=0.3.0->segmentation_models_pytorch==0.1.0->-r requirements.txt (line 8)) (7.1.2)\n",
            "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch->pretrainedmodels==0.7.4->segmentation_models_pytorch==0.1.0->-r requirements.txt (line 8)) (3.10.0.2)\n",
            "Installing collected packages: pyyaml\n",
            "  Attempting uninstall: pyyaml\n",
            "    Found existing installation: PyYAML 6.0\n",
            "    Uninstalling PyYAML-6.0:\n",
            "      Successfully uninstalled PyYAML-6.0\n",
            "Successfully installed pyyaml-5.3.1\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "application/vnd.colab-display-data+json": {
              "pip_warning": {
                "packages": [
                  "yaml"
                ]
              }
            }
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Requirement already satisfied: pytorch-lightning in /usr/local/lib/python3.7/dist-packages (1.5.9)\n",
            "Requirement already satisfied: pyDeprecate==0.3.1 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning) (0.3.1)\n",
            "Requirement already satisfied: torchmetrics>=0.4.1 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning) (0.7.1)\n",
            "Requirement already satisfied: PyYAML>=5.1 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning) (5.3.1)\n",
            "Requirement already satisfied: setuptools==59.5.0 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning) (59.5.0)\n",
            "Requirement already satisfied: tqdm>=4.41.0 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning) (4.48.0)\n",
            "Requirement already satisfied: numpy>=1.17.2 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning) (1.19.5)\n",
            "Requirement already satisfied: torch>=1.7.* in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning) (1.10.0+cu111)\n",
            "Requirement already satisfied: fsspec[http]!=2021.06.0,>=2021.05.0 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning) (2022.1.0)\n",
            "Requirement already satisfied: future>=0.17.1 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning) (0.18.2)\n",
            "Requirement already satisfied: packaging>=17.0 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning) (21.3)\n",
            "Requirement already satisfied: tensorboard>=2.2.0 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning) (2.7.0)\n",
            "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning) (3.10.0.2)\n",
            "Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning) (2.23.0)\n",
            "Requirement already satisfied: aiohttp in /usr/local/lib/python3.7/dist-packages (from fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning) (3.8.1)\n",
            "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>=17.0->pytorch-lightning) (3.0.7)\n",
            "Requirement already satisfied: google-auth<3,>=1.6.3 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (1.35.0)\n",
            "Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (0.4.6)\n",
            "Requirement already satisfied: wheel>=0.26 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (0.37.1)\n",
            "Requirement already satisfied: grpcio>=1.24.3 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (1.43.0)\n",
            "Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (3.3.6)\n",
            "Requirement already satisfied: tensorboard-data-server<0.7.0,>=0.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (0.6.1)\n",
            "Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (1.8.1)\n",
            "Requirement already satisfied: absl-py>=0.4 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (1.0.0)\n",
            "Requirement already satisfied: werkzeug>=0.11.15 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (1.0.1)\n",
            "Requirement already satisfied: protobuf>=3.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.2.0->pytorch-lightning) (3.17.3)\n",
            "Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from absl-py>=0.4->tensorboard>=2.2.0->pytorch-lightning) (1.15.0)\n",
            "Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.7/dist-packages (from google-auth<3,>=1.6.3->tensorboard>=2.2.0->pytorch-lightning) (0.2.8)\n",
            "Requirement already satisfied: rsa<5,>=3.1.4 in /usr/local/lib/python3.7/dist-packages (from google-auth<3,>=1.6.3->tensorboard>=2.2.0->pytorch-lightning) (4.8)\n",
            "Requirement already satisfied: cachetools<5.0,>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from google-auth<3,>=1.6.3->tensorboard>=2.2.0->pytorch-lightning) (4.2.4)\n",
            "Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.7/dist-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard>=2.2.0->pytorch-lightning) (1.3.0)\n",
            "Requirement already satisfied: importlib-metadata>=4.4 in /usr/local/lib/python3.7/dist-packages (from markdown>=2.6.8->tensorboard>=2.2.0->pytorch-lightning) (4.10.1)\n",
            "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata>=4.4->markdown>=2.6.8->tensorboard>=2.2.0->pytorch-lightning) (3.7.0)\n",
            "Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /usr/local/lib/python3.7/dist-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard>=2.2.0->pytorch-lightning) (0.4.8)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning) (2021.10.8)\n",
            "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning) (3.0.4)\n",
            "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning) (2.10)\n",
            "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning) (1.24.3)\n",
            "Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.7/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard>=2.2.0->pytorch-lightning) (3.1.1)\n",
            "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.7/dist-packages (from aiohttp->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning) (1.2.0)\n",
            "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning) (21.4.0)\n",
            "Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.7/dist-packages (from aiohttp->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning) (4.0.2)\n",
            "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.7/dist-packages (from aiohttp->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning) (6.0.2)\n",
            "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning) (1.7.2)\n",
            "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.7/dist-packages (from aiohttp->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning) (1.3.0)\n",
            "Requirement already satisfied: asynctest==0.13.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning) (0.13.0)\n",
            "Requirement already satisfied: charset-normalizer<3.0,>=2.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning) (2.0.11)\n",
            "Requirement already satisfied: wandb in /usr/local/lib/python3.7/dist-packages (0.12.10)\n",
            "Requirement already satisfied: requests<3,>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from wandb) (2.23.0)\n",
            "Requirement already satisfied: six>=1.13.0 in /usr/local/lib/python3.7/dist-packages (from wandb) (1.15.0)\n",
            "Requirement already satisfied: psutil>=5.0.0 in /usr/local/lib/python3.7/dist-packages (from wandb) (5.4.8)\n",
            "Requirement already satisfied: docker-pycreds>=0.4.0 in /usr/local/lib/python3.7/dist-packages (from wandb) (0.4.0)\n",
            "Requirement already satisfied: protobuf>=3.12.0 in /usr/local/lib/python3.7/dist-packages (from wandb) (3.17.3)\n",
            "Requirement already satisfied: GitPython>=1.0.0 in /usr/local/lib/python3.7/dist-packages (from wandb) (3.1.26)\n",
            "Requirement already satisfied: python-dateutil>=2.6.1 in /usr/local/lib/python3.7/dist-packages (from wandb) (2.8.2)\n",
            "Requirement already satisfied: shortuuid>=0.5.0 in /usr/local/lib/python3.7/dist-packages (from wandb) (1.0.8)\n",
            "Requirement already satisfied: pathtools in /usr/local/lib/python3.7/dist-packages (from wandb) (0.1.2)\n",
            "Requirement already satisfied: yaspin>=1.0.0 in /usr/local/lib/python3.7/dist-packages (from wandb) (2.1.0)\n",
            "Requirement already satisfied: Click!=8.0.0,>=7.0 in /usr/local/lib/python3.7/dist-packages (from wandb) (7.1.2)\n",
            "Requirement already satisfied: PyYAML in /usr/local/lib/python3.7/dist-packages (from wandb) (5.3.1)\n",
            "Requirement already satisfied: promise<3,>=2.0 in /usr/local/lib/python3.7/dist-packages (from wandb) (2.3)\n",
            "Requirement already satisfied: sentry-sdk>=1.0.0 in /usr/local/lib/python3.7/dist-packages (from wandb) (1.5.4)\n",
            "Requirement already satisfied: gitdb<5,>=4.0.1 in /usr/local/lib/python3.7/dist-packages (from GitPython>=1.0.0->wandb) (4.0.9)\n",
            "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.7/dist-packages (from GitPython>=1.0.0->wandb) (3.10.0.2)\n",
            "Requirement already satisfied: smmap<6,>=3.0.1 in /usr/local/lib/python3.7/dist-packages (from gitdb<5,>=4.0.1->GitPython>=1.0.0->wandb) (5.0.0)\n",
            "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.0.0->wandb) (2.10)\n",
            "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.0.0->wandb) (3.0.4)\n",
            "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.0.0->wandb) (1.24.3)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.0.0->wandb) (2021.10.8)\n",
            "Requirement already satisfied: termcolor<2.0.0,>=1.1.0 in /usr/local/lib/python3.7/dist-packages (from yaspin>=1.0.0->wandb) (1.1.0)\n",
            "Requirement already satisfied: pickle5 in /usr/local/lib/python3.7/dist-packages (0.0.12)\n",
            "Requirement already satisfied: PyYAML in /usr/local/lib/python3.7/dist-packages (5.3.1)\n",
            "Collecting PyYAML\n",
            "  Using cached PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)\n",
            "Installing collected packages: PyYAML\n",
            "  Attempting uninstall: PyYAML\n",
            "    Found existing installation: PyYAML 5.3.1\n",
            "    Uninstalling PyYAML-5.3.1:\n",
            "      Successfully uninstalled PyYAML-5.3.1\n",
            "Successfully installed PyYAML-6.0\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "application/vnd.colab-display-data+json": {
              "pip_warning": {
                "packages": [
                  "yaml"
                ]
              }
            }
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Collecting pandas==1.3.0\n",
            "  Downloading pandas-1.3.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl (10.8 MB)\n",
            "\u001b[K     |████████████████████████████████| 10.8 MB 5.0 MB/s \n",
            "\u001b[?25hRequirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.7/dist-packages (from pandas==1.3.0) (2018.9)\n",
            "Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.7/dist-packages (from pandas==1.3.0) (2.8.2)\n",
            "Requirement already satisfied: numpy>=1.17.3 in /usr/local/lib/python3.7/dist-packages (from pandas==1.3.0) (1.19.5)\n",
            "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil>=2.7.3->pandas==1.3.0) (1.15.0)\n",
            "Installing collected packages: pandas\n",
            "  Attempting uninstall: pandas\n",
            "    Found existing installation: pandas 1.1.0\n",
            "    Uninstalling pandas-1.1.0:\n",
            "      Successfully uninstalled pandas-1.1.0\n",
            "Successfully installed pandas-1.3.0\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "application/vnd.colab-display-data+json": {
              "pip_warning": {
                "packages": [
                  "pandas"
                ]
              }
            }
          },
          "metadata": {}
        }
      ],
      "source": [
        "#Installing requires library \n",
        "!pip install -r requirements.txt\n",
        "!pip install pytorch-lightning\n",
        "!pip install wandb\n",
        "!pip3 install pickle5\n",
        "!pip install -U PyYAML\n",
        "!pip install pandas==1.3.0\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0KgyVaYQS2Ww"
      },
      "outputs": [],
      "source": [
        "import pandas as pd\n",
        "import yaml\n",
        "import argparse\n",
        "import torch\n",
        "from model import YNet\n",
        "import pytorch_lightning as pl\n",
        "import pickle5 as pickle\n",
        "import weights_and_biases as wandb\n",
        "import utils\n",
        "import cv2"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "s7EqaLpWVoRn"
      },
      "source": [
        "\n",
        "# Some hyperparameters and settings"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "pGbHKoTuTl7A"
      },
      "outputs": [],
      "source": [
        "CONFIG_FILE_PATH = 'config/eth_trajnet.yaml'  # yaml config file containing all the hyperparameters\n",
        "EXPERIMENT_NAME = 'eth_ucy'  # arbitrary name for this experiment\n",
        "DATASET_NAME = 'eth'\n",
        "\n",
        "TRAIN_DATA_PATH = '/content/drive/MyDrive/ynet/data/eth_ucy/zara2_train.pkl'#train pkl file path \n",
        "TRAIN_IMAGE_PATH = '/content/drive/MyDrive/ynet/data/eth_ucy/train/zara2/oracle.png'#Train_image_file_path we are not giving train folder path beacuse there is only one scene image\n",
        "VAL_DATA_PATH = '/content/drive/MyDrive/ynet/data/eth_ucy/zara2_test.pkl'#Test_pkl_file_path\n",
        "VAL_IMAGE_PATH = '/content/drive/MyDrive/ynet/data/eth_ucy/test'#path to test folder of dataset\n",
        "OBS_LEN = 8  # in timesteps\n",
        "PRED_LEN = 12  # in timesteps\n",
        "NUM_GOALS = 20  # K_e\n",
        "NUM_TRAJ = 1  # K_a\n",
        "\n",
        "BATCH_SIZE = 4  \n",
        "scene_name = \"dataset_name\""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "bI3ew5fyfCLH",
        "outputId": "e9e971fd-bed6-44dc-d9f5-3d8e0f3ed0d8"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "{'CWS_params': {'ratio': 2, 'rot': True, 'sigma_factor': 6},\n",
              " 'batch_size': 8,\n",
              " 'decoder_channels': [64, 64, 64, 32, 32],\n",
              " 'encoder_channels': [32, 32, 64, 64, 64],\n",
              " 'kernlen': 31,\n",
              " 'learning_rate': 0.0001,\n",
              " 'loss_scale': 1000,\n",
              " 'nsig': 4,\n",
              " 'num_epochs': 300,\n",
              " 'rel_threshold': 0.002,\n",
              " 'resize': 0.67,\n",
              " 'segmentation_model_fp': 'segmentation_models/SDD_segmentation.pth',\n",
              " 'semantic_classes': 6,\n",
              " 'temperature': 1.8,\n",
              " 'unfreeze': 100,\n",
              " 'use_CWS': True,\n",
              " 'use_TTST': True,\n",
              " 'use_features_only': False,\n",
              " 'viz_epoch': 10,\n",
              " 'waypoints': [11]}"
            ]
          },
          "metadata": {},
          "execution_count": 6
        }
      ],
      "source": [
        "with open(CONFIG_FILE_PATH) as file:\n",
        "    params = yaml.load(file, Loader=yaml.FullLoader)\n",
        "experiment_name = CONFIG_FILE_PATH.split('.yaml')[0].split('config/')[1]\n",
        "params['resize'] = 0.67\n",
        "params"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ADiVkrgakslt"
      },
      "outputs": [],
      "source": [
        "with open(TRAIN_DATA_PATH, \"rb\") as fh:\n",
        "\tdf_train = pickle.load(fh)\n",
        "with open(VAL_DATA_PATH, \"rb\") as fh:\n",
        "\tdf_val = pickle.load(fh)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 423
        },
        "id": "vmfrnqCPlxxg",
        "outputId": "d7e50c04-fd6b-4d7c-8134-caeb9021b4ac"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/html": [
              "\n",
              "  <div id=\"df-4e2a6fb7-82af-4443-9eff-6029f77f7f47\">\n",
              "    <div class=\"colab-df-container\">\n",
              "      <div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>frame</th>\n",
              "      <th>trackId</th>\n",
              "      <th>x</th>\n",
              "      <th>y</th>\n",
              "      <th>sceneId</th>\n",
              "      <th>metaId</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>0</th>\n",
              "      <td>7.0</td>\n",
              "      <td>1.0</td>\n",
              "      <td>403</td>\n",
              "      <td>718</td>\n",
              "      <td>zara2</td>\n",
              "      <td>0</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>17.0</td>\n",
              "      <td>1.0</td>\n",
              "      <td>402</td>\n",
              "      <td>697</td>\n",
              "      <td>zara2</td>\n",
              "      <td>0</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2</th>\n",
              "      <td>27.0</td>\n",
              "      <td>1.0</td>\n",
              "      <td>401</td>\n",
              "      <td>676</td>\n",
              "      <td>zara2</td>\n",
              "      <td>0</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>3</th>\n",
              "      <td>37.0</td>\n",
              "      <td>1.0</td>\n",
              "      <td>400</td>\n",
              "      <td>655</td>\n",
              "      <td>zara2</td>\n",
              "      <td>0</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>4</th>\n",
              "      <td>47.0</td>\n",
              "      <td>1.0</td>\n",
              "      <td>399</td>\n",
              "      <td>634</td>\n",
              "      <td>zara2</td>\n",
              "      <td>0</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>...</th>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>7781</th>\n",
              "      <td>8057.0</td>\n",
              "      <td>159.0</td>\n",
              "      <td>368</td>\n",
              "      <td>347</td>\n",
              "      <td>zara2</td>\n",
              "      <td>158</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>7782</th>\n",
              "      <td>8067.0</td>\n",
              "      <td>159.0</td>\n",
              "      <td>368</td>\n",
              "      <td>325</td>\n",
              "      <td>zara2</td>\n",
              "      <td>158</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>7783</th>\n",
              "      <td>8077.0</td>\n",
              "      <td>159.0</td>\n",
              "      <td>369</td>\n",
              "      <td>303</td>\n",
              "      <td>zara2</td>\n",
              "      <td>158</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>7784</th>\n",
              "      <td>8087.0</td>\n",
              "      <td>159.0</td>\n",
              "      <td>371</td>\n",
              "      <td>281</td>\n",
              "      <td>zara2</td>\n",
              "      <td>158</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>7785</th>\n",
              "      <td>8097.0</td>\n",
              "      <td>159.0</td>\n",
              "      <td>372</td>\n",
              "      <td>259</td>\n",
              "      <td>zara2</td>\n",
              "      <td>158</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "<p>6240 rows × 6 columns</p>\n",
              "</div>\n",
              "      <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-4e2a6fb7-82af-4443-9eff-6029f77f7f47')\"\n",
              "              title=\"Convert this dataframe to an interactive table.\"\n",
              "              style=\"display:none;\">\n",
              "        \n",
              "  <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
              "       width=\"24px\">\n",
              "    <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
              "    <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
              "  </svg>\n",
              "      </button>\n",
              "      \n",
              "  <style>\n",
              "    .colab-df-container {\n",
              "      display:flex;\n",
              "      flex-wrap:wrap;\n",
              "      gap: 12px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert {\n",
              "      background-color: #E8F0FE;\n",
              "      border: none;\n",
              "      border-radius: 50%;\n",
              "      cursor: pointer;\n",
              "      display: none;\n",
              "      fill: #1967D2;\n",
              "      height: 32px;\n",
              "      padding: 0 0 0 0;\n",
              "      width: 32px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert:hover {\n",
              "      background-color: #E2EBFA;\n",
              "      box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
              "      fill: #174EA6;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert {\n",
              "      background-color: #3B4455;\n",
              "      fill: #D2E3FC;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert:hover {\n",
              "      background-color: #434B5C;\n",
              "      box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
              "      filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
              "      fill: #FFFFFF;\n",
              "    }\n",
              "  </style>\n",
              "\n",
              "      <script>\n",
              "        const buttonEl =\n",
              "          document.querySelector('#df-4e2a6fb7-82af-4443-9eff-6029f77f7f47 button.colab-df-convert');\n",
              "        buttonEl.style.display =\n",
              "          google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
              "\n",
              "        async function convertToInteractive(key) {\n",
              "          const element = document.querySelector('#df-4e2a6fb7-82af-4443-9eff-6029f77f7f47');\n",
              "          const dataTable =\n",
              "            await google.colab.kernel.invokeFunction('convertToInteractive',\n",
              "                                                     [key], {});\n",
              "          if (!dataTable) return;\n",
              "\n",
              "          const docLinkHtml = 'Like what you see? Visit the ' +\n",
              "            '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
              "            + ' to learn more about interactive tables.';\n",
              "          element.innerHTML = '';\n",
              "          dataTable['output_type'] = 'display_data';\n",
              "          await google.colab.output.renderOutput(dataTable, element);\n",
              "          const docLink = document.createElement('div');\n",
              "          docLink.innerHTML = docLinkHtml;\n",
              "          element.appendChild(docLink);\n",
              "        }\n",
              "      </script>\n",
              "    </div>\n",
              "  </div>\n",
              "  "
            ],
            "text/plain": [
              "       frame  trackId    x    y sceneId  metaId\n",
              "0        7.0      1.0  403  718   zara2       0\n",
              "1       17.0      1.0  402  697   zara2       0\n",
              "2       27.0      1.0  401  676   zara2       0\n",
              "3       37.0      1.0  400  655   zara2       0\n",
              "4       47.0      1.0  399  634   zara2       0\n",
              "...      ...      ...  ...  ...     ...     ...\n",
              "7781  8057.0    159.0  368  347   zara2     158\n",
              "7782  8067.0    159.0  368  325   zara2     158\n",
              "7783  8077.0    159.0  369  303   zara2     158\n",
              "7784  8087.0    159.0  371  281   zara2     158\n",
              "7785  8097.0    159.0  372  259   zara2     158\n",
              "\n",
              "[6240 rows x 6 columns]"
            ]
          },
          "metadata": {},
          "execution_count": 21
        }
      ],
      "source": [
        "#df_train we are changing sceneid to dataset name baecuse there is only one scene in dataset \n",
        "#df_train['sceneId'] = \"zara2\"\n",
        "#df_train = df_train.iloc[: , 1:]\n",
        "df_train"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "#we are changing sceneid to dataset name baecuse there is only one scene in dataset\n",
        "#df_val['sceneId'] = \"zara2\"\n",
        "#df_val = df_val.iloc[: , 1:]\n",
        "df_val"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 423
        },
        "id": "gVtihdMXHA73",
        "outputId": "239dfb56-e243-43c8-cc47-a42403868647"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/html": [
              "\n",
              "  <div id=\"df-e9d980ef-92c5-4063-b824-4f97eeb10324\">\n",
              "    <div class=\"colab-df-container\">\n",
              "      <div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>frame</th>\n",
              "      <th>trackId</th>\n",
              "      <th>x</th>\n",
              "      <th>y</th>\n",
              "      <th>sceneId</th>\n",
              "      <th>metaId</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>7797</th>\n",
              "      <td>7997.0</td>\n",
              "      <td>160.0</td>\n",
              "      <td>419</td>\n",
              "      <td>690</td>\n",
              "      <td>zara2</td>\n",
              "      <td>159</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>7798</th>\n",
              "      <td>8007.0</td>\n",
              "      <td>160.0</td>\n",
              "      <td>416</td>\n",
              "      <td>659</td>\n",
              "      <td>zara2</td>\n",
              "      <td>159</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>7799</th>\n",
              "      <td>8017.0</td>\n",
              "      <td>160.0</td>\n",
              "      <td>413</td>\n",
              "      <td>629</td>\n",
              "      <td>zara2</td>\n",
              "      <td>159</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>7800</th>\n",
              "      <td>8027.0</td>\n",
              "      <td>160.0</td>\n",
              "      <td>410</td>\n",
              "      <td>600</td>\n",
              "      <td>zara2</td>\n",
              "      <td>159</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>7801</th>\n",
              "      <td>8037.0</td>\n",
              "      <td>160.0</td>\n",
              "      <td>407</td>\n",
              "      <td>572</td>\n",
              "      <td>zara2</td>\n",
              "      <td>159</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>...</th>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>9500</th>\n",
              "      <td>10397.0</td>\n",
              "      <td>202.0</td>\n",
              "      <td>353</td>\n",
              "      <td>361</td>\n",
              "      <td>zara2</td>\n",
              "      <td>201</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>9501</th>\n",
              "      <td>10407.0</td>\n",
              "      <td>202.0</td>\n",
              "      <td>355</td>\n",
              "      <td>387</td>\n",
              "      <td>zara2</td>\n",
              "      <td>201</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>9502</th>\n",
              "      <td>10417.0</td>\n",
              "      <td>202.0</td>\n",
              "      <td>356</td>\n",
              "      <td>413</td>\n",
              "      <td>zara2</td>\n",
              "      <td>201</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>9503</th>\n",
              "      <td>10427.0</td>\n",
              "      <td>202.0</td>\n",
              "      <td>357</td>\n",
              "      <td>438</td>\n",
              "      <td>zara2</td>\n",
              "      <td>201</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>9504</th>\n",
              "      <td>10437.0</td>\n",
              "      <td>202.0</td>\n",
              "      <td>358</td>\n",
              "      <td>463</td>\n",
              "      <td>zara2</td>\n",
              "      <td>201</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "<p>1240 rows × 6 columns</p>\n",
              "</div>\n",
              "      <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-e9d980ef-92c5-4063-b824-4f97eeb10324')\"\n",
              "              title=\"Convert this dataframe to an interactive table.\"\n",
              "              style=\"display:none;\">\n",
              "        \n",
              "  <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
              "       width=\"24px\">\n",
              "    <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
              "    <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
              "  </svg>\n",
              "      </button>\n",
              "      \n",
              "  <style>\n",
              "    .colab-df-container {\n",
              "      display:flex;\n",
              "      flex-wrap:wrap;\n",
              "      gap: 12px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert {\n",
              "      background-color: #E8F0FE;\n",
              "      border: none;\n",
              "      border-radius: 50%;\n",
              "      cursor: pointer;\n",
              "      display: none;\n",
              "      fill: #1967D2;\n",
              "      height: 32px;\n",
              "      padding: 0 0 0 0;\n",
              "      width: 32px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert:hover {\n",
              "      background-color: #E2EBFA;\n",
              "      box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
              "      fill: #174EA6;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert {\n",
              "      background-color: #3B4455;\n",
              "      fill: #D2E3FC;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert:hover {\n",
              "      background-color: #434B5C;\n",
              "      box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
              "      filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
              "      fill: #FFFFFF;\n",
              "    }\n",
              "  </style>\n",
              "\n",
              "      <script>\n",
              "        const buttonEl =\n",
              "          document.querySelector('#df-e9d980ef-92c5-4063-b824-4f97eeb10324 button.colab-df-convert');\n",
              "        buttonEl.style.display =\n",
              "          google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
              "\n",
              "        async function convertToInteractive(key) {\n",
              "          const element = document.querySelector('#df-e9d980ef-92c5-4063-b824-4f97eeb10324');\n",
              "          const dataTable =\n",
              "            await google.colab.kernel.invokeFunction('convertToInteractive',\n",
              "                                                     [key], {});\n",
              "          if (!dataTable) return;\n",
              "\n",
              "          const docLinkHtml = 'Like what you see? Visit the ' +\n",
              "            '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
              "            + ' to learn more about interactive tables.';\n",
              "          element.innerHTML = '';\n",
              "          dataTable['output_type'] = 'display_data';\n",
              "          await google.colab.output.renderOutput(dataTable, element);\n",
              "          const docLink = document.createElement('div');\n",
              "          docLink.innerHTML = docLinkHtml;\n",
              "          element.appendChild(docLink);\n",
              "        }\n",
              "      </script>\n",
              "    </div>\n",
              "  </div>\n",
              "  "
            ],
            "text/plain": [
              "        frame  trackId    x    y sceneId  metaId\n",
              "7797   7997.0    160.0  419  690   zara2     159\n",
              "7798   8007.0    160.0  416  659   zara2     159\n",
              "7799   8017.0    160.0  413  629   zara2     159\n",
              "7800   8027.0    160.0  410  600   zara2     159\n",
              "7801   8037.0    160.0  407  572   zara2     159\n",
              "...       ...      ...  ...  ...     ...     ...\n",
              "9500  10397.0    202.0  353  361   zara2     201\n",
              "9501  10407.0    202.0  355  387   zara2     201\n",
              "9502  10417.0    202.0  356  413   zara2     201\n",
              "9503  10427.0    202.0  357  438   zara2     201\n",
              "9504  10437.0    202.0  358  463   zara2     201\n",
              "\n",
              "[1240 rows x 6 columns]"
            ]
          },
          "metadata": {},
          "execution_count": 22
        }
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 35
        },
        "id": "_g_T5zd1SkNy",
        "outputId": "a55800f0-7db3-4c97-aa7a-ca49a72a847b"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            },
            "text/plain": [
              "'1.10.0+cu111'"
            ]
          },
          "metadata": {},
          "execution_count": 10
        }
      ],
      "source": [
        "torch.__version__"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "bIsOhwDjl2Ug",
        "outputId": "b7418720-ab3b-492c-c4ec-d72cc8db800a"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.7/dist-packages/torch/serialization.py:671: SourceChangeWarning: source code of class 'segmentation_models_pytorch.encoders.resnet.ResNetEncoder' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n",
            "  warnings.warn(msg, SourceChangeWarning)\n",
            "/usr/local/lib/python3.7/dist-packages/torch/serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.conv.Conv2d' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n",
            "  warnings.warn(msg, SourceChangeWarning)\n",
            "/usr/local/lib/python3.7/dist-packages/torch/serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.batchnorm.BatchNorm2d' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n",
            "  warnings.warn(msg, SourceChangeWarning)\n",
            "/usr/local/lib/python3.7/dist-packages/torch/serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.activation.ReLU' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n",
            "  warnings.warn(msg, SourceChangeWarning)\n",
            "/usr/local/lib/python3.7/dist-packages/torch/serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.pooling.MaxPool2d' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n",
            "  warnings.warn(msg, SourceChangeWarning)\n",
            "/usr/local/lib/python3.7/dist-packages/torch/serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.container.Sequential' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n",
            "  warnings.warn(msg, SourceChangeWarning)\n",
            "/usr/local/lib/python3.7/dist-packages/torch/serialization.py:671: SourceChangeWarning: source code of class 'torchvision.models.resnet.Bottleneck' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n",
            "  warnings.warn(msg, SourceChangeWarning)\n",
            "/usr/local/lib/python3.7/dist-packages/torch/serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.linear.Identity' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n",
            "  warnings.warn(msg, SourceChangeWarning)\n",
            "/usr/local/lib/python3.7/dist-packages/torch/serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.container.ModuleList' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n",
            "  warnings.warn(msg, SourceChangeWarning)\n",
            "/usr/local/lib/python3.7/dist-packages/torch/serialization.py:671: SourceChangeWarning: source code of class 'segmentation_models_pytorch.base.modules.Conv2dReLU' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n",
            "  warnings.warn(msg, SourceChangeWarning)\n",
            "/usr/local/lib/python3.7/dist-packages/torch/serialization.py:671: SourceChangeWarning: source code of class 'segmentation_models_pytorch.base.modules.Activation' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n",
            "  warnings.warn(msg, SourceChangeWarning)\n",
            "/usr/local/lib/python3.7/dist-packages/torch/serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.activation.Softmax' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n",
            "  warnings.warn(msg, SourceChangeWarning)\n"
          ]
        }
      ],
      "source": [
        "model = YNet(obs_len=OBS_LEN, pred_len=PRED_LEN, params=params)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 104
        },
        "id": "h80HDvDPoBvm",
        "outputId": "33f49372-855d-419e-9f5b-e3f0f6f0bf66"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\u001b[34m\u001b[1mwandb\u001b[0m: You can find your API key in your browser here: https://wandb.ai/authorize\n",
            "\u001b[34m\u001b[1mwandb\u001b[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit: \n",
            "\u001b[34m\u001b[1mwandb\u001b[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33magv\u001b[0m (use `wandb login --relogin` to force relogin)\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/html": [
              "\n",
              "                    Syncing run <strong><a href=\"https://wandb.ai/agv/ynet/runs/1e6vo1xu\" target=\"_blank\">wandering-totem-51</a></strong> to <a href=\"https://wandb.ai/agv/ynet\" target=\"_blank\">Weights & Biases</a> (<a href=\"https://docs.wandb.com/integrations/jupyter.html\" target=\"_blank\">docs</a>).<br/>\n",
              "\n",
              "                "
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {}
        }
      ],
      "source": [
        "!wandb login --relogin\n",
        "wandb.init_wandb(params.copy(), model.model)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Dn_bPH7KTLyE"
      },
      "outputs": [],
      "source": [
        "model.train(df_train, df_val, params, train_image_path=TRAIN_IMAGE_PATH, val_image_path=VAL_IMAGE_PATH,\n",
        "            experiment_name=EXPERIMENT_NAME, batch_size=BATCH_SIZE, num_goals=NUM_GOALS, num_traj=NUM_TRAJ, \n",
        "            device=None, dataset_name=DATASET_NAME,scene_name=scene_name)"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "name": "train_ETH_UCY_training_exp_file.ipynb",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}