{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "a26cad62",
   "metadata": {},
   "outputs": [],
   "source": [
    "cd /content/drive/MyDrive/Human-Path-Prediction-master/ynet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "9a801e28",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import yaml\n",
    "import argparse\n",
    "import torch\n",
    "from model import YNet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "05994212",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "330a18bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "CONFIG_FILE_PATH = 'config/sdd_longterm.yaml'  # yaml config file containing all the hyperparameters\n",
    "EXPERIMENT_NAME = 'sdd_longterm'  # arbitrary name for this experiment\n",
    "DATASET_NAME = 'sdd'\n",
    "\n",
    "TRAIN_DATA_PATH = 'data/SDD/train_longterm.pkl'\n",
    "TRAIN_IMAGE_PATH = 'data/SDD/train'\n",
    "VAL_DATA_PATH = 'data/SDD/test_longterm.pkl'\n",
    "VAL_IMAGE_PATH = 'data/SDD/test'\n",
    "OBS_LEN = 5  # in timesteps\n",
    "PRED_LEN = 30  # in timesteps\n",
    "NUM_GOALS = 20  # K_e\n",
    "NUM_TRAJ = 1  # K_a\n",
    "\n",
    "BATCH_SIZE = 4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "c33f847c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'CWS_params': {'ratio': 2, 'rot': True, 'sigma_factor': 6},\n",
      " 'batch_size': 8,\n",
      " 'decoder_channels': [64, 64, 64, 32, 32],\n",
      " 'encoder_channels': [32, 32, 64, 64, 64],\n",
      " 'kernlen': 31,\n",
      " 'learning_rate': 0.0001,\n",
      " 'loss_scale': 1000,\n",
      " 'nsig': 4,\n",
      " 'num_epochs': 300,\n",
      " 'rel_threshold': 0.002,\n",
      " 'resize': 0.25,\n",
      " 'segmentation_model_fp': 'segmentation_models/SDD_segmentation.pth',\n",
      " 'semantic_classes': 6,\n",
      " 'temperature': 1.8,\n",
      " 'unfreeze': 100,\n",
      " 'use_CWS': False,\n",
      " 'use_TTST': False,\n",
      " 'use_features_only': False,\n",
      " 'viz_epoch': 10,\n",
      " 'waypoints': [14, 29]}"
     ]
    }
   ],
   "source": [
    "with open(CONFIG_FILE_PATH) as file:\n",
    "    params = yaml.load(file, Loader=yaml.FullLoader)\n",
    "experiment_name = CONFIG_FILE_PATH.split('.yaml')[0].split('config/')[1]\n",
    "params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "07315f62",
   "metadata": {},
   "outputs": [],
   "source": [
    "#df_train = pd.read_pickle(TRAIN_DATA_PATH)\n",
    "#df_val = pd.read_pickle(VAL_DATA_PATH)\n",
    "!pip3 install pickle5\n",
    "#df_train = pd.read_pickle(TRAIN_DATA_PATH)\n",
    "#df_val = pd.read_pickle(VAL_DATA_PATH)\n",
    "\n",
    "import pickle5 as pickle \n",
    "with open(TRAIN_DATA_PATH, \"rb\") as fh:\n",
    "    df_train = pickle.load(fh)\n",
    "with open(VAL_DATA_PATH, \"rb\") as fh1:\n",
    "    df_val = pickle.load(fh1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "e1f96d6e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "   trackId  frame      x      y      sceneId  metaId\n",
      "0        2   6881   17.0  893.5  bookstore_0       0\n",
      "1        2   6911   31.0  904.0  bookstore_0       0\n",
      "2        2   6941   63.0  910.5  bookstore_0       0\n",
      "3        2   6971   98.5  917.5  bookstore_0       0\n",
      "4        2   7001  134.0  919.5  bookstore_0       0"
     ]
    }
   ],
   "source": [
    "df_train.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "565c097f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "   trackId  frame      x      y      sceneId  metaId\n",
      "0        2   6881   17.0  893.5  bookstore_0       0\n",
      "1        2   6911   31.0  904.0  bookstore_0       0\n",
      "2        2   6941   63.0  910.5  bookstore_0       0\n",
      "3        2   6971   98.5  917.5  bookstore_0       0\n",
      "4        2   7001  134.0  919.5  bookstore_0       0"
     ]
    }
   ],
   "source": [
    "df_train.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "16fb2956",
   "metadata": {},
   "outputs": [],
   "source": [
    "import weights_and_biases as wandb\n",
    "wandb.init_wandb(params.copy(), model.model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "145839a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.load('/content/drive/MyDrive/Human-Path-Prediction-master/ynet/pretrained_models/fg/sdd_longterm_weights.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "c967b8c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = YNet(obs_len=OBS_LEN, pred_len=PRED_LEN, params=params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "52f4465d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "                    Syncing run <strong><a href=\"https://wandb.ai/agv/ynet/runs/2jz9cve3\" target=\"_blank\">whole-sunset-29</a></strong> to <a href=\"https://wandb.ai/agv/ynet\" target=\"_blank\">Weights & Biases</a> (<a href=\"https://docs.wandb.com/integrations/jupyter.html\" target=\"_blank\">docs</a>).<br/>\n",
       "\n",
       "                "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import weights_and_biases as wandb\n",
    "wandb.init_wandb(params.copy(), model.model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "c08818cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.load('/content/drive/MyDrive/Human-Path-Prediction-master/ynet/pretrained_models/fg/sdd_longterm_weights.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "0178b4fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.evaluate(df_val, params, image_path='data/SDD/test',\n",
    "               batch_size=BATCH_SIZE, rounds=3, \n",
    "               num_goals=NUM_GOALS, num_traj=NUM_TRAJ, device=None, dataset_name=DATASET_NAME)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "a248ff72",
   "metadata": {},
   "outputs": [],
   "source": [
    "CONFIG_FILE_PATH = 'config/sdd_longterm.yaml'  # yaml config file containing all the hyperparameters\n",
    "EXPERIMENT_NAME = 'sdd_longterm'  # arbitrary name for this experiment\n",
    "DATASET_NAME = 'sdd'\n",
    "\n",
    "TRAIN_DATA_PATH = 'data/SDD/train_longterm.pkl'\n",
    "TRAIN_IMAGE_PATH = 'data/SDD/train'\n",
    "VAL_DATA_PATH = 'data/SDD/test_longterm.pkl'\n",
    "VAL_IMAGE_PATH = 'data/SDD/test'\n",
    "OBS_LEN = 5  # in timesteps\n",
    "PRED_LEN = 30  # in timesteps\n",
    "NUM_GOALS = 20  # K_e\n",
    "NUM_TRAJ = 1  # K_a\n",
    "\n",
    "BATCH_SIZE = 4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "da841d2f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'CWS_params': {'ratio': 2, 'rot': True, 'sigma_factor': 6},\n",
      " 'batch_size': 8,\n",
      " 'decoder_channels': [64, 64, 64, 32, 32],\n",
      " 'encoder_channels': [32, 32, 64, 64, 64],\n",
      " 'kernlen': 31,\n",
      " 'learning_rate': 0.0001,\n",
      " 'loss_scale': 1000,\n",
      " 'nsig': 4,\n",
      " 'num_epochs': 300,\n",
      " 'rel_threshold': 0.002,\n",
      " 'resize': 0.25,\n",
      " 'segmentation_model_fp': 'segmentation_models/SDD_segmentation.pth',\n",
      " 'semantic_classes': 6,\n",
      " 'temperature': 1.8,\n",
      " 'unfreeze': 100,\n",
      " 'use_CWS': True,\n",
      " 'use_TTST': False,\n",
      " 'use_features_only': False,\n",
      " 'viz_epoch': 10,\n",
      " 'waypoints': [14, 29]}"
     ]
    }
   ],
   "source": [
    "with open(CONFIG_FILE_PATH) as file:\n",
    "    params = yaml.load(file, Loader=yaml.FullLoader)\n",
    "experiment_name = CONFIG_FILE_PATH.split('.yaml')[0].split('config/')[1]\n",
    "params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "d1e52451",
   "metadata": {},
   "outputs": [],
   "source": [
    "#df_train = pd.read_pickle(TRAIN_DATA_PATH)\n",
    "#df_val = pd.read_pickle(VAL_DATA_PATH)\n",
    "!pip3 install pickle5\n",
    "#df_train = pd.read_pickle(TRAIN_DATA_PATH)\n",
    "#df_val = pd.read_pickle(VAL_DATA_PATH)\n",
    "\n",
    "import pickle5 as pickle \n",
    "with open(TRAIN_DATA_PATH, \"rb\") as fh:\n",
    "    df_train = pickle.load(fh)\n",
    "with open(VAL_DATA_PATH, \"rb\") as fh1:\n",
    "    df_val = pickle.load(fh1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "a1d5b6d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = YNet(obs_len=OBS_LEN, pred_len=PRED_LEN, params=params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "6c4ff396",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.load('/content/drive/MyDrive/Human-Path-Prediction-master/ynet/pretrained_models/fg/sdd_longterm_weights.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "4bcdf266",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.evaluate(df_val, params, image_path='data/SDD/test',\n",
    "               batch_size=BATCH_SIZE, rounds=3, \n",
    "               num_goals=NUM_GOALS, num_traj=NUM_TRAJ, device=None, dataset_name=DATASET_NAME)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "e8fc2851",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import yaml\n",
    "import argparse\n",
    "import torch\n",
    "from model import YNet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "ae048be9",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "471db192",
   "metadata": {},
   "outputs": [],
   "source": [
    "CONFIG_FILE_PATH = 'config/sdd_longterm.yaml'  # yaml config file containing all the hyperparameters\n",
    "EXPERIMENT_NAME = 'sdd_longterm'  # arbitrary name for this experiment\n",
    "DATASET_NAME = 'sdd'\n",
    "\n",
    "TRAIN_DATA_PATH = 'data/SDD/train_longterm.pkl'\n",
    "TRAIN_IMAGE_PATH = 'data/SDD/train'\n",
    "VAL_DATA_PATH = 'data/SDD/test_longterm.pkl'\n",
    "VAL_IMAGE_PATH = 'data/SDD/test'\n",
    "OBS_LEN = 5  # in timesteps\n",
    "PRED_LEN = 30  # in timesteps\n",
    "NUM_GOALS = 20  # K_e\n",
    "NUM_TRAJ = 1  # K_a\n",
    "\n",
    "BATCH_SIZE = 4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "2dadab4e",
   "metadata": {},
   "outputs": [],
   "source": [
    "CONFIG_FILE_PATH = 'config/sdd_longterm.yaml'  # yaml config file containing all the hyperparameters\n",
    "EXPERIMENT_NAME = 'sdd_longterm'  # arbitrary name for this experiment\n",
    "DATASET_NAME = 'sdd'\n",
    "\n",
    "TRAIN_DATA_PATH = 'data/SDD/train_longterm.pkl'\n",
    "TRAIN_IMAGE_PATH = 'data/SDD/train'\n",
    "VAL_DATA_PATH = 'data/SDD/test_longterm.pkl'\n",
    "VAL_IMAGE_PATH = 'data/SDD/test'\n",
    "OBS_LEN = 5  # in timesteps\n",
    "PRED_LEN = 30  # in timesteps\n",
    "NUM_GOALS = 20  # K_e\n",
    "NUM_TRAJ = 1  # K_a\n",
    "\n",
    "BATCH_SIZE = 4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "12faa7f1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'CWS_params': {'ratio': 2, 'rot': True, 'sigma_factor': 6},\n",
      " 'batch_size': 8,\n",
      " 'decoder_channels': [64, 64, 64, 32, 32],\n",
      " 'encoder_channels': [32, 32, 64, 64, 64],\n",
      " 'kernlen': 31,\n",
      " 'learning_rate': 0.0001,\n",
      " 'loss_scale': 1000,\n",
      " 'nsig': 4,\n",
      " 'num_epochs': 300,\n",
      " 'rel_threshold': 0.002,\n",
      " 'resize': 0.25,\n",
      " 'segmentation_model_fp': 'segmentation_models/SDD_segmentation.pth',\n",
      " 'semantic_classes': 6,\n",
      " 'temperature': 1.8,\n",
      " 'unfreeze': 100,\n",
      " 'use_CWS': True,\n",
      " 'use_TTST': True,\n",
      " 'use_features_only': False,\n",
      " 'viz_epoch': 10,\n",
      " 'waypoints': [14, 29]}"
     ]
    }
   ],
   "source": [
    "with open(CONFIG_FILE_PATH) as file:\n",
    "    params = yaml.load(file, Loader=yaml.FullLoader)\n",
    "experiment_name = CONFIG_FILE_PATH.split('.yaml')[0].split('config/')[1]\n",
    "params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "18e62294",
   "metadata": {},
   "outputs": [],
   "source": [
    "#df_train = pd.read_pickle(TRAIN_DATA_PATH)\n",
    "#df_val = pd.read_pickle(VAL_DATA_PATH)\n",
    "!pip3 install pickle5\n",
    "#df_train = pd.read_pickle(TRAIN_DATA_PATH)\n",
    "#df_val = pd.read_pickle(VAL_DATA_PATH)\n",
    "\n",
    "import pickle5 as pickle \n",
    "with open(TRAIN_DATA_PATH, \"rb\") as fh:\n",
    "    df_train = pickle.load(fh)\n",
    "with open(VAL_DATA_PATH, \"rb\") as fh1:\n",
    "    df_val = pickle.load(fh1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "2674984e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "   trackId  frame      x      y      sceneId  metaId\n",
      "0        2   6881   17.0  893.5  bookstore_0       0\n",
      "1        2   6911   31.0  904.0  bookstore_0       0\n",
      "2        2   6941   63.0  910.5  bookstore_0       0\n",
      "3        2   6971   98.5  917.5  bookstore_0       0\n",
      "4        2   7001  134.0  919.5  bookstore_0       0"
     ]
    }
   ],
   "source": [
    "df_train.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "ff578248",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = YNet(obs_len=OBS_LEN, pred_len=PRED_LEN, params=params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "9af35a71",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.load('/content/drive/MyDrive/Human-Path-Prediction-master/ynet/pretrained_models/fg/sdd_longterm_weights.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "d1da8f4d",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.evaluate(df_val, params, image_path='data/SDD/test',\n",
    "               batch_size=BATCH_SIZE, rounds=3, \n",
    "               num_goals=NUM_GOALS, num_traj=NUM_TRAJ, device=None, dataset_name=DATASET_NAME)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "ff91e073",
   "metadata": {},
   "outputs": [],
   "source": [
    "cd /content/drive/MyDrive/Human-Path-Prediction-master/ynet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "ff271d5b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import yaml\n",
    "import argparse\n",
    "import torch\n",
    "from model import YNet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "b394de29",
   "metadata": {},
   "outputs": [],
   "source": [
    "CONFIG_FILE_PATH = 'config/inD_longterm.yaml'  # yaml config file containing all the hyperparameters\n",
    "EXPERIMENT_NAME = 'inD_longterm'  # arbitrary name for this experiment\n",
    "DATASET_NAME = 'inD'\n",
    "\n",
    "TRAIN_DATA_PATH = 'data/inD/train_longterm.pkl'\n",
    "TRAIN_IMAGE_PATH = 'data/inD/train'\n",
    "VAL_DATA_PATH = 'data/inD/test_longterm.pkl'\n",
    "VAL_IMAGE_PATH = 'data/inD/test'\n",
    "OBS_LEN = 5  # in timesteps\n",
    "PRED_LEN = 30  # in timesteps\n",
    "NUM_GOALS = 20  # K_e\n",
    "NUM_TRAJ = 1  # K_a\n",
    "\n",
    "BATCH_SIZE = 4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "9be29205",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'CWS_params': {'ratio': 2, 'rot': True, 'sigma_factor': 6},\n",
      " 'batch_size': 8,\n",
      " 'decoder_channels': [64, 64, 64, 32, 32],\n",
      " 'encoder_channels': [32, 32, 64, 64, 64],\n",
      " 'kernlen': 31,\n",
      " 'learning_rate': 0.0001,\n",
      " 'loss_scale': 1000,\n",
      " 'nsig': 4,\n",
      " 'num_epochs': 300,\n",
      " 'rel_threshold': 0.002,\n",
      " 'resize': 0.33,\n",
      " 'segmentation_model_fp': 'segmentation_models/inD_segmentation.pth',\n",
      " 'semantic_classes': 6,\n",
      " 'temperature': 1.8,\n",
      " 'unfreeze': 100,\n",
      " 'use_CWS': True,\n",
      " 'use_TTST': False,\n",
      " 'use_features_only': False,\n",
      " 'viz_epoch': 10,\n",
      " 'waypoints': [14, 29]}"
     ]
    }
   ],
   "source": [
    "with open(CONFIG_FILE_PATH) as file:\n",
    "    params = yaml.load(file, Loader=yaml.FullLoader)\n",
    "experiment_name = CONFIG_FILE_PATH.split('.yaml')[0].split('config/')[1]\n",
    "params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "da46a0cb",
   "metadata": {},
   "outputs": [],
   "source": [
    "#df_train = pd.read_pickle(TRAIN_DATA_PATH)\n",
    "#df_val = pd.read_pickle(VAL_DATA_PATH)\n",
    "!pip3 install pickle5\n",
    "#df_train = pd.read_pickle(TRAIN_DATA_PATH)\n",
    "#df_val = pd.read_pickle(VAL_DATA_PATH)\n",
    "\n",
    "import pickle5 as pickle \n",
    "with open(TRAIN_DATA_PATH, \"rb\") as fh:\n",
    "    df_train = pickle.load(fh)\n",
    "with open(VAL_DATA_PATH, \"rb\") as fh1:\n",
    "    df_val = pickle.load(fh1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "df6d8688",
   "metadata": {},
   "outputs": [],
   "source": [
    "CONFIG_FILE_PATH = 'config/inD_longterm.yaml'  # yaml config file containing all the hyperparameters\n",
    "EXPERIMENT_NAME = 'inD_longterm'  # arbitrary name for this experiment\n",
    "DATASET_NAME = 'inD'\n",
    "\n",
    "TRAIN_DATA_PATH = 'data/inD/train.pkl'\n",
    "TRAIN_IMAGE_PATH = 'data/inD/train'\n",
    "VAL_DATA_PATH = 'data/inD/test.pkl'\n",
    "VAL_IMAGE_PATH = 'data/inD/test'\n",
    "OBS_LEN = 5  # in timesteps\n",
    "PRED_LEN = 30  # in timesteps\n",
    "NUM_GOALS = 20  # K_e\n",
    "NUM_TRAJ = 1  # K_a\n",
    "\n",
    "BATCH_SIZE = 4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "b5341d66",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'CWS_params': {'ratio': 2, 'rot': True, 'sigma_factor': 6},\n",
      " 'batch_size': 8,\n",
      " 'decoder_channels': [64, 64, 64, 32, 32],\n",
      " 'encoder_channels': [32, 32, 64, 64, 64],\n",
      " 'kernlen': 31,\n",
      " 'learning_rate': 0.0001,\n",
      " 'loss_scale': 1000,\n",
      " 'nsig': 4,\n",
      " 'num_epochs': 300,\n",
      " 'rel_threshold': 0.002,\n",
      " 'resize': 0.33,\n",
      " 'segmentation_model_fp': 'segmentation_models/inD_segmentation.pth',\n",
      " 'semantic_classes': 6,\n",
      " 'temperature': 1.8,\n",
      " 'unfreeze': 100,\n",
      " 'use_CWS': True,\n",
      " 'use_TTST': False,\n",
      " 'use_features_only': False,\n",
      " 'viz_epoch': 10,\n",
      " 'waypoints': [14, 29]}"
     ]
    }
   ],
   "source": [
    "with open(CONFIG_FILE_PATH) as file:\n",
    "    params = yaml.load(file, Loader=yaml.FullLoader)\n",
    "experiment_name = CONFIG_FILE_PATH.split('.yaml')[0].split('config/')[1]\n",
    "params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "4aa3c31b",
   "metadata": {},
   "outputs": [],
   "source": [
    "#df_train = pd.read_pickle(TRAIN_DATA_PATH)\n",
    "#df_val = pd.read_pickle(VAL_DATA_PATH)\n",
    "!pip3 install pickle5\n",
    "#df_train = pd.read_pickle(TRAIN_DATA_PATH)\n",
    "#df_val = pd.read_pickle(VAL_DATA_PATH)\n",
    "\n",
    "import pickle5 as pickle \n",
    "with open(TRAIN_DATA_PATH, \"rb\") as fh:\n",
    "    df_train = pickle.load(fh)\n",
    "with open(VAL_DATA_PATH, \"rb\") as fh1:\n",
    "    df_val = pickle.load(fh1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "701ae2ea",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "   trackId  frame         x         y sceneId  metaId\n",
      "0       31   2217  25.07654   6.78323      07       0\n",
      "1       31   2242  26.11484   7.72170      07       0\n",
      "2       31   2267  27.05390   8.94723      07       0\n",
      "3       31   2292  28.08326  10.18219      07       0\n",
      "4       31   2317  29.08530  11.39276      07       0"
     ]
    }
   ],
   "source": [
    "df_train.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "e110ca10",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = YNet(obs_len=OBS_LEN, pred_len=PRED_LEN, params=params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "69bf931d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import weights_and_biases as wandb\n",
    "wandb.init_wandb(params.copy(), model.model)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
