{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "eth-pecnet-torch-no-pooling.ipynb",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "code",
      "metadata": {
        "id": "IwqqqKRoOKd4"
      },
      "source": [
        "import os\n",
        "import math\n",
        "import random\n",
        "import pickle\n",
        "\n",
        "import yaml\n",
        "import numpy as np\n",
        "\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.optim as optim\n",
        "import torch.nn.functional as F\n",
        "\n",
        "from torch.utils import data\n",
        "from torch.autograd import Variable\n",
        "from torch.nn.utils import weight_norm\n",
        "from torch.distributions.normal import Normal"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "dYDYJBYuqAPA"
      },
      "source": [
        "def from_npz_cache(file_path: str):\n",
        "  npz = np.load(file_path, allow_pickle=True)\n",
        "  return npz['observations'], npz['obs_speed'], npz['targets'], npz[\n",
        "      'target_speed'], npz['mean'], npz['std']"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Ruu6dVTuTTg6",
        "outputId": "1bcf3730-8cb3-4100-e3da-d50925e3ba80",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "source": [
        "def download_assets():\n",
        "  if not os.path.exists('/content/optimal.yaml'):\n",
        "    !wget -P /content https://raw.githubusercontent.com/HarshayuGirase/PECNet/master/config/optimal.yaml\n",
        "\n",
        "  if not os.path.exists('/content/eth_test.npz'):\n",
        "    !wget --no-check-certificate -r 'https://docs.google.com/uc?export=download&id=1QvGR2cyduaO2kffcEtINlobS7cG5UBWp' -O eth_test.npz\n",
        "\n",
        "  if not os.path.exists('/content/eth_train.npz'):\n",
        "    !wget --no-check-certificate -r 'https://docs.google.com/uc?export=download&id=14m6fQNddxsomfDQBmcNew06j-PVsOAsx' -O eth_train.npz\n",
        "  \n",
        "\n",
        "download_assets()"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "--2020-10-07 22:04:25--  https://raw.githubusercontent.com/HarshayuGirase/PECNet/master/config/optimal.yaml\n",
            "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.0.133, 151.101.64.133, 151.101.128.133, ...\n",
            "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.0.133|:443... connected.\n",
            "HTTP request sent, awaiting response... 200 OK\n",
            "Length: 600 [text/plain]\n",
            "Saving to: ‘/content/optimal.yaml’\n",
            "\n",
            "optimal.yaml        100%[===================>]     600  --.-KB/s    in 0s      \n",
            "\n",
            "2020-10-07 22:04:26 (30.5 MB/s) - ‘/content/optimal.yaml’ saved [600/600]\n",
            "\n",
            "WARNING: combining -O with -r or -p will mean that all downloaded content\n",
            "will be placed in the single file you specified.\n",
            "\n",
            "--2020-10-07 22:04:26--  https://docs.google.com/uc?export=download&id=1QvGR2cyduaO2kffcEtINlobS7cG5UBWp\n",
            "Resolving docs.google.com (docs.google.com)... 74.125.195.102, 74.125.195.139, 74.125.195.113, ...\n",
            "Connecting to docs.google.com (docs.google.com)|74.125.195.102|:443... connected.\n",
            "HTTP request sent, awaiting response... 302 Moved Temporarily\n",
            "Location: https://doc-08-a4-docs.googleusercontent.com/docs/securesc/ha0ro937gcuc7l7deffksulhg5h7mbp1/3vt5o3abftsbv7csnv7nf6akllonrvqv/1602108225000/12025453477311729321/*/1QvGR2cyduaO2kffcEtINlobS7cG5UBWp?e=download [following]\n",
            "Warning: wildcards not supported in HTTP.\n",
            "--2020-10-07 22:04:26--  https://doc-08-a4-docs.googleusercontent.com/docs/securesc/ha0ro937gcuc7l7deffksulhg5h7mbp1/3vt5o3abftsbv7csnv7nf6akllonrvqv/1602108225000/12025453477311729321/*/1QvGR2cyduaO2kffcEtINlobS7cG5UBWp?e=download\n",
            "Resolving doc-08-a4-docs.googleusercontent.com (doc-08-a4-docs.googleusercontent.com)... 172.253.117.132, 2607:f8b0:400e:c0a::84\n",
            "Connecting to doc-08-a4-docs.googleusercontent.com (doc-08-a4-docs.googleusercontent.com)|172.253.117.132|:443... connected.\n",
            "HTTP request sent, awaiting response... 200 OK\n",
            "Length: 234500 (229K) [application/x-zip]\n",
            "Saving to: ‘eth_test.npz’\n",
            "\n",
            "eth_test.npz        100%[===================>] 229.00K  --.-KB/s    in 0.002s  \n",
            "\n",
            "2020-10-07 22:04:27 (99.9 MB/s) - ‘eth_test.npz’ saved [234500/234500]\n",
            "\n",
            "FINISHED --2020-10-07 22:04:27--\n",
            "Total wall clock time: 1.1s\n",
            "Downloaded: 1 files, 229K in 0.002s (99.9 MB/s)\n",
            "WARNING: combining -O with -r or -p will mean that all downloaded content\n",
            "will be placed in the single file you specified.\n",
            "\n",
            "--2020-10-07 22:04:27--  https://docs.google.com/uc?export=download&id=14m6fQNddxsomfDQBmcNew06j-PVsOAsx\n",
            "Resolving docs.google.com (docs.google.com)... 74.125.142.138, 74.125.142.101, 74.125.142.100, ...\n",
            "Connecting to docs.google.com (docs.google.com)|74.125.142.138|:443... connected.\n",
            "HTTP request sent, awaiting response... 302 Moved Temporarily\n",
            "Location: https://doc-0o-a4-docs.googleusercontent.com/docs/securesc/ha0ro937gcuc7l7deffksulhg5h7mbp1/5gkhok28gkm1eiddgtg5gub8f1kqcpgd/1602108225000/12025453477311729321/*/14m6fQNddxsomfDQBmcNew06j-PVsOAsx?e=download [following]\n",
            "Warning: wildcards not supported in HTTP.\n",
            "--2020-10-07 22:04:28--  https://doc-0o-a4-docs.googleusercontent.com/docs/securesc/ha0ro937gcuc7l7deffksulhg5h7mbp1/5gkhok28gkm1eiddgtg5gub8f1kqcpgd/1602108225000/12025453477311729321/*/14m6fQNddxsomfDQBmcNew06j-PVsOAsx?e=download\n",
            "Resolving doc-0o-a4-docs.googleusercontent.com (doc-0o-a4-docs.googleusercontent.com)... 172.253.117.132, 2607:f8b0:400e:c0a::84\n",
            "Connecting to doc-0o-a4-docs.googleusercontent.com (doc-0o-a4-docs.googleusercontent.com)|172.253.117.132|:443... connected.\n",
            "HTTP request sent, awaiting response... 200 OK\n",
            "Length: unspecified [application/x-zip]\n",
            "Saving to: ‘eth_train.npz’\n",
            "\n",
            "eth_train.npz           [  <=>               ]  18.50M  58.0MB/s    in 0.3s    \n",
            "\n",
            "2020-10-07 22:04:29 (58.0 MB/s) - ‘eth_train.npz’ saved [19398020]\n",
            "\n",
            "FINISHED --2020-10-07 22:04:29--\n",
            "Total wall clock time: 1.9s\n",
            "Downloaded: 1 files, 18M in 0.3s (58.0 MB/s)\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "W_oUNg8xS1un"
      },
      "source": [
        "device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LE074vugOlvW"
      },
      "source": [
        "## Model Definitions"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "e9JbEpzcOjo6"
      },
      "source": [
        "class MLP(nn.Module):\n",
        "  def __init__(self, input_dim, output_dim, hidden_size=(1024, 512), activation='relu', discrim=False, dropout=-1):\n",
        "    super(MLP, self).__init__()\n",
        "    dims = []\n",
        "    dims.append(input_dim)\n",
        "    dims.extend(hidden_size)\n",
        "    dims.append(output_dim)\n",
        "    self.layers = nn.ModuleList()\n",
        "    for i in range(len(dims)-1):\n",
        "        self.layers.append(nn.Linear(dims[i], dims[i+1]))\n",
        "\n",
        "    if activation == 'relu':\n",
        "        self.activation = nn.ReLU()\n",
        "    elif activation == 'sigmoid':\n",
        "        self.activation = nn.Sigmoid()\n",
        "\n",
        "    self.sigmoid = nn.Sigmoid() if discrim else None\n",
        "    self.dropout = dropout\n",
        "\n",
        "  def forward(self, x):\n",
        "    for i in range(len(self.layers)):\n",
        "        x = self.layers[i](x)\n",
        "        if i != len(self.layers)-1:\n",
        "            x = self.activation(x)\n",
        "            if self.dropout != -1:\n",
        "                x = nn.Dropout(min(0.1, self.dropout/3) if i == 1 else self.dropout)(x)\n",
        "        elif self.sigmoid:\n",
        "            x = self.sigmoid(x)\n",
        "    return x"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "N_1XDMj4Ot7r"
      },
      "source": [
        "class PECNet(nn.Module):\n",
        "  def __init__(self, \n",
        "               enc_past_size, \n",
        "               enc_dest_size, \n",
        "               enc_latent_size, \n",
        "               dec_size, \n",
        "               predictor_size, \n",
        "               fdim, \n",
        "               zdim, \n",
        "               sigma, \n",
        "               past_length, \n",
        "               future_length, \n",
        "               verbose):\n",
        "    '''\n",
        "    Args:\n",
        "        size parameters: Dimension sizes\n",
        "        sigma: Standard deviation used for sampling N(0, sigma)\n",
        "        past_length: Length of past history (number of timesteps)\n",
        "        future_length: Length of future trajectory to be predicted\n",
        "    '''\n",
        "    super(PECNet, self).__init__()\n",
        "\n",
        "    self.zdim = zdim\n",
        "    self.sigma = sigma\n",
        "\n",
        "    # takes in the past\n",
        "    self.encoder_past = MLP(input_dim = past_length*2, output_dim = fdim, hidden_size=enc_past_size)\n",
        "\n",
        "    self.encoder_dest = MLP(input_dim = 2, output_dim = fdim, hidden_size=enc_dest_size)\n",
        "\n",
        "    self.encoder_latent = MLP(input_dim = 2*fdim, output_dim = 2*zdim, hidden_size=enc_latent_size)\n",
        "\n",
        "    self.decoder = MLP(input_dim = fdim + zdim, output_dim = 2, hidden_size=dec_size)\n",
        "\n",
        "    self.predictor = MLP(input_dim = 2*fdim, output_dim = 2*(future_length-1), hidden_size=predictor_size)\n",
        "\n",
        "    architecture = lambda net: [l.in_features for l in net.layers] + [net.layers[-1].out_features]\n",
        "\n",
        "    if verbose:\n",
        "        print(\"Past Encoder architecture : {}\".format(architecture(self.encoder_past)))\n",
        "        print(\"Dest Encoder architecture : {}\".format(architecture(self.encoder_dest)))\n",
        "        print(\"Latent Encoder architecture : {}\".format(architecture(self.encoder_latent)))\n",
        "        print(\"Decoder architecture : {}\".format(architecture(self.decoder)))\n",
        "        print(\"Predictor architecture : {}\".format(architecture(self.predictor)))\n",
        "\n",
        "  def forward(self, x, dest = None, device=torch.device('cpu')):\n",
        "\n",
        "    # provide destination iff training\n",
        "    # assert model.training\n",
        "    assert self.training ^ (dest is None)\n",
        "    \n",
        "    # encode\n",
        "    ftraj = self.encoder_past(x)\n",
        "\n",
        "    if not self.training:\n",
        "        z = torch.Tensor(x.size(0), self.zdim)\n",
        "        z.normal_(0, self.sigma)\n",
        "\n",
        "    else:\n",
        "        # during training, use the destination to produce generated_dest and use it again to predict final future points\n",
        "\n",
        "        # CVAE code\n",
        "        dest_features = self.encoder_dest(dest)\n",
        "        features = torch.cat((ftraj, dest_features), dim = 1)\n",
        "        latent =  self.encoder_latent(features)\n",
        "\n",
        "        mu = latent[:, 0:self.zdim] # 2-d array\n",
        "        logvar = latent[:, self.zdim:] # 2-d array\n",
        "\n",
        "        var = logvar.mul(0.5).exp_()\n",
        "        eps = torch.DoubleTensor(var.size()).normal_()\n",
        "        eps = eps.to(device)\n",
        "        z = eps.mul(var).add_(mu)\n",
        "\n",
        "    z = z.double().to(device)\n",
        "    decoder_input = torch.cat((ftraj, z), dim = 1)\n",
        "    generated_dest = self.decoder(decoder_input)\n",
        "\n",
        "    if self.training:\n",
        "        generated_dest_features = self.encoder_dest(generated_dest)\n",
        "\n",
        "        prediction_features = torch.cat((ftraj, generated_dest_features), dim = 1)\n",
        "\n",
        "        pred_future = self.predictor(prediction_features)\n",
        "        return generated_dest, mu, logvar, pred_future\n",
        "\n",
        "    return generated_dest\n",
        "\n",
        "  # separated for forward to let choose the best destination\n",
        "  # def predict(self, past, generated_dest, mask, initial_pos):\n",
        "  def predict(self, past, generated_dest):\n",
        "    ftraj = self.encoder_past(past)\n",
        "    generated_dest_features = self.encoder_dest(generated_dest)\n",
        "\n",
        "    prediction_features = torch.cat((ftraj, generated_dest_features), dim = 1)\n",
        "\n",
        "    interpolated_future = self.predictor(prediction_features)\n",
        "    return interpolated_future"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_I_gQft7Pa7j"
      },
      "source": [
        "## Loss"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Vg0a_KB_PDoz"
      },
      "source": [
        "def calculate_loss(x, reconstructed_x, mean, log_var, criterion, future, interpolated_future):\n",
        "\t# reconstruction loss\n",
        "\tRCL_dest = criterion(x, reconstructed_x)\n",
        "\n",
        "\tADL_traj = criterion(future, interpolated_future) # better with l2 loss\n",
        "\n",
        "\t# kl divergence loss\n",
        "\tKLD = -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp())\n",
        "\n",
        "\treturn RCL_dest, KLD, ADL_traj"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zIS-8upZPWL0"
      },
      "source": [
        "## Dataset"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "xXclR9kOPQr6"
      },
      "source": [
        "class SocialDataset(data.Dataset):\n",
        "\tdef __init__(self, npz_path, set_name=\"train\", id=False, verbose=True):\n",
        "\t\tobservations, _, targets, _, _, _ = from_npz_cache(npz_path)\n",
        "\t\tself.traj = np.concatenate([observations, targets], axis=1)\n",
        "\t\n",
        "\tdef __len__(self):\n",
        "\t\treturn len(self.traj)\t\n",
        "\n",
        "\tdef __getitem__(self, idx):\n",
        "\t\treturn self.traj[idx]\t\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iMM0iIL1UkuP"
      },
      "source": [
        "## Training"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ZU3HNOdDSsYB"
      },
      "source": [
        "def load_hyper_parameters(file_name='optimal.yaml'):\n",
        "  with open(os.path.join('/content', file_name), 'r') as file:\n",
        "    hyper_params = yaml.load(file)\n",
        "\n",
        "  return hyper_params"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "bpl-Tpo1YBAU"
      },
      "source": [
        "hyper_params = load_hyper_parameters()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "FNRYRrWFVD1J"
      },
      "source": [
        "def train(train_dataset, model, optimizer):\n",
        "\n",
        "\tdataloader = data.DataLoader(\n",
        "\t\t\ttrain_dataset, batch_size=100, shuffle=True, num_workers=0)\n",
        "\n",
        "\tmodel.train()\n",
        "\ttrain_loss = 0\n",
        "\ttotal_rcl, total_kld, total_adl = 0, 0, 0\n",
        "\tcriterion = nn.MSELoss()\n",
        "\n",
        "\tfor i, trajx in enumerate(dataloader):\n",
        "\n",
        "\t\ttraj = trajx - trajx[:, :1, :]\n",
        "\t\ttraj *= hyper_params[\"data_scale\"]\t\t\n",
        "\n",
        "\t\ttraj = torch.DoubleTensor(traj).to(device)\n",
        "\t\tx = traj[:, :hyper_params['past_length'], :]\n",
        "\t\ty = traj[:, hyper_params['past_length']:, :]\n",
        "\n",
        "\t\tx = x.view(-1, x.shape[1]*x.shape[2]) # (x,y,x,y ... )\n",
        "\t\tx = x.to(device)\n",
        "\t\tdest = y[:, -1, :].to(device)\n",
        "\t\tfuture = y[:, :-1, :].view(y.size(0),-1).to(device)\n",
        "\n",
        "\t\t# dest_recon, mu, var, interpolated_future = model.forward(x, initial_pos, dest=dest, mask=mask, device=device)\n",
        "\t\tdest_recon, mu, var, interpolated_future = model.forward(x, dest=dest, device=device)\n",
        "\n",
        "\t\toptimizer.zero_grad()\n",
        "\t\trcl, kld, adl = calculate_loss(dest, dest_recon, mu, var, criterion, future, interpolated_future)\n",
        "\t\tloss = rcl + kld*hyper_params[\"kld_reg\"] + adl*hyper_params[\"adl_reg\"]\n",
        "\t\tloss.backward()\n",
        "\n",
        "\t\ttrain_loss += loss.item()\n",
        "\t\ttotal_rcl += rcl.item()\n",
        "\t\ttotal_kld += kld.item()\n",
        "\t\ttotal_adl += adl.item()\n",
        "\t\toptimizer.step()\n",
        "\n",
        "\treturn train_loss, total_rcl, total_kld, total_adl"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "zlB0EaGZVO21"
      },
      "source": [
        "def test(test_dataset, model, best_of_n = 1):\n",
        "\t'''Evalutes test metrics. Assumes all test data is in one batch'''\n",
        "\n",
        "\tdataloader = data.DataLoader(\n",
        "\t\t\ttest_dataset, batch_size=len(test_dataset), shuffle=False, num_workers=0)\n",
        "\n",
        "\tmodel.eval()\n",
        "\tassert best_of_n >= 1 and type(best_of_n) == int\n",
        "\n",
        "\twith torch.no_grad():\n",
        "\t\tfor i, trajx in enumerate(dataloader):\n",
        "\n",
        "\t\t\ttraj = trajx - trajx[:, :1, :]\n",
        "\t\t\ttraj *= hyper_params[\"data_scale\"]\t\t\t\n",
        "\n",
        "\t\t\ttraj = torch.DoubleTensor(traj).to(device)\n",
        "\t\t\tx = traj[:, :hyper_params['past_length'], :]\n",
        "\t\t\ty = traj[:, hyper_params['past_length']:, :]\n",
        "\t\t\ty = y.cpu().numpy()\n",
        "\n",
        "\t\t\t# reshape the data\n",
        "\t\t\tx = x.view(-1, x.shape[1]*x.shape[2])\n",
        "\t\t\tx = x.to(device)\n",
        "\n",
        "\t\t\tdest = y[:, -1, :]\n",
        "\t\t\tall_l2_errors_dest = []\n",
        "\t\t\tall_guesses = []\n",
        "\t\t\tfor _ in range(best_of_n):\n",
        "\t\t\t\t# dest_recon = model.forward(x, initial_pos, device=device)\n",
        "\t\t\t\tdest_recon = model.forward(x, device=device)\n",
        "\t\t\t\tdest_recon = dest_recon.cpu().numpy()\n",
        "\t\t\t\tall_guesses.append(dest_recon)\n",
        "\n",
        "\t\t\t\tl2error_sample = np.linalg.norm(dest_recon - dest, axis = 1)\n",
        "\t\t\t\tall_l2_errors_dest.append(l2error_sample)\n",
        "\n",
        "\t\t\tall_l2_errors_dest = np.array(all_l2_errors_dest)\n",
        "\t\t\tall_guesses = np.array(all_guesses)\n",
        "\t\t\t# average error\n",
        "\t\t\tl2error_avg_dest = np.mean(all_l2_errors_dest)\n",
        "\n",
        "\t\t\t# choosing the best guess\n",
        "\t\t\tindices = np.argmin(all_l2_errors_dest, axis = 0)\n",
        "\n",
        "\t\t\tbest_guess_dest = all_guesses[indices,np.arange(x.shape[0]),  :]\n",
        "\n",
        "\t\t\t# taking the minimum error out of all guess\n",
        "\t\t\tl2error_dest = np.mean(np.min(all_l2_errors_dest, axis = 0))\n",
        "\n",
        "\t\t\tbest_guess_dest = torch.DoubleTensor(best_guess_dest).to(device)\n",
        "\n",
        "\t\t\tinterpolated_future = model.predict(x, best_guess_dest)\n",
        "      # interpolated_future = interpolated_future.cpu().numpy()\n",
        "\t\t\tinterpolated_future = interpolated_future.cpu().numpy()\n",
        "\t\t\tbest_guess_dest = best_guess_dest.cpu().numpy()\n",
        "\t \n",
        "\t \t\t# final overall prediction\n",
        "\t\t\tpredicted_future = np.concatenate((interpolated_future, best_guess_dest), axis = 1)\n",
        "\t\t\tpredicted_future = np.reshape(predicted_future, (-1, hyper_params['future_length'], 2)) # making sure\n",
        "\t\t\t# ADE error\n",
        "\t\t\tl2error_overall = np.mean(np.linalg.norm(y - predicted_future, axis = 2))\n",
        "\n",
        "\t\t\tl2error_overall /= hyper_params[\"data_scale\"]\n",
        "\t\t\tl2error_dest /= hyper_params[\"data_scale\"]\n",
        "\t\t\tl2error_avg_dest /= hyper_params[\"data_scale\"]\n",
        "\n",
        "\t\t\tprint('Test time error in destination best: {:0.3f} and mean: {:0.3f}'.format(l2error_dest, l2error_avg_dest))\n",
        "\t\t\tprint('Test time error overall (ADE) best: {:0.3f}'.format(l2error_overall))\n",
        "\n",
        "\treturn l2error_overall, l2error_dest, l2error_avg_dest\n",
        "\n",
        "            "
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "i1yVNgNXVSG4"
      },
      "source": [
        "def run_training():\n",
        "\n",
        "  model = PECNet(\n",
        "      hyper_params[\"enc_past_size\"],\n",
        "      hyper_params[\"enc_dest_size\"],\n",
        "      hyper_params[\"enc_latent_size\"],\n",
        "      hyper_params[\"dec_size\"],\n",
        "      hyper_params[\"predictor_hidden_size\"],\n",
        "      hyper_params[\"fdim\"],\n",
        "      hyper_params[\"zdim\"],\n",
        "      hyper_params[\"sigma\"],\n",
        "      hyper_params[\"past_length\"],\n",
        "      hyper_params[\"future_length\"], verbose=True)\n",
        "  \n",
        "  model = model.double().to(device)\n",
        "  optimizer = optim.Adam(model.parameters(), lr= hyper_params[\"learning_rate\"])\n",
        "\n",
        "  train_dataset = SocialDataset(\n",
        "      '/content/eth_train.npz',\n",
        "      set_name=\"train\",\n",
        "      verbose=True)\n",
        "  \n",
        "  test_dataset = SocialDataset(\n",
        "      '/content/eth_test.npz',\n",
        "      set_name=\"test\",\n",
        "      verbose=True)\n",
        "  \n",
        "  best_test_loss = 50 # start saving after this threshold\n",
        "  best_endpoint_loss = 50\n",
        "  N = hyper_params[\"n_values\"]\n",
        "\n",
        "  for e in range(hyper_params['num_epochs']):  \n",
        "    train_loss, rcl, kld, adl = train(train_dataset, model,optimizer)\n",
        "    test_loss, final_point_loss_best, final_point_loss_avg = test(test_dataset, model, best_of_n = N)\n",
        "\n",
        "    if best_test_loss > test_loss:\n",
        "      print(\"Epoch: \", e)\n",
        "      print('################## BEST PERFORMANCE {:0.2f} ########'.format(test_loss))\n",
        "      best_test_loss = test_loss\n",
        "      if best_test_loss < 10.25:\n",
        "        save_path = '/content/trained.pt'\n",
        "        torch.save({\n",
        "              'hyper_params': hyper_params,\n",
        "              'model_state_dict': model.state_dict(),\n",
        "              'optimizer_state_dict': optimizer.state_dict()\n",
        "              }, save_path)\n",
        "        print(\"Saved model to:\\n{}\".format(save_path))\n",
        "\n",
        "    if final_point_loss_best < best_endpoint_loss:\n",
        "      best_endpoint_loss = final_point_loss_best\n",
        "\n",
        "    print(\"Train Loss\", train_loss)\n",
        "    print(\"RCL\", rcl)\n",
        "    print(\"KLD\", kld)\n",
        "    print(\"ADL\", adl)\n",
        "    print(\"Test ADE\", test_loss)\n",
        "    print(\"Test Average FDE (Across  all samples)\", final_point_loss_avg)\n",
        "    print(\"Test Min FDE\", final_point_loss_best)\n",
        "    print(\"Test Best ADE Loss So Far (N = {})\".format(N), best_test_loss)\n",
        "    print(\"Test Best Min FDE (N = {})\".format(N), best_endpoint_loss)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "8AY1gRGjWC1F",
        "outputId": "8e940704-3420-4afe-8f9c-540c3aa2d50c",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "source": [
        "hyper_params['num_epochs'] = 100\n",
        "hyper_params[\"data_scale\"] = 170.\n",
        "\n",
        "run_training()"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Past Encoder architecture : [16, 512, 256, 16]\n",
            "Dest Encoder architecture : [2, 8, 16, 16]\n",
            "Latent Encoder architecture : [32, 8, 50, 32]\n",
            "Decoder architecture : [32, 1024, 512, 1024, 2]\n",
            "Predictor architecture : [32, 1024, 512, 256, 22]\n",
            "(30307, 8, 2)\n",
            "(30307, 12, 2)\n",
            "(30307, 20, 2)\n",
            "(364, 8, 2)\n",
            "(364, 12, 2)\n",
            "(364, 20, 2)\n",
            "Test time error in destination best: 1.859 and mean: 2.044\n",
            "Test time error overall (ADE) best: 1.012\n",
            "Epoch:  0\n",
            "################## BEST PERFORMANCE 1.01 ########\n",
            "Saved model to:\n",
            "/content/trained.pt\n",
            "Train Loss 34612243.75906248\n",
            "RCL 21301493.53796042\n",
            "KLD 2653148.7481085635\n",
            "ADL 10657601.472993527\n",
            "Test ADE 1.0117048493366219\n",
            "Test Average FDE (Across  all samples) 2.044313183439747\n",
            "Test Min FDE 1.8589745748486588\n",
            "Test Best ADE Loss So Far (N = 20) 1.0117048493366219\n",
            "Test Best Min FDE (N = 20) 1.8589745748486588\n",
            "Test time error in destination best: 1.464 and mean: 1.886\n",
            "Test time error overall (ADE) best: 0.829\n",
            "Epoch:  1\n",
            "################## BEST PERFORMANCE 0.83 ########\n",
            "Saved model to:\n",
            "/content/trained.pt\n",
            "Train Loss 6829166.328049693\n",
            "RCL 4697129.545153512\n",
            "KLD 434069.1561092482\n",
            "ADL 1697967.6267869268\n",
            "Test ADE 0.8287784019866089\n",
            "Test Average FDE (Across  all samples) 1.88639762741892\n",
            "Test Min FDE 1.464003211263869\n",
            "Test Best ADE Loss So Far (N = 20) 0.8287784019866089\n",
            "Test Best Min FDE (N = 20) 1.464003211263869\n",
            "Test time error in destination best: 1.607 and mean: 2.021\n",
            "Test time error overall (ADE) best: 0.885\n",
            "Train Loss 5286333.055170444\n",
            "RCL 3642328.3708912986\n",
            "KLD 383957.7018769727\n",
            "ADL 1260046.9824021705\n",
            "Test ADE 0.8853607849523882\n",
            "Test Average FDE (Across  all samples) 2.020894551979837\n",
            "Test Min FDE 1.6073888104677805\n",
            "Test Best ADE Loss So Far (N = 20) 0.8287784019866089\n",
            "Test Best Min FDE (N = 20) 1.464003211263869\n",
            "Test time error in destination best: 1.566 and mean: 2.038\n",
            "Test time error overall (ADE) best: 0.883\n",
            "Train Loss 4774117.820770824\n",
            "RCL 3264692.033918683\n",
            "KLD 377591.47991671815\n",
            "ADL 1131834.306935424\n",
            "Test ADE 0.8826310795270104\n",
            "Test Average FDE (Across  all samples) 2.038487598598257\n",
            "Test Min FDE 1.5664593783966796\n",
            "Test Best ADE Loss So Far (N = 20) 0.8287784019866089\n",
            "Test Best Min FDE (N = 20) 1.464003211263869\n",
            "Test time error in destination best: 1.527 and mean: 2.092\n",
            "Test time error overall (ADE) best: 0.819\n",
            "Epoch:  4\n",
            "################## BEST PERFORMANCE 0.82 ########\n",
            "Saved model to:\n",
            "/content/trained.pt\n",
            "Train Loss 3813784.7211804283\n",
            "RCL 2417362.754210177\n",
            "KLD 448121.55062388565\n",
            "ADL 948300.4163463645\n",
            "Test ADE 0.8185424965400152\n",
            "Test Average FDE (Across  all samples) 2.092072066565918\n",
            "Test Min FDE 1.5268001268593747\n",
            "Test Best ADE Loss So Far (N = 20) 0.8185424965400152\n",
            "Test Best Min FDE (N = 20) 1.464003211263869\n",
            "Test time error in destination best: 1.697 and mean: 2.290\n",
            "Test time error overall (ADE) best: 0.943\n",
            "Train Loss 3353044.089055967\n",
            "RCL 2058205.754061189\n",
            "KLD 455665.86602485913\n",
            "ADL 839172.4689699194\n",
            "Test ADE 0.9434719506193274\n",
            "Test Average FDE (Across  all samples) 2.2896389626900495\n",
            "Test Min FDE 1.6970373044165616\n",
            "Test Best ADE Loss So Far (N = 20) 0.8185424965400152\n",
            "Test Best Min FDE (N = 20) 1.464003211263869\n",
            "Test time error in destination best: 1.549 and mean: 2.224\n",
            "Test time error overall (ADE) best: 0.949\n",
            "Train Loss 2953473.5488613397\n",
            "RCL 1693540.453335677\n",
            "KLD 483222.63240626367\n",
            "ADL 776710.4631193954\n",
            "Test ADE 0.9491415880419456\n",
            "Test Average FDE (Across  all samples) 2.2238131034493827\n",
            "Test Min FDE 1.548727914087897\n",
            "Test Best ADE Loss So Far (N = 20) 0.8185424965400152\n",
            "Test Best Min FDE (N = 20) 1.464003211263869\n",
            "Test time error in destination best: 1.640 and mean: 2.309\n",
            "Test time error overall (ADE) best: 0.858\n",
            "Train Loss 2025661.6162999535\n",
            "RCL 977226.8300401097\n",
            "KLD 499074.6800190267\n",
            "ADL 549360.1062408183\n",
            "Test ADE 0.8584181311711119\n",
            "Test Average FDE (Across  all samples) 2.309423465277844\n",
            "Test Min FDE 1.6404261838801413\n",
            "Test Best ADE Loss So Far (N = 20) 0.8185424965400152\n",
            "Test Best Min FDE (N = 20) 1.464003211263869\n",
            "Test time error in destination best: 1.602 and mean: 2.316\n",
            "Test time error overall (ADE) best: 0.850\n",
            "Train Loss 1559620.9402833525\n",
            "RCL 649672.7513337175\n",
            "KLD 487355.5675335417\n",
            "ADL 422592.62141609326\n",
            "Test ADE 0.8504572415658632\n",
            "Test Average FDE (Across  all samples) 2.315639500566094\n",
            "Test Min FDE 1.6015506294911153\n",
            "Test Best ADE Loss So Far (N = 20) 0.8185424965400152\n",
            "Test Best Min FDE (N = 20) 1.464003211263869\n",
            "Test time error in destination best: 1.363 and mean: 2.124\n",
            "Test time error overall (ADE) best: 0.775\n",
            "Epoch:  9\n",
            "################## BEST PERFORMANCE 0.77 ########\n",
            "Saved model to:\n",
            "/content/trained.pt\n",
            "Train Loss 1258870.7425069017\n",
            "RCL 433809.9402925813\n",
            "KLD 476379.032162941\n",
            "ADL 348681.77005137934\n",
            "Test ADE 0.7746172842006712\n",
            "Test Average FDE (Across  all samples) 2.12351395109288\n",
            "Test Min FDE 1.362651485883445\n",
            "Test Best ADE Loss So Far (N = 20) 0.7746172842006712\n",
            "Test Best Min FDE (N = 20) 1.362651485883445\n",
            "Test time error in destination best: 1.557 and mean: 2.297\n",
            "Test time error overall (ADE) best: 0.788\n",
            "Train Loss 1109109.0115557453\n",
            "RCL 346461.3759841955\n",
            "KLD 455995.38615090406\n",
            "ADL 306652.24942064594\n",
            "Test ADE 0.7879987386433679\n",
            "Test Average FDE (Across  all samples) 2.296530305136194\n",
            "Test Min FDE 1.557201470145007\n",
            "Test Best ADE Loss So Far (N = 20) 0.7746172842006712\n",
            "Test Best Min FDE (N = 20) 1.362651485883445\n",
            "Test time error in destination best: 1.392 and mean: 2.166\n",
            "Test time error overall (ADE) best: 0.745\n",
            "Epoch:  11\n",
            "################## BEST PERFORMANCE 0.74 ########\n",
            "Saved model to:\n",
            "/content/trained.pt\n",
            "Train Loss 1018916.0803336541\n",
            "RCL 295331.2358810173\n",
            "KLD 437964.5204155948\n",
            "ADL 285620.32403704204\n",
            "Test ADE 0.744641534502008\n",
            "Test Average FDE (Across  all samples) 2.1662520354764205\n",
            "Test Min FDE 1.3916874541712263\n",
            "Test Best ADE Loss So Far (N = 20) 0.744641534502008\n",
            "Test Best Min FDE (N = 20) 1.362651485883445\n",
            "Test time error in destination best: 1.396 and mean: 2.207\n",
            "Test time error overall (ADE) best: 0.799\n",
            "Train Loss 953868.4284604302\n",
            "RCL 258929.16254007525\n",
            "KLD 426713.91070747626\n",
            "ADL 268225.35521287785\n",
            "Test ADE 0.7989802115632846\n",
            "Test Average FDE (Across  all samples) 2.2065032858434175\n",
            "Test Min FDE 1.396389659217013\n",
            "Test Best ADE Loss So Far (N = 20) 0.744641534502008\n",
            "Test Best Min FDE (N = 20) 1.362651485883445\n",
            "Test time error in destination best: 1.404 and mean: 2.211\n",
            "Test time error overall (ADE) best: 0.756\n",
            "Train Loss 942145.4598648218\n",
            "RCL 255573.42360680213\n",
            "KLD 414284.518070493\n",
            "ADL 272287.51818752656\n",
            "Test ADE 0.7559791943316725\n",
            "Test Average FDE (Across  all samples) 2.2113660863420006\n",
            "Test Min FDE 1.4037430652824885\n",
            "Test Best ADE Loss So Far (N = 20) 0.744641534502008\n",
            "Test Best Min FDE (N = 20) 1.362651485883445\n",
            "Test time error in destination best: 1.361 and mean: 2.196\n",
            "Test time error overall (ADE) best: 0.814\n",
            "Train Loss 908232.2274725782\n",
            "RCL 239483.71910799763\n",
            "KLD 401887.1260462454\n",
            "ADL 266861.382318335\n",
            "Test ADE 0.8143356076008699\n",
            "Test Average FDE (Across  all samples) 2.19599488176107\n",
            "Test Min FDE 1.3610738596426468\n",
            "Test Best ADE Loss So Far (N = 20) 0.744641534502008\n",
            "Test Best Min FDE (N = 20) 1.3610738596426468\n",
            "Test time error in destination best: 1.357 and mean: 2.196\n",
            "Test time error overall (ADE) best: 0.736\n",
            "Epoch:  15\n",
            "################## BEST PERFORMANCE 0.74 ########\n",
            "Saved model to:\n",
            "/content/trained.pt\n",
            "Train Loss 856228.4105304981\n",
            "RCL 213785.28330043718\n",
            "KLD 390020.2776365307\n",
            "ADL 252422.84959352957\n",
            "Test ADE 0.7358210743614584\n",
            "Test Average FDE (Across  all samples) 2.1961671111004506\n",
            "Test Min FDE 1.3565381411802209\n",
            "Test Best ADE Loss So Far (N = 20) 0.7358210743614584\n",
            "Test Best Min FDE (N = 20) 1.3565381411802209\n",
            "Test time error in destination best: 1.242 and mean: 2.122\n",
            "Test time error overall (ADE) best: 0.706\n",
            "Epoch:  16\n",
            "################## BEST PERFORMANCE 0.71 ########\n",
            "Saved model to:\n",
            "/content/trained.pt\n",
            "Train Loss 834327.6043315968\n",
            "RCL 204230.8932745334\n",
            "KLD 380671.6408970613\n",
            "ADL 249425.07016000233\n",
            "Test ADE 0.7063442028401266\n",
            "Test Average FDE (Across  all samples) 2.12177448610007\n",
            "Test Min FDE 1.2424826531108946\n",
            "Test Best ADE Loss So Far (N = 20) 0.7063442028401266\n",
            "Test Best Min FDE (N = 20) 1.2424826531108946\n",
            "Test time error in destination best: 1.403 and mean: 2.271\n",
            "Test time error overall (ADE) best: 0.790\n",
            "Train Loss 829709.8479268287\n",
            "RCL 203492.67383070008\n",
            "KLD 371443.45176569215\n",
            "ADL 254773.72233043698\n",
            "Test ADE 0.7900625972676825\n",
            "Test Average FDE (Across  all samples) 2.2711559730680877\n",
            "Test Min FDE 1.4033248242595344\n",
            "Test Best ADE Loss So Far (N = 20) 0.7063442028401266\n",
            "Test Best Min FDE (N = 20) 1.2424826531108946\n",
            "Test time error in destination best: 1.333 and mean: 2.212\n",
            "Test time error overall (ADE) best: 0.777\n",
            "Train Loss 815024.2351593054\n",
            "RCL 194255.39380555\n",
            "KLD 367853.4451993647\n",
            "ADL 252915.39615439103\n",
            "Test ADE 0.7771789846056808\n",
            "Test Average FDE (Across  all samples) 2.2116281826657227\n",
            "Test Min FDE 1.3332934084366728\n",
            "Test Best ADE Loss So Far (N = 20) 0.7063442028401266\n",
            "Test Best Min FDE (N = 20) 1.2424826531108946\n",
            "Test time error in destination best: 1.225 and mean: 2.157\n",
            "Test time error overall (ADE) best: 0.775\n",
            "Train Loss 786110.8729962491\n",
            "RCL 182278.3319758631\n",
            "KLD 357787.92464859434\n",
            "ADL 246044.61637179196\n",
            "Test ADE 0.7747127157017116\n",
            "Test Average FDE (Across  all samples) 2.1566299512382403\n",
            "Test Min FDE 1.225401606987127\n",
            "Test Best ADE Loss So Far (N = 20) 0.7063442028401266\n",
            "Test Best Min FDE (N = 20) 1.225401606987127\n",
            "Test time error in destination best: 1.146 and mean: 2.074\n",
            "Test time error overall (ADE) best: 0.684\n",
            "Epoch:  20\n",
            "################## BEST PERFORMANCE 0.68 ########\n",
            "Saved model to:\n",
            "/content/trained.pt\n",
            "Train Loss 759218.7798558826\n",
            "RCL 172257.35700443518\n",
            "KLD 347415.3255071512\n",
            "ADL 239546.0973442959\n",
            "Test ADE 0.6844007368371239\n",
            "Test Average FDE (Across  all samples) 2.073530146099513\n",
            "Test Min FDE 1.1457466002174401\n",
            "Test Best ADE Loss So Far (N = 20) 0.6844007368371239\n",
            "Test Best Min FDE (N = 20) 1.1457466002174401\n",
            "Test time error in destination best: 1.276 and mean: 2.237\n",
            "Test time error overall (ADE) best: 0.829\n",
            "Train Loss 733611.9032964308\n",
            "RCL 159337.34072973597\n",
            "KLD 339592.7222124669\n",
            "ADL 234681.84035422798\n",
            "Test ADE 0.8293735137415932\n",
            "Test Average FDE (Across  all samples) 2.23678260753382\n",
            "Test Min FDE 1.2755660716227746\n",
            "Test Best ADE Loss So Far (N = 20) 0.6844007368371239\n",
            "Test Best Min FDE (N = 20) 1.1457466002174401\n",
            "Test time error in destination best: 1.296 and mean: 2.218\n",
            "Test time error overall (ADE) best: 0.716\n",
            "Train Loss 759635.9053899152\n",
            "RCL 173324.21456489884\n",
            "KLD 335137.44461665524\n",
            "ADL 251174.2462083619\n",
            "Test ADE 0.7163618738909483\n",
            "Test Average FDE (Across  all samples) 2.218440028030651\n",
            "Test Min FDE 1.2962841625199286\n",
            "Test Best ADE Loss So Far (N = 20) 0.6844007368371239\n",
            "Test Best Min FDE (N = 20) 1.1457466002174401\n",
            "Test time error in destination best: 1.212 and mean: 2.181\n",
            "Test time error overall (ADE) best: 0.709\n",
            "Train Loss 707155.213838704\n",
            "RCL 150287.21111635413\n",
            "KLD 324566.24588295014\n",
            "ADL 232301.75683940007\n",
            "Test ADE 0.709347452674737\n",
            "Test Average FDE (Across  all samples) 2.1813392493305135\n",
            "Test Min FDE 1.2121075932922707\n",
            "Test Best ADE Loss So Far (N = 20) 0.6844007368371239\n",
            "Test Best Min FDE (N = 20) 1.1457466002174401\n",
            "Test time error in destination best: 1.447 and mean: 2.463\n",
            "Test time error overall (ADE) best: 0.914\n",
            "Train Loss 708322.8068023149\n",
            "RCL 151156.7564668539\n",
            "KLD 319272.57207511034\n",
            "ADL 237893.47826035033\n",
            "Test ADE 0.9142161896279299\n",
            "Test Average FDE (Across  all samples) 2.462630707524586\n",
            "Test Min FDE 1.4471845813046083\n",
            "Test Best ADE Loss So Far (N = 20) 0.6844007368371239\n",
            "Test Best Min FDE (N = 20) 1.1457466002174401\n",
            "Test time error in destination best: 1.208 and mean: 2.223\n",
            "Test time error overall (ADE) best: 0.698\n",
            "Train Loss 690053.7564077784\n",
            "RCL 143153.20444268725\n",
            "KLD 312400.48710784636\n",
            "ADL 234500.0648572446\n",
            "Test ADE 0.6978944501032603\n",
            "Test Average FDE (Across  all samples) 2.222675983949705\n",
            "Test Min FDE 1.2076326677501663\n",
            "Test Best ADE Loss So Far (N = 20) 0.6844007368371239\n",
            "Test Best Min FDE (N = 20) 1.1457466002174401\n",
            "Test time error in destination best: 1.549 and mean: 2.569\n",
            "Test time error overall (ADE) best: 0.875\n",
            "Train Loss 661681.0436436591\n",
            "RCL 133882.34773646915\n",
            "KLD 305561.90775067755\n",
            "ADL 222236.78815651216\n",
            "Test ADE 0.8748006942677825\n",
            "Test Average FDE (Across  all samples) 2.5686722485109845\n",
            "Test Min FDE 1.549146913039662\n",
            "Test Best ADE Loss So Far (N = 20) 0.6844007368371239\n",
            "Test Best Min FDE (N = 20) 1.1457466002174401\n",
            "Test time error in destination best: 1.042 and mean: 2.109\n",
            "Test time error overall (ADE) best: 0.741\n",
            "Train Loss 660271.2681267016\n",
            "RCL 135270.6561986829\n",
            "KLD 296643.2979023845\n",
            "ADL 228357.31402563376\n",
            "Test ADE 0.7412580702110079\n",
            "Test Average FDE (Across  all samples) 2.109303828802836\n",
            "Test Min FDE 1.041939478950858\n",
            "Test Best ADE Loss So Far (N = 20) 0.6844007368371239\n",
            "Test Best Min FDE (N = 20) 1.041939478950858\n",
            "Test time error in destination best: 1.171 and mean: 2.253\n",
            "Test time error overall (ADE) best: 0.729\n",
            "Train Loss 643111.1358848532\n",
            "RCL 128930.82389641553\n",
            "KLD 290841.39780066896\n",
            "ADL 223338.91418776903\n",
            "Test ADE 0.7285333324941398\n",
            "Test Average FDE (Across  all samples) 2.2526964358839945\n",
            "Test Min FDE 1.1706928047227871\n",
            "Test Best ADE Loss So Far (N = 20) 0.6844007368371239\n",
            "Test Best Min FDE (N = 20) 1.041939478950858\n",
            "Test time error in destination best: 1.263 and mean: 2.312\n",
            "Test time error overall (ADE) best: 0.719\n",
            "Train Loss 630566.8145391985\n",
            "RCL 122754.01639184519\n",
            "KLD 283813.2950188866\n",
            "ADL 223999.50312846675\n",
            "Test ADE 0.7190955445972684\n",
            "Test Average FDE (Across  all samples) 2.3116398844576107\n",
            "Test Min FDE 1.263353757964243\n",
            "Test Best ADE Loss So Far (N = 20) 0.6844007368371239\n",
            "Test Best Min FDE (N = 20) 1.041939478950858\n",
            "Test time error in destination best: 1.065 and mean: 2.168\n",
            "Test time error overall (ADE) best: 0.680\n",
            "Epoch:  30\n",
            "################## BEST PERFORMANCE 0.68 ########\n",
            "Saved model to:\n",
            "/content/trained.pt\n",
            "Train Loss 611244.5387398099\n",
            "RCL 116606.73549684537\n",
            "KLD 276814.7765299411\n",
            "ADL 217823.02671302366\n",
            "Test ADE 0.6795406544390876\n",
            "Test Average FDE (Across  all samples) 2.168078340076661\n",
            "Test Min FDE 1.064521136078038\n",
            "Test Best ADE Loss So Far (N = 20) 0.6795406544390876\n",
            "Test Best Min FDE (N = 20) 1.041939478950858\n",
            "Test time error in destination best: 1.040 and mean: 2.156\n",
            "Test time error overall (ADE) best: 0.644\n",
            "Epoch:  31\n",
            "################## BEST PERFORMANCE 0.64 ########\n",
            "Saved model to:\n",
            "/content/trained.pt\n",
            "Train Loss 608600.5870425225\n",
            "RCL 115881.1964499542\n",
            "KLD 271892.54227941914\n",
            "ADL 220826.8483131492\n",
            "Test ADE 0.6440782506268214\n",
            "Test Average FDE (Across  all samples) 2.1562077097742396\n",
            "Test Min FDE 1.0397219162855094\n",
            "Test Best ADE Loss So Far (N = 20) 0.6440782506268214\n",
            "Test Best Min FDE (N = 20) 1.0397219162855094\n",
            "Test time error in destination best: 1.111 and mean: 2.212\n",
            "Test time error overall (ADE) best: 0.687\n",
            "Train Loss 586227.954921314\n",
            "RCL 108369.42871968073\n",
            "KLD 263934.3877449168\n",
            "ADL 213924.13845671684\n",
            "Test ADE 0.6873624252981458\n",
            "Test Average FDE (Across  all samples) 2.2118185305895706\n",
            "Test Min FDE 1.110725913330438\n",
            "Test Best ADE Loss So Far (N = 20) 0.6440782506268214\n",
            "Test Best Min FDE (N = 20) 1.0397219162855094\n",
            "Test time error in destination best: 1.047 and mean: 2.185\n",
            "Test time error overall (ADE) best: 0.643\n",
            "Epoch:  33\n",
            "################## BEST PERFORMANCE 0.64 ########\n",
            "Saved model to:\n",
            "/content/trained.pt\n",
            "Train Loss 580164.6114811627\n",
            "RCL 104877.6892588582\n",
            "KLD 260476.89299489223\n",
            "ADL 214810.02922741193\n",
            "Test ADE 0.6428865447843487\n",
            "Test Average FDE (Across  all samples) 2.1849753646848518\n",
            "Test Min FDE 1.0472493203424398\n",
            "Test Best ADE Loss So Far (N = 20) 0.6428865447843487\n",
            "Test Best Min FDE (N = 20) 1.0397219162855094\n",
            "Test time error in destination best: 1.110 and mean: 2.267\n",
            "Test time error overall (ADE) best: 0.666\n",
            "Train Loss 567591.6041440457\n",
            "RCL 102297.38750522435\n",
            "KLD 253871.4247049639\n",
            "ADL 211422.79193385757\n",
            "Test ADE 0.6659222368838809\n",
            "Test Average FDE (Across  all samples) 2.26664231228371\n",
            "Test Min FDE 1.110021392470709\n",
            "Test Best ADE Loss So Far (N = 20) 0.6428865447843487\n",
            "Test Best Min FDE (N = 20) 1.0397219162855094\n",
            "Test time error in destination best: 1.039 and mean: 2.224\n",
            "Test time error overall (ADE) best: 0.701\n",
            "Train Loss 577341.5245906383\n",
            "RCL 109095.16184342183\n",
            "KLD 249606.98796021088\n",
            "ADL 218639.37478700545\n",
            "Test ADE 0.70110232972305\n",
            "Test Average FDE (Across  all samples) 2.2237643827518006\n",
            "Test Min FDE 1.0389957242433674\n",
            "Test Best ADE Loss So Far (N = 20) 0.6428865447843487\n",
            "Test Best Min FDE (N = 20) 1.0389957242433674\n",
            "Test time error in destination best: 1.093 and mean: 2.280\n",
            "Test time error overall (ADE) best: 0.671\n",
            "Train Loss 546004.4836786675\n",
            "RCL 96322.77001868935\n",
            "KLD 244769.44834114486\n",
            "ADL 204912.2653188323\n",
            "Test ADE 0.670587753376003\n",
            "Test Average FDE (Across  all samples) 2.280354352166648\n",
            "Test Min FDE 1.0933125324814714\n",
            "Test Best ADE Loss So Far (N = 20) 0.6428865447843487\n",
            "Test Best Min FDE (N = 20) 1.0389957242433674\n",
            "Test time error in destination best: 1.039 and mean: 2.236\n",
            "Test time error overall (ADE) best: 0.666\n",
            "Train Loss 548738.820821487\n",
            "RCL 99562.69157534775\n",
            "KLD 239773.33418249493\n",
            "ADL 209402.7950636445\n",
            "Test ADE 0.6661927518049782\n",
            "Test Average FDE (Across  all samples) 2.2364478943931263\n",
            "Test Min FDE 1.0389193848886262\n",
            "Test Best ADE Loss So Far (N = 20) 0.6428865447843487\n",
            "Test Best Min FDE (N = 20) 1.0389193848886262\n",
            "Test time error in destination best: 1.248 and mean: 2.490\n",
            "Test time error overall (ADE) best: 0.733\n",
            "Train Loss 533483.1560551883\n",
            "RCL 93397.98142266758\n",
            "KLD 233092.6624915886\n",
            "ADL 206992.51214093185\n",
            "Test ADE 0.7333870123080168\n",
            "Test Average FDE (Across  all samples) 2.4904766825100344\n",
            "Test Min FDE 1.2483714574087024\n",
            "Test Best ADE Loss So Far (N = 20) 0.6428865447843487\n",
            "Test Best Min FDE (N = 20) 1.0389193848886262\n",
            "Test time error in destination best: 1.162 and mean: 2.389\n",
            "Test time error overall (ADE) best: 0.701\n",
            "Train Loss 516112.6916195358\n",
            "RCL 86977.14999932455\n",
            "KLD 225857.2125895048\n",
            "ADL 203278.3290307066\n",
            "Test ADE 0.7014729517442106\n",
            "Test Average FDE (Across  all samples) 2.3886025348865774\n",
            "Test Min FDE 1.162178451230367\n",
            "Test Best ADE Loss So Far (N = 20) 0.6428865447843487\n",
            "Test Best Min FDE (N = 20) 1.0389193848886262\n",
            "Test time error in destination best: 1.150 and mean: 2.405\n",
            "Test time error overall (ADE) best: 0.703\n",
            "Train Loss 507789.5289926716\n",
            "RCL 86023.312718316\n",
            "KLD 219611.96757912048\n",
            "ADL 202154.24869523544\n",
            "Test ADE 0.7027747414094077\n",
            "Test Average FDE (Across  all samples) 2.404831584241948\n",
            "Test Min FDE 1.150479508999285\n",
            "Test Best ADE Loss So Far (N = 20) 0.6428865447843487\n",
            "Test Best Min FDE (N = 20) 1.0389193848886262\n",
            "Test time error in destination best: 1.252 and mean: 2.603\n",
            "Test time error overall (ADE) best: 0.775\n",
            "Train Loss 493516.7667109792\n",
            "RCL 81492.3154071425\n",
            "KLD 213496.5322825932\n",
            "ADL 198527.91902124343\n",
            "Test ADE 0.775338522537629\n",
            "Test Average FDE (Across  all samples) 2.6030647480839804\n",
            "Test Min FDE 1.2516825231728843\n",
            "Test Best ADE Loss So Far (N = 20) 0.6428865447843487\n",
            "Test Best Min FDE (N = 20) 1.0389193848886262\n",
            "Test time error in destination best: 1.125 and mean: 2.438\n",
            "Test time error overall (ADE) best: 0.684\n",
            "Train Loss 491793.6907578779\n",
            "RCL 82369.5917335893\n",
            "KLD 209159.669138559\n",
            "ADL 200264.42988572942\n",
            "Test ADE 0.6839821420455174\n",
            "Test Average FDE (Across  all samples) 2.43772706410531\n",
            "Test Min FDE 1.124734832565348\n",
            "Test Best ADE Loss So Far (N = 20) 0.6428865447843487\n",
            "Test Best Min FDE (N = 20) 1.0389193848886262\n",
            "Test time error in destination best: 0.968 and mean: 2.271\n",
            "Test time error overall (ADE) best: 0.658\n",
            "Train Loss 486376.26983908756\n",
            "RCL 81781.80458716798\n",
            "KLD 204681.9088097467\n",
            "ADL 199912.5564421728\n",
            "Test ADE 0.6575848178518063\n",
            "Test Average FDE (Across  all samples) 2.2707901242325277\n",
            "Test Min FDE 0.9681398622531562\n",
            "Test Best ADE Loss So Far (N = 20) 0.6428865447843487\n",
            "Test Best Min FDE (N = 20) 0.9681398622531562\n",
            "Test time error in destination best: 1.064 and mean: 2.413\n",
            "Test time error overall (ADE) best: 0.657\n",
            "Train Loss 477171.4867058537\n",
            "RCL 79360.33753205501\n",
            "KLD 200763.83651397564\n",
            "ADL 197047.31265982316\n",
            "Test ADE 0.6565532908242859\n",
            "Test Average FDE (Across  all samples) 2.41295754025157\n",
            "Test Min FDE 1.0638353954307094\n",
            "Test Best ADE Loss So Far (N = 20) 0.6428865447843487\n",
            "Test Best Min FDE (N = 20) 0.9681398622531562\n",
            "Test time error in destination best: 0.899 and mean: 2.266\n",
            "Test time error overall (ADE) best: 0.652\n",
            "Train Loss 462597.1220174572\n",
            "RCL 73836.57314966733\n",
            "KLD 193964.76501964868\n",
            "ADL 194795.7838481411\n",
            "Test ADE 0.6515868012983843\n",
            "Test Average FDE (Across  all samples) 2.2656581474621356\n",
            "Test Min FDE 0.898781535551971\n",
            "Test Best ADE Loss So Far (N = 20) 0.6428865447843487\n",
            "Test Best Min FDE (N = 20) 0.898781535551971\n",
            "Test time error in destination best: 0.988 and mean: 2.342\n",
            "Test time error overall (ADE) best: 0.636\n",
            "Epoch:  46\n",
            "################## BEST PERFORMANCE 0.64 ########\n",
            "Saved model to:\n",
            "/content/trained.pt\n",
            "Train Loss 461314.29484130716\n",
            "RCL 75452.3322832993\n",
            "KLD 189238.40223156236\n",
            "ADL 196623.56032644576\n",
            "Test ADE 0.6357970005307566\n",
            "Test Average FDE (Across  all samples) 2.3415843362896194\n",
            "Test Min FDE 0.9883010541723853\n",
            "Test Best ADE Loss So Far (N = 20) 0.6357970005307566\n",
            "Test Best Min FDE (N = 20) 0.898781535551971\n",
            "Test time error in destination best: 1.099 and mean: 2.491\n",
            "Test time error overall (ADE) best: 0.664\n",
            "Train Loss 445605.13463607023\n",
            "RCL 68882.71587145746\n",
            "KLD 183572.38379033355\n",
            "ADL 193150.03497427923\n",
            "Test ADE 0.6640549514428185\n",
            "Test Average FDE (Across  all samples) 2.4914041301361762\n",
            "Test Min FDE 1.0989492015731963\n",
            "Test Best ADE Loss So Far (N = 20) 0.6357970005307566\n",
            "Test Best Min FDE (N = 20) 0.898781535551971\n",
            "Test time error in destination best: 1.005 and mean: 2.408\n",
            "Test time error overall (ADE) best: 0.656\n",
            "Train Loss 433767.60976675537\n",
            "RCL 64971.172859058766\n",
            "KLD 178981.1621256711\n",
            "ADL 189815.2747820257\n",
            "Test ADE 0.6559046934961348\n",
            "Test Average FDE (Across  all samples) 2.4081377165636018\n",
            "Test Min FDE 1.0054684305087214\n",
            "Test Best ADE Loss So Far (N = 20) 0.6357970005307566\n",
            "Test Best Min FDE (N = 20) 0.898781535551971\n",
            "Test time error in destination best: 1.089 and mean: 2.524\n",
            "Test time error overall (ADE) best: 0.662\n",
            "Train Loss 432922.37605368526\n",
            "RCL 67488.16565066877\n",
            "KLD 176089.48368254505\n",
            "ADL 189344.7267204718\n",
            "Test ADE 0.6617766620635102\n",
            "Test Average FDE (Across  all samples) 2.5237200326305897\n",
            "Test Min FDE 1.0890280021848175\n",
            "Test Best ADE Loss So Far (N = 20) 0.6357970005307566\n",
            "Test Best Min FDE (N = 20) 0.898781535551971\n",
            "Test time error in destination best: 1.085 and mean: 2.479\n",
            "Test time error overall (ADE) best: 0.657\n",
            "Train Loss 448466.755493413\n",
            "RCL 74701.95629606799\n",
            "KLD 175620.7329767901\n",
            "ADL 198144.06622055493\n",
            "Test ADE 0.6570567589021936\n",
            "Test Average FDE (Across  all samples) 2.4791757396074767\n",
            "Test Min FDE 1.084682005276522\n",
            "Test Best ADE Loss So Far (N = 20) 0.6357970005307566\n",
            "Test Best Min FDE (N = 20) 0.898781535551971\n",
            "Test time error in destination best: 1.127 and mean: 2.532\n",
            "Test time error overall (ADE) best: 0.697\n",
            "Train Loss 408347.1551133628\n",
            "RCL 57203.69149996019\n",
            "KLD 172044.13686734842\n",
            "ADL 179099.3267460544\n",
            "Test ADE 0.6972223362730942\n",
            "Test Average FDE (Across  all samples) 2.532116785945637\n",
            "Test Min FDE 1.1274704912690954\n",
            "Test Best ADE Loss So Far (N = 20) 0.6357970005307566\n",
            "Test Best Min FDE (N = 20) 0.898781535551971\n",
            "Test time error in destination best: 0.938 and mean: 2.373\n",
            "Test time error overall (ADE) best: 0.709\n",
            "Train Loss 429733.1820635876\n",
            "RCL 67545.78936902912\n",
            "KLD 171242.4702362373\n",
            "ADL 190944.9224583213\n",
            "Test ADE 0.708779310422362\n",
            "Test Average FDE (Across  all samples) 2.37296000785075\n",
            "Test Min FDE 0.9383303562855098\n",
            "Test Best ADE Loss So Far (N = 20) 0.6357970005307566\n",
            "Test Best Min FDE (N = 20) 0.898781535551971\n",
            "Test time error in destination best: 0.948 and mean: 2.369\n",
            "Test time error overall (ADE) best: 0.614\n",
            "Epoch:  53\n",
            "################## BEST PERFORMANCE 0.61 ########\n",
            "Saved model to:\n",
            "/content/trained.pt\n",
            "Train Loss 413978.780400934\n",
            "RCL 60591.791360822885\n",
            "KLD 169501.65527674305\n",
            "ADL 183885.33376336805\n",
            "Test ADE 0.614479696152802\n",
            "Test Average FDE (Across  all samples) 2.369055924410505\n",
            "Test Min FDE 0.9476688283018174\n",
            "Test Best ADE Loss So Far (N = 20) 0.614479696152802\n",
            "Test Best Min FDE (N = 20) 0.898781535551971\n",
            "Test time error in destination best: 0.999 and mean: 2.418\n",
            "Test time error overall (ADE) best: 0.622\n",
            "Train Loss 416623.07764979074\n",
            "RCL 61694.574164751386\n",
            "KLD 169217.78647550073\n",
            "ADL 185710.71700953846\n",
            "Test ADE 0.6221846505333888\n",
            "Test Average FDE (Across  all samples) 2.418410850141989\n",
            "Test Min FDE 0.9994033340518058\n",
            "Test Best ADE Loss So Far (N = 20) 0.614479696152802\n",
            "Test Best Min FDE (N = 20) 0.898781535551971\n",
            "Test time error in destination best: 1.004 and mean: 2.461\n",
            "Test time error overall (ADE) best: 0.623\n",
            "Train Loss 399662.0698029581\n",
            "RCL 55584.49128991621\n",
            "KLD 166595.36191301717\n",
            "ADL 177482.21660002464\n",
            "Test ADE 0.6230941700449631\n",
            "Test Average FDE (Across  all samples) 2.4612121772198825\n",
            "Test Min FDE 1.0035314464445177\n",
            "Test Best ADE Loss So Far (N = 20) 0.614479696152802\n",
            "Test Best Min FDE (N = 20) 0.898781535551971\n",
            "Test time error in destination best: 1.032 and mean: 2.457\n",
            "Test time error overall (ADE) best: 0.646\n",
            "Train Loss 411920.5340381381\n",
            "RCL 62057.071323711156\n",
            "KLD 166852.25508104658\n",
            "ADL 183011.2076333798\n",
            "Test ADE 0.6456230625838442\n",
            "Test Average FDE (Across  all samples) 2.4568711818906013\n",
            "Test Min FDE 1.031724375225655\n",
            "Test Best ADE Loss So Far (N = 20) 0.614479696152802\n",
            "Test Best Min FDE (N = 20) 0.898781535551971\n",
            "Test time error in destination best: 1.077 and mean: 2.535\n",
            "Test time error overall (ADE) best: 0.692\n",
            "Train Loss 402918.6108875422\n",
            "RCL 57683.55851739079\n",
            "KLD 166539.28018016144\n",
            "ADL 178695.77218998957\n",
            "Test ADE 0.691794584462955\n",
            "Test Average FDE (Across  all samples) 2.5349035215101114\n",
            "Test Min FDE 1.0768605578408708\n",
            "Test Best ADE Loss So Far (N = 20) 0.614479696152802\n",
            "Test Best Min FDE (N = 20) 0.898781535551971\n",
            "Test time error in destination best: 0.970 and mean: 2.445\n",
            "Test time error overall (ADE) best: 0.596\n",
            "Epoch:  58\n",
            "################## BEST PERFORMANCE 0.60 ########\n",
            "Saved model to:\n",
            "/content/trained.pt\n",
            "Train Loss 404376.3336454817\n",
            "RCL 58077.96159147892\n",
            "KLD 165734.63142018951\n",
            "ADL 180563.74063381332\n",
            "Test ADE 0.5962630104931447\n",
            "Test Average FDE (Across  all samples) 2.44542151267942\n",
            "Test Min FDE 0.9696051770994767\n",
            "Test Best ADE Loss So Far (N = 20) 0.5962630104931447\n",
            "Test Best Min FDE (N = 20) 0.898781535551971\n",
            "Test time error in destination best: 1.033 and mean: 2.492\n",
            "Test time error overall (ADE) best: 0.719\n",
            "Train Loss 403218.7650029678\n",
            "RCL 57801.58891355836\n",
            "KLD 164898.3872129649\n",
            "ADL 180518.78887644477\n",
            "Test ADE 0.7194316572221744\n",
            "Test Average FDE (Across  all samples) 2.4916862465132774\n",
            "Test Min FDE 1.0330001004818032\n",
            "Test Best ADE Loss So Far (N = 20) 0.5962630104931447\n",
            "Test Best Min FDE (N = 20) 0.898781535551971\n",
            "Test time error in destination best: 1.037 and mean: 2.484\n",
            "Test time error overall (ADE) best: 0.629\n",
            "Train Loss 414476.102424577\n",
            "RCL 63092.99526417272\n",
            "KLD 165096.6534263332\n",
            "ADL 186286.45373407105\n",
            "Test ADE 0.6287058747681717\n",
            "Test Average FDE (Across  all samples) 2.4837403059132934\n",
            "Test Min FDE 1.0368276745208063\n",
            "Test Best ADE Loss So Far (N = 20) 0.5962630104931447\n",
            "Test Best Min FDE (N = 20) 0.898781535551971\n",
            "Test time error in destination best: 1.074 and mean: 2.487\n",
            "Test time error overall (ADE) best: 0.661\n",
            "Train Loss 393142.0415893652\n",
            "RCL 53633.42184677941\n",
            "KLD 163629.96036207516\n",
            "ADL 175878.6593805107\n",
            "Test ADE 0.6614908540308584\n",
            "Test Average FDE (Across  all samples) 2.48733274594753\n",
            "Test Min FDE 1.0739365838578352\n",
            "Test Best ADE Loss So Far (N = 20) 0.5962630104931447\n",
            "Test Best Min FDE (N = 20) 0.898781535551971\n",
            "Test time error in destination best: 0.993 and mean: 2.488\n",
            "Test time error overall (ADE) best: 0.641\n",
            "Train Loss 398431.84525581956\n",
            "RCL 56159.65238901473\n",
            "KLD 163968.496483561\n",
            "ADL 178303.69638324398\n",
            "Test ADE 0.6409855895902775\n",
            "Test Average FDE (Across  all samples) 2.4878553316870353\n",
            "Test Min FDE 0.9926133487830033\n",
            "Test Best ADE Loss So Far (N = 20) 0.5962630104931447\n",
            "Test Best Min FDE (N = 20) 0.898781535551971\n",
            "Test time error in destination best: 0.967 and mean: 2.445\n",
            "Test time error overall (ADE) best: 0.634\n",
            "Train Loss 390165.291629127\n",
            "RCL 53704.703024048475\n",
            "KLD 163273.1037812373\n",
            "ADL 173187.48482384114\n",
            "Test ADE 0.6339721988283585\n",
            "Test Average FDE (Across  all samples) 2.4453560247926864\n",
            "Test Min FDE 0.9666122284062223\n",
            "Test Best ADE Loss So Far (N = 20) 0.5962630104931447\n",
            "Test Best Min FDE (N = 20) 0.898781535551971\n",
            "Test time error in destination best: 1.026 and mean: 2.546\n",
            "Test time error overall (ADE) best: 0.702\n",
            "Train Loss 397448.5086249084\n",
            "RCL 57036.65604298488\n",
            "KLD 163556.66109444748\n",
            "ADL 176855.19148747573\n",
            "Test ADE 0.7021041666310919\n",
            "Test Average FDE (Across  all samples) 2.5459225756928214\n",
            "Test Min FDE 1.0260433944482732\n",
            "Test Best ADE Loss So Far (N = 20) 0.5962630104931447\n",
            "Test Best Min FDE (N = 20) 0.898781535551971\n",
            "Test time error in destination best: 0.957 and mean: 2.448\n",
            "Test time error overall (ADE) best: 0.607\n",
            "Train Loss 395189.31617321464\n",
            "RCL 55765.41637139006\n",
            "KLD 163233.0148975751\n",
            "ADL 176190.88490424995\n",
            "Test ADE 0.606610502642645\n",
            "Test Average FDE (Across  all samples) 2.448389706390669\n",
            "Test Min FDE 0.9569350215701061\n",
            "Test Best ADE Loss So Far (N = 20) 0.5962630104931447\n",
            "Test Best Min FDE (N = 20) 0.898781535551971\n",
            "Test time error in destination best: 0.984 and mean: 2.450\n",
            "Test time error overall (ADE) best: 0.623\n",
            "Train Loss 391185.91675108747\n",
            "RCL 55323.32031972692\n",
            "KLD 162497.75181957608\n",
            "ADL 173364.84461178462\n",
            "Test ADE 0.6232416591312895\n",
            "Test Average FDE (Across  all samples) 2.449666504884343\n",
            "Test Min FDE 0.9838063800029347\n",
            "Test Best ADE Loss So Far (N = 20) 0.5962630104931447\n",
            "Test Best Min FDE (N = 20) 0.898781535551971\n",
            "Test time error in destination best: 0.971 and mean: 2.424\n",
            "Test time error overall (ADE) best: 0.622\n",
            "Train Loss 396013.3802064539\n",
            "RCL 57138.02185010163\n",
            "KLD 162639.7786935517\n",
            "ADL 176235.57966280013\n",
            "Test ADE 0.622057399174458\n",
            "Test Average FDE (Across  all samples) 2.4236062612948\n",
            "Test Min FDE 0.9710506292768747\n",
            "Test Best ADE Loss So Far (N = 20) 0.5962630104931447\n",
            "Test Best Min FDE (N = 20) 0.898781535551971\n",
            "Test time error in destination best: 0.938 and mean: 2.433\n",
            "Test time error overall (ADE) best: 0.609\n",
            "Train Loss 385027.04536172835\n",
            "RCL 52048.06827755656\n",
            "KLD 161927.15044746335\n",
            "ADL 171051.82663670857\n",
            "Test ADE 0.6086663733654294\n",
            "Test Average FDE (Across  all samples) 2.4328798968327927\n",
            "Test Min FDE 0.9375373717243265\n",
            "Test Best ADE Loss So Far (N = 20) 0.5962630104931447\n",
            "Test Best Min FDE (N = 20) 0.898781535551971\n",
            "Test time error in destination best: 0.945 and mean: 2.430\n",
            "Test time error overall (ADE) best: 0.653\n",
            "Train Loss 382076.21770393185\n",
            "RCL 51272.61701657357\n",
            "KLD 161522.84523594019\n",
            "ADL 169280.7554514183\n",
            "Test ADE 0.6529303579304586\n",
            "Test Average FDE (Across  all samples) 2.4299308387538847\n",
            "Test Min FDE 0.944613161908236\n",
            "Test Best ADE Loss So Far (N = 20) 0.5962630104931447\n",
            "Test Best Min FDE (N = 20) 0.898781535551971\n",
            "Test time error in destination best: 0.947 and mean: 2.447\n",
            "Test time error overall (ADE) best: 0.623\n",
            "Train Loss 396307.3425885095\n",
            "RCL 56548.743957465216\n",
            "KLD 162800.42964045535\n",
            "ADL 176958.16899058898\n",
            "Test ADE 0.6232906619156242\n",
            "Test Average FDE (Across  all samples) 2.447333897519856\n",
            "Test Min FDE 0.9474627427956889\n",
            "Test Best ADE Loss So Far (N = 20) 0.5962630104931447\n",
            "Test Best Min FDE (N = 20) 0.898781535551971\n",
            "Test time error in destination best: 1.043 and mean: 2.558\n",
            "Test time error overall (ADE) best: 0.648\n",
            "Train Loss 386142.0002591601\n",
            "RCL 52673.803396554904\n",
            "KLD 161740.8276747592\n",
            "ADL 171727.3691878465\n",
            "Test ADE 0.648221680496352\n",
            "Test Average FDE (Across  all samples) 2.558211932085924\n",
            "Test Min FDE 1.0429240145872005\n",
            "Test Best ADE Loss So Far (N = 20) 0.5962630104931447\n",
            "Test Best Min FDE (N = 20) 0.898781535551971\n",
            "Test time error in destination best: 1.034 and mean: 2.544\n",
            "Test time error overall (ADE) best: 0.658\n",
            "Train Loss 388090.35804481106\n",
            "RCL 53551.902415212404\n",
            "KLD 161736.28748395518\n",
            "ADL 172802.16814564337\n",
            "Test ADE 0.6581383925178389\n",
            "Test Average FDE (Across  all samples) 2.543732021817386\n",
            "Test Min FDE 1.0337797354227733\n",
            "Test Best ADE Loss So Far (N = 20) 0.5962630104931447\n",
            "Test Best Min FDE (N = 20) 0.898781535551971\n",
            "Test time error in destination best: 1.030 and mean: 2.538\n",
            "Test time error overall (ADE) best: 0.647\n",
            "Train Loss 394490.9425016654\n",
            "RCL 55745.26869115314\n",
            "KLD 162227.77049434165\n",
            "ADL 176517.90331617076\n",
            "Test ADE 0.6468051483560087\n",
            "Test Average FDE (Across  all samples) 2.537626247314648\n",
            "Test Min FDE 1.03001708632186\n",
            "Test Best ADE Loss So Far (N = 20) 0.5962630104931447\n",
            "Test Best Min FDE (N = 20) 0.898781535551971\n",
            "Test time error in destination best: 1.036 and mean: 2.524\n",
            "Test time error overall (ADE) best: 0.643\n",
            "Train Loss 387525.5853659728\n",
            "RCL 54364.42183794848\n",
            "KLD 160707.22107314027\n",
            "ADL 172453.9424548837\n",
            "Test ADE 0.6434031950297475\n",
            "Test Average FDE (Across  all samples) 2.5238139032601805\n",
            "Test Min FDE 1.0355931699507508\n",
            "Test Best ADE Loss So Far (N = 20) 0.5962630104931447\n",
            "Test Best Min FDE (N = 20) 0.898781535551971\n",
            "Test time error in destination best: 0.959 and mean: 2.511\n",
            "Test time error overall (ADE) best: 0.622\n",
            "Train Loss 372887.79560531647\n",
            "RCL 47827.0153954956\n",
            "KLD 160900.25076459846\n",
            "ADL 164160.52944522194\n",
            "Test ADE 0.6216759573167088\n",
            "Test Average FDE (Across  all samples) 2.511183754196256\n",
            "Test Min FDE 0.9591107613373845\n",
            "Test Best ADE Loss So Far (N = 20) 0.5962630104931447\n",
            "Test Best Min FDE (N = 20) 0.898781535551971\n",
            "Test time error in destination best: 0.982 and mean: 2.496\n",
            "Test time error overall (ADE) best: 0.623\n",
            "Train Loss 378889.0675979117\n",
            "RCL 50824.52057726787\n",
            "KLD 160904.64418514888\n",
            "ADL 167159.90283549475\n",
            "Test ADE 0.6226894075739541\n",
            "Test Average FDE (Across  all samples) 2.49574514954517\n",
            "Test Min FDE 0.9823406069617394\n",
            "Test Best ADE Loss So Far (N = 20) 0.5962630104931447\n",
            "Test Best Min FDE (N = 20) 0.898781535551971\n",
            "Test time error in destination best: 0.942 and mean: 2.460\n",
            "Test time error overall (ADE) best: 0.613\n",
            "Train Loss 382858.03520604107\n",
            "RCL 53293.22793376715\n",
            "KLD 160842.7962042508\n",
            "ADL 168722.0110680235\n",
            "Test ADE 0.6130816750867463\n",
            "Test Average FDE (Across  all samples) 2.4595319530443045\n",
            "Test Min FDE 0.9421003466006893\n",
            "Test Best ADE Loss So Far (N = 20) 0.5962630104931447\n",
            "Test Best Min FDE (N = 20) 0.898781535551971\n",
            "Test time error in destination best: 0.907 and mean: 2.434\n",
            "Test time error overall (ADE) best: 0.606\n",
            "Train Loss 379714.8359479103\n",
            "RCL 50577.138013453274\n",
            "KLD 160369.07188701033\n",
            "ADL 168768.6260474463\n",
            "Test ADE 0.6057831032312946\n",
            "Test Average FDE (Across  all samples) 2.4336976570892537\n",
            "Test Min FDE 0.9071842091681676\n",
            "Test Best ADE Loss So Far (N = 20) 0.5962630104931447\n",
            "Test Best Min FDE (N = 20) 0.898781535551971\n",
            "Test time error in destination best: 0.923 and mean: 2.432\n",
            "Test time error overall (ADE) best: 0.611\n",
            "Train Loss 374532.72045281186\n",
            "RCL 49964.448940543094\n",
            "KLD 159648.01668143843\n",
            "ADL 164920.25483083015\n",
            "Test ADE 0.6106179329743062\n",
            "Test Average FDE (Across  all samples) 2.432336519010085\n",
            "Test Min FDE 0.9234169273470135\n",
            "Test Best ADE Loss So Far (N = 20) 0.5962630104931447\n",
            "Test Best Min FDE (N = 20) 0.898781535551971\n",
            "Test time error in destination best: 1.036 and mean: 2.552\n",
            "Test time error overall (ADE) best: 0.641\n",
            "Train Loss 379643.53686190076\n",
            "RCL 52635.48507031887\n",
            "KLD 159523.3942911368\n",
            "ADL 167484.65750044506\n",
            "Test ADE 0.6411940527388308\n",
            "Test Average FDE (Across  all samples) 2.551567320054414\n",
            "Test Min FDE 1.0357000448913\n",
            "Test Best ADE Loss So Far (N = 20) 0.5962630104931447\n",
            "Test Best Min FDE (N = 20) 0.898781535551971\n",
            "Test time error in destination best: 1.059 and mean: 2.595\n",
            "Test time error overall (ADE) best: 0.645\n",
            "Train Loss 389770.6917939073\n",
            "RCL 56292.79851127905\n",
            "KLD 160547.5289071842\n",
            "ADL 172930.36437544416\n",
            "Test ADE 0.6445167503460848\n",
            "Test Average FDE (Across  all samples) 2.5947196290933503\n",
            "Test Min FDE 1.0593428288811326\n",
            "Test Best ADE Loss So Far (N = 20) 0.5962630104931447\n",
            "Test Best Min FDE (N = 20) 0.898781535551971\n",
            "Test time error in destination best: 0.935 and mean: 2.460\n",
            "Test time error overall (ADE) best: 0.648\n",
            "Train Loss 367227.5219080574\n",
            "RCL 46574.5417743887\n",
            "KLD 159130.84902096193\n",
            "ADL 161522.13111270682\n",
            "Test ADE 0.6475494016790779\n",
            "Test Average FDE (Across  all samples) 2.459937922928019\n",
            "Test Min FDE 0.9346281734545362\n",
            "Test Best ADE Loss So Far (N = 20) 0.5962630104931447\n",
            "Test Best Min FDE (N = 20) 0.898781535551971\n",
            "Test time error in destination best: 0.985 and mean: 2.513\n",
            "Test time error overall (ADE) best: 0.632\n",
            "Train Loss 365378.2349376059\n",
            "RCL 45542.13838530039\n",
            "KLD 159740.50360243872\n",
            "ADL 160095.59294986664\n",
            "Test ADE 0.6320186333541048\n",
            "Test Average FDE (Across  all samples) 2.51349262812383\n",
            "Test Min FDE 0.9854818500973619\n",
            "Test Best ADE Loss So Far (N = 20) 0.5962630104931447\n",
            "Test Best Min FDE (N = 20) 0.898781535551971\n",
            "Test time error in destination best: 0.986 and mean: 2.489\n",
            "Test time error overall (ADE) best: 0.641\n",
            "Train Loss 378103.6628517093\n",
            "RCL 52675.55028106876\n",
            "KLD 158642.97890112264\n",
            "ADL 166785.1336695177\n",
            "Test ADE 0.6413963735314264\n",
            "Test Average FDE (Across  all samples) 2.488865172430446\n",
            "Test Min FDE 0.9858890892321418\n",
            "Test Best ADE Loss So Far (N = 20) 0.5962630104931447\n",
            "Test Best Min FDE (N = 20) 0.898781535551971\n",
            "Test time error in destination best: 0.900 and mean: 2.421\n",
            "Test time error overall (ADE) best: 0.624\n",
            "Train Loss 375081.47561740695\n",
            "RCL 51371.628885469494\n",
            "KLD 159185.92596697828\n",
            "ADL 164523.92076495863\n",
            "Test ADE 0.6241198994433426\n",
            "Test Average FDE (Across  all samples) 2.4210100231475318\n",
            "Test Min FDE 0.899910608006728\n",
            "Test Best ADE Loss So Far (N = 20) 0.5962630104931447\n",
            "Test Best Min FDE (N = 20) 0.898781535551971\n",
            "Test time error in destination best: 0.985 and mean: 2.512\n",
            "Test time error overall (ADE) best: 0.631\n",
            "Train Loss 374287.13833142736\n",
            "RCL 50421.4724526451\n",
            "KLD 159450.50919513224\n",
            "ADL 164415.15668364987\n",
            "Test ADE 0.6308972360835536\n",
            "Test Average FDE (Across  all samples) 2.5122446508102523\n",
            "Test Min FDE 0.9852673467500458\n",
            "Test Best ADE Loss So Far (N = 20) 0.5962630104931447\n",
            "Test Best Min FDE (N = 20) 0.898781535551971\n",
            "Test time error in destination best: 0.961 and mean: 2.461\n",
            "Test time error overall (ADE) best: 0.612\n",
            "Train Loss 371921.40380313894\n",
            "RCL 49695.41887768077\n",
            "KLD 158564.92743576574\n",
            "ADL 163661.05748969267\n",
            "Test ADE 0.6118252700992702\n",
            "Test Average FDE (Across  all samples) 2.461273281718508\n",
            "Test Min FDE 0.9608758612987061\n",
            "Test Best ADE Loss So Far (N = 20) 0.5962630104931447\n",
            "Test Best Min FDE (N = 20) 0.898781535551971\n",
            "Test time error in destination best: 1.022 and mean: 2.553\n",
            "Test time error overall (ADE) best: 0.663\n",
            "Train Loss 381803.02249693585\n",
            "RCL 54124.82242473636\n",
            "KLD 159549.99335454043\n",
            "ADL 168128.2067176595\n",
            "Test ADE 0.6632007714325806\n",
            "Test Average FDE (Across  all samples) 2.5533893392857805\n",
            "Test Min FDE 1.0222551612495299\n",
            "Test Best ADE Loss So Far (N = 20) 0.5962630104931447\n",
            "Test Best Min FDE (N = 20) 0.898781535551971\n",
            "Test time error in destination best: 0.935 and mean: 2.457\n",
            "Test time error overall (ADE) best: 0.601\n",
            "Train Loss 369042.08756623394\n",
            "RCL 48825.11027151084\n",
            "KLD 159045.8323267413\n",
            "ADL 161171.1449679819\n",
            "Test ADE 0.6005850487140039\n",
            "Test Average FDE (Across  all samples) 2.4570296278249084\n",
            "Test Min FDE 0.9353613406593334\n",
            "Test Best ADE Loss So Far (N = 20) 0.5962630104931447\n",
            "Test Best Min FDE (N = 20) 0.898781535551971\n",
            "Test time error in destination best: 1.031 and mean: 2.508\n",
            "Test time error overall (ADE) best: 0.649\n",
            "Train Loss 365783.04098475294\n",
            "RCL 47222.056872944966\n",
            "KLD 159101.12070093802\n",
            "ADL 159459.8634108699\n",
            "Test ADE 0.6492273680604035\n",
            "Test Average FDE (Across  all samples) 2.507661562070466\n",
            "Test Min FDE 1.03145202765158\n",
            "Test Best ADE Loss So Far (N = 20) 0.5962630104931447\n",
            "Test Best Min FDE (N = 20) 0.898781535551971\n",
            "Test time error in destination best: 0.967 and mean: 2.466\n",
            "Test time error overall (ADE) best: 0.622\n",
            "Train Loss 385026.5139158478\n",
            "RCL 56260.129397115896\n",
            "KLD 159021.42307896377\n",
            "ADL 169744.96143976817\n",
            "Test ADE 0.6224015630229084\n",
            "Test Average FDE (Across  all samples) 2.46633716811398\n",
            "Test Min FDE 0.9668196028349241\n",
            "Test Best ADE Loss So Far (N = 20) 0.5962630104931447\n",
            "Test Best Min FDE (N = 20) 0.898781535551971\n",
            "Test time error in destination best: 1.002 and mean: 2.519\n",
            "Test time error overall (ADE) best: 0.636\n",
            "Train Loss 378248.0185931097\n",
            "RCL 52212.59116902631\n",
            "KLD 158830.14727073885\n",
            "ADL 167205.28015334439\n",
            "Test ADE 0.6355834229343523\n",
            "Test Average FDE (Across  all samples) 2.5193301194415048\n",
            "Test Min FDE 1.0024295899352529\n",
            "Test Best ADE Loss So Far (N = 20) 0.5962630104931447\n",
            "Test Best Min FDE (N = 20) 0.898781535551971\n",
            "Test time error in destination best: 0.986 and mean: 2.520\n",
            "Test time error overall (ADE) best: 0.628\n",
            "Train Loss 362271.6838262041\n",
            "RCL 46717.4240409829\n",
            "KLD 158521.89715550514\n",
            "ADL 157032.3626297163\n",
            "Test ADE 0.6282302706738816\n",
            "Test Average FDE (Across  all samples) 2.519934355346562\n",
            "Test Min FDE 0.9858212331485884\n",
            "Test Best ADE Loss So Far (N = 20) 0.5962630104931447\n",
            "Test Best Min FDE (N = 20) 0.898781535551971\n",
            "Test time error in destination best: 0.931 and mean: 2.429\n",
            "Test time error overall (ADE) best: 0.648\n",
            "Train Loss 364820.8267111119\n",
            "RCL 48039.09480139832\n",
            "KLD 158323.56254858227\n",
            "ADL 158458.16936113164\n",
            "Test ADE 0.6483404641399098\n",
            "Test Average FDE (Across  all samples) 2.428931348739617\n",
            "Test Min FDE 0.9309648855146231\n",
            "Test Best ADE Loss So Far (N = 20) 0.5962630104931447\n",
            "Test Best Min FDE (N = 20) 0.898781535551971\n",
            "Test time error in destination best: 1.102 and mean: 2.620\n",
            "Test time error overall (ADE) best: 0.718\n",
            "Train Loss 369829.82372734445\n",
            "RCL 49835.709416217935\n",
            "KLD 157929.279114177\n",
            "ADL 162064.83519694945\n",
            "Test ADE 0.7184398729899418\n",
            "Test Average FDE (Across  all samples) 2.6204496068232683\n",
            "Test Min FDE 1.101744091545386\n",
            "Test Best ADE Loss So Far (N = 20) 0.5962630104931447\n",
            "Test Best Min FDE (N = 20) 0.898781535551971\n",
            "Test time error in destination best: 1.033 and mean: 2.557\n",
            "Test time error overall (ADE) best: 0.650\n",
            "Train Loss 375780.74631498\n",
            "RCL 52007.63701830604\n",
            "KLD 159743.18594939352\n",
            "ADL 164029.9233472804\n",
            "Test ADE 0.6496204989457736\n",
            "Test Average FDE (Across  all samples) 2.557244435679376\n",
            "Test Min FDE 1.0333143446280935\n",
            "Test Best ADE Loss So Far (N = 20) 0.5962630104931447\n",
            "Test Best Min FDE (N = 20) 0.898781535551971\n",
            "Test time error in destination best: 0.955 and mean: 2.457\n",
            "Test time error overall (ADE) best: 0.622\n",
            "Train Loss 359142.1447553678\n",
            "RCL 44927.316265713634\n",
            "KLD 158680.9490672391\n",
            "ADL 155533.87942241514\n",
            "Test ADE 0.622043470118972\n",
            "Test Average FDE (Across  all samples) 2.4569229034456463\n",
            "Test Min FDE 0.955439082531653\n",
            "Test Best ADE Loss So Far (N = 20) 0.5962630104931447\n",
            "Test Best Min FDE (N = 20) 0.898781535551971\n",
            "Test time error in destination best: 0.896 and mean: 2.422\n",
            "Test time error overall (ADE) best: 0.655\n",
            "Train Loss 358525.75600768044\n",
            "RCL 44337.04976281552\n",
            "KLD 158838.8598109375\n",
            "ADL 155349.84643392768\n",
            "Test ADE 0.65465532049485\n",
            "Test Average FDE (Across  all samples) 2.4221333260528097\n",
            "Test Min FDE 0.896489962710824\n",
            "Test Best ADE Loss So Far (N = 20) 0.5962630104931447\n",
            "Test Best Min FDE (N = 20) 0.896489962710824\n",
            "Test time error in destination best: 0.967 and mean: 2.509\n",
            "Test time error overall (ADE) best: 0.621\n",
            "Train Loss 360360.3460951231\n",
            "RCL 45707.48825190898\n",
            "KLD 158360.25805727395\n",
            "ADL 156292.59978594017\n",
            "Test ADE 0.6211952213851184\n",
            "Test Average FDE (Across  all samples) 2.5085194045146166\n",
            "Test Min FDE 0.9671530097954933\n",
            "Test Best ADE Loss So Far (N = 20) 0.5962630104931447\n",
            "Test Best Min FDE (N = 20) 0.896489962710824\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4DL0vHVUPzyl"
      },
      "source": [
        "## Test Pretrained Model"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "qrz1mDd9Yhkv"
      },
      "source": [
        "checkpoint = torch.load('/content/trained.pt', map_location=device)\n",
        "hyper_params = checkpoint[\"hyper_params\"]  "
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "MKMCFWf1FDQh",
        "outputId": "95136287-0e54-4ed7-80ad-bfbfce7e7081",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "source": [
        "hyper_params[\"data_scale\"]"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "170.0"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 16
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ZB9f4TCCPe3T"
      },
      "source": [
        "def test_model(test_dataset, model, best_of_n = 1):\n",
        "\n",
        "\tdataloader = data.DataLoader(\n",
        "\t\t\ttest_dataset, batch_size=len(test_dataset), shuffle=False, num_workers=0)\n",
        "\n",
        "\tmodel.eval()\n",
        "\tassert best_of_n >= 1 and type(best_of_n) == int\n",
        "\ttest_loss = 0\n",
        "\n",
        "\twith torch.no_grad():\n",
        "\t\tfor i, trajx in enumerate(dataloader):\n",
        "\n",
        "\t\t\ttraj = trajx - trajx[:, :1, :]\n",
        "\t\t\ttraj *= hyper_params[\"data_scale\"]\n",
        "\n",
        "\t\t\ttraj = torch.DoubleTensor(traj).to(device)\n",
        "\t \n",
        "\t\t\tx = traj[:, :hyper_params[\"past_length\"], :]\n",
        "\t\t\ty = traj[:, hyper_params[\"past_length\"]:, :]\n",
        "\t\t\ty = y.cpu().numpy()\n",
        "\t \n",
        "\t \t\t# reshape the data\n",
        "\t\t\tx = x.view(-1, x.shape[1]*x.shape[2])\n",
        "\t \t\t# x = x.to(device)\n",
        "\t\t\tx = x.to(device)\n",
        "\t\t\n",
        "\t\t\tfuture = y[:, :-1, :]\n",
        "\t\t\tdest = y[:, -1, :]\n",
        "\t\t\tall_l2_errors_dest = []\n",
        "\t\t\tall_guesses = []\n",
        "\t\t\tfor index in range(best_of_n):\n",
        "\n",
        "\t\t\t\t# dest_recon = model.forward(x, initial_pos, device=device)\n",
        "\t\t\t\tdest_recon = model.forward(x, device=device)\n",
        "\t\t\t\tdest_recon = dest_recon.cpu().numpy()\n",
        "\t\t\t\tall_guesses.append(dest_recon)\n",
        "\n",
        "\t\t\t\tl2error_sample = np.linalg.norm(dest_recon - dest, axis = 1)\n",
        "\t\t\t\tall_l2_errors_dest.append(l2error_sample)\n",
        "\n",
        "\t\t\tall_l2_errors_dest = np.array(all_l2_errors_dest)\n",
        "\t\t\tall_guesses = np.array(all_guesses)\n",
        "\t\t\t# average error\n",
        "\t\t\tl2error_avg_dest = np.mean(all_l2_errors_dest)\n",
        "\n",
        "\t\t\t# choosing the best guess\n",
        "\t\t\tindices = np.argmin(all_l2_errors_dest, axis = 0)\n",
        "\n",
        "\t\t\tbest_guess_dest = all_guesses[indices,np.arange(x.shape[0]),  :]\n",
        "\n",
        "\t\t\t# taking the minimum error out of all guess\n",
        "\t\t\tl2error_dest = np.mean(np.min(all_l2_errors_dest, axis = 0))\n",
        "\n",
        "\t\t\t# back to torch land\n",
        "\t\t\tbest_guess_dest = torch.DoubleTensor(best_guess_dest).to(device)\n",
        "\n",
        "\t\t\t# using the best guess for interpolation\n",
        "\t\t\t# interpolated_future = model.predict(x, best_guess_dest, mask, initial_pos)\n",
        "\t\t\tinterpolated_future = model.predict(x, best_guess_dest)\n",
        "\t\t\tinterpolated_future = interpolated_future.cpu().numpy()\n",
        "\t\t\tbest_guess_dest = best_guess_dest.cpu().numpy()\n",
        "\n",
        "\t\t\t# final overall prediction\n",
        "\t\t\tpredicted_future = np.concatenate((interpolated_future, best_guess_dest), axis = 1)\n",
        "\t\t\tpredicted_future = np.reshape(predicted_future, (-1, hyper_params[\"future_length\"], 2))\n",
        "\n",
        "\t\t\t# ADE error\n",
        "\t\t\tl2error_overall = np.mean(np.linalg.norm(y - predicted_future, axis = 2))\n",
        "\n",
        "\t\t\tl2error_overall /= hyper_params[\"data_scale\"]\n",
        "\t\t\tl2error_dest /= hyper_params[\"data_scale\"]\n",
        "\t\t\tl2error_avg_dest /= hyper_params[\"data_scale\"]\n",
        "\n",
        "\t\t\tprint('Test time error in destination best: {:0.3f} and mean: {:0.3f}'.format(l2error_dest, l2error_avg_dest))\n",
        "\t\t\tprint('Test time error overall (ADE) best: {:0.3f}'.format(l2error_overall))\n",
        "\n",
        "\treturn l2error_overall, l2error_dest, l2error_avg_dest"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "dXZeP0NAP_Ak"
      },
      "source": [
        "def run_test(N=20):\n",
        "  \n",
        "  model = PECNet(\n",
        "      hyper_params[\"enc_past_size\"],\n",
        "      hyper_params[\"enc_dest_size\"],\n",
        "      hyper_params[\"enc_latent_size\"],\n",
        "      hyper_params[\"dec_size\"],\n",
        "      hyper_params[\"predictor_hidden_size\"],\n",
        "      hyper_params[\"fdim\"], \n",
        "      hyper_params[\"zdim\"], \n",
        "      hyper_params[\"sigma\"], \n",
        "      hyper_params[\"past_length\"], \n",
        "      hyper_params[\"future_length\"], verbose=True)\n",
        "  \n",
        "  \n",
        "  model = model.double().to(device)\n",
        "  model.load_state_dict(checkpoint[\"model_state_dict\"])\n",
        "  test_dataset = SocialDataset(\n",
        "      \"/content/eth_test.npz\",\n",
        "      set_name=\"test\",\n",
        "      verbose=True)\n",
        "    \n",
        "  #average ade/fde for k=20 (to account for variance in sampling)\n",
        "  num_samples = 150\n",
        "  average_ade, average_fde = 0, 0\n",
        "  for i in range(num_samples):\n",
        "    test_loss, final_point_loss_best, final_point_loss_avg = test_model(test_dataset, model, best_of_n = N)\n",
        "    average_ade += test_loss\n",
        "    average_fde += final_point_loss_best\n",
        "    \n",
        "  print()\n",
        "  print(\"Average ADE:\", average_ade/num_samples)\n",
        "  print(\"Average FDE:\", average_fde/num_samples)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "xduIj7cbQv5l",
        "outputId": "c18167c4-e124-4b15-fada-6608eda6b268",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "source": [
        "run_test()"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Past Encoder architecture : [16, 512, 256, 16]\n",
            "Dest Encoder architecture : [2, 8, 16, 16]\n",
            "Latent Encoder architecture : [32, 8, 50, 32]\n",
            "Decoder architecture : [32, 1024, 512, 1024, 2]\n",
            "Predictor architecture : [32, 1024, 512, 256, 22]\n",
            "(364, 8, 2)\n",
            "(364, 12, 2)\n",
            "(364, 20, 2)\n",
            "Test time error in destination best: 0.980 and mean: 2.444\n",
            "Test time error overall (ADE) best: 0.606\n",
            "Test time error in destination best: 1.006 and mean: 2.435\n",
            "Test time error overall (ADE) best: 0.613\n",
            "Test time error in destination best: 0.970 and mean: 2.416\n",
            "Test time error overall (ADE) best: 0.604\n",
            "Test time error in destination best: 0.949 and mean: 2.432\n",
            "Test time error overall (ADE) best: 0.600\n",
            "Test time error in destination best: 0.929 and mean: 2.450\n",
            "Test time error overall (ADE) best: 0.598\n",
            "Test time error in destination best: 0.983 and mean: 2.444\n",
            "Test time error overall (ADE) best: 0.618\n",
            "Test time error in destination best: 0.967 and mean: 2.428\n",
            "Test time error overall (ADE) best: 0.605\n",
            "Test time error in destination best: 0.994 and mean: 2.440\n",
            "Test time error overall (ADE) best: 0.611\n",
            "Test time error in destination best: 0.956 and mean: 2.429\n",
            "Test time error overall (ADE) best: 0.600\n",
            "Test time error in destination best: 0.964 and mean: 2.424\n",
            "Test time error overall (ADE) best: 0.601\n",
            "Test time error in destination best: 0.967 and mean: 2.433\n",
            "Test time error overall (ADE) best: 0.604\n",
            "Test time error in destination best: 0.981 and mean: 2.415\n",
            "Test time error overall (ADE) best: 0.605\n",
            "Test time error in destination best: 0.966 and mean: 2.440\n",
            "Test time error overall (ADE) best: 0.601\n",
            "Test time error in destination best: 0.975 and mean: 2.442\n",
            "Test time error overall (ADE) best: 0.607\n",
            "Test time error in destination best: 0.987 and mean: 2.451\n",
            "Test time error overall (ADE) best: 0.612\n",
            "Test time error in destination best: 0.955 and mean: 2.444\n",
            "Test time error overall (ADE) best: 0.599\n",
            "Test time error in destination best: 0.950 and mean: 2.421\n",
            "Test time error overall (ADE) best: 0.599\n",
            "Test time error in destination best: 0.988 and mean: 2.438\n",
            "Test time error overall (ADE) best: 0.608\n",
            "Test time error in destination best: 0.959 and mean: 2.426\n",
            "Test time error overall (ADE) best: 0.601\n",
            "Test time error in destination best: 0.963 and mean: 2.418\n",
            "Test time error overall (ADE) best: 0.604\n",
            "Test time error in destination best: 0.983 and mean: 2.446\n",
            "Test time error overall (ADE) best: 0.606\n",
            "Test time error in destination best: 0.977 and mean: 2.441\n",
            "Test time error overall (ADE) best: 0.605\n",
            "Test time error in destination best: 0.961 and mean: 2.433\n",
            "Test time error overall (ADE) best: 0.596\n",
            "Test time error in destination best: 0.965 and mean: 2.427\n",
            "Test time error overall (ADE) best: 0.602\n",
            "Test time error in destination best: 1.000 and mean: 2.443\n",
            "Test time error overall (ADE) best: 0.607\n",
            "Test time error in destination best: 0.987 and mean: 2.437\n",
            "Test time error overall (ADE) best: 0.612\n",
            "Test time error in destination best: 0.965 and mean: 2.432\n",
            "Test time error overall (ADE) best: 0.608\n",
            "Test time error in destination best: 0.982 and mean: 2.428\n",
            "Test time error overall (ADE) best: 0.603\n",
            "Test time error in destination best: 0.969 and mean: 2.426\n",
            "Test time error overall (ADE) best: 0.605\n",
            "Test time error in destination best: 1.003 and mean: 2.421\n",
            "Test time error overall (ADE) best: 0.615\n",
            "Test time error in destination best: 0.971 and mean: 2.444\n",
            "Test time error overall (ADE) best: 0.610\n",
            "Test time error in destination best: 0.962 and mean: 2.449\n",
            "Test time error overall (ADE) best: 0.606\n",
            "Test time error in destination best: 1.006 and mean: 2.437\n",
            "Test time error overall (ADE) best: 0.619\n",
            "Test time error in destination best: 0.972 and mean: 2.417\n",
            "Test time error overall (ADE) best: 0.608\n",
            "Test time error in destination best: 0.981 and mean: 2.434\n",
            "Test time error overall (ADE) best: 0.603\n",
            "Test time error in destination best: 0.968 and mean: 2.455\n",
            "Test time error overall (ADE) best: 0.603\n",
            "Test time error in destination best: 0.946 and mean: 2.439\n",
            "Test time error overall (ADE) best: 0.596\n",
            "Test time error in destination best: 1.011 and mean: 2.442\n",
            "Test time error overall (ADE) best: 0.616\n",
            "Test time error in destination best: 0.989 and mean: 2.437\n",
            "Test time error overall (ADE) best: 0.610\n",
            "Test time error in destination best: 0.988 and mean: 2.436\n",
            "Test time error overall (ADE) best: 0.604\n",
            "Test time error in destination best: 0.985 and mean: 2.448\n",
            "Test time error overall (ADE) best: 0.604\n",
            "Test time error in destination best: 0.970 and mean: 2.421\n",
            "Test time error overall (ADE) best: 0.611\n",
            "Test time error in destination best: 0.958 and mean: 2.445\n",
            "Test time error overall (ADE) best: 0.601\n",
            "Test time error in destination best: 0.994 and mean: 2.434\n",
            "Test time error overall (ADE) best: 0.618\n",
            "Test time error in destination best: 0.962 and mean: 2.430\n",
            "Test time error overall (ADE) best: 0.603\n",
            "Test time error in destination best: 0.965 and mean: 2.434\n",
            "Test time error overall (ADE) best: 0.598\n",
            "Test time error in destination best: 0.985 and mean: 2.439\n",
            "Test time error overall (ADE) best: 0.609\n",
            "Test time error in destination best: 0.945 and mean: 2.443\n",
            "Test time error overall (ADE) best: 0.595\n",
            "Test time error in destination best: 0.966 and mean: 2.435\n",
            "Test time error overall (ADE) best: 0.601\n",
            "Test time error in destination best: 0.927 and mean: 2.434\n",
            "Test time error overall (ADE) best: 0.575\n",
            "Test time error in destination best: 0.981 and mean: 2.434\n",
            "Test time error overall (ADE) best: 0.607\n",
            "Test time error in destination best: 0.991 and mean: 2.446\n",
            "Test time error overall (ADE) best: 0.615\n",
            "Test time error in destination best: 0.990 and mean: 2.439\n",
            "Test time error overall (ADE) best: 0.613\n",
            "Test time error in destination best: 0.973 and mean: 2.433\n",
            "Test time error overall (ADE) best: 0.600\n",
            "Test time error in destination best: 0.965 and mean: 2.454\n",
            "Test time error overall (ADE) best: 0.604\n",
            "Test time error in destination best: 0.982 and mean: 2.435\n",
            "Test time error overall (ADE) best: 0.610\n",
            "Test time error in destination best: 0.941 and mean: 2.441\n",
            "Test time error overall (ADE) best: 0.594\n",
            "Test time error in destination best: 0.950 and mean: 2.428\n",
            "Test time error overall (ADE) best: 0.596\n",
            "Test time error in destination best: 0.963 and mean: 2.437\n",
            "Test time error overall (ADE) best: 0.609\n",
            "Test time error in destination best: 0.962 and mean: 2.445\n",
            "Test time error overall (ADE) best: 0.605\n",
            "Test time error in destination best: 0.955 and mean: 2.438\n",
            "Test time error overall (ADE) best: 0.589\n",
            "Test time error in destination best: 0.991 and mean: 2.454\n",
            "Test time error overall (ADE) best: 0.609\n",
            "Test time error in destination best: 0.975 and mean: 2.428\n",
            "Test time error overall (ADE) best: 0.613\n",
            "Test time error in destination best: 0.994 and mean: 2.439\n",
            "Test time error overall (ADE) best: 0.612\n",
            "Test time error in destination best: 0.982 and mean: 2.436\n",
            "Test time error overall (ADE) best: 0.610\n",
            "Test time error in destination best: 0.958 and mean: 2.426\n",
            "Test time error overall (ADE) best: 0.596\n",
            "Test time error in destination best: 0.979 and mean: 2.437\n",
            "Test time error overall (ADE) best: 0.597\n",
            "Test time error in destination best: 1.014 and mean: 2.449\n",
            "Test time error overall (ADE) best: 0.620\n",
            "Test time error in destination best: 0.980 and mean: 2.434\n",
            "Test time error overall (ADE) best: 0.609\n",
            "Test time error in destination best: 0.990 and mean: 2.413\n",
            "Test time error overall (ADE) best: 0.613\n",
            "Test time error in destination best: 0.978 and mean: 2.431\n",
            "Test time error overall (ADE) best: 0.608\n",
            "Test time error in destination best: 0.967 and mean: 2.441\n",
            "Test time error overall (ADE) best: 0.605\n",
            "Test time error in destination best: 0.967 and mean: 2.441\n",
            "Test time error overall (ADE) best: 0.603\n",
            "Test time error in destination best: 0.959 and mean: 2.406\n",
            "Test time error overall (ADE) best: 0.596\n",
            "Test time error in destination best: 0.973 and mean: 2.439\n",
            "Test time error overall (ADE) best: 0.602\n",
            "Test time error in destination best: 1.004 and mean: 2.438\n",
            "Test time error overall (ADE) best: 0.616\n",
            "Test time error in destination best: 0.964 and mean: 2.432\n",
            "Test time error overall (ADE) best: 0.600\n",
            "Test time error in destination best: 0.958 and mean: 2.437\n",
            "Test time error overall (ADE) best: 0.601\n",
            "Test time error in destination best: 0.978 and mean: 2.437\n",
            "Test time error overall (ADE) best: 0.605\n",
            "Test time error in destination best: 0.973 and mean: 2.433\n",
            "Test time error overall (ADE) best: 0.601\n",
            "Test time error in destination best: 0.975 and mean: 2.448\n",
            "Test time error overall (ADE) best: 0.610\n",
            "Test time error in destination best: 0.966 and mean: 2.433\n",
            "Test time error overall (ADE) best: 0.600\n",
            "Test time error in destination best: 0.956 and mean: 2.434\n",
            "Test time error overall (ADE) best: 0.595\n",
            "Test time error in destination best: 0.990 and mean: 2.425\n",
            "Test time error overall (ADE) best: 0.611\n",
            "Test time error in destination best: 0.935 and mean: 2.430\n",
            "Test time error overall (ADE) best: 0.586\n",
            "Test time error in destination best: 0.952 and mean: 2.437\n",
            "Test time error overall (ADE) best: 0.598\n",
            "Test time error in destination best: 0.962 and mean: 2.451\n",
            "Test time error overall (ADE) best: 0.595\n",
            "Test time error in destination best: 1.002 and mean: 2.453\n",
            "Test time error overall (ADE) best: 0.615\n",
            "Test time error in destination best: 0.990 and mean: 2.435\n",
            "Test time error overall (ADE) best: 0.610\n",
            "Test time error in destination best: 0.989 and mean: 2.439\n",
            "Test time error overall (ADE) best: 0.613\n",
            "Test time error in destination best: 0.976 and mean: 2.431\n",
            "Test time error overall (ADE) best: 0.608\n",
            "Test time error in destination best: 0.972 and mean: 2.427\n",
            "Test time error overall (ADE) best: 0.611\n",
            "Test time error in destination best: 0.990 and mean: 2.435\n",
            "Test time error overall (ADE) best: 0.609\n",
            "Test time error in destination best: 0.995 and mean: 2.477\n",
            "Test time error overall (ADE) best: 0.613\n",
            "Test time error in destination best: 0.993 and mean: 2.455\n",
            "Test time error overall (ADE) best: 0.619\n",
            "Test time error in destination best: 0.980 and mean: 2.453\n",
            "Test time error overall (ADE) best: 0.614\n",
            "Test time error in destination best: 0.985 and mean: 2.444\n",
            "Test time error overall (ADE) best: 0.611\n",
            "Test time error in destination best: 0.947 and mean: 2.429\n",
            "Test time error overall (ADE) best: 0.596\n",
            "Test time error in destination best: 0.951 and mean: 2.432\n",
            "Test time error overall (ADE) best: 0.603\n",
            "Test time error in destination best: 0.982 and mean: 2.443\n",
            "Test time error overall (ADE) best: 0.610\n",
            "Test time error in destination best: 1.004 and mean: 2.428\n",
            "Test time error overall (ADE) best: 0.608\n",
            "Test time error in destination best: 0.959 and mean: 2.450\n",
            "Test time error overall (ADE) best: 0.599\n",
            "Test time error in destination best: 0.955 and mean: 2.429\n",
            "Test time error overall (ADE) best: 0.599\n",
            "Test time error in destination best: 0.966 and mean: 2.455\n",
            "Test time error overall (ADE) best: 0.607\n",
            "Test time error in destination best: 0.957 and mean: 2.419\n",
            "Test time error overall (ADE) best: 0.606\n",
            "Test time error in destination best: 0.972 and mean: 2.440\n",
            "Test time error overall (ADE) best: 0.602\n",
            "Test time error in destination best: 0.994 and mean: 2.438\n",
            "Test time error overall (ADE) best: 0.607\n",
            "Test time error in destination best: 0.979 and mean: 2.435\n",
            "Test time error overall (ADE) best: 0.612\n",
            "Test time error in destination best: 1.005 and mean: 2.432\n",
            "Test time error overall (ADE) best: 0.607\n",
            "Test time error in destination best: 0.997 and mean: 2.445\n",
            "Test time error overall (ADE) best: 0.613\n",
            "Test time error in destination best: 0.995 and mean: 2.432\n",
            "Test time error overall (ADE) best: 0.601\n",
            "Test time error in destination best: 0.998 and mean: 2.441\n",
            "Test time error overall (ADE) best: 0.612\n",
            "Test time error in destination best: 0.970 and mean: 2.434\n",
            "Test time error overall (ADE) best: 0.598\n",
            "Test time error in destination best: 0.961 and mean: 2.424\n",
            "Test time error overall (ADE) best: 0.599\n",
            "Test time error in destination best: 0.977 and mean: 2.436\n",
            "Test time error overall (ADE) best: 0.601\n",
            "Test time error in destination best: 0.943 and mean: 2.442\n",
            "Test time error overall (ADE) best: 0.592\n",
            "Test time error in destination best: 0.959 and mean: 2.442\n",
            "Test time error overall (ADE) best: 0.605\n",
            "Test time error in destination best: 0.987 and mean: 2.433\n",
            "Test time error overall (ADE) best: 0.613\n",
            "Test time error in destination best: 0.959 and mean: 2.420\n",
            "Test time error overall (ADE) best: 0.595\n",
            "Test time error in destination best: 0.962 and mean: 2.440\n",
            "Test time error overall (ADE) best: 0.608\n",
            "Test time error in destination best: 0.978 and mean: 2.439\n",
            "Test time error overall (ADE) best: 0.610\n",
            "Test time error in destination best: 0.970 and mean: 2.460\n",
            "Test time error overall (ADE) best: 0.605\n",
            "Test time error in destination best: 0.955 and mean: 2.423\n",
            "Test time error overall (ADE) best: 0.599\n",
            "Test time error in destination best: 0.950 and mean: 2.434\n",
            "Test time error overall (ADE) best: 0.597\n",
            "Test time error in destination best: 0.980 and mean: 2.441\n",
            "Test time error overall (ADE) best: 0.613\n",
            "Test time error in destination best: 0.982 and mean: 2.425\n",
            "Test time error overall (ADE) best: 0.603\n",
            "Test time error in destination best: 0.974 and mean: 2.409\n",
            "Test time error overall (ADE) best: 0.607\n",
            "Test time error in destination best: 0.968 and mean: 2.437\n",
            "Test time error overall (ADE) best: 0.603\n",
            "Test time error in destination best: 1.024 and mean: 2.438\n",
            "Test time error overall (ADE) best: 0.621\n",
            "Test time error in destination best: 0.967 and mean: 2.438\n",
            "Test time error overall (ADE) best: 0.603\n",
            "Test time error in destination best: 0.962 and mean: 2.435\n",
            "Test time error overall (ADE) best: 0.598\n",
            "Test time error in destination best: 0.966 and mean: 2.444\n",
            "Test time error overall (ADE) best: 0.601\n",
            "Test time error in destination best: 0.939 and mean: 2.420\n",
            "Test time error overall (ADE) best: 0.585\n",
            "Test time error in destination best: 1.009 and mean: 2.445\n",
            "Test time error overall (ADE) best: 0.607\n",
            "Test time error in destination best: 0.950 and mean: 2.417\n",
            "Test time error overall (ADE) best: 0.594\n",
            "Test time error in destination best: 0.982 and mean: 2.455\n",
            "Test time error overall (ADE) best: 0.611\n",
            "Test time error in destination best: 0.987 and mean: 2.443\n",
            "Test time error overall (ADE) best: 0.603\n",
            "Test time error in destination best: 0.952 and mean: 2.437\n",
            "Test time error overall (ADE) best: 0.602\n",
            "Test time error in destination best: 1.013 and mean: 2.439\n",
            "Test time error overall (ADE) best: 0.621\n",
            "Test time error in destination best: 0.977 and mean: 2.449\n",
            "Test time error overall (ADE) best: 0.602\n",
            "Test time error in destination best: 0.963 and mean: 2.437\n",
            "Test time error overall (ADE) best: 0.605\n",
            "Test time error in destination best: 0.977 and mean: 2.441\n",
            "Test time error overall (ADE) best: 0.604\n",
            "Test time error in destination best: 0.947 and mean: 2.415\n",
            "Test time error overall (ADE) best: 0.599\n",
            "Test time error in destination best: 0.973 and mean: 2.429\n",
            "Test time error overall (ADE) best: 0.605\n",
            "Test time error in destination best: 0.963 and mean: 2.438\n",
            "Test time error overall (ADE) best: 0.606\n",
            "Test time error in destination best: 0.979 and mean: 2.414\n",
            "Test time error overall (ADE) best: 0.605\n",
            "Test time error in destination best: 0.970 and mean: 2.432\n",
            "Test time error overall (ADE) best: 0.600\n",
            "Test time error in destination best: 0.979 and mean: 2.436\n",
            "Test time error overall (ADE) best: 0.606\n",
            "Test time error in destination best: 0.938 and mean: 2.425\n",
            "Test time error overall (ADE) best: 0.588\n",
            "Test time error in destination best: 0.983 and mean: 2.433\n",
            "Test time error overall (ADE) best: 0.610\n",
            "\n",
            "Average ADE: 0.604837519227152\n",
            "Average FDE: 0.9734009180144085\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "2PNTyMgyheh0"
      },
      "source": [
        ""
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}