{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "f5e2bf88",
   "metadata": {},
   "outputs": [],
   "source": [
    "cd /content/drive/MyDrive/Human-Path-Prediction-master (1)/ynet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "972f2449",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import yaml\n",
    "import argparse\n",
    "import torch\n",
    "from model import YNet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "bc64493a",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "78150b6c",
   "metadata": {},
   "outputs": [],
   "source": [
    "CONFIG_FILE_PATH = 'config/inD_longterm.yaml'  # yaml config file containing all the hyperparameters\n",
    "EXPERIMENT_NAME = 'ind_longterm'  # arbitrary name for this experiment\n",
    "DATASET_NAME = 'ind'\n",
    "\n",
    "TRAIN_DATA_PATH = 'data/inD/train.pkl'\n",
    "TRAIN_IMAGE_PATH = 'data/inD/train'\n",
    "VAL_DATA_PATH = 'data/inD/test.pkl'\n",
    "VAL_IMAGE_PATH = 'data/inD/test'\n",
    "OBS_LEN = 5  # in timesteps\n",
    "PRED_LEN = 30  # in timesteps\n",
    "NUM_GOALS = 20  # K_e\n",
    "NUM_TRAJ = 1  # K_a\n",
    "\n",
    "BATCH_SIZE = 8"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "33d48432",
   "metadata": {},
   "outputs": [],
   "source": [
    "CONFIG_FILE_PATH = 'config/inD_longterm.yaml'  # yaml config file containing all the hyperparameters\n",
    "EXPERIMENT_NAME = 'ind_longterm'  # arbitrary name for this experiment\n",
    "DATASET_NAME = 'ind'\n",
    "\n",
    "TRAIN_DATA_PATH = 'data/inD/train.pkl'\n",
    "TRAIN_IMAGE_PATH = 'data/inD/train'\n",
    "VAL_DATA_PATH = 'data/inD/test.pkl'\n",
    "VAL_IMAGE_PATH = 'data/inD/test'\n",
    "OBS_LEN = 5  # in timesteps\n",
    "PRED_LEN = 30  # in timesteps\n",
    "NUM_GOALS = 20  # K_e\n",
    "NUM_TRAJ = 1  # K_a\n",
    "\n",
    "BATCH_SIZE = 8"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "59af295d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'CWS_params': {'ratio': 2, 'rot': True, 'sigma_factor': 6},\n",
      " 'batch_size': 8,\n",
      " 'decoder_channels': [64, 64, 64, 32, 32],\n",
      " 'encoder_channels': [32, 32, 64, 64, 64],\n",
      " 'kernlen': 31,\n",
      " 'learning_rate': 0.0001,\n",
      " 'loss_scale': 1000,\n",
      " 'nsig': 4,\n",
      " 'num_epochs': 300,\n",
      " 'rel_threshold': 0.002,\n",
      " 'resize': 0.33,\n",
      " 'segmentation_model_fp': 'segmentation_models/inD_segmentation.pth',\n",
      " 'semantic_classes': 6,\n",
      " 'temperature': 1.8,\n",
      " 'unfreeze': 100,\n",
      " 'use_CWS': True,\n",
      " 'use_TTST': True,\n",
      " 'use_features_only': False,\n",
      " 'viz_epoch': 10,\n",
      " 'waypoints': [14, 29]}"
     ]
    }
   ],
   "source": [
    "with open(CONFIG_FILE_PATH) as file:\n",
    "    params = yaml.load(file, Loader=yaml.FullLoader)\n",
    "experiment_name = CONFIG_FILE_PATH.split('.yaml')[0].split('config/')[1]\n",
    "params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "e0d3aa3b",
   "metadata": {},
   "outputs": [],
   "source": [
    "#df_train = pd.read_pickle(TRAIN_DATA_PATH)\n",
    "#df_val = pd.read_pickle(VAL_DATA_PATH)\n",
    "!pip3 install pickle5\n",
    "#df_train = pd.read_pickle(TRAIN_DATA_PATH)\n",
    "#df_val = pd.read_pickle(VAL_DATA_PATH)\n",
    "\n",
    "import pickle5 as pickle \n",
    "with open(TRAIN_DATA_PATH, \"rb\") as fh:\n",
    "    df_train = pickle.load(fh)\n",
    "with open(VAL_DATA_PATH, \"rb\") as fh1:\n",
    "    df_val = pickle.load(fh1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "ead95079",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "   trackId  frame         x         y sceneId  metaId\n",
      "0       31   2217  25.07654   6.78323      07       0\n",
      "1       31   2242  26.11484   7.72170      07       0\n",
      "2       31   2267  27.05390   8.94723      07       0\n",
      "3       31   2292  28.08326  10.18219      07       0\n",
      "4       31   2317  29.08530  11.39276      07       0"
     ]
    }
   ],
   "source": [
    "df_train.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "04e745f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = YNet(obs_len=OBS_LEN, pred_len=PRED_LEN, params=params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "da01af65",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.train(df_train, df_val, params, train_image_path=TRAIN_IMAGE_PATH, val_image_path=VAL_IMAGE_PATH, \n",
    "            experiment_name=EXPERIMENT_NAME, batch_size=BATCH_SIZE, num_goals=NUM_GOALS, num_traj=NUM_TRAJ, \n",
    "            device=None, dataset_name= 'ind')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "e44bbdec",
   "metadata": {},
   "outputs": [],
   "source": [
    "#df_train = pd.read_pickle(TRAIN_DATA_PATH)\n",
    "#df_val = pd.read_pickle(VAL_DATA_PATH)\n",
    "!pip3 install pickle5\n",
    "#df_train = pd.read_pickle(TRAIN_DATA_PATH)\n",
    "#df_val = pd.read_pickle(VAL_DATA_PATH)\n",
    "\n",
    "import pickle5 as pickle \n",
    "with open(TRAIN_DATA_PATH, \"rb\") as fh:\n",
    "    df_train = pickle.load(fh)\n",
    "with open(VAL_DATA_PATH, \"rb\") as fh1:\n",
    "    df_val = pickle.load(fh1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "e62b5559",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "   trackId  frame         x         y sceneId  metaId\n",
      "0       31   2217  25.07654   6.78323      07       0\n",
      "1       31   2242  26.11484   7.72170      07       0\n",
      "2       31   2267  27.05390   8.94723      07       0\n",
      "3       31   2292  28.08326  10.18219      07       0\n",
      "4       31   2317  29.08530  11.39276      07       0"
     ]
    }
   ],
   "source": [
    "df_train.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "1ee3c924",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = YNet(obs_len=OBS_LEN, pred_len=PRED_LEN, params=params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "9615468e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "                    Syncing run <strong><a href=\"https://wandb.ai/agv/ynet/runs/3ajw53j0\" target=\"_blank\">eternal-cosmos-9</a></strong> to <a href=\"https://wandb.ai/agv/ynet\" target=\"_blank\">Weights & Biases</a> (<a href=\"https://docs.wandb.com/integrations/jupyter.html\" target=\"_blank\">docs</a>).<br/>\n",
       "\n",
       "                "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import weights_and_biases as wandb\n",
    "wandb.init_wandb(params.copy(), model.model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "076c9bd2",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = YNet(obs_len=OBS_LEN, pred_len=PRED_LEN, params=params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "f7bb6e77",
   "metadata": {},
   "outputs": [],
   "source": [
    "import weights_and_biases as wandb\n",
    "wandb.init_wandb(params.copy(), model.model)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
