{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "3035bb16",
   "metadata": {
    "cellId": "9c89p320il9z9rlv6wv1pm",
    "id": "3035bb16"
   },
   "outputs": [],
   "source": [
    "#!g1.1\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from model_rnn import MainModel\n",
    "from dataset import LocalizationDataset\n",
    "from torch.utils.data import DataLoader\n",
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from ModelParams import ModelParams\n",
    "from matplotlib import pyplot as plt\n",
    "import matplotlib as mpl\n",
    "from IPython.display import clear_output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "85d4720a",
   "metadata": {
    "cellId": "5gg9z2ucnml262zy30ts37h",
    "id": "85d4720a"
   },
   "outputs": [],
   "source": [
    "#!g1.1\n",
    "device = 'cuda:0' if torch.cuda.is_available() else 'cpu'\n",
    "\n",
    "num_tracks = 10000\n",
    "track_data = np.array(pd.read_csv('trajectories.csv', header=None))\n",
    "world = np.loadtxt('environment.csv', delimiter=',')\n",
    "points = []\n",
    "for i, line in enumerate(world):\n",
    "    nb_y = world.shape[0] - i - 1\n",
    "    for j, block in enumerate(line):\n",
    "        if block == 2:\n",
    "            points.extend([[j+0.5, nb_y+0.5]])\n",
    "\n",
    "points = torch.from_numpy(np.array(points) / np.vstack((world.shape[1], world.shape[0])).reshape(1, 2)).double().to(device)\n",
    "\n",
    "all_data = dict()\n",
    "\n",
    "track_len = track_data.shape[0] // num_tracks\n",
    "track = np.zeros((num_tracks, track_len, track_data.shape[1]))\n",
    "for i in range(num_tracks):\n",
    "    track[i] = track_data[i*track_len:(i+1)*track_len]\n",
    "\n",
    "eval_test_numbers = np.random.choice(num_tracks, size=num_tracks//5, replace=False)\n",
    "eval_numbers = eval_test_numbers[:len(eval_test_numbers)//2]\n",
    "test_numbers = eval_test_numbers[len(eval_test_numbers)//2:]\n",
    "train_numbers = np.setdiff1d(np.arange(num_tracks), eval_test_numbers)\n",
    "\n",
    "train_data = dict()\n",
    "eval_data = dict()\n",
    "test_data = dict()\n",
    "\n",
    "train_data['tracks'] = track[train_numbers]\n",
    "train_data['map'] = world\n",
    "\n",
    "eval_data['tracks'] = track[eval_numbers]\n",
    "eval_data['map'] = world\n",
    "\n",
    "test_data['tracks'] = track[test_numbers]\n",
    "test_data['map'] = world\n",
    "\n",
    "\n",
    "train_dataset = LocalizationDataset(train_data)\n",
    "eval_dataset = LocalizationDataset(eval_data)\n",
    "test_dataset = LocalizationDataset(test_data)\n",
    "\n",
    "params = ModelParams()\n",
    "train_loader = DataLoader(train_dataset, batch_size=params.batch_size, pin_memory=True, shuffle=True)\n",
    "eval_loader = DataLoader(eval_dataset, batch_size=params.batch_size, pin_memory=True, shuffle=False)\n",
    "test_loader = DataLoader(test_dataset, batch_size=params.batch_size, pin_memory=True, shuffle=False)\n",
    "\n",
    "device = 'cuda:0' if torch.cuda.is_available() else 'cpu'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "7ea6dfc0",
   "metadata": {
    "cellId": "shfnuoxcnqcsqqk5ghc0r",
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 469
    },
    "id": "7ea6dfc0",
    "outputId": "35be7436-d5a4-4c1f-f807-24bf30bc7cae"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|                                                                                         | 0/5000 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "AAAAAAAAAA tensor([[0.4626, 0.4181],\n",
      "        [0.4768, 0.4167],\n",
      "        [0.5063, 0.4724],\n",
      "        ...,\n",
      "        [0.4182, 0.4201],\n",
      "        [0.4914, 0.4614],\n",
      "        [0.4432, 0.3984]], dtype=torch.float64, grad_fn=<SliceBackward0>)\n",
      "BBBBBBBBBB tensor([[3.2518, 3.6824, 3.6945, 4.1798, 4.2116],\n",
      "        [3.2487, 3.6715, 3.6879, 4.1665, 4.2096],\n",
      "        [3.1865, 3.6157, 3.6251, 4.1238, 4.1484],\n",
      "        ...,\n",
      "        [3.2644, 3.7175, 3.7180, 4.2209, 4.2222],\n",
      "        [3.2016, 3.6342, 3.6425, 4.1409, 4.1626],\n",
      "        [3.2767, 3.7095, 3.7216, 4.2037, 4.2355]], dtype=torch.float64,\n",
      "       grad_fn=<TopkBackward0>)\n",
      "AAAAAAAAAA tensor([[0.3766, 0.4526],\n",
      "        [0.3903, 0.4770],\n",
      "        [0.3580, 0.5384],\n",
      "        ...,\n",
      "        [0.4338, 0.4541],\n",
      "        [0.4435, 0.4230],\n",
      "        [0.5465, 0.4348]], dtype=torch.float64, grad_fn=<SliceBackward0>)\n",
      "BBBBBBBBBB tensor([[3.2479, 3.7143, 3.7347, 4.2005, 4.2544],\n",
      "        [3.2202, 3.6864, 3.7098, 4.1732, 4.2351],\n",
      "        [3.1741, 3.6550, 3.7040, 4.1229, 4.2521],\n",
      "        ...,\n",
      "        [3.2271, 3.6805, 3.6860, 4.1840, 4.1985],\n",
      "        [3.2533, 3.6954, 3.7009, 4.1970, 4.2116],\n",
      "        [3.2101, 3.6040, 3.6348, 4.0945, 4.1756]], dtype=torch.float64,\n",
      "       grad_fn=<TopkBackward0>)\n",
      "AAAAAAAAAA tensor([[0.4328, 0.4522],\n",
      "        [0.4659, 0.4635],\n",
      "        [0.5162, 0.4535],\n",
      "        ...,\n",
      "        [0.4141, 0.4848],\n",
      "        [0.4907, 0.4409],\n",
      "        [0.4890, 0.4721]], dtype=torch.float64, grad_fn=<SliceBackward0>)\n",
      "BBBBBBBBBB tensor([[3.2292, 3.6827, 3.6880, 4.1861, 4.2000],\n",
      "        [3.2077, 3.6542, 3.6548, 4.1651, 4.1668],\n",
      "        [3.2014, 3.6182, 3.6354, 4.1190, 4.1643],\n",
      "        ...,\n",
      "        [3.2047, 3.6663, 3.6856, 4.1594, 4.2101],\n",
      "        [3.2214, 3.6464, 3.6600, 4.1469, 4.1828],\n",
      "        [3.1923, 3.6303, 3.6349, 4.1407, 4.1529]], dtype=torch.float64,\n",
      "       grad_fn=<TopkBackward0>)\n",
      "AAAAAAAAAA tensor([[0.4826, 0.5077],\n",
      "        [0.4448, 0.4409],\n",
      "        [0.4923, 0.5703],\n",
      "        ...,\n",
      "        [0.4377, 0.4633],\n",
      "        [0.4250, 0.5195],\n",
      "        [0.4347, 0.4780]], dtype=torch.float64, grad_fn=<SliceBackward0>)\n",
      "BBBBBBBBBB tensor([[3.1605, 3.6089, 3.6158, 4.1199, 4.1382],\n",
      "        [3.2359, 3.6843, 3.6853, 4.1912, 4.1940],\n",
      "        [3.0982, 3.5517, 3.5736, 4.0569, 4.1142],\n",
      "        ...,\n",
      "        [3.2171, 3.6708, 3.6777, 4.1741, 4.1925],\n",
      "        [3.1684, 3.6316, 3.6576, 4.1231, 4.1913],\n",
      "        [3.2043, 3.6604, 3.6722, 4.1607, 4.1918]], dtype=torch.float64,\n",
      "       grad_fn=<TopkBackward0>)\n",
      "AAAAAAAAAA tensor([[0.4067, 0.5025],\n",
      "        [0.4534, 0.4719],\n",
      "        [0.4351, 0.4909],\n",
      "        ...,\n",
      "        [0.4685, 0.4225],\n",
      "        [0.4167, 0.5304],\n",
      "        [0.4500, 0.4320]], dtype=torch.float64, grad_fn=<SliceBackward0>)\n",
      "BBBBBBBBBB tensor([[3.1907, 3.6561, 3.6822, 4.1444, 4.2131],\n",
      "        [3.2039, 3.6548, 3.6599, 4.1618, 4.1752],\n",
      "        [3.1919, 3.6494, 3.6647, 4.1481, 4.1883],\n",
      "        ...,\n",
      "        [3.2458, 3.6751, 3.6876, 4.1730, 4.2059],\n",
      "        [3.1610, 3.6275, 3.6587, 4.1147, 4.1968],\n",
      "        [3.2427, 3.6850, 3.6899, 4.1885, 4.2014]], dtype=torch.float64,\n",
      "       grad_fn=<TopkBackward0>)\n",
      "AAAAAAAAAA tensor([[0.4304, 0.3643],\n",
      "        [0.4171, 0.4168],\n",
      "        [0.4996, 0.5326],\n",
      "        ...,\n",
      "        [0.4150, 0.4507],\n",
      "        [0.4333, 0.5212],\n",
      "        [0.4480, 0.4496]], dtype=torch.float64, grad_fn=<SliceBackward0>)\n",
      "BBBBBBBBBB tensor([[3.3131, 3.7395, 3.7571, 4.2251, 4.2718],\n",
      "        [3.2678, 3.7208, 3.7208, 4.2242, 4.2243],\n",
      "        [3.1315, 3.5787, 3.5879, 4.0916, 4.1157],\n",
      "        ...,\n",
      "        [3.2365, 3.6939, 3.7036, 4.1921, 4.2176],\n",
      "        [3.1640, 3.6255, 3.6497, 4.1193, 4.1828],\n",
      "        [3.2267, 3.6763, 3.6768, 4.1847, 4.1859]], dtype=torch.float64,\n",
      "       grad_fn=<TopkBackward0>)\n",
      "AAAAAAAAAA tensor([[0.4696, 0.3495],\n",
      "        [0.4635, 0.5021],\n",
      "        [0.4189, 0.4771],\n",
      "        ...,\n",
      "        [0.4401, 0.5386],\n",
      "        [0.4259, 0.4962],\n",
      "        [0.4728, 0.5345]], dtype=torch.float64, grad_fn=<SliceBackward0>)\n",
      "BBBBBBBBBB tensor([[3.3147, 3.7159, 3.7481, 4.1914, 4.2765],\n",
      "        [3.1720, 3.6241, 3.6348, 4.1300, 4.1580],\n",
      "        [3.2104, 3.6699, 3.6858, 4.1656, 4.2073],\n",
      "        ...,\n",
      "        [3.1453, 3.6073, 3.6345, 4.1007, 4.1722],\n",
      "        [3.1901, 3.6503, 3.6695, 4.1454, 4.1959],\n",
      "        [3.1384, 3.5922, 3.6093, 4.0964, 4.1413]], dtype=torch.float64,\n",
      "       grad_fn=<TopkBackward0>)\n",
      "AAAAAAAAAA tensor([[0.4197, 0.4418],\n",
      "        [0.4123, 0.5180],\n",
      "        [0.4009, 0.5600],\n",
      "        ...,\n",
      "        [0.4713, 0.4609],\n",
      "        [0.4300, 0.4414],\n",
      "        [0.4443, 0.4711]], dtype=torch.float64, grad_fn=<SliceBackward0>)\n",
      "BBBBBBBBBB tensor([[3.2434, 3.6987, 3.7047, 4.1995, 4.2153],\n",
      "        [3.1741, 3.6401, 3.6690, 4.1278, 4.2040],\n",
      "        [3.1387, 3.6124, 3.6562, 4.0904, 4.2055],\n",
      "        ...,\n",
      "        [3.2085, 3.6511, 3.6540, 4.1605, 4.1680],\n",
      "        [3.2403, 3.6932, 3.6963, 4.1972, 4.2054],\n",
      "        [3.2076, 3.6606, 3.6680, 4.1649, 4.1842]], dtype=torch.float64,\n",
      "       grad_fn=<TopkBackward0>)\n",
      "AAAAAAAAAA tensor([[0.5240, 0.5014],\n",
      "        [0.5097, 0.5441],\n",
      "        [0.4758, 0.4417],\n",
      "        ...,\n",
      "        [0.4155, 0.4946],\n",
      "        [0.4201, 0.4828],\n",
      "        [0.4496, 0.4796]], dtype=torch.float64, grad_fn=<SliceBackward0>)\n",
      "BBBBBBBBBB tensor([[3.1534, 3.5848, 3.5911, 4.0995, 4.1160],\n",
      "        [3.1174, 3.5635, 3.5731, 4.0780, 4.1032],\n",
      "        [3.2253, 3.6582, 3.6675, 4.1610, 4.1856],\n",
      "        ...,\n",
      "        [3.1951, 3.6575, 3.6791, 4.1496, 4.2064],\n",
      "        [3.2046, 3.6646, 3.6817, 4.1598, 4.2048],\n",
      "        [3.1979, 3.6506, 3.6588, 4.1554, 4.1769]], dtype=torch.float64,\n",
      "       grad_fn=<TopkBackward0>)\n",
      "AAAAAAAAAA tensor([[0.4330, 0.4876],\n",
      "        [0.4516, 0.4007],\n",
      "        [0.3921, 0.4897],\n",
      "        ...,\n",
      "        [0.3858, 0.5263],\n",
      "        [0.4432, 0.5204],\n",
      "        [0.4676, 0.5478]], dtype=torch.float64, grad_fn=<SliceBackward0>)\n",
      "BBBBBBBBBB tensor([[3.1958, 3.6533, 3.6682, 4.1518, 4.1911],\n",
      "        [3.2718, 3.7013, 3.7151, 4.1950, 4.2312],\n",
      "        [3.2077, 3.6749, 3.7014, 4.1605, 4.2303],\n",
      "        ...,\n",
      "        [3.1755, 3.6486, 3.6869, 4.1269, 4.2278],\n",
      "        [3.1614, 3.6205, 3.6418, 4.1175, 4.1734],\n",
      "        [3.1275, 3.5842, 3.6065, 4.0849, 4.1433]], dtype=torch.float64,\n",
      "       grad_fn=<TopkBackward0>)\n",
      "AAAAAAAAAA tensor([[0.5130, 0.4507],\n",
      "        [0.5525, 0.4788],\n",
      "        [0.4621, 0.4791],\n",
      "        ...,\n",
      "        [0.4243, 0.5015],\n",
      "        [0.4984, 0.4944],\n",
      "        [0.5014, 0.4707]], dtype=torch.float64, grad_fn=<SliceBackward0>)\n",
      "BBBBBBBBBB tensor([[3.2050, 3.6224, 3.6396, 4.1228, 4.1678],\n",
      "        [3.1663, 3.5739, 3.5945, 4.0774, 4.1313],\n",
      "        [3.1942, 3.6440, 3.6487, 4.1527, 4.1650],\n",
      "        ...,\n",
      "        [3.1857, 3.6469, 3.6680, 4.1407, 4.1963],\n",
      "        [3.1681, 3.6100, 3.6111, 4.1260, 4.1289],\n",
      "        [3.1897, 3.6207, 3.6292, 4.1290, 4.1512]], dtype=torch.float64,\n",
      "       grad_fn=<TopkBackward0>)\n",
      "AAAAAAAAAA tensor([[0.5773, 0.4240],\n",
      "        [0.5563, 0.5019],\n",
      "        [0.5764, 0.5287],\n",
      "        ...,\n",
      "        [0.6079, 0.5059],\n",
      "        [0.5529, 0.5531],\n",
      "        [0.5570, 0.4920]], dtype=torch.float64, grad_fn=<SliceBackward0>)\n",
      "BBBBBBBBBB tensor([[3.2115, 3.5850, 3.6275, 4.0676, 4.1792],\n",
      "        [3.1431, 3.5578, 3.5730, 4.0681, 4.1080],\n",
      "        [3.1115, 3.5261, 3.5396, 4.0420, 4.0773],\n",
      "        ...,\n",
      "        [3.1242, 3.5129, 3.5419, 4.0171, 4.0926],\n",
      "        [3.0953, 3.5320, 3.5321, 4.0589, 4.0590],\n",
      "        [3.1524, 3.5628, 3.5810, 4.0699, 4.1175]], dtype=torch.float64,\n",
      "       grad_fn=<TopkBackward0>)\n",
      "AAAAAAAAAA tensor([[0.5703, 0.4971],\n",
      "        [0.5970, 0.5329],\n",
      "        [0.5960, 0.5523],\n",
      "        ...,\n",
      "        [0.6416, 0.6134],\n",
      "        [0.6316, 0.5294],\n",
      "        [0.6026, 0.4943]], dtype=torch.float64, grad_fn=<SliceBackward0>)\n",
      "BBBBBBBBBB tensor([[3.1435, 3.5489, 3.5694, 4.0556, 4.1094],\n",
      "        [3.1014, 3.5067, 3.5249, 4.0210, 4.0685],\n",
      "        [3.0832, 3.4967, 3.5092, 4.0174, 4.0499],\n",
      "        ...,\n",
      "        [3.0115, 3.4249, 3.4331, 3.9590, 3.9803],\n",
      "        [3.0950, 3.4801, 3.5094, 3.9883, 4.0645],\n",
      "        [3.1368, 3.5239, 3.5545, 4.0251, 4.1050]], dtype=torch.float64,\n",
      "       grad_fn=<TopkBackward0>)\n",
      "AAAAAAAAAA tensor([[0.6392, 0.6142],\n",
      "        [0.5838, 0.5883],\n",
      "        [0.5768, 0.4867],\n",
      "        ...,\n",
      "        [0.6196, 0.5547],\n",
      "        [0.6503, 0.5047],\n",
      "        [0.6685, 0.5303]], dtype=torch.float64, grad_fn=<SliceBackward0>)\n",
      "BBBBBBBBBB tensor([[3.0114, 3.4264, 3.4337, 3.9611, 3.9800],\n",
      "        [3.0524, 3.4857, 3.4870, 4.0175, 4.0210],\n",
      "        [3.1515, 3.5494, 3.5747, 4.0519, 4.1181],\n",
      "        ...,\n",
      "        [3.0741, 3.4758, 3.4944, 3.9939, 4.0424],\n",
      "        [3.1135, 3.4788, 3.5204, 3.9763, 4.0847],\n",
      "        [3.0839, 3.4492, 3.4890, 3.9523, 4.0558]], dtype=torch.float64,\n",
      "       grad_fn=<TopkBackward0>)\n",
      "AAAAAAAAAA tensor([[0.6598, 0.6294],\n",
      "        [0.5454, 0.6235],\n",
      "        [0.6442, 0.5189],\n",
      "        ...,\n",
      "        [0.6704, 0.5971],\n",
      "        [0.5955, 0.5370],\n",
      "        [0.6853, 0.6538]], dtype=torch.float64, grad_fn=<SliceBackward0>)\n",
      "BBBBBBBBBB tensor([[2.9910, 3.4008, 3.4098, 3.9376, 3.9607],\n",
      "        [3.0308, 3.4778, 3.5002, 3.9923, 4.0506],\n",
      "        [3.1015, 3.4757, 3.5116, 3.9786, 4.0721],\n",
      "        ...,\n",
      "        [3.0191, 3.4100, 3.4314, 3.9346, 3.9901],\n",
      "        [3.0980, 3.5057, 3.5223, 4.0215, 4.0649],\n",
      "        [2.9604, 3.3661, 3.3754, 3.9074, 3.9315]], dtype=torch.float64,\n",
      "       grad_fn=<TopkBackward0>)\n",
      "AAAAAAAAAA tensor([[0.6143, 0.5670],\n",
      "        [0.6741, 0.6327],\n",
      "        [0.7391, 0.4715],\n",
      "        ...,\n",
      "        [0.6534, 0.6498],\n",
      "        [0.6174, 0.6529],\n",
      "        [0.6302, 0.5615]], dtype=torch.float64, grad_fn=<SliceBackward0>)\n",
      "BBBBBBBBBB tensor([[3.0638, 3.4733, 3.4868, 3.9961, 4.0315],\n",
      "        [2.9839, 3.3871, 3.3993, 3.9229, 3.9545],\n",
      "        [3.1226, 3.4260, 3.5032, 3.8990, 4.0981],\n",
      "        ...,\n",
      "        [2.9732, 3.3950, 3.3961, 3.9394, 3.9422],\n",
      "        [2.9808, 3.4132, 3.4236, 3.9471, 3.9739],\n",
      "        [3.0645, 3.4631, 3.4829, 3.9819, 4.0334]], dtype=torch.float64,\n",
      "       grad_fn=<TopkBackward0>)\n",
      "AAAAAAAAAA tensor([[0.6286, 0.5602],\n",
      "        [0.6268, 0.5715],\n",
      "        [0.6391, 0.6090],\n",
      "        ...,\n",
      "        [0.5798, 0.4979],\n",
      "        [0.6258, 0.5660],\n",
      "        [0.5834, 0.5785]], dtype=torch.float64, grad_fn=<SliceBackward0>)\n",
      "BBBBBBBBBB tensor([[3.0663, 3.4652, 3.4849, 3.9838, 4.0350],\n",
      "        [3.0559, 3.4604, 3.4763, 3.9830, 4.0243],\n",
      "        [3.0165, 3.4294, 3.4382, 3.9624, 3.9851],\n",
      "        ...,\n",
      "        [3.1399, 3.5406, 3.5636, 4.0462, 4.1065],\n",
      "        [3.0615, 3.4643, 3.4815, 3.9852, 4.0300],\n",
      "        [3.0619, 3.4927, 3.4941, 4.0235, 4.0272]], dtype=torch.float64,\n",
      "       grad_fn=<TopkBackward0>)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "AAAAAAAAAA tensor([[0.6686, 0.5376],\n",
      "        [0.5854, 0.6025],\n",
      "        [0.6738, 0.6098],\n",
      "        ...,\n",
      "        [0.6015, 0.5884],\n",
      "        [0.5711, 0.5797],\n",
      "        [0.6394, 0.5573]], dtype=torch.float64, grad_fn=<SliceBackward0>)\n",
      "BBBBBBBBBB tensor([[3.0769, 3.4450, 3.4828, 3.9504, 4.0487],\n",
      "        [3.0384, 3.4729, 3.4778, 4.0033, 4.0161],\n",
      "        [3.0059, 3.4000, 3.4188, 3.9284, 3.9769],\n",
      "        ...,\n",
      "        [3.0471, 3.4721, 3.4759, 4.0037, 4.0135],\n",
      "        [3.0645, 3.4999, 3.5023, 4.0289, 4.0353],\n",
      "        [3.0660, 3.4579, 3.4816, 3.9741, 4.0355]], dtype=torch.float64,\n",
      "       grad_fn=<TopkBackward0>)\n",
      "AAAAAAAAAA tensor([[0.6259, 0.5916],\n",
      "        [0.6237, 0.6230],\n",
      "        [0.6039, 0.5282],\n",
      "        ...,\n",
      "        [0.6251, 0.5579],\n",
      "        [0.6074, 0.5618],\n",
      "        [0.7212, 0.5551]], dtype=torch.float64, grad_fn=<SliceBackward0>)\n",
      "BBBBBBBBBB tensor([[3.0370, 3.4500, 3.4599, 3.9792, 4.0050],\n",
      "        [3.0075, 3.4346, 3.4348, 3.9743, 3.9748],\n",
      "        [3.1040, 3.5036, 3.5252, 4.0155, 4.0716],\n",
      "        ...,\n",
      "        [3.0694, 3.4694, 3.4887, 3.9878, 4.0380],\n",
      "        [3.0708, 3.4819, 3.4950, 4.0041, 4.0381],\n",
      "        [3.0462, 3.3918, 3.4405, 3.8951, 4.0211]], dtype=torch.float64,\n",
      "       grad_fn=<TopkBackward0>)\n",
      "AAAAAAAAAA tensor([[0.6339, 0.6511],\n",
      "        [0.5886, 0.5097],\n",
      "        [0.6145, 0.6129],\n",
      "        ...,\n",
      "        [0.5854, 0.5234],\n",
      "        [0.6340, 0.5591],\n",
      "        [0.6661, 0.6040]], dtype=torch.float64, grad_fn=<SliceBackward0>)\n",
      "BBBBBBBBBB tensor([[2.9776, 3.4057, 3.4107, 3.9451, 3.9582],\n",
      "        [3.1261, 3.5267, 3.5490, 4.0348, 4.0931],\n",
      "        [3.0198, 3.4478, 3.4482, 3.9855, 3.9867],\n",
      "        ...,\n",
      "        [3.1139, 3.5216, 3.5392, 4.0346, 4.0804],\n",
      "        [3.0657, 3.4614, 3.4829, 3.9789, 4.0349],\n",
      "        [3.0136, 3.4096, 3.4278, 3.9372, 3.9842]], dtype=torch.float64,\n",
      "       grad_fn=<TopkBackward0>)\n",
      "AAAAAAAAAA tensor([[0.4908, 0.3432],\n",
      "        [0.5555, 0.4505],\n",
      "        [0.5232, 0.4459],\n",
      "        ...,\n",
      "        [0.4755, 0.4736],\n",
      "        [0.4886, 0.4052],\n",
      "        [0.5062, 0.4257]], dtype=torch.float64, grad_fn=<SliceBackward0>)\n",
      "BBBBBBBBBB tensor([[3.3142, 3.7023, 3.7420, 4.1728, 4.2776],\n",
      "        [3.1924, 3.5875, 3.6167, 4.0817, 4.1582],\n",
      "        [3.2065, 3.6167, 3.6381, 4.1141, 4.1702],\n",
      "        ...,\n",
      "        [3.1951, 3.6406, 3.6411, 4.1533, 4.1547],\n",
      "        [3.2559, 3.6683, 3.6910, 4.1581, 4.2178],\n",
      "        [3.2310, 3.6422, 3.6643, 4.1358, 4.1938]], dtype=torch.float64,\n",
      "       grad_fn=<TopkBackward0>)\n",
      "AAAAAAAAAA tensor([[0.5971, 0.4144],\n",
      "        [0.5408, 0.4286],\n",
      "        [0.5224, 0.4752],\n",
      "        ...,\n",
      "        [0.4922, 0.4716],\n",
      "        [0.5627, 0.4145],\n",
      "        [0.4916, 0.4395]], dtype=torch.float64, grad_fn=<SliceBackward0>)\n",
      "BBBBBBBBBB tensor([[3.2150, 3.5744, 3.6252, 4.0510, 4.1842],\n",
      "        [3.2177, 3.6122, 3.6431, 4.1016, 4.1829],\n",
      "        [3.1789, 3.6008, 3.6139, 4.1075, 4.1418],\n",
      "        ...,\n",
      "        [3.1916, 3.6278, 3.6335, 4.1376, 4.1525],\n",
      "        [3.2247, 3.6024, 3.6433, 4.0842, 4.1916],\n",
      "        [3.2223, 3.6463, 3.6606, 4.1463, 4.1838]], dtype=torch.float64,\n",
      "       grad_fn=<TopkBackward0>)\n",
      "AAAAAAAAAA tensor([[0.6055, 0.3974],\n",
      "        [0.5686, 0.4558],\n",
      "        [0.4926, 0.4868],\n",
      "        ...,\n",
      "        [0.5064, 0.3891],\n",
      "        [0.4874, 0.4610],\n",
      "        [0.5065, 0.5052]], dtype=torch.float64, grad_fn=<SliceBackward0>)\n",
      "BBBBBBBBBB tensor([[3.2290, 3.5776, 3.6353, 4.0476, 4.1990],\n",
      "        [3.1835, 3.5738, 3.6053, 4.0678, 4.1501],\n",
      "        [3.1772, 3.6191, 3.6207, 4.1335, 4.1377],\n",
      "        ...,\n",
      "        [3.2657, 3.6630, 3.6949, 4.1452, 4.2293],\n",
      "        [3.2032, 3.6377, 3.6450, 4.1449, 4.1640],\n",
      "        [3.1553, 3.5972, 3.5976, 4.1155, 4.1164]], dtype=torch.float64,\n",
      "       grad_fn=<TopkBackward0>)\n",
      "AAAAAAAAAA tensor([[0.4848, 0.4637],\n",
      "        [0.5395, 0.5073],\n",
      "        [0.5661, 0.4085],\n",
      "        ...,\n",
      "        [0.5481, 0.4702],\n",
      "        [0.4895, 0.4652],\n",
      "        [0.5032, 0.4438]], dtype=torch.float64, grad_fn=<SliceBackward0>)\n",
      "BBBBBBBBBB tensor([[3.2015, 3.6384, 3.6442, 4.1468, 4.1620],\n",
      "        [3.1430, 3.5686, 3.5776, 4.0830, 4.1066],\n",
      "        [3.2294, 3.6030, 3.6465, 4.0825, 4.1967],\n",
      "        ...,\n",
      "        [3.1759, 3.5825, 3.6042, 4.0839, 4.1407],\n",
      "        [3.1986, 3.6336, 3.6403, 4.1418, 4.1594],\n",
      "        [3.2147, 3.6344, 3.6507, 4.1340, 4.1769]], dtype=torch.float64,\n",
      "       grad_fn=<TopkBackward0>)\n",
      "AAAAAAAAAA tensor([[0.5240, 0.3854],\n",
      "        [0.5575, 0.4737],\n",
      "        [0.6228, 0.4649],\n",
      "        ...,\n",
      "        [0.5536, 0.4818],\n",
      "        [0.4972, 0.4360],\n",
      "        [0.5842, 0.4612]], dtype=torch.float64, grad_fn=<SliceBackward0>)\n",
      "BBBBBBBBBB tensor([[3.2640, 3.6508, 3.6886, 4.1293, 4.2288],\n",
      "        [3.1697, 3.5728, 3.5961, 4.0739, 4.1352],\n",
      "        [3.1593, 3.5242, 3.5687, 4.0130, 4.1294],\n",
      "        ...,\n",
      "        [3.1631, 3.5713, 3.5914, 4.0757, 4.1281],\n",
      "        [3.2240, 3.6438, 3.6606, 4.1419, 4.1859],\n",
      "        [3.1738, 3.5579, 3.5923, 4.0512, 4.1413]], dtype=torch.float64,\n",
      "       grad_fn=<TopkBackward0>)\n",
      "AAAAAAAAAA tensor([[0.5965, 0.4714],\n",
      "        [0.5729, 0.3637],\n",
      "        [0.5958, 0.5050],\n",
      "        ...,\n",
      "        [0.5473, 0.4633],\n",
      "        [0.5569, 0.4006],\n",
      "        [0.6139, 0.4789]], dtype=torch.float64, grad_fn=<SliceBackward0>)\n",
      "BBBBBBBBBB tensor([[3.1605, 3.5420, 3.5771, 4.0368, 4.1287],\n",
      "        [3.2705, 3.6238, 3.6811, 4.0882, 4.2390],\n",
      "        [3.1286, 3.5234, 3.5491, 4.0290, 4.0961],\n",
      "        ...,\n",
      "        [3.1826, 3.5870, 3.6103, 4.0863, 4.1476],\n",
      "        [3.2397, 3.6152, 3.6581, 4.0935, 4.2065],\n",
      "        [3.1484, 3.5235, 3.5616, 4.0180, 4.1176]], dtype=torch.float64,\n",
      "       grad_fn=<TopkBackward0>)\n",
      "AAAAAAAAAA tensor([[0.5511, 0.4751],\n",
      "        [0.5516, 0.4279],\n",
      "        [0.5330, 0.5254],\n",
      "        ...,\n",
      "        [0.5657, 0.4223],\n",
      "        [0.4469, 0.4893],\n",
      "        [0.5432, 0.4786]], dtype=torch.float64, grad_fn=<SliceBackward0>)\n",
      "BBBBBBBBBB tensor([[3.1703, 3.5772, 3.5984, 4.0798, 4.1353],\n",
      "        [3.2151, 3.6037, 3.6378, 4.0913, 4.1810],\n",
      "        [3.1278, 3.5640, 3.5661, 4.0850, 4.0905],\n",
      "        ...,\n",
      "        [3.2164, 3.5954, 3.6351, 4.0793, 4.1834],\n",
      "        [3.1896, 3.6442, 3.6558, 4.1467, 4.1772],\n",
      "        [3.1693, 3.5817, 3.5997, 4.0865, 4.1336]], dtype=torch.float64,\n",
      "       grad_fn=<TopkBackward0>)\n",
      "AAAAAAAAAA tensor([[0.5233, 0.5307],\n",
      "        [0.5324, 0.3884],\n",
      "        [0.5466, 0.5110],\n",
      "        ...,\n",
      "        [0.5101, 0.3901],\n",
      "        [0.5288, 0.4036],\n",
      "        [0.5357, 0.4592]], dtype=torch.float64, grad_fn=<SliceBackward0>)\n",
      "BBBBBBBBBB tensor([[3.1259, 3.5671, 3.5692, 4.0877, 4.0932],\n",
      "        [3.2585, 3.6422, 3.6815, 4.1204, 4.2239],\n",
      "        [3.1374, 3.5607, 3.5707, 4.0753, 4.1014],\n",
      "        ...,\n",
      "        [3.2636, 3.6594, 3.6920, 4.1414, 4.2274],\n",
      "        [3.2451, 3.6363, 3.6706, 4.1198, 4.2099],\n",
      "        [3.1900, 3.5988, 3.6200, 4.0986, 4.1542]], dtype=torch.float64,\n",
      "       grad_fn=<TopkBackward0>)\n",
      "AAAAAAAAAA tensor([[0.5774, 0.3826],\n",
      "        [0.5604, 0.3891],\n",
      "        [0.5268, 0.3645],\n",
      "        ...,\n",
      "        [0.5246, 0.4823],\n",
      "        [0.5327, 0.4504],\n",
      "        [0.5477, 0.4203]], dtype=torch.float64, grad_fn=<SliceBackward0>)\n",
      "BBBBBBBBBB tensor([[3.2510, 3.6090, 3.6626, 4.0786, 4.2195],\n",
      "        [3.2497, 3.6190, 3.6660, 4.0932, 4.2169],\n",
      "        [3.2831, 3.6607, 3.7047, 4.1323, 4.2485],\n",
      "        ...,\n",
      "        [3.1714, 3.5950, 3.6068, 4.1036, 4.1344],\n",
      "        [3.1993, 3.6063, 3.6291, 4.1038, 4.1635],\n",
      "        [3.2236, 3.6113, 3.6464, 4.0971, 4.1894]], dtype=torch.float64,\n",
      "       grad_fn=<TopkBackward0>)\n",
      "AAAAAAAAAA tensor([[0.5792, 0.4105],\n",
      "        [0.5311, 0.4613],\n",
      "        [0.5479, 0.3951],\n",
      "        ...,\n",
      "        [0.4901, 0.4730],\n",
      "        [0.4924, 0.4664],\n",
      "        [0.5324, 0.3325]], dtype=torch.float64, grad_fn=<SliceBackward0>)\n",
      "BBBBBBBBBB tensor([[3.2238, 3.5913, 3.6379, 4.0694, 4.1919],\n",
      "        [3.1894, 3.6015, 3.6208, 4.1026, 4.1533],\n",
      "        [3.2476, 3.6257, 3.6676, 4.1037, 4.2139],\n",
      "        ...,\n",
      "        [3.1910, 3.6288, 3.6335, 4.1393, 4.1517],\n",
      "        [3.1965, 3.6306, 3.6377, 4.1387, 4.1575],\n",
      "        [3.3120, 3.6748, 3.7289, 4.1358, 4.2783]], dtype=torch.float64,\n",
      "       grad_fn=<TopkBackward0>)\n",
      "AAAAAAAAAA tensor([[0.4495, 0.5242],\n",
      "        [0.4925, 0.4977],\n",
      "        [0.5362, 0.5277],\n",
      "        ...,\n",
      "        [0.5367, 0.4349],\n",
      "        [0.4809, 0.5068],\n",
      "        [0.4889, 0.4881]], dtype=torch.float64, grad_fn=<SliceBackward0>)\n",
      "BBBBBBBBBB tensor([[3.1558, 3.6139, 3.6345, 4.1123, 4.1664],\n",
      "        [3.1668, 3.6116, 3.6131, 4.1272, 4.1309],\n",
      "        [3.1247, 3.5600, 3.5624, 4.0813, 4.0876],\n",
      "        ...,\n",
      "        [3.2129, 3.6119, 3.6400, 4.1039, 4.1777],\n",
      "        [3.1620, 3.6106, 3.6177, 4.1212, 4.1400],\n",
      "        [3.1771, 3.6214, 3.6216, 4.1368, 4.1374]], dtype=torch.float64,\n",
      "       grad_fn=<TopkBackward0>)\n",
      "AAAAAAAAAA tensor([[0.5514, 0.5284],\n",
      "        [0.5356, 0.5130],\n",
      "        [0.5561, 0.4962],\n",
      "        ...,\n",
      "        [0.5598, 0.5181],\n",
      "        [0.5687, 0.5233],\n",
      "        [0.5855, 0.4691]], dtype=torch.float64, grad_fn=<SliceBackward0>)\n",
      "BBBBBBBBBB tensor([[3.1194, 3.5471, 3.5535, 4.0664, 4.0833],\n",
      "        [3.1388, 3.5687, 3.5751, 4.0854, 4.1020],\n",
      "        [3.1486, 3.5612, 3.5780, 4.0697, 4.1136],\n",
      "        ...,\n",
      "        [3.1266, 3.5458, 3.5575, 4.0607, 4.0914],\n",
      "        [3.1190, 3.5356, 3.5484, 4.0509, 4.0843],\n",
      "        [3.1659, 3.5523, 3.5849, 4.0479, 4.1333]], dtype=torch.float64,\n",
      "       grad_fn=<TopkBackward0>)\n",
      "AAAAAAAAAA tensor([[0.5748, 0.4786],\n",
      "        [0.6439, 0.6186],\n",
      "        [0.6361, 0.5367],\n",
      "        ...,\n",
      "        [0.5513, 0.6111],\n",
      "        [0.5740, 0.4248],\n",
      "        [0.5711, 0.4507]], dtype=torch.float64, grad_fn=<SliceBackward0>)\n",
      "BBBBBBBBBB tensor([[3.1599, 3.5557, 3.5827, 4.0559, 4.1265],\n",
      "        [3.0059, 3.4201, 3.4275, 3.9556, 3.9747],\n",
      "        [3.0866, 3.4722, 3.5007, 3.9821, 4.0563],\n",
      "        ...,\n",
      "        [3.0407, 3.4847, 3.5018, 4.0030, 4.0475],\n",
      "        [3.2116, 3.5872, 3.6286, 4.0706, 4.1791],\n",
      "        [3.1876, 3.5746, 3.6081, 4.0666, 4.1544]], dtype=torch.float64,\n",
      "       grad_fn=<TopkBackward0>)\n",
      "AAAAAAAAAA tensor([[0.6298, 0.4866],\n",
      "        [0.6093, 0.5103],\n",
      "        [0.6511, 0.4975],\n",
      "        ...,\n",
      "        [0.5217, 0.5137],\n",
      "        [0.6190, 0.5118],\n",
      "        [0.5559, 0.5197]], dtype=torch.float64, grad_fn=<SliceBackward0>)\n",
      "BBBBBBBBBB tensor([[3.1365, 3.5060, 3.5466, 4.0007, 4.1067],\n",
      "        [3.1195, 3.5093, 3.5374, 4.0146, 4.0879],\n",
      "        [3.1202, 3.4823, 3.5262, 3.9774, 4.0916],\n",
      "        ...,\n",
      "        [3.1425, 3.5799, 3.5822, 4.0988, 4.1046],\n",
      "        [3.1154, 3.5004, 3.5309, 4.0048, 4.0843],\n",
      "        [3.1262, 3.5481, 3.5583, 4.0641, 4.0907]], dtype=torch.float64,\n",
      "       grad_fn=<TopkBackward0>)\n",
      "AAAAAAAAAA tensor([[0.6163, 0.6036],\n",
      "        [0.5876, 0.5808],\n",
      "        [0.6018, 0.6433],\n",
      "        ...,\n",
      "        [0.5307, 0.5735],\n",
      "        [0.6523, 0.5467],\n",
      "        [0.5973, 0.5029]], dtype=torch.float64, grad_fn=<SliceBackward0>)\n",
      "BBBBBBBBBB tensor([[3.0283, 3.4514, 3.4551, 3.9858, 3.9954],\n",
      "        [3.0585, 3.4879, 3.4899, 4.0189, 4.0240],\n",
      "        [2.9946, 3.4299, 3.4419, 3.9599, 3.9912],\n",
      "        ...,\n",
      "        [3.0828, 3.5274, 3.5395, 4.0443, 4.0760],\n",
      "        [3.0725, 3.4532, 3.4837, 3.9640, 4.0432],\n",
      "        [3.1301, 3.5234, 3.5500, 4.0281, 4.0977]], dtype=torch.float64,\n",
      "       grad_fn=<TopkBackward0>)\n",
      "AAAAAAAAAA tensor([[0.5968, 0.5337],\n",
      "        [0.5612, 0.5358],\n",
      "        [0.6183, 0.4764],\n",
      "        ...,\n",
      "        [0.5773, 0.5730],\n",
      "        [0.5607, 0.6167],\n",
      "        [0.5942, 0.5051]], dtype=torch.float64, grad_fn=<SliceBackward0>)\n",
      "BBBBBBBBBB tensor([[3.1008, 3.5064, 3.5244, 4.0211, 4.0679],\n",
      "        [3.1093, 3.5347, 3.5419, 4.0551, 4.0738],\n",
      "        [3.1495, 3.5212, 3.5613, 4.0144, 4.1191],\n",
      "        ...,\n",
      "        [3.0690, 3.5008, 3.5020, 4.0308, 4.0339],\n",
      "        [3.0324, 3.4748, 3.4909, 3.9952, 4.0371],\n",
      "        [3.1289, 3.5246, 3.5498, 4.0305, 4.0963]], dtype=torch.float64,\n",
      "       grad_fn=<TopkBackward0>)\n",
      "AAAAAAAAAA tensor([[0.5502, 0.5665],\n",
      "        [0.5549, 0.5798],\n",
      "        [0.5430, 0.4894],\n",
      "        ...,\n",
      "        [0.6297, 0.5378],\n",
      "        [0.5178, 0.5636],\n",
      "        [0.5612, 0.5329]], dtype=torch.float64, grad_fn=<SliceBackward0>)\n",
      "BBBBBBBBBB tensor([[3.0835, 3.5224, 3.5271, 4.0466, 4.0587],\n",
      "        [3.0693, 3.5087, 3.5158, 4.0325, 4.0510],\n",
      "        [3.1590, 3.5758, 3.5907, 4.0840, 4.1232],\n",
      "        ...,\n",
      "        [3.0874, 3.4769, 3.5033, 3.9881, 4.0567],\n",
      "        [3.0963, 3.5428, 3.5557, 4.0570, 4.0908],\n",
      "        [3.1121, 3.5364, 3.5444, 4.0558, 4.0767]], dtype=torch.float64,\n",
      "       grad_fn=<TopkBackward0>)\n",
      "AAAAAAAAAA tensor([[0.6090, 0.4944],\n",
      "        [0.6256, 0.6128],\n",
      "        [0.5974, 0.5020],\n",
      "        ...,\n",
      "        [0.5886, 0.6182],\n",
      "        [0.5802, 0.5240],\n",
      "        [0.5435, 0.5780]], dtype=torch.float64, grad_fn=<SliceBackward0>)\n",
      "BBBBBBBBBB tensor([[3.1349, 3.5186, 3.5510, 4.0189, 4.1035],\n",
      "        [3.0167, 3.4386, 3.4423, 3.9747, 3.9843],\n",
      "        [3.1309, 3.5238, 3.5508, 4.0282, 4.0986],\n",
      "        ...,\n",
      "        [3.0225, 3.4580, 3.4666, 3.9873, 4.0096],\n",
      "        [3.1149, 3.5256, 3.5415, 4.0395, 4.0810],\n",
      "        [3.0746, 3.5166, 3.5263, 4.0370, 4.0625]], dtype=torch.float64,\n",
      "       grad_fn=<TopkBackward0>)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "AAAAAAAAAA tensor([[0.5869, 0.5556],\n",
      "        [0.5615, 0.5739],\n",
      "        [0.6173, 0.5677],\n",
      "        ...,\n",
      "        [0.5441, 0.5017],\n",
      "        [0.6223, 0.5139],\n",
      "        [0.5564, 0.4696]], dtype=torch.float64, grad_fn=<SliceBackward0>)\n",
      "BBBBBBBBBB tensor([[3.0828, 3.5024, 3.5114, 4.0255, 4.0487],\n",
      "        [3.0730, 3.5100, 3.5135, 4.0367, 4.0459],\n",
      "        [3.0622, 3.4704, 3.4846, 3.9930, 4.0301],\n",
      "        ...,\n",
      "        [3.1470, 3.5680, 3.5799, 4.0799, 4.1110],\n",
      "        [3.1125, 3.4965, 3.5274, 4.0011, 4.0816],\n",
      "        [3.1739, 3.5759, 3.6001, 4.0760, 4.1394]], dtype=torch.float64,\n",
      "       grad_fn=<TopkBackward0>)\n",
      "AAAAAAAAAA tensor([[0.5542, 0.5270],\n",
      "        [0.6375, 0.5314],\n",
      "        [0.5518, 0.6016],\n",
      "        ...,\n",
      "        [0.5811, 0.5481],\n",
      "        [0.6491, 0.5820],\n",
      "        [0.5687, 0.5978]], dtype=torch.float64, grad_fn=<SliceBackward0>)\n",
      "BBBBBBBBBB tensor([[3.1199, 3.5455, 3.5532, 4.0640, 4.0841],\n",
      "        [3.0913, 3.4741, 3.5045, 3.9821, 4.0612],\n",
      "        [3.0495, 3.4923, 3.5065, 4.0120, 4.0491],\n",
      "        ...,\n",
      "        [3.0916, 3.5114, 3.5208, 4.0329, 4.0574],\n",
      "        [3.0395, 3.4360, 3.4555, 3.9588, 4.0093],\n",
      "        [3.0480, 3.4861, 3.4944, 4.0118, 4.0335]], dtype=torch.float64,\n",
      "       grad_fn=<TopkBackward0>)\n",
      "AAAAAAAAAA tensor([[0.5016, 0.4426],\n",
      "        [0.5174, 0.4384],\n",
      "        [0.5882, 0.4534],\n",
      "        ...,\n",
      "        [0.4996, 0.4621],\n",
      "        [0.5373, 0.4936],\n",
      "        [0.5470, 0.4464]], dtype=torch.float64, grad_fn=<SliceBackward0>)\n",
      "BBBBBBBBBB tensor([[3.2162, 3.6363, 3.6525, 4.1358, 4.1784],\n",
      "        [3.2154, 3.6257, 3.6474, 4.1216, 4.1787],\n",
      "        [3.1802, 3.5591, 3.5968, 4.0494, 4.1481],\n",
      "        ...,\n",
      "        [3.1984, 3.6270, 3.6374, 4.1328, 4.1600],\n",
      "        [3.1568, 3.5782, 3.5904, 4.0885, 4.1205],\n",
      "        [3.1989, 3.5969, 3.6247, 4.0910, 4.1641]], dtype=torch.float64,\n",
      "       grad_fn=<TopkBackward0>)\n",
      "AAAAAAAAAA tensor([[0.5426, 0.6338],\n",
      "        [0.5768, 0.5674],\n",
      "        [0.6035, 0.6594],\n",
      "        ...,\n",
      "        [0.6047, 0.5579],\n",
      "        [0.5434, 0.5353],\n",
      "        [0.6083, 0.5424]], dtype=torch.float64, grad_fn=<SliceBackward0>)\n",
      "BBBBBBBBBB tensor([[3.0218, 3.4708, 3.4970, 3.9830, 4.0511],\n",
      "        [3.0745, 3.5043, 3.5070, 4.0325, 4.0395],\n",
      "        [2.9787, 3.4155, 3.4318, 3.9438, 3.9861],\n",
      "        ...,\n",
      "        [3.0753, 3.4863, 3.4997, 4.0076, 4.0425],\n",
      "        [3.1152, 3.5498, 3.5521, 4.0725, 4.0784],\n",
      "        [3.0891, 3.4920, 3.5109, 4.0078, 4.0568]], dtype=torch.float64,\n",
      "       grad_fn=<TopkBackward0>)\n",
      "AAAAAAAAAA tensor([[0.6308, 0.6121],\n",
      "        [0.6051, 0.6201],\n",
      "        [0.6240, 0.5334],\n",
      "        ...,\n",
      "        [0.6493, 0.6078],\n",
      "        [0.6311, 0.5423],\n",
      "        [0.7393, 0.5668]], dtype=torch.float64, grad_fn=<SliceBackward0>)\n",
      "BBBBBBBBBB tensor([[3.0159, 3.4346, 3.4401, 3.9698, 3.9839],\n",
      "        [3.0157, 3.4474, 3.4518, 3.9818, 3.9931],\n",
      "        [3.0932, 3.4841, 3.5100, 3.9947, 4.0622],\n",
      "        ...,\n",
      "        [3.0147, 3.4215, 3.4337, 3.9527, 3.9841],\n",
      "        [3.0827, 3.4732, 3.4987, 3.9857, 4.0520],\n",
      "        [3.0303, 3.3703, 3.4211, 3.8748, 4.0061]], dtype=torch.float64,\n",
      "       grad_fn=<TopkBackward0>)\n",
      "AAAAAAAAAA tensor([[0.6039, 0.5826],\n",
      "        [0.6122, 0.6580],\n",
      "        [0.5933, 0.6350],\n",
      "        ...,\n",
      "        [0.6263, 0.6270],\n",
      "        [0.6639, 0.6386],\n",
      "        [0.6616, 0.6432]], dtype=torch.float64, grad_fn=<SliceBackward0>)\n",
      "BBBBBBBBBB tensor([[3.0520, 3.4733, 3.4795, 4.0027, 4.0186],\n",
      "        [2.9775, 3.4119, 3.4253, 3.9433, 3.9780],\n",
      "        [3.0050, 3.4415, 3.4536, 3.9699, 4.0013],\n",
      "        ...,\n",
      "        [3.0029, 3.4300, 3.4302, 3.9704, 3.9709],\n",
      "        [2.9810, 3.3924, 3.3998, 3.9316, 3.9508],\n",
      "        [2.9773, 3.3917, 3.3972, 3.9328, 3.9469]], dtype=torch.float64,\n",
      "       grad_fn=<TopkBackward0>)\n",
      "AAAAAAAAAA tensor([[0.5836, 0.5942],\n",
      "        [0.5433, 0.6441],\n",
      "        [0.6240, 0.6488],\n",
      "        ...,\n",
      "        [0.6469, 0.6346],\n",
      "        [0.6109, 0.5639],\n",
      "        [0.6722, 0.5644]], dtype=torch.float64, grad_fn=<SliceBackward0>)\n",
      "BBBBBBBBBB tensor([[3.0469, 3.4808, 3.4839, 4.0119, 4.0198],\n",
      "        [3.0119, 3.4620, 3.4910, 3.9729, 4.0483],\n",
      "        [2.9827, 3.4130, 3.4202, 3.9495, 3.9683],\n",
      "        ...,\n",
      "        [2.9897, 3.4088, 3.4124, 3.9491, 3.9584],\n",
      "        [3.0677, 3.4778, 3.4913, 4.0001, 4.0352],\n",
      "        [3.0501, 3.4268, 3.4581, 3.9405, 4.0217]], dtype=torch.float64,\n",
      "       grad_fn=<TopkBackward0>)\n",
      "AAAAAAAAAA tensor([[0.6152, 0.6539],\n",
      "        [0.6129, 0.6226],\n",
      "        [0.6779, 0.6845],\n",
      "        ...,\n",
      "        [0.6780, 0.6034],\n",
      "        [0.6179, 0.6000],\n",
      "        [0.6520, 0.6353]], dtype=torch.float64, grad_fn=<SliceBackward0>)\n",
      "BBBBBBBBBB tensor([[2.9805, 3.4136, 3.4249, 3.9466, 3.9758],\n",
      "        [3.0111, 3.4410, 3.4438, 3.9776, 3.9849],\n",
      "        [2.9331, 3.3537, 3.3557, 3.9031, 3.9082],\n",
      "        ...,\n",
      "        [3.0110, 3.4001, 3.4220, 3.9258, 3.9824],\n",
      "        [3.0312, 3.4520, 3.4572, 3.9851, 3.9985],\n",
      "        [2.9875, 3.4041, 3.4090, 3.9439, 3.9566]], dtype=torch.float64,\n",
      "       grad_fn=<TopkBackward0>)\n",
      "AAAAAAAAAA tensor([[0.6476, 0.5900],\n",
      "        [0.5863, 0.6134],\n",
      "        [0.6511, 0.6033],\n",
      "        ...,\n",
      "        [0.6426, 0.6633],\n",
      "        [0.6510, 0.6203],\n",
      "        [0.5440, 0.6299]], dtype=torch.float64, grad_fn=<SliceBackward0>)\n",
      "BBBBBBBBBB tensor([[3.0323, 3.4328, 3.4496, 3.9584, 4.0019],\n",
      "        [3.0277, 3.4633, 3.4711, 3.9925, 4.0129],\n",
      "        [3.0186, 3.4226, 3.4365, 3.9520, 3.9881],\n",
      "        ...,\n",
      "        [2.9634, 3.3907, 3.3968, 3.9313, 3.9471],\n",
      "        [3.0022, 3.4132, 3.4222, 3.9483, 3.9715],\n",
      "        [3.0252, 3.4733, 3.4979, 3.9865, 4.0506]], dtype=torch.float64,\n",
      "       grad_fn=<TopkBackward0>)\n",
      "AAAAAAAAAA tensor([[0.6145, 0.6062],\n",
      "        [0.5978, 0.5664],\n",
      "        [0.6718, 0.5696],\n",
      "        ...,\n",
      "        [0.6543, 0.6365],\n",
      "        [0.6171, 0.6199],\n",
      "        [0.5341, 0.5882]], dtype=torch.float64, grad_fn=<SliceBackward0>)\n",
      "BBBBBBBBBB tensor([[3.0263, 3.4515, 3.4539, 3.9870, 3.9932],\n",
      "        [3.0692, 3.4873, 3.4963, 4.0123, 4.0357],\n",
      "        [3.0452, 3.4242, 3.4539, 3.9396, 4.0167],\n",
      "        ...,\n",
      "        [2.9857, 3.4015, 3.4067, 3.9414, 3.9549],\n",
      "        [3.0124, 3.4409, 3.4417, 3.9793, 3.9814],\n",
      "        [3.0678, 3.5133, 3.5286, 4.0292, 4.0693]], dtype=torch.float64,\n",
      "       grad_fn=<TopkBackward0>)\n",
      "AAAAAAAAAA tensor([[0.6163, 0.5823],\n",
      "        [0.5913, 0.6227],\n",
      "        [0.5734, 0.6426],\n",
      "        ...,\n",
      "        [0.6172, 0.5864],\n",
      "        [0.6159, 0.5607],\n",
      "        [0.6111, 0.6488]], dtype=torch.float64, grad_fn=<SliceBackward0>)\n",
      "BBBBBBBBBB tensor([[3.0486, 3.4631, 3.4729, 3.9906, 4.0161],\n",
      "        [3.0174, 3.4528, 3.4619, 3.9824, 4.0059],\n",
      "        [3.0039, 3.4462, 3.4662, 3.9671, 4.0191],\n",
      "        ...,\n",
      "        [3.0444, 3.4601, 3.4690, 3.9888, 4.0119],\n",
      "        [3.0694, 3.4755, 3.4913, 3.9961, 4.0373],\n",
      "        [2.9866, 3.4201, 3.4311, 3.9525, 3.9810]], dtype=torch.float64,\n",
      "       grad_fn=<TopkBackward0>)\n",
      "AAAAAAAAAA tensor([[0.7199, 0.6516],\n",
      "        [0.6332, 0.5424],\n",
      "        [0.6330, 0.5637],\n",
      "        ...,\n",
      "        [0.5822, 0.5931],\n",
      "        [0.6314, 0.5851],\n",
      "        [0.6077, 0.6504]], dtype=torch.float64, grad_fn=<SliceBackward0>)\n",
      "BBBBBBBBBB tensor([[2.9533, 3.3385, 3.3589, 3.8741, 3.9267],\n",
      "        [3.0820, 3.4714, 3.4975, 3.9836, 4.0514],\n",
      "        [3.0617, 3.4597, 3.4796, 3.9788, 4.0307],\n",
      "        ...,\n",
      "        [3.0483, 3.4825, 3.4856, 4.0132, 4.0214],\n",
      "        [3.0416, 3.4490, 3.4624, 3.9753, 4.0101],\n",
      "        [2.9860, 3.4206, 3.4331, 3.9516, 3.9839]], dtype=torch.float64,\n",
      "       grad_fn=<TopkBackward0>)\n",
      "AAAAAAAAAA tensor([[0.4614, 0.4551],\n",
      "        [0.5679, 0.5557],\n",
      "        [0.5280, 0.4797],\n",
      "        ...,\n",
      "        [0.5034, 0.4795],\n",
      "        [0.5446, 0.4873],\n",
      "        [0.5308, 0.4482]], dtype=torch.float64, grad_fn=<SliceBackward0>)\n",
      "BBBBBBBBBB tensor([[3.2171, 3.6626, 3.6643, 4.1715, 4.1761],\n",
      "        [3.0883, 3.5182, 3.5217, 4.0439, 4.0530],\n",
      "        [3.1729, 3.5937, 3.6071, 4.1009, 4.1361],\n",
      "        ...,\n",
      "        [3.1807, 3.6142, 3.6208, 4.1249, 4.1422],\n",
      "        [3.1606, 3.5757, 3.5917, 4.0830, 4.1249],\n",
      "        [3.2020, 3.6092, 3.6320, 4.1062, 4.1661]], dtype=torch.float64,\n",
      "       grad_fn=<TopkBackward0>)\n",
      "AAAAAAAAAA tensor([[0.4963, 0.5336],\n",
      "        [0.5102, 0.5936],\n",
      "        [0.5154, 0.5493],\n",
      "        ...,\n",
      "        [0.5842, 0.5430],\n",
      "        [0.4674, 0.4501],\n",
      "        [0.5838, 0.5767]], dtype=torch.float64, grad_fn=<SliceBackward0>)\n",
      "BBBBBBBBBB tensor([[3.1316, 3.5797, 3.5901, 4.0914, 4.1186],\n",
      "        [3.0703, 3.5223, 3.5459, 4.0299, 4.0914],\n",
      "        [3.1106, 3.5560, 3.5655, 4.0715, 4.0964],\n",
      "        ...,\n",
      "        [3.0956, 3.5117, 3.5234, 4.0310, 4.0616],\n",
      "        [3.2200, 3.6605, 3.6652, 4.1671, 4.1795],\n",
      "        [3.0635, 3.4933, 3.4953, 4.0235, 4.0289]], dtype=torch.float64,\n",
      "       grad_fn=<TopkBackward0>)\n",
      "AAAAAAAAAA tensor([[0.5109, 0.5382],\n",
      "        [0.5081, 0.5494],\n",
      "        [0.5196, 0.5792],\n",
      "        ...,\n",
      "        [0.5768, 0.5811],\n",
      "        [0.5708, 0.5289],\n",
      "        [0.5556, 0.5403]], dtype=torch.float64, grad_fn=<SliceBackward0>)\n",
      "BBBBBBBBBB tensor([[3.1226, 3.5677, 3.5754, 4.0834, 4.1034],\n",
      "        [3.1129, 3.5600, 3.5716, 4.0732, 4.1036],\n",
      "        [3.0809, 3.5288, 3.5457, 4.0415, 4.0855],\n",
      "        ...,\n",
      "        [3.0615, 3.4956, 3.4968, 4.0262, 4.0294],\n",
      "        [3.1130, 3.5307, 3.5425, 4.0475, 4.0784],\n",
      "        [3.1067, 3.5369, 3.5412, 4.0595, 4.0707]], dtype=torch.float64,\n",
      "       grad_fn=<TopkBackward0>)\n",
      "AAAAAAAAAA tensor([[0.5142, 0.5623],\n",
      "        [0.5708, 0.5617],\n",
      "        [0.5238, 0.5644],\n",
      "        ...,\n",
      "        [0.5591, 0.6083],\n",
      "        [0.6291, 0.5549],\n",
      "        [0.5633, 0.5008]], dtype=torch.float64, grad_fn=<SliceBackward0>)\n",
      "BBBBBBBBBB tensor([[3.0988, 3.5460, 3.5595, 4.0593, 4.0947],\n",
      "        [3.0817, 3.5125, 3.5150, 4.0397, 4.0465],\n",
      "        [3.0936, 3.5388, 3.5503, 4.0548, 4.0848],\n",
      "        ...,\n",
      "        [3.0410, 3.4827, 3.4968, 4.0039, 4.0405],\n",
      "        [3.0711, 3.4678, 3.4891, 3.9846, 4.0400],\n",
      "        [3.1421, 3.5526, 3.5701, 4.0615, 4.1074]], dtype=torch.float64,\n",
      "       grad_fn=<TopkBackward0>)\n",
      "AAAAAAAAAA tensor([[0.4871, 0.6063],\n",
      "        [0.5354, 0.5369],\n",
      "        [0.5812, 0.5441],\n",
      "        ...,\n",
      "        [0.5874, 0.5004],\n",
      "        [0.5275, 0.5341],\n",
      "        [0.5878, 0.5160]], dtype=torch.float64, grad_fn=<SliceBackward0>)\n",
      "BBBBBBBBBB tensor([[3.0659, 3.5250, 3.5586, 4.0233, 4.1112],\n",
      "        [3.1161, 3.5552, 3.5557, 4.0788, 4.0799],\n",
      "        [3.0954, 3.5135, 3.5240, 4.0337, 4.0612],\n",
      "        ...,\n",
      "        [3.1353, 3.5329, 3.5575, 4.0383, 4.1024],\n",
      "        [3.1212, 3.5619, 3.5637, 4.0834, 4.0882],\n",
      "        [3.1203, 3.5238, 3.5441, 4.0340, 4.0871]], dtype=torch.float64,\n",
      "       grad_fn=<TopkBackward0>)\n",
      "AAAAAAAAAA tensor([[0.5629, 0.5671],\n",
      "        [0.5898, 0.5507],\n",
      "        [0.5944, 0.5849],\n",
      "        ...,\n",
      "        [0.5671, 0.4984],\n",
      "        [0.5949, 0.5016],\n",
      "        [0.6039, 0.4698]], dtype=torch.float64, grad_fn=<SliceBackward0>)\n",
      "BBBBBBBBBB tensor([[3.0789, 3.5148, 3.5160, 4.0430, 4.0461],\n",
      "        [3.0865, 3.5027, 3.5138, 4.0238, 4.0528],\n",
      "        [3.0525, 3.4800, 3.4827, 4.0114, 4.0185],\n",
      "        ...,\n",
      "        [3.1432, 3.5508, 3.5701, 4.0584, 4.1089],\n",
      "        [3.1320, 3.5261, 3.5524, 4.0307, 4.0996],\n",
      "        [3.1600, 3.5369, 3.5746, 4.0300, 4.1287]], dtype=torch.float64,\n",
      "       grad_fn=<TopkBackward0>)\n",
      "AAAAAAAAAA tensor([[0.4923, 0.5300],\n",
      "        [0.5579, 0.5730],\n",
      "        [0.5949, 0.5702],\n",
      "        ...,\n",
      "        [0.5714, 0.4955],\n",
      "        [0.7273, 0.5415],\n",
      "        [0.4917, 0.5947]], dtype=torch.float64, grad_fn=<SliceBackward0>)\n",
      "BBBBBBBBBB tensor([[3.1363, 3.5850, 3.5955, 4.0959, 4.1235],\n",
      "        [3.0749, 3.5127, 3.5170, 4.0384, 4.0497],\n",
      "        [3.0664, 3.4877, 3.4947, 4.0143, 4.0327],\n",
      "        ...,\n",
      "        [3.1447, 3.5489, 3.5702, 4.0550, 4.1107],\n",
      "        [3.0577, 3.3946, 3.4489, 3.8925, 4.0332],\n",
      "        [3.0753, 3.5319, 3.5609, 4.0334, 4.1092]], dtype=torch.float64,\n",
      "       grad_fn=<TopkBackward0>)\n",
      "AAAAAAAAAA tensor([[0.5225, 0.5514],\n",
      "        [0.5328, 0.5457],\n",
      "        [0.4665, 0.5898],\n",
      "        ...,\n",
      "        [0.5952, 0.5534],\n",
      "        [0.6442, 0.4651],\n",
      "        [0.5702, 0.5595]], dtype=torch.float64, grad_fn=<SliceBackward0>)\n",
      "BBBBBBBBBB tensor([[3.1065, 3.5504, 3.5585, 4.0678, 4.0891],\n",
      "        [3.1086, 3.5493, 3.5529, 4.0709, 4.0803],\n",
      "        [3.0883, 3.5503, 3.5848, 4.0445, 4.1349],\n",
      "        ...,\n",
      "        [3.0824, 3.4967, 3.5087, 4.0179, 4.0490],\n",
      "        [3.1533, 3.5067, 3.5574, 3.9923, 4.1247],\n",
      "        [3.0840, 3.5141, 3.5172, 4.0407, 4.0487]], dtype=torch.float64,\n",
      "       grad_fn=<TopkBackward0>)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "AAAAAAAAAA tensor([[0.5011, 0.5469],\n",
      "        [0.4687, 0.5649],\n",
      "        [0.5634, 0.5563],\n",
      "        ...,\n",
      "        [0.5721, 0.5722],\n",
      "        [0.5475, 0.5859],\n",
      "        [0.5797, 0.5804]], dtype=torch.float64, grad_fn=<SliceBackward0>)\n",
      "BBBBBBBBBB tensor([[3.1175, 3.5660, 3.5788, 4.0774, 4.1109],\n",
      "        [3.1110, 3.5694, 3.5963, 4.0680, 4.1383],\n",
      "        [3.0891, 3.5216, 3.5236, 4.0482, 4.0534],\n",
      "        ...,\n",
      "        [3.0713, 3.5055, 3.5056, 4.0359, 4.0360],\n",
      "        [3.0659, 3.5078, 3.5187, 4.0284, 4.0569],\n",
      "        [3.0612, 3.4945, 3.4947, 4.0262, 4.0267]], dtype=torch.float64,\n",
      "       grad_fn=<TopkBackward0>)\n",
      "AAAAAAAAAA tensor([[0.5789, 0.6463],\n",
      "        [0.5562, 0.5615],\n",
      "        [0.5157, 0.5913],\n",
      "        ...,\n",
      "        [0.6022, 0.4850],\n",
      "        [0.5927, 0.5104],\n",
      "        [0.5448, 0.5889]], dtype=torch.float64, grad_fn=<SliceBackward0>)\n",
      "BBBBBBBBBB tensor([[2.9986, 3.4401, 3.4596, 3.9622, 4.0130],\n",
      "        [3.0864, 3.5233, 3.5247, 4.0500, 4.0539],\n",
      "        [3.0707, 3.5211, 3.5425, 4.0307, 4.0866],\n",
      "        ...,\n",
      "        [3.1458, 3.5295, 3.5626, 4.0278, 4.1141],\n",
      "        [3.1242, 3.5229, 3.5462, 4.0306, 4.0914],\n",
      "        [3.0638, 3.5068, 3.5193, 4.0261, 4.0588]], dtype=torch.float64,\n",
      "       grad_fn=<TopkBackward0>)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|                                                                                         | 0/5000 [00:10<?, ?it/s]\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[1;32m~\\AppData\\Local\\Temp\\ipykernel_17264\\553613524.py\u001b[0m in \u001b[0;36m<cell line: 9>\u001b[1;34m()\u001b[0m\n\u001b[0;32m     19\u001b[0m         \u001b[0mmodel\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     20\u001b[0m         \u001b[0mloss\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mlast_loss\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mpred\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mparticle_pred\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmeasurement\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmotion\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mlocation\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbpdecay\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 21\u001b[1;33m         \u001b[0mloss\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     22\u001b[0m         \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mnn\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mclip_grad_norm_\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mparameters\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mgrad_clip\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     23\u001b[0m         \u001b[0moptimizer\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\work\\python_env\\lib\\site-packages\\torch\\_tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[1;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[0;32m    486\u001b[0m                 \u001b[0minputs\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    487\u001b[0m             )\n\u001b[1;32m--> 488\u001b[1;33m         torch.autograd.backward(\n\u001b[0m\u001b[0;32m    489\u001b[0m             \u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    490\u001b[0m         )\n",
      "\u001b[1;32mD:\\work\\python_env\\lib\\site-packages\\torch\\autograd\\__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[1;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[0;32m    195\u001b[0m     \u001b[1;31m# some Python versions print out the first line of a multi-line function\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    196\u001b[0m     \u001b[1;31m# calls in the traceback and some print out the last line\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 197\u001b[1;33m     Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass\n\u001b[0m\u001b[0;32m    198\u001b[0m         \u001b[0mtensors\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mgrad_tensors_\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    199\u001b[0m         allow_unreachable=True, accumulate_grad=True)  # Calls into the C++ engine to run the backward pass\n",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "#!g1.1\n",
    "model = MainModel(points).to(device).double()\n",
    "epochs = 5000\n",
    "grad_clip = 3\n",
    "MSE_min = np.Inf\n",
    "MSE_accum = []\n",
    "optimizer = torch.optim.RMSprop(model.parameters(), lr=5e-4)\n",
    "\n",
    "for epoch in tqdm(range(epochs)):\n",
    "    model.train()\n",
    "    train_loss = []\n",
    "    curr_loss1 = 0\n",
    "    for iteration, data in enumerate(train_loader):\n",
    "        _, measurement, location, motion = data\n",
    "        measurement = measurement.to(device).double()\n",
    "        location = location.to(device).double()\n",
    "        motion = motion.to(device).double()\n",
    "\n",
    "        model.zero_grad()\n",
    "        loss, last_loss, pred, particle_pred = model.step(measurement, motion, location, params.bpdecay)\n",
    "        loss.backward()\n",
    "        torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)\n",
    "        optimizer.step()\n",
    "        curr_loss1 += loss.to('cpu').detach().numpy() / len(train_numbers)\n",
    "    train_loss.append(curr_loss1)\n",
    "\n",
    "    model.eval()\n",
    "    eval_loss = []\n",
    "    curr_loss2 = 0\n",
    "    y_pred = []\n",
    "    y_true = []\n",
    "    with torch.no_grad():\n",
    "        for iteration, data in enumerate(eval_loader):\n",
    "            _, measurement, location, motion = data\n",
    "            batch_size, seq_len = motion.size(0), motion.size(1)\n",
    "            measurement = measurement.to(device).double()\n",
    "            location = location.to(device).double()\n",
    "            motion = motion.to(device).double()\n",
    "            loss, last_loss, pred, particle_pred = model.step(measurement, motion, location, params.bpdecay)\n",
    "            curr_loss2 += loss.to('cpu').detach().numpy() / len(eval_numbers)\n",
    "            y_true_curr = location[:, :, :2]\n",
    "            y_pred_curr = torch.cat((pred[:, :, :1] * model.width, pred[:, :, 1:2] * model.height), dim=2)\n",
    "            y_true.extend(y_true_curr.to('cpu').numpy())\n",
    "            y_pred.extend(y_pred_curr.detach().to('cpu').numpy())\n",
    "\n",
    "    eval_loss.append(curr_loss2)\n",
    "    y_true = np.array(y_true)\n",
    "    y_pred = np.array(y_pred)\n",
    "    y_true = y_true.reshape(len(y_true), track_len, 2)\n",
    "    y_pred = y_pred.reshape(len(y_pred), track_len, 2)\n",
    "\n",
    "    X = (y_pred - y_true).reshape(y_pred.shape[0]*track_len, 2)\n",
    "    MSE = (X**2).sum(axis=-1).mean()\n",
    "    print('epoch, MSE', epoch, np.round(MSE, 3))\n",
    "    if MSE < MSE_min:\n",
    "        state = {\n",
    "            'epoch': epoch,\n",
    "            'state_dict': model.state_dict(),\n",
    "            'optimizer': optimizer.state_dict(),\n",
    "        }\n",
    "        filepath='mbpfn27.ptm'\n",
    "        torch.save(state, filepath)\n",
    "        #torch.save(state, '/content/gdrive/MyDrive/mbpfn27.ptm')\n",
    "        MSE_min = MSE\n",
    "\n",
    "    MSE_accum.append(MSE)\n",
    "\n",
    "    clear_output(True)\n",
    "    plt.figure()\n",
    "    plt.title(str(np.round(MSE_min, 2)))\n",
    "    plt.plot(MSE_accum)\n",
    "    plt.plot([MSE_min]*(len(MSE_accum)))\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "d398ad25",
   "metadata": {
    "cellId": "6gvjtmok1kwmqr1hwl0hzf",
    "id": "d398ad25",
    "outputId": "bd9d046a-4498-4686-ae27-b407ec2cdac8"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch:  327\n",
      "MSE mean and std, final point mean and std 58.873520447185086 55.00930962739599 5.532542629980847 3.6856876474648956\n"
     ]
    }
   ],
   "source": [
    "#!g1.1\n",
    "def load_checkpoint(filepath='mbpfn27.ptm'):\n",
    "    checkpoint = torch.load(filepath, map_location=torch.device('cpu'))\n",
    "    print('epoch: ', checkpoint['epoch'])\n",
    "    model.load_state_dict(checkpoint['state_dict'])\n",
    "    for parameter in model.parameters():\n",
    "        parameter.requires_grad = False\n",
    "    model.eval()\n",
    "    return model\n",
    "\n",
    "model = MainModel(points).to(device).double()\n",
    "\n",
    "model = load_checkpoint()\n",
    "\n",
    "model.eval()\n",
    "eval_loss = []\n",
    "curr_loss2 = 0\n",
    "y_pred = []\n",
    "y_true = []\n",
    "with torch.no_grad():\n",
    "    for iteration, data in enumerate(test_loader):\n",
    "        env_map, measurement, location, motion = data\n",
    "        batch_size, seq_len = motion.size(0), motion.size(1)\n",
    "        env_map = env_map.to(device).double()\n",
    "        measurement = measurement.to(device).double()\n",
    "        location = location.to(device).double()\n",
    "        motion = motion.to(device).double()\n",
    "        loss, last_loss, pred, particle_pred = model.step(measurement, motion, location, 0.1)\n",
    "        curr_loss2 += loss.to('cpu').detach().numpy() / len(eval_numbers)\n",
    "        y_true_curr = location[:, :, :2]\n",
    "        y_pred_curr = pred[:, :, :2]\n",
    "        y_pred_curr[:, :, 0] *= model.width\n",
    "        y_pred_curr[:, :, 1] *= model.height\n",
    "        y_true.extend(y_true_curr.to('cpu').numpy())\n",
    "        y_pred.extend(y_pred_curr.detach().to('cpu').numpy())\n",
    "\n",
    "eval_loss.append(curr_loss2)\n",
    "y_true = np.array(y_true)\n",
    "y_pred = np.array(y_pred)\n",
    "y_true = y_true.reshape(len(y_true), track_len, 2)\n",
    "y_pred = y_pred.reshape(len(y_pred), track_len, 2)\n",
    "\n",
    "X = (y_pred - y_true).reshape(y_pred.shape[0]*seq_len, 2)\n",
    "MSE_arr = (X**2).sum(axis=-1)\n",
    "Y = ((((y_pred - y_true)[:, -1])**2).sum(axis=-1)**0.5)\n",
    "print('MSE mean and std, final point mean and std', MSE_arr.mean(), MSE_arr.std(), Y.mean(), Y.std())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5cad5998",
   "metadata": {
    "cellId": "xw34g7w45tx4usilocup",
    "id": "5cad5998"
   },
   "outputs": [],
   "source": [
    "#!g1.1\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b810608",
   "metadata": {
    "cellId": "ukmfxbee2u6ztki76nc53",
    "id": "0b810608"
   },
   "outputs": [],
   "source": [
    "#!g1.1\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f9bee46",
   "metadata": {
    "cellId": "echs5rsk1mvvnca9mufrzg",
    "id": "6f9bee46"
   },
   "outputs": [],
   "source": [
    "#!g1.1\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8599d74",
   "metadata": {
    "cellId": "gck4f9ag4gzljk8vg7yj",
    "id": "a8599d74"
   },
   "outputs": [],
   "source": [
    "#!g1.1\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee310f7c",
   "metadata": {
    "cellId": "tecp3gzyuymggi8fojs9bl",
    "id": "ee310f7c"
   },
   "outputs": [],
   "source": [
    "#!g1.1\n"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuType": "V100",
   "machine_shape": "hm",
   "provenance": []
  },
  "gpuClass": "standard",
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  },
  "notebookId": "2eff6be8-f7ea-4468-a171-8075d4bcb638",
  "notebookPath": "Jup_main.ipynb"
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
