{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "3035bb16",
   "metadata": {
    "cellId": "9c89p320il9z9rlv6wv1pm",
    "id": "3035bb16"
   },
   "outputs": [],
   "source": [
    "#!g1.1\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from model_rnn import MainModel\n",
    "from dataset import LocalizationDataset\n",
    "from torch.utils.data import DataLoader\n",
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from ModelParams import ModelParams\n",
    "from matplotlib import pyplot as plt\n",
    "import matplotlib as mpl\n",
    "from IPython.display import clear_output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "85d4720a",
   "metadata": {
    "cellId": "5gg9z2ucnml262zy30ts37h",
    "id": "85d4720a"
   },
   "outputs": [],
   "source": [
    "#!g1.1\n",
    "device = 'cuda:0' if torch.cuda.is_available() else 'cpu'\n",
    "\n",
    "num_tracks = 10000\n",
    "track_data = np.array(pd.read_csv('trajectories.csv', header=None))\n",
    "world = np.loadtxt('environment.csv', delimiter=',')\n",
    "points = []\n",
    "for i, line in enumerate(world):\n",
    "    nb_y = world.shape[0] - i - 1\n",
    "    for j, block in enumerate(line):\n",
    "        if block == 2:\n",
    "            points.extend([[j+0.5, nb_y+0.5]])\n",
    "\n",
    "points = torch.from_numpy(np.array(points)).double().to(device)\n",
    "\n",
    "all_data = dict()\n",
    "\n",
    "track_len = track_data.shape[0] // num_tracks\n",
    "track = np.zeros((num_tracks, track_len, track_data.shape[1]))\n",
    "for i in range(num_tracks):\n",
    "    track[i] = track_data[i*track_len:(i+1)*track_len]\n",
    "\n",
    "eval_test_numbers = np.random.choice(num_tracks, size=num_tracks//5, replace=False)\n",
    "eval_numbers = eval_test_numbers[:len(eval_test_numbers)//2]\n",
    "test_numbers = eval_test_numbers[len(eval_test_numbers)//2:]\n",
    "train_numbers = np.setdiff1d(np.arange(num_tracks), eval_test_numbers)\n",
    "\n",
    "train_data = dict()\n",
    "eval_data = dict()\n",
    "test_data = dict()\n",
    "\n",
    "train_data['tracks'] = track[train_numbers]\n",
    "train_data['map'] = world\n",
    "\n",
    "eval_data['tracks'] = track[eval_numbers]\n",
    "eval_data['map'] = world\n",
    "\n",
    "test_data['tracks'] = track[test_numbers]\n",
    "test_data['map'] = world\n",
    "\n",
    "\n",
    "train_dataset = LocalizationDataset(train_data)\n",
    "eval_dataset = LocalizationDataset(eval_data)\n",
    "test_dataset = LocalizationDataset(test_data)\n",
    "\n",
    "params = ModelParams()\n",
    "train_loader = DataLoader(train_dataset, batch_size=params.batch_size, pin_memory=True, shuffle=True)\n",
    "eval_loader = DataLoader(eval_dataset, batch_size=params.batch_size, pin_memory=True, shuffle=False)\n",
    "test_loader = DataLoader(test_dataset, batch_size=params.batch_size, pin_memory=True, shuffle=False)\n",
    "\n",
    "device = 'cuda:0' if torch.cuda.is_available() else 'cpu'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "7ea6dfc0",
   "metadata": {
    "cellId": "shfnuoxcnqcsqqk5ghc0r",
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 469
    },
    "id": "7ea6dfc0",
    "outputId": "35be7436-d5a4-4c1f-f807-24bf30bc7cae"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  3%|▎         | 131/5000 [4:19:57<159:32:01, 117.95s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "60.58 66.18\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  3%|▎         | 131/5000 [4:20:27<161:20:43, 119.29s/it]\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[1;32mIn[5], line 21\u001b[0m\n\u001b[0;32m     19\u001b[0m model\u001b[38;5;241m.\u001b[39mzero_grad()\n\u001b[0;32m     20\u001b[0m loss, last_loss, pred, particle_pred \u001b[38;5;241m=\u001b[39m model\u001b[38;5;241m.\u001b[39mstep(measurement, motion, location, params\u001b[38;5;241m.\u001b[39mbpdecay)\n\u001b[1;32m---> 21\u001b[0m \u001b[43mloss\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m     22\u001b[0m torch\u001b[38;5;241m.\u001b[39mnn\u001b[38;5;241m.\u001b[39mutils\u001b[38;5;241m.\u001b[39mclip_grad_norm_(model\u001b[38;5;241m.\u001b[39mparameters(), grad_clip)\n\u001b[0;32m     23\u001b[0m optimizer\u001b[38;5;241m.\u001b[39mstep()\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python310\\site-packages\\torch\\_tensor.py:487\u001b[0m, in \u001b[0;36mTensor.backward\u001b[1;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[0;32m    477\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m has_torch_function_unary(\u001b[38;5;28mself\u001b[39m):\n\u001b[0;32m    478\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m handle_torch_function(\n\u001b[0;32m    479\u001b[0m         Tensor\u001b[38;5;241m.\u001b[39mbackward,\n\u001b[0;32m    480\u001b[0m         (\u001b[38;5;28mself\u001b[39m,),\n\u001b[1;32m   (...)\u001b[0m\n\u001b[0;32m    485\u001b[0m         inputs\u001b[38;5;241m=\u001b[39minputs,\n\u001b[0;32m    486\u001b[0m     )\n\u001b[1;32m--> 487\u001b[0m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mautograd\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m    488\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgradient\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs\u001b[49m\n\u001b[0;32m    489\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[1;32m~\\AppData\\Roaming\\Python\\Python310\\site-packages\\torch\\autograd\\__init__.py:200\u001b[0m, in \u001b[0;36mbackward\u001b[1;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[0;32m    195\u001b[0m     retain_graph \u001b[38;5;241m=\u001b[39m create_graph\n\u001b[0;32m    197\u001b[0m \u001b[38;5;66;03m# The reason we repeat same the comment below is that\u001b[39;00m\n\u001b[0;32m    198\u001b[0m \u001b[38;5;66;03m# some Python versions print out the first line of a multi-line function\u001b[39;00m\n\u001b[0;32m    199\u001b[0m \u001b[38;5;66;03m# calls in the traceback and some print out the last line\u001b[39;00m\n\u001b[1;32m--> 200\u001b[0m \u001b[43mVariable\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_execution_engine\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun_backward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m  \u001b[49m\u001b[38;5;66;43;03m# Calls into the C++ engine to run the backward pass\u001b[39;49;00m\n\u001b[0;32m    201\u001b[0m \u001b[43m    \u001b[49m\u001b[43mtensors\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgrad_tensors_\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    202\u001b[0m \u001b[43m    \u001b[49m\u001b[43mallow_unreachable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maccumulate_grad\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "#!g1.1\n",
    "model = MainModel(points).to(device).double()\n",
    "epochs = 5000\n",
    "grad_clip = 3\n",
    "MSE_min = np.Inf\n",
    "MSE_accum = []\n",
    "optimizer = torch.optim.RMSprop(model.parameters(), lr=5e-4)\n",
    "\n",
    "for epoch in tqdm(range(epochs)):\n",
    "    model.train()\n",
    "    train_loss = []\n",
    "    curr_loss1 = 0\n",
    "    for iteration, data in enumerate(train_loader):\n",
    "        _, measurement, location, motion = data\n",
    "        measurement = measurement.to(device).double()\n",
    "        location = location.to(device).double()\n",
    "        motion = motion.to(device).double()\n",
    "\n",
    "        model.zero_grad()\n",
    "        loss, last_loss, pred, particle_pred = model.step(measurement, motion, location, params.bpdecay)\n",
    "        loss.backward()\n",
    "        torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)\n",
    "        optimizer.step()\n",
    "        curr_loss1 += loss.to('cpu').detach().numpy() / len(train_numbers)\n",
    "    train_loss.append(curr_loss1)\n",
    "\n",
    "    model.eval()\n",
    "    eval_loss = []\n",
    "    curr_loss2 = 0\n",
    "    y_pred = []\n",
    "    y_true = []\n",
    "    with torch.no_grad():\n",
    "        for iteration, data in enumerate(eval_loader):\n",
    "            _, measurement, location, motion = data\n",
    "            batch_size, seq_len = motion.size(0), motion.size(1)\n",
    "            measurement = measurement.to(device).double()\n",
    "            location = location.to(device).double()\n",
    "            motion = motion.to(device).double()\n",
    "            loss, last_loss, pred, particle_pred = model.step(measurement, motion, location, params.bpdecay)\n",
    "            curr_loss2 += loss.to('cpu').detach().numpy() / len(eval_numbers)\n",
    "            y_true_curr = location[:, :, :2]\n",
    "            y_pred_curr = torch.cat((pred[:, :, :1] * model.width, pred[:, :, 1:2] * model.height), dim=2)\n",
    "            y_true.extend(y_true_curr.to('cpu').numpy())\n",
    "            y_pred.extend(y_pred_curr.detach().to('cpu').numpy())\n",
    "\n",
    "    eval_loss.append(curr_loss2)\n",
    "    y_true = np.array(y_true)\n",
    "    y_pred = np.array(y_pred)\n",
    "    y_true = y_true.reshape(len(y_true), track_len, 2)\n",
    "    y_pred = y_pred.reshape(len(y_pred), track_len, 2)\n",
    "\n",
    "    X = (y_pred - y_true).reshape(y_pred.shape[0]*track_len, 2)\n",
    "    MSE = (X**2).sum(axis=-1).mean()\n",
    "    print('epoch, MSE', epoch, np.round(MSE, 3))\n",
    "    if MSE < MSE_min:\n",
    "        state = {\n",
    "            'epoch': epoch,\n",
    "            'state_dict': model.state_dict(),\n",
    "            'optimizer': optimizer.state_dict(),\n",
    "        }\n",
    "        filepath='mbpfn27.ptm'\n",
    "        torch.save(state, filepath)\n",
    "        #torch.save(state, '/content/gdrive/MyDrive/mbpfn27.ptm')\n",
    "        MSE_min = MSE\n",
    "\n",
    "    MSE_accum.append(MSE)\n",
    "\n",
    "    clear_output(True)\n",
    "    plt.figure()\n",
    "    plt.title(str(np.round(MSE_min, 2)))\n",
    "    plt.plot(MSE_accum)\n",
    "    plt.plot([MSE_min]*(len(MSE_accum)))\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "d398ad25",
   "metadata": {
    "cellId": "6gvjtmok1kwmqr1hwl0hzf",
    "id": "d398ad25",
    "outputId": "bd9d046a-4498-4686-ae27-b407ec2cdac8"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch:  327\n",
      "MSE mean and std, final point mean and std 58.873520447185086 55.00930962739599 5.532542629980847 3.6856876474648956\n"
     ]
    }
   ],
   "source": [
    "#!g1.1\n",
    "def load_checkpoint(filepath='mbpfn27.ptm'):\n",
    "    checkpoint = torch.load(filepath, map_location=torch.device('cpu'))\n",
    "    print('epoch: ', checkpoint['epoch'])\n",
    "    model.load_state_dict(checkpoint['state_dict'])\n",
    "    for parameter in model.parameters():\n",
    "        parameter.requires_grad = False\n",
    "    model.eval()\n",
    "    return model\n",
    "\n",
    "model = MainModel(points).to(device).double()\n",
    "\n",
    "model = load_checkpoint()\n",
    "\n",
    "model.eval()\n",
    "eval_loss = []\n",
    "curr_loss2 = 0\n",
    "y_pred = []\n",
    "y_true = []\n",
    "with torch.no_grad():\n",
    "    for iteration, data in enumerate(test_loader):\n",
    "        env_map, measurement, location, motion = data\n",
    "        batch_size, seq_len = motion.size(0), motion.size(1)\n",
    "        env_map = env_map.to(device).double()\n",
    "        measurement = measurement.to(device).double()\n",
    "        location = location.to(device).double()\n",
    "        motion = motion.to(device).double()\n",
    "        loss, last_loss, pred, particle_pred = model.step(measurement, motion, location, 0.1)\n",
    "        curr_loss2 += loss.to('cpu').detach().numpy() / len(eval_numbers)\n",
    "        y_true_curr = location[:, :, :2]\n",
    "        y_pred_curr = pred[:, :, :2]\n",
    "        y_pred_curr[:, :, 0] *= model.width\n",
    "        y_pred_curr[:, :, 1] *= model.height\n",
    "        y_true.extend(y_true_curr.to('cpu').numpy())\n",
    "        y_pred.extend(y_pred_curr.detach().to('cpu').numpy())\n",
    "\n",
    "eval_loss.append(curr_loss2)\n",
    "y_true = np.array(y_true)\n",
    "y_pred = np.array(y_pred)\n",
    "y_true = y_true.reshape(len(y_true), track_len, 2)\n",
    "y_pred = y_pred.reshape(len(y_pred), track_len, 2)\n",
    "\n",
    "X = (y_pred - y_true).reshape(y_pred.shape[0]*seq_len, 2)\n",
    "MSE_arr = (X**2).sum(axis=-1)\n",
    "Y = ((((y_pred - y_true)[:, -1])**2).sum(axis=-1)**0.5)\n",
    "print('MSE mean and std, final point mean and std', MSE_arr.mean(), MSE_arr.std(), Y.mean(), Y.std())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5cad5998",
   "metadata": {
    "cellId": "xw34g7w45tx4usilocup",
    "id": "5cad5998"
   },
   "outputs": [],
   "source": [
    "#!g1.1\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b810608",
   "metadata": {
    "cellId": "ukmfxbee2u6ztki76nc53",
    "id": "0b810608"
   },
   "outputs": [],
   "source": [
    "#!g1.1\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f9bee46",
   "metadata": {
    "cellId": "echs5rsk1mvvnca9mufrzg",
    "id": "6f9bee46"
   },
   "outputs": [],
   "source": [
    "#!g1.1\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8599d74",
   "metadata": {
    "cellId": "gck4f9ag4gzljk8vg7yj",
    "id": "a8599d74"
   },
   "outputs": [],
   "source": [
    "#!g1.1\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee310f7c",
   "metadata": {
    "cellId": "tecp3gzyuymggi8fojs9bl",
    "id": "ee310f7c"
   },
   "outputs": [],
   "source": [
    "#!g1.1\n"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuType": "V100",
   "machine_shape": "hm",
   "provenance": []
  },
  "gpuClass": "standard",
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  },
  "notebookId": "2eff6be8-f7ea-4468-a171-8075d4bcb638",
  "notebookPath": "Jup_main.ipynb"
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
