{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchvision import transforms\n",
    "from torch.utils.data import Dataset, DataLoader, random_split\n",
    "import torch\n",
    "from torch.utils.data import Dataset\n",
    "import os\n",
    "import json\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import cv2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['../data/record\\\\Town01_100npcs_1', '../data/record\\\\Town01_200npcs_1', '../data/record\\\\Town04_100npcs_1', '../data/record\\\\Town05_100npcs_1', '../data/record\\\\Town05_100npcs_2', '../data/record\\\\Town07_100npcs_1', '../data/record\\\\Town07_100npcs_2', '../data/record\\\\Town12_1200npcs_1', '../data/record\\\\Town12_1200npcs_2', '../data/record\\\\Town12_1200npcs_3', '../data/record\\\\Town12_800npcs_1']\n"
     ]
    }
   ],
   "source": [
    "dataset_paths = []\n",
    "records_folder = \"../data/record\"\n",
    "for name in os.listdir(records_folder):\n",
    "   dir = os.path.join(records_folder, name)\n",
    "   if os.path.isdir(dir):\n",
    "      dataset_paths.append(dir)\n",
    "\n",
    "print(dataset_paths)\n",
    "\n",
    "actions = [\n",
    "    \"Void\",\n",
    "    \"Left\",\n",
    "    \"Right\",\n",
    "    \"Straight\",\n",
    "    \"LaneFollow\",\n",
    "    \"ChangeLaneLeft\",\n",
    "    \"ChangeLaneRight\",\n",
    "    \"RoadEnd\",\n",
    "    \"Other\",\n",
    "]\n",
    "actions_index_map = {a: i for i, a in enumerate(actions)}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "46653 11664\n"
     ]
    }
   ],
   "source": [
    "class BehaviorPredictionDataset(Dataset):\n",
    "    def __init__(self, dataset_paths, return_metadata=False, transform=None, sensor_range=50.0, bev_range=50.0):\n",
    "        super().__init__()\n",
    "        self.records = []\n",
    "        self.return_metadata = return_metadata\n",
    "        self.transform = transform\n",
    "        self.sensor_range = sensor_range\n",
    "        self.bev_range = bev_range\n",
    "\n",
    "        for path in dataset_paths:\n",
    "            self.read_data_folder(path)\n",
    "    \n",
    "    def read_data_folder(self, path):\n",
    "        for agent in os.listdir(os.path.join(path, 'agents')):\n",
    "            agent_folder = os.path.join(path, 'agents', agent)\n",
    "\n",
    "            agent_location = {}\n",
    "            for line in open(os.path.join(agent_folder, 'gt_location', 'data.jsonl'), 'r'):\n",
    "                record = json.loads(line)\n",
    "                agent_location[record['frame']] = {\n",
    "                    'location': record['location'],\n",
    "                    'rotation': record['rotation'],\n",
    "                }\n",
    "\n",
    "            vehicle_bbox_records = []\n",
    "            \n",
    "            for line in open(os.path.join(agent_folder, 'gt_vehicle_bbox', 'data.jsonl'), 'r'):\n",
    "                record = json.loads(line)\n",
    "                vehicle_bbox_records.append(record)\n",
    "                frame = record['frame']\n",
    "                for vehicle in record['vehicles']:\n",
    "                    distance = np.linalg.norm(np.array(agent_location[frame]['location'])-np.array(vehicle['location']))\n",
    "                    if distance > self.sensor_range:\n",
    "                        continue\n",
    "                    self.records.append({\n",
    "                        'vehicle_id': vehicle['id'],\n",
    "                        'frame': frame,\n",
    "                        'timestamp': record['timestamp'],\n",
    "                        'current_action': vehicle['current_action'],\n",
    "                        'distance': distance,\n",
    "                        'agent_location': np.array(agent_location[frame]['location']),\n",
    "                        'agent_rotation': np.array(agent_location[frame]['rotation']),\n",
    "                        'vehicle_location': np.array(vehicle['location']),\n",
    "                        'vehicle_rotation': np.array(vehicle['rotation']),\n",
    "                        'bev_image_path': os.path.join(agent_folder, 'birds_view_semantic_camera', str(frame)+'.png'),\n",
    "                        'output_path': os.path.join(agent_folder, 'pred_vehicle_current_action', 'data.jsonl'),\n",
    "                    })\n",
    "    \n",
    "    def __len__(self):\n",
    "        return len(self.records)\n",
    "    \n",
    "    def __getitem__(self, index):\n",
    "        record = self.records[index]\n",
    "        agent_image = cv2.imread(record['bev_image_path'])\n",
    "        agent_image = cv2.cvtColor(agent_image, cv2.COLOR_BGR2RGB)\n",
    "        y = actions_index_map[record['current_action']]\n",
    "        width = agent_image.shape[1]\n",
    "        height = agent_image.shape[0]\n",
    "        assert agent_image.shape[0] == agent_image.shape[1]\n",
    "        img_scale =  agent_image.shape[0] / (self.bev_range*2)\n",
    "        \n",
    "        image = agent_image\n",
    "        # Rotate the image to standard rotation\n",
    "        agent_pitch = record['agent_rotation'][2]\n",
    "        # print(agent_pitch)\n",
    "        mat = cv2.getRotationMatrix2D((width/2, height/2), -agent_pitch, 1.0)\n",
    "        image = cv2.warpAffine(src=agent_image, M=mat, dsize=(width, height))\n",
    "\n",
    "        translation = record['vehicle_location'] - record['agent_location']\n",
    "        # Move the vehicle to the center\n",
    "        translation_matrix = np.array([\n",
    "            [1, 0, -translation[1]*img_scale],\n",
    "            [0, 1, translation[0]*img_scale]\n",
    "        ], dtype=np.float32)\n",
    "\n",
    "        image = cv2.warpAffine(src=image, M=translation_matrix, dsize=(width, height))\n",
    "        \n",
    "        # print(translation[0], translation[1])\n",
    "        # print(translation[0]*img_scale, translation[1]*img_scale)\n",
    "\n",
    "        vehicle_pitch = record['vehicle_rotation'][2]\n",
    "        # print(vehicle_pitch)\n",
    "        mat = cv2.getRotationMatrix2D((width/2, height/2), vehicle_pitch, 1.0)\n",
    "        image = cv2.warpAffine(src=image, M=mat, dsize=(width, height))\n",
    "\n",
    "        # print(record['vehicle_id'])\n",
    "\n",
    "        if self.transform is not None:\n",
    "            image = self.transform(image)\n",
    "\n",
    "        if self.return_metadata:\n",
    "            return image, y, self.records[index]\n",
    "        return image, y\n",
    "\n",
    "preprocess = transforms.Compose([\n",
    "    transforms.ToPILImage(),\n",
    "    transforms.Resize(size=(224, 224)),\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n",
    "])\n",
    "\n",
    "dataset = BehaviorPredictionDataset(dataset_paths, transform=preprocess)\n",
    "\n",
    "train_size = int(len(dataset) * 0.8)\n",
    "test_size = len(dataset) - train_size\n",
    "print(train_size, test_size)\n",
    "train_dataset, test_dataset = random_split(dataset, [train_size, test_size], generator=torch.Generator().manual_seed(233))\n",
    "\n",
    "train_loader = DataLoader(\n",
    "    train_dataset,\n",
    "    batch_size=64,\n",
    "    shuffle=True,\n",
    "    num_workers=0,\n",
    ")\n",
    "\n",
    "test_loader = DataLoader(\n",
    "    test_dataset,\n",
    "    batch_size=64,\n",
    "    shuffle=False,\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "58317\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQEAAAD8CAYAAAB3lxGOAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/av/WaAAAACXBIWXMAAAsTAAALEwEAmpwYAABKsklEQVR4nO29eZwcZZ34//5UH3Mfue9AAiEQQAKEQxERRUVXRbxdvyrqeuyqX/3p7oq6369+dQ/Xa9XVRWV1PVfAa0UFFVEUxCAkEMKRhCSEJJNzkrln+qp+fn98qqZ7evqoma6erp6pd16T6amqrnq666nP83k+z+cQYwwhISFzF6veDQgJCakvoRAICZnjhEIgJGSOEwqBkJA5TigEQkLmOKEQCAmZ49RMCIjI1SKyU0R2i8j1tbpOSEhIdUgt/AREJALsAp4HHATuB15njHnM94uFhIRURa00gYuB3caYvcaYFHATcE2NrhUSElIF0RqddwVwIO/vg8AlpQ4WkdBtMSSk9vQaYxYVbqyVEKiIiLwdeHu9rh8SMgd5qtjGWgmBHmBV3t8rnW3jGGO+BnwNQk0gJKSe1MomcD+wTkTWiEgceC1wa42uFRISUgU10QSMMRkReTfwKyACfMMY82gtrhUSElIdNVkinHIjwulASMhMsMUYs6lwY+gxGBIyxwmFQEjIHCcUAiEhc5xQCISEzHFCIRASMscJhUBIyBwnFAIhIXOcUAiEhMxxQiEQEjLHCYVASMgcJxQCISFznFAIhITMcUIhEBIyxwmFQEjIHCcUAiEhc5xpCwERWSUivxORx0TkURF5r7P9YyLSIyIPOT8v8q+5ISEhflNNZqEM8AFjzFYR6QC2iMgdzr5/M8Z8pvrmhYSE1JppCwFjzGHgsPN6SEQeR1ONh4SENBC+2ARE5FTgfOA+Z9O7ReRhEfmGiMzz4xohISG1oWohICLtwI+A9xljBoEbgNOAjaim8NkS73u7iDwgIg9U24aQkJDpU1WiURGJAT8HfmWM+VyR/acCPzfGnFPhPGGi0ZDg0QqcBZwADl4FmWc5G+vJMPA94IkKx30EmAfYwPWAgRKJRqdtExARAb4OPJ4vAERkmWMvALgWeGS61wgJqSvtwCnASSBzMXAF0FzXJsEAcLuH4zYBy5zXlwCbSx5ZzerAZcAbgO0i8pCz7cPA60RkIyp69gHvqOIaISH1oxf4JZC6ADgbaKP+rjXNHtqwGugA4s7fy8ocW93qwD2AFNl123TPGRISKLLA6CuB16AqQb0FgEuxxy6fi1A1xuVVwE9KHh2UTxUSElDOQkfWutXuLSDr/JSjhYmCYi0qxIoTCoGQkJJ0olOASiPvTHIYGKlwTGF7B4ADJY8OhUBISEmeCZxJcLQAgH4gMYXjj6PTgdLaQygEQkKKsgF4IZWMasEmA3wUrhgse1SQRFxISJVYVJ4ve6EZXfxa6eF8w+gaokHX5TvLHGuAY857moEllH4EM0AfOvLHgEXo1MQL+ee01VXA49EhIQ3HKcAh4qRJUe6BjaCPaD/6eJUnwSJuJsGDDPFi4FRKK83bWcbNtJJiLy/FcBm5OXm64NijRLmFs9jPk6xkmNcBi1HhUPik9tHJ71nJ/ZxkMUe4ClgH/BFdu5zMJc7nvI/j2NzjnPeFcM8DzuvihKXJQwLNcnK27ig6fragD/RpQIwuxojyKCfG35Nh4ugWRe3781GT2n4gVeJ6FjrmLgKSQA9RRomREwIZ8h+objKsJosFHAJOjq/NT6aNDKvIEkFn9T1EyBABDBY2kbxjYxjmY2h32nECYQwhgkEKHugo0EXOhHkcGHNaGXVa/EM9tKjHYCgEinHeq2HbLfVuRWARrKrs5QYwzv/Tu74SBVpoZYwxTkF45YXvZf6Wu2nmUN4Rhe8VIpO2WxQb6SNEsAq2R4ggRT/95GtFiRVte9SzAm5hEZnUMm1DlPxVgIl/5dNCmgyv4GfYZP11G57VmFLjxGxHKNaVIlhOxzfAy/gQX+dNdIwfW9iJ3FG7cHvEOcMP2Msn+B7HuQnY5ax8exMI+hArBhijE/hrlnMRzzCv4SWezjL3eIA+zmd+0X0B0gS0Q1nEJklfr6OOl+OyQIY0pUeh/LMY/DE0FcOi1EPnD0IUi+iUzv9c1LEk3z++mx/ztzyf+AR11S8M8Dke4UYeYYzk+PYmmib1gyainMPp/Cdn01KT1sxuBAnudCAizaaVc7Ho5nI+zst4+oT9MaAp728LnRcWzr4sdPQpVJ/iedvuJ8m/8n/oZz8U7Ugtea+zwK+A8kss0+MlwBomfwp3DPWicgoWTUW2RljEQt7G+byftf40dyaxRiEbJ1RU/SXQQmCTbDIPEKYVUHNOM8HyUKsD614Ch/8ehi+vd0tmFaWEQOgsFCi+Cnkq8ZzliatgeEm9WzFnCPWtQPG+ejcgILy33g0IFLaluqFVI/NUqAlUwfQXuUJCvNO/AIa6anf+UBMICQk4C47X9vxVCwER2QcMoX6PGWPMJhGZD9yM+lvuA15tjOmr9lpBo5j5bqqaQeHxpbQLXdpUTzd3tcM9zl1sLHcOU2J7qXaUOqbYcaXeG2fiWktI7ah0/3ZxpOQ+vzSBK40x+Q7N1wN3GmM+KSLXO39/sNSbR0jxZ/ZN6YKmakW8+Ptrpd5nMUU9DuyCa9oF2wyQwSJJjCGiHEVYgiGOwQYEQxNpYo4Dql5rste6jSnpXZ8t0oZiZErsK7bNQr3iTy3pYZeP9299Zqdf9W+XhUWMGFls0mWjHoq3wHVy/lu+UPKdVS8ROprApnwhICI7gWcbYw6LyDLgLmPM+lLnaJJlZhlv9nxNgyFbtRNP8a5eIeCqCopnhKl0vTEshuhGH6lW1LfBFRXueXtpoZcO59EXLJodh58xx4s868t35h31GrSI0zLJ6WcihSKoPLW7P8Xw3rZatauJJuYxjzFGGGSoTGsKh5OJe46pEKiZ27ABfu34/3/VGPM1YElexuEjaMzkBETk7cDbASzaGWLIh6bMPobGx9QFZY6KkiRBM70I0EKcFazAJste9mDP8KMD2inSZElXzIITUol++hljjFTJsKfq8EMIPNMY0yMii4E7RGRH/k5jjCkWIOQIi68BRGVxaGQvS+nINKUJaCWL6gmCRYQIadI+TJtC6kmcODFijkZXG6oWAsaYHuf3MRH5CXAxcNStP+BMB45Ve525zdRG8hRJDnCANOkZnQKE+EuUCBGiZMjUVJuryk9ARNqcisSISBvwfLTYyK3Am5zD3gT8tJrrzGXiZNEMM2OoMHCDn9zgpjS6ZpAYN7/Z2IwwQnqSeTCkkWiimU46iBP3YFydPtVqAkuAn2gxIqLAfxtjfiki9wO3iMhbgaeAV1d5nTlLKzZZjpHBRtX+LBOnB0nAJkb/uES3EFppJUKEfvpnuMUhfmBhESWKzIA/X1VCwBizFzivyPYTaFyqJzSrSga7QnMEQ4QkEUYR7IIMK0KGVtK0UjkAJ0ucQSIkKGZRzdBOmg5PbRdsLAwWaSKMYhMn66yOZxEM1ngumHK3M04/MQYplmqqBRhjEcnxePD8M+lVYnnLRxpb2DSF5BXBxCKJRabgbmbzXkWxp+CJIKSIkqBwlUYzJei5sh5ClC2SxBlGiqjoWeIk6aSSki1AhCRx+pHx9mSdPXGE+VgIgwyQIFF2OhBniAjDkz6XVwLRS+LtI8SHBxkrkfTAxSLFvBV7WfG6vVhd2bwvD7IDwtE7TuHYtvVkKnQMixQrL7if7hcPI7ECIZCF3q+2cfDQc0hMWH3PkXPSSdIhT9Ld3EPb6gzt548x+GScnq3NjKQjGJZhWIOhFYNByCAYBIiRJTp+piyLFj3AsneeQFqKu/WcvKWLfQ9dhT0hqHr2YmHTzeO0zOtFovmd2/lusjB2YiEnOR/j4cEVbJo5wrz4Y1jNBUIgC8nkfAbS60kyr+K5OtnFKe85QHRpGgps3kO/iLPnj1d6EE42rV1HWfu6rcRW5AkBEZK9TRy6cz3HtjeTIYuN7Rh4J7qFxcfFYA9tPImUmP5lidBTpiWBEALGFuwiqZgmI9jJCKn9FtKuj5JLdliwh0snWSo8T3qoncQ+g0QnCwE72YLBMEYbGdopnnfAQhgDRsnaQ6RGR+FohOHBOEPZ+SToQjPknYKu72dQ1V39BdIM0sooUUf6J1KL6N/XgtXktMfkfRQDI0NxjEfVMEuWBImaziNrj5CiCysDks1/aJ17YSBFO1NJOWPTxlh2GVLoc2Mgk23z2AfBppnEwSaio+KeGrC0TQPNjpj30J5MEyM9nURS7uezSNDGwOBieodXMkAnEz9fhJyGkSXDIO2MkmSB012KawJZBHXcLb5SFIh8AjFZaDp5LZVuqCrXSaKMMvEp0b02zdh4uQmGGGNYJdZds8RI0MoAi4FVFE+gJajangIGiXCcFk6QJcYoq9FUlTHnx02F7ap0BhggykHaGXGmExmE9KQkkrk2RclWXCp0Wyfj+fFqtbY8MxisSR07l40pO/731M5Znsrns0gTdaak+Vu1TTFPU1JBtR2LUdxPkqGJIVaSYjlq9yk8R/7fNnCEVo7S5GiY5TjJDRDkHINmkvd7qeNyD3p1iHOjStea16/UQpN8lPuaWoBObLoYoQkhheaeL7QpRJioUcwjwwDGsfpni2blmx6C0E47bbRzlCNkPCTZDibiaY4+1XNWS5YYKaoL61PLTwSbDsc1OE6KKCkWA90ez9BOll5Uy5z+5wpDicvidW1WgGaM47DjDTcZmv8qu3H+xYk1vHFwtmNh0Uqr4+Y9lb6gPpmWD85goRAogY4/I+DZU8vO+0nid0hJE0mewWMs9tAetWeMMcBAg9sFZjcWFm200UwzY4yRnZJDkA2MEcWu+g6Hw0QJBGgjyQhPofN7N3inMP2XO5U5QYyTREiTYAAtBRFHnXmyznHxvPcknB9vwmI1R/ncBfdx39ZH+QCvqqjg29ikSFUI3gmpB/k2mxgxhhkmRQpDFI3Kd9O5u30nnzjaZ04CQxVtAV4IhUAZ4mSJ0M8QIxjiqBAoXIbRwhVRErSRcqTyAAnSzj53RSBCLmeyBaSIM1rE8FWcXhbQ/MQIL96Y4JMP2RyuMFfOkiVNmjhxLKzQfTggCEKMKG10MMYogww6PgDq6xGjl/T4QKOOYBNx6zEN0UaSiOf7GqGUH0EghIAg43PX4g1S14qpYGGxiqXMZ4ETTpNjmFH28BRJj0k9dcU3TWEJKsW1UQsRmrAQ2og4D52rAbgW7XwBYrBo9izJM7TygaHXcPrOFAs4jVH6Jhn8DIYkyfF1ZfVN0FEniELAwqKDdrpoIzJJqdVQqEKiRbZHx4/Xu5MgwxEOezaIFu9dxasSlaJ078ztsbGJEWMe87GxyZAu6JnQjmDGS48X6xtJ5xuIYNGS902Uf5RP0oFqD5MJxBLh02SjuZXfAKVMI9Ob9cScEAwKZsZZsqRqFGEnef/7TQSbGIYMTY4b8URsDIfJcg/9PMx2HmEXaWyGnX/1QhCaaOZ01vASLudpLAVUuC7Gor3owzbVnpDbYzBkyJAocn+jFH9gJ593avewUruywEPYPMJxrmMFzHB+h9s5zDs5J9h1B+536g6EZqzpo8tOhqQzFUiRZhdj3Mj9/I67GZnh2H5B6KSdK7iI63gOZxOllWZizmOoa+W1sU5PtVfXut+pLd+QIkt7HaonjZGllUhw/QQgfPj9QLP5iKMyR4Bm5tPOep7HD1nCV/k1PRydkRHIwmIda/gn3sZFxGgmPqOdLWj9yY0IqE0xt8q0lBG1oel4lhNFWEict3I+3+cDvILn0eTR83C6WFiczmq+yF/zLNpon2EBEDI1wnszB9BRKMJ6InyEq8mS5KfcXdSuUC1RolzKOXyNN7KIWDjKNADhPZpDWMBKoryAjZzGGt99CARhJYv5ItexOBQADcO075OIrBeRh/J+BkXkfSLyMRHpydv+Ij8bHFIdAlzLWt7JVaxkua/nbqaJV3E1q4ss+IUEl2kLAWPMTmPMRmPMRuBCYBT4ibP739x9xpjbPJ2v6L8xEvwPhhNF9oVMFwvhDWzgCtYS98k+YGGxgVP5IBt9FwC5O+7trk+th3jvUVPrfWYK7fbes0s/Baboj5fz+mUTeC6wxxjzlJNqbEps6d6C9VxrcosEdbJzfSeSwOOXcunOn/JfLGI9mhOmNKU952ezGPESLyAIZ3Aa89jBUR/ywDbRxHt4h6drm7z/mXAPi9/PEX7KIAdYwrOwOJvyjmNZDvElhBjL+OsKLeljkLs5yRMs5Om084yyRx/h+xznEdawkXb+gnJRqLCHB/gCQpwL+d/A6hLH2YzwKDv4TzJkWMWVLOdVJc+a4gG2831OsJ91XMQq3kKUhVDUMWqILfwzPTzGp7i95Dl98RMQkW8AW40xXxKRjwHXAYPAA8AHKpUgk6VieL2HCyWAbefAH7+KcD5n0cwmehngAOCm3dQMu29jHX/Bggk+Vfm8h71sZsektE3qFzg5SGesiHdemjTJvPcbsqRIkSA9aRFO6/8UGuIyFPdCLBVTUCp4KAFYLOMM/pn38sbxxOPlH8jjGN7JjdzF9rLHVUIFyhr+yPsqCgGD4SCGexgmQZZuOunHJoMmauulhyZaWU6aCB/nCb7DAMMYmBRsXexKGfQeSpHjJ7dFj7Wd4yqlFHEdwL0c67UdhlwvAFXNKy0i2uQcgCvFobq96z+BwVoVHxGROPBS4EPOphuATzjX/gTwWeAtRd43XnyETpCY41wbFUzG5Hw5M2AMmPFn0AApDAM8xs95jP8L7Cg8Pb/mNfyI/8fLKV746IdcxpEy9dkakcPAm/k73sIZvITP8BNeXFYQLEJYSTfNxElUkXzEQljN4ooCIIvh2/yeN/MBYGvF8wown9IPxQZgRZHt89CHNMXkSI9S5D+Ifh7rCpp0wbZS5NeW8oIf+aT9mA68ENUCjgK4vwFE5Ebg58XelF98ZM2KLvNKETqj87l0/jN41L6dRdFV2CbJscxJfp86xs+Luj0XL+011zHs4gE+w5c4j/9dUg1VlrGEVlqqEgIRIryUF1RoE/yEe3kz7wO2eTqvAU6U2V9pEtMBPB2N53Qp9vBOTu2qFCvslaEZU5D1p5RAmNr2yREMpXq31qAcxPIgAtwjBssc44cQeB3wffcPt+iI8+e1aB2Csuw7NMhnPgkwzPx1S5nXfCad0cWctA/RHF/EwUXHYJ0PLZ1DHGIz9/AJ3sONZcfny7mYn/MQJxmY1nXUbBPnGSwse5wNJDiCVwHgB0PAr309owW8EziTyfpJtY+SO3HwQh/wReCJKq+pVNVyp+DI84B35G3+lIhsRIXovoJ9FbA5+cSfJsc6daF2FZ/MmP6n/AgiKWwGSEHZ/MTraKGrqi9WmE83p1Y46iRZHm34YiiXAs9BJyH1XATdip89uNq6AyMUVMo0xryhqhaVwv3OI8U2FqN0ldb/IsXoHBADB4EtUNbmXTxg13920sN3uXsGrlQrBHgRmmCm3l4QCfysg9wYTl0l0wmUyzNQWgj8jD8x1tBZeL3hFiorx0wJgQW0cS7LZuBKtWIpcCqVi8PWGhst6uVfaHhjCAF3HSWkKI1QjmSUFIfLmqeCznp0XlpvLcBGbQH+fZeNIQTAT+1nVhEDzqh3IzxgkyXZ0JL8QrylAq81ukTu56pY4wiBaQng2T/vT0NJVx/vjra1J0aUjobQWYrRhlqmq6134Qcj6MKpf0bWxhECU25paZtAJU+62UKp9e98cvV8qr1SedqIs7zKgh314wJ0RSAIj8sAkzNeV0cQPtUUye/a5ZJTZSnVOeNE5kQ+fpvKphR/OkDludooaY7UMc9hdawglwa83riOzv7RgEIgZCaJkKKFo1hVdrxFtHGhz6HLM0MLOhUoFyw0UxjgON4L4ngjFAIhZWnjIGe89C7i1mCFcTAD3Fdybz9j7OC4z62bCc4BzoOKpcZnAoOuCszV6YCbvj+Sxs/lkZDytHGQtf/ZSzReSZW3KecSnCDDSZ9HsNrTAlyBVqYOwqNiAecCG/FTMwnCJ6uMm5s6ArSeAOvPBTuLUdom0CgfOwhkwWPdlybcoNBitBFnRcMZBtegcQJtlQ6cQZaj6TsW+XbGxnkaXBtgaz9E3RHHlQ7FKC0EYkQDYeKpNQaDXWE9uVJR+CxgKvQSLyvWabIMN5yX5lnAMoJhEHSxUCNlpYwG3mmcbMNuT5MsRFPqL1F2gau0FbWZ6BxZHTBOepPST3Flt2ELIpXyBFQWA1ka0elzIaoFjKHz8ATwJPrYnEP97AT+LOy6NJ4QAJD8dFSlvoxyfgKNowBVQ9aDJlC5O3n5rioLgSW0s4lVDRZCZAM9wJ+AgxBLwMr9cCIKQ28HczH1eYT87b+NIwRcLCCaQL2mYkzHi2sEKzCedEHHJsLgj8G2qzNEZbAZ89mqXXt2AwfhrM2wpF9jhxajXe+uG2HMDSpq7EGlcYSAO2R1AOt3wa5vQ/+1wKuB852DDBpc8QD60SbfnK8AD9BJmr+C8dp8+9DA2zPQu+zSBzzunGcNKnC+4+enCjyGKKknIWuqE5zHGeFBenxr18zwICxJwtMSao9zVaYFwIltsPURyARl5WD6eBICTiLRFwPHjDHnONvmAzejonAf8GpjTJ9ouuEvoMHXo8B1xpjKCeUqNsL53Qos64VdDwFvQqPlL3Q+igEOoOu6S9AMdRNpAvpYguGV5Pyve4F+1OKab8EeQjP3jaK9YIxGEgKGbMUqQ5USVSZYTs+t68nYpUNovbgRddLMWhaU8SQIImlYbEMnE78kNwu21Cs9TSVz7tTwqgl8E/gS8O28bdcDdxpjPiki1zt/fxDNObjO+bkETTx6SVWtjJFbpjLAWCsMdQOPAr8FdpGL806hsdYvpJiEfinwK5r5CZ8hNe68kn8jhdwc1w3BSaFfVWNZtw2Vy19X6k5JFtHz4EVkyxjBvMazZRvOGHsmpI5A+lBuUxa1De5YDJkN1EeZroNh0BjzBxE5tWDzNcCzndffAu5ChcA1wLeN5jLfLCLdBXkHp4a7IiJ5fy9NwtPvhn0Pw8EjaFBF/pdiUC3gWmDlhNMtABZxAIu7gKOElCdLlBSdVZ9nmBQH6a++QTPGhcC1sP9uGDwJpyRUSdwL9C6DgfeAOY1GnwpAdWJsSd6DfQTVv0GjLQ7kHXfQ2VZeCJyO2mEEnX7vdbbHmFi3wQLm29BxGFYfVie1x9C85BMYpNTIbTNKmKCg9uxfCd95PXzkX92VikYyx14APB3GzoaxF0PfIxDth9EzILsWTTIShNDi6vFFlzHGGBGZ0h2eUHcgBlwJnI2uxqxHp+gn0WnACiZ6rVnonGwJeq8y6IygkfrYHGCoE/50mb5uJc5iOurbIM9E0bFrG+oPcAkkzkU7Wgva+RpfA3CpRggcddV8EVlGLg18D+ps7bLS2TaB/LoDslAMHajRr8P5WUTOeD+vRAsMqoM8xZQEgCcv2FlAvasyiFG/LoBDDHD3uHoXYGJAiw2rfg/Ht8CxdwBXgw9ToqBSjTi7FTXP4/z+ad72N4pyKTBQ0R4Qyfu9CBW2C1FRsorSuR1HUNGTKLaztNtwE01zwmMwiyHt07Sniy6WsIR22idsj5QRqafvhhvera9tDKmgT8HWAa8EXm3gWaNw6TFY+DtUKwiSmulvelhPQkBEvo+jqIvIQRF5K/BJ4Hki8gRwlfM3wG3ojH43cCPwN5UvUOJ1uZwhoMJiPfBMVGvbQM4yUfZys18AgLf0Yl6+iShRVrGKDfPPZpmTMViAFaxkKctYUuJLj6dhmVPpbQ3zeQnneG36zLMI7T9L0UG/BdVhO+8D9jObK115XR14XYldzy1yrAHeVU2jPBNFl+8Xo0v6feiSPkDRAqAh08HGZje7ebLvSZLY4zEARzjudKDKUQFDJNlP2bq09eUEulh0Wt62BJBpQeMHgjSJXIifOaaD4THoLu23VzqwCDb64DehqwhD6DTh5DawwyVAPzAYkiRJGBilHcMiIIMwjEWCbg/+ExYS7NhNLWnNeLzVMHAPcOANqMU6SHSgmY8j+DHQBUMIuDdgOsTI2WwsVJIfAI7b3ECSHdhcSIRnQMPYpoOIwTBCC7AW9ws3jGKzq+SUI4PK506gjSaWl7Tw1pszgMtg1zaIb9Xo4W3A3reDeQXB6zlCrvbybBECTUxcATDoNCyBCofl5DyAM6iUbia3TJs/XdsK7NGXv+UN3MV1Tp2i93M/7+ccFhANlGpXO9RfsLpOEmWYeTxGL+egNynfrdq12BafDuw4B/7+M3Db1dBGjBWBe5hA1cfXA8+Fod2w+Q6Qo2AuBfNsglFroJA4cBGwGV0eq45gCIF8LXEQdcv8rfN3E3A52vfyjaIJ9LtIoisEFmrcyVsUyJIm66gYaT5JltcCC+ZINgFvocSVaOcAFx/cyl2nLeJEcgn6wE/uNn1MXsk1AilHThxlmD9P8CELAhHUGngluj59LphzHMczf/3z/Wc+fpVEC5bHgxv/cyc503Ya9QNwB5sIKpy70dYngR3AH9CVnMuBVwGbKNDiyqUbm81U95mFDLHOLDEBDaAqniewu8i2jn54+h36ej6tbPCydDOjrEW93F1jVH4euyALAH8JlhAYRg2f1wAvc7bZ5NyQihEnpxEk0T46Dw0nGKpVQ0MKKfbInHIAPv4v+noxTVxYJKqzfkSAU4Cn1bshdScY04EkOtKngHtRe8AmVFN7rMJ7m1CHoibn/bcBh8q+I2SGyBcMuxngFnbXrS2TORV4BcFIJV5fgiEEXNU/g47i+8ilsI+itptiLbXRB7/Nef0olFuKNmQxc2pKUG/H4RwnGeOxQEVtLqKxtQD/pizBmA40o4a+fjQg6EXkHvpm1BBazKCfQLWGw3n73aVGQeOGV6NRoWfAb2LvYpgeopoRwv/PETBsH92Gy1HKVcgAGee+nMsy3sVlNW+LN1YDf0ljawEz7DY8Y6xAF5XXoTlBmilvE2hFtbpTyMURgAqQC9C4j2tQt+IhOJJ+GSlaSDFIUEbIoGMQRvdYpE1sfEv+3nLegvtOgX/4qL62MEQDoYUJqgVsrHM7qqUbv1YHgjEdgImWfNcX4Ax0GbTUYOameXJXrVzfiQywBQ0vvgKNJzgLfjnwYZ4z+hiP8wwyPpZ2DioWUnUF5gSr2Hr+OYyyGIsMWU6Qm78dJ8IYpTrjSDs8dKG+PkmCRzlRVVv8YSmqBUzHPTVILMOvfAbBEQL5RNF5/r3oQ70XLQQD+oCPkXOdjqIawWVohab70VyjGYgMQXY7mBHo2trFh0dfxrP5V37DCSI+Fm+YWfJzrZVC00+tYCEbWVrV1RJ0keCZgI49fRzEOA4qMTJlA2wXHINXOCkZu2jhdBZW1RZ/6CB4bsDToVJ2yKmdKXi4/qZt6PRgQd4+dyn3BOpT0Ib2zgw6IF3gbNsFpw/HOdqSYaQ/y29St7OJSwFhHlaD1h4QlvEadN6UT3TCMRAlRjuXcDEX+Xh1g2DyEmpkiJB10ogbYCTuFIlyQgmWHYe33aSv2zGsqvt0IIpOBYLouVg/gikE4mjC4HXoiJ8/fU+jhsABC3Z3wBELWvq09y1ANSTHYf3QvBT2QrjkUbDGtnKc+bTSxQDHqnanrQdRIhyqY7bjPmKoAaYbAEM/fRxgAYZUBL52KXRG4K9+N/m9g2TZW/caRJ1o4GsjGwT9J5hCAHRAa0VHeKtgezIOu9fBkWuBIbC/oIIjhoYV9wNHYGgtrN8H6weW8zu+yCH+nau4lNsxJMfTFoV4p52JCRuWoE4ZzsMdhUyJ+UEMob2uKzLuctGldWxDMAmuEHBLsR9CBfepzvYUcPAM6Pkn1FDwX+pleD6q6eWHWUcgfS+89tj7uSL1V8Sc4Jdb+AW7uZVk8ZREIVPCAmwQiHRDdqD4UcdJ8GBdDYPN6Fqxf9V8ZwsVJ8Yi8g0ROSYij+Rt+7SI7BCRh0XkJyLS7Ww/VUTGROQh5+cr025ZFl0h6EGFgYsNjHWjFWOdgiO9aN6jLWjBoB7GVw72XgZfXvN39ER2jp/ibK5wfAUaC7XH5zz43Z/Rgp8R5/fMKN+5lQHJlq5d2kKMxXUt8d2BOqDMjQjSqeBFE/gmkwuP3AF8yBiTEZF/BT6E1hwA2GOM2VhVqwzqCDQfrWyQ35ubgSUJ2NuL6v5ZXS1ZgmYnbkWnBSndRTPcc7qh/xANH0uQxfA57uKsgu351g3j/B0FzmYRF9fcEp7JXXhYIweLMUqKwzN5A9zkOzaOI1oE7VAZZkeAUBq/AuIqCoFihUeMMb/O+3Mzmp7RP1JoPEEr2sJC43csgc4TxoDDqs8sIZfvwiYXTBSBdbthaHQLKTYQpx1p0FBig80nudLz8a/iVdzCLTVsEUAm1xVHIVOiRy2gjXNZxg9q3BpAB/3no5p/AvU1eWAQ+r8DIxehhS1OIWi+clPjJH5VxPLDJvAWtCahyxoReRBV4v/BGFO0GvWEugOFKzZZVFgX84WIAt1HoOO7QBra/6Qjfy/6nQgq7LvQpUIL/vQ0+Hz2Q/yvA2muSryANN1zIIbAAp98ISyyZMs8MK4mEslApsQlmyhWGbJGnA7j7hFtaLappSOw42bYcysceSFk3o+OMo1KP4EQAiLyEfSR+56z6TCw2hhzQkQuBP5HRM42xgwWvndC3YGlBYVL3PoOxYgAi3rhlNvhZAssHdR7uQQVHoKuYDWjz0EaWASPrxjhzmM3cklyFWPmLMxscxuWCJGWhXQ2qeRsirSw2qyjWltchFHaOcgg6yof3JHFjLnTtIm0QZVuSx4RJlasym/ABcDKMbhtF/S5qmbItIWAiFyHVip+rpNhGGOMG9WPMWaLiOxBnX8fmPIFymlqcWBRDI4uhQdbof2QCoFOVKuwUCGwCJ5/6Byah0Z5zSPP5Lkj17KI59ONwZptBiKxiDR10d7WgQHa460ssDurFgJNHGNJy70MjblpeN1MnO7rPLJZ7OFBigmBGaXU5d1l58VD0LeTcLlQmZYQEJGrgb8HrjDGjOZtXwScNMbYIrIWdffxt+yMG7PSNgxrj6uD+hA5o1/EOaYVIqfDe556Cxv65rOMq2lx1rhjZGh8w1AB2TSpvl0cGA+ljvIga6o+bYQ0sa4xJ6FQCp3luVb+IYQ0YBHJwplPRRmV06u+Zk1pAU49Ajt/RSgElIpCwCk88mxgoYgcBD6KrgY0AXeICMBmY8w7gWcBHxcRN3nzO40xJ31tsQ0cWQiPbITmA9C1M6/WALkJ6nArz9u1nrWpy1jLxRNO0cjmIO8Y/I6UbGGMMfbjBt9YDNJMAmglYuCyvY4aGGSiQFeSsCJ1Di+rA8UKj3y9xLE/An5UbaMm4PbjfA10YBnseys0H4HYfwEPF3ljJ68ZfTOnF53LzoV8g1n8Kb5ijX/3LdgI/RjHcSNCdtxLQJiYADrQZJvBqaTUuAizK4Ao/3m0UW+XKKr270Ut/RM02wiwBBJPg8QeiguBCJ2cRbxIrvs7uZlkiYSZswfXY6BarAmqUxRIOZ0vsBrVEVTtL5WNaqSdiaWGGpEBpl+sYyLBEAKj6OeJofrk71Eh5yYUWQyTp7fOxH9CHvx8+ugnwRD6IV2FohfYzH+QZtjPTzAnUC/EbjLOmm6GIZror2ubJmGAP6ODx8q87W5K+pPAti4aO7UYaPJNf5yvgiMEHkbvi4Va/x/N2z9pudDk/ZSa945yA5/mNm4atxU6Tm2cZHuZ94XkYxNj7KTaABK0oE42GoWXYXQ8lDhQHHV+DrobIiCroaMVRpfAiSvBy5JnoDmOXxaYYAiBLOr3fwCdqi1F/f/7K72xfN3dB/jDNNYmQ/JJspAjo1dgsFBp3EluLholO57mOYBMCGaaD/3vJeex0HixIxPxUnPaG8EQAqBD9DA6n3NzAriEg/Y08dZJXE/6YkfbNDPKKnLZXCbXjrcDn5vBRkeUs5gdAURJ1AHEH5tA8Gw7o0x2iy7sneP30SKIH2H2Ud4K3Rgu2KPMnmXBfjS3vj/Ct/GeoEkrI7PM6SekRqTROeZswHXD8YfGEwIhHnEtrCHKCLrsNBsYxq+pAIRCYBZjMZMmHwOcENi2oOKhdSKFrhsG3X5RiTTwOzRWzx9CIRDiCxkLfncR3HZtvVtSCtfppD9vmxuI0kiW5wPAViam26qO4KwOTIksgV2WakjKrQ94RCAdhdFSvluBIIVaneej64d3o1OE89BSVeWqKASBLKoF7MVPt/fGEAKzYVVnVpBBU/W4EQIJJqQXGwM70EmcE8AeNADqW7D8VrgwDfv/CNt6gfcQ7EfCRivr9Pt61iB/YsWCuuanDBmnhSHG2AXj8RgnaWGI/AxGkUAL7FFgM0QHYP2P4dkZtZ122jDwBOx7AiZlcAwa/ge/BV8I+BUHEzItogwzj8fp5UJasGhhgImueE6HbIgwwixwAubthIsyOafBBcCKrbDvboItBAapRbB28A2DbpLYkLrQTD+LTt2CjGeAMwU/eQTeZcOZDozG1L7mYgOZZRTPSxYkjuCnQdBlunUHPiYiPXn1BV6Ut+9DIrJbRHaKyAuqbqGN2nJCpog/KpQ4FQgrrktnIXoM7P6qL1meZ6Ppaa9jmnlUW2FkE2w9HXajX9N+4PEzCH658tqsZHjRBL4JXF1k+78ZYzY6P7cBiMgG4LVo2dergf8QkepmiQYmVwxzE2bEUCNP6BQzXYbwqUiJQLQTal54+Az0ls8HXsM0tA8BzoS+6+G2tXAD8LOVMHgZdc+NWCemVXegDNcANzkJR58Ukd3AxWh9oOlTVoyE8QPV8CSTo9Ld5yqQEQHHUUNxgmlatNKoank+ZG6EzABa5bkRStUfphbTgWoMg+8WkTeimYQ/YIzpQ7/NzXnHHGRyHW2goO6AhYZ3t5AblprJ3eiSuSvdNERhTcHp8jRgKU10040gRImyxEnI+gRP4HWoFQOtg9BR6yJDt6HawCNMU0oNO28+D13lmJx5Krg8jBbd8ZfpCoEbgE+gt+ETwGfRIiSemVB3YIEY/oLJwtid1kaYqLOO3/w4WnZoLX4nNZ5LxInTRjuWo1G10U6z1cxT2ac8P2dRAy/YrT8uBi1LZvmpUiSB7dWcoJEdzVLUYqlsWkLAGDMekykiNwI/d/7sAVblHboSr6FbbsWCGLqc6woEtwNl8o4bsVDJYKGSfAGhEChG5VHcAEMM0cPB8ZDgQ/QgWSFDhjYg1RcpXWSwDMfmw6+fDm/4xZTfWmMCOdGpG9OtO7DMGONGMFyL6lcAtwL/LSKfA5ajSv6fK55wDDU/xtDsVQ+jk4gkmhSwnYKJq2sQdF+XyiqZnwRjJtavpnONWrWrFctDGe5+sqQKHopMnto1zFIODLwUMw3jqxiIBM7Ho5E1gdoUUp1u3YFni8hGVKTuA94BYIx5VERuQbMgZoB3GWMqd4P8xL/9zu+DedsmzTPdiqOgsub55JwJXENhC+2cQ4zF5KqaevkCi30lTp3zotsLq6VGUcHkxVgZwbtByj1vqX35WMynk/M9CIFfcYCnGCyTGCRKYpoFxLr64PJfVz4uxCtu3/ZXsvpad8A5/p+Af6qmUZXpBR5CVyL7gXtRi1E+T+NbvIOXc1ltmxJSkiZgVeAC9NzIQUMDeDcVcBo69T3i61kbdG0tQa7I3gC6blTIXvwOtAiZDYwCu6Ah606cAR60u6nSoEIgxC/iRIlgNdyYOH1sdOBoRLtAYaJX/846Cyj1xYRW4EqcylK6aKcWncsWGAmkM6d/6bpnA7NACJTyGLQJb3RlehlihERNMgYPtcB9630/rU8EzlhRN4IfSjxtMuxhgIcZxMLCFMRha+2ibNGubwo6iI4bEx8TDazJYmETday17jmzRTqYmXSG4tsAp1V6jnnM5wwWIXkjtU2W+zhKmjGKj2r69yK62VBhDjnMMGkfk1bmk2qCQ6uA7fppfIlRCPGdWSwE0nyGX3ATJ4gRwyY94YGzAZsUdpGHMFMwX9Rj80tsGKLYREgTZ5QWZ7nSYEg5/yafM0O2YGknS5ZMEY0l5VwR4AIu5Jf884T9CdK8iE8xwMHx1k1ELd+v4ipu4T2T2jJTxNOw4pi+HsBmT42ETUkEXTmOorbkgfKHl8a9P7PTcjKLhQAc4yaOcVO9m1EVv+YOKBAChjQDfN7Du5ugaiFgiJDGnoazUPcwXOnUgTtEmj/4VEDTMx2oh0sbuoC0Dcf1fqpTnyPoatRC56RNzKbAtQYWArNbOs8U4vwrRYQxuq0d9GfPwy4Tzllp1V2LmM3gvXKD0lY5rxegHuY/Bwan0o4R4FZYcB8kzoDRpWC6UI/Vc9Alu5l6jCZV3vGFBhUCbj4ByNXIC5kOVoXvLs4wi7ruZ6jvLGynGvF0rxObyYyxrhBwP56gGsE8ikTjujYY16XYzZUGcBcs/wVccBBGHsqVAEy3Q8/zYew01GntDGqfW602S4QNKgQgtO7OFDZWVxb67SozkktFgeMrWTSF2AJUe3eT0xwrdvAI6oH6BGo8sID1wDxo/ylcdBDWOJtdWZEchh0/1pKAJy6EQ5eCeT0BT7JYlAYWAiFBZ6QZnlwB5+ypw8WzaOhaGo1ljQCPkucomD+i7oWF34LVW3Jh6wNroX8xrN0LS8hpFK4pIApc6Bx7dAv84SE4vAnN0NBYU9RQCITUjOFWePDMOgkBUAFwPxriFsMJao8BS8mN2E6SwSWPwdNRrSENDOyF/r2qSZRLeR9FI17Ps+HI78GcS6MJgXAyHeILxSZnkQy0TXtZzicMOgUYz2rRhEae5keFLoITq3OVy2PoQsDpqA2h0jOdxYl6Xejh4OAxCzSB0DBYDYJVdnUA8GR+yTA5FWT7KGx63L2Opi4LHk7i0eNXwuYDcGRU01PMRw3/XlZG+4DHzwJzFaEQqBuN98UHBavCEqFNM0O9q7BN6a6S82+cSHMGVjvBnjNuGCxE0IebBIxuQ1Nhnocu9XWC/QLoWQC9R9XLqe1P8PRdmuSmUvc6CpjnUftUy7XBS1KRbwAvBo4ZY85xtt2Mmk8BuoF+Y8xGJyvx48BOZ99mY8w7/W50LvlguEJQLZUezAwdHBu9lGyFIbFSmou6CoEIat1/GpDNwIkdsP/f4akrgTejw/1qMEshMQYJGwYXwKPfg9hRNQyWe1ISUIsQ35nCiybwTeBLwLfdDcaY17ivReSzTHTI3GOM2ehT+0pQogJOyJSp5CyUJUqywghngJ0C0gWr+0uPhxWnHbWiCbgEfZgNsDIJ83fBQBP0Pxsdz9zsUa6t4LmwLwN9P4RzDsGZlF79G4ZGVqorimZjzB8oUQNIRAR4NfB9n9tVAdepI9QEgoABdsVT/O3bv8tXXl5vS2ABUVQDcAdqQQf+lcDZe4CfUTy3wBJIXgPHPg5/vgR+H9VVBjdmy2UMrWSUXUvtp6XN1KI+QrXi63LgqDHmibxta0TkQdQv6x+MMXdXeY0ijKCZg/oIBUF1WESqHqENkLBg58Yop7a3wo+LXUeI1WM6YKECoPDScaB7lOJZqdw3zgO6YPjvYOej0PNLiG+FWFJtBQk0FmHg3ah/cq2FwEJqUaK7WiHwOiZqAYeB1caYEyJyIfA/InK2MWaSo+aE4iNTJovOQO4Hfgs8qp9kKWr8OYyjooVUwq/8tTqpaMFOpCk2WlkIkXpMB2x0vFjLxOWLFDAwH61CXG50tYA1kFkO/ZtQoXECereCaYfUBtTAOBOegk0V2jo9pi0ERCQKvBz1mwLAKT+WdF5vEZE9qFP1A4Xvn1B8JFfy1hutQOchGPwerLkfVo3qPWhFv6Mn0YixgGmmQcS3x9Jk4OhmMr3zgGfV9lpTwUYLjh5iYkWMk8AjK1BjQaWWubEEy1DDgg3JC1ABEWNmA4j8pxr97CpghzFmPDm4iCxyC5CKyFo0hGP6VUG6geehwVr5WADDkNkPiRHA6LEGzSM5TFiZbKYxNhy9FzNaSr2uI8PAXehS3hiqKW5vhoEzUBVhKrgPfhc66twDfBH1RmpMQ/W06g4YY76OVh8uNAg+C/i4iKRRnf2dxpipFRY/zWlVM3CB87sTleiO4wkjQEc/vGBQ9zWj92YENRP0offDDfgIqT1iQfc6ogcC+CAYNEv9j9HBNLsc0m8FcwnF60l4wUbjkr+MFsZoBt5axfnqx3TrDmCMua7Ith8BP5p2awRV2U5Bv9c4KrUfQyPCQFu8BDjVQLutOR5SwA40zHMlupxzpnOOh6lFIddZhF9WAYtI8woyAw8Br/ThfD5jcIKHuoFXoK4v01XjB1Fb1JfJpb7/DnA1U9cspkobqon4l6UpEIubTawkyUEd1Xeiv915/VNoH12Ofr/N6HxOUF/VPmfbPFQIL0G1tDFnWws6J3ySxswyDXyV7fQzRIIEWQy/Yodv5454Wh3IUKmrZCSKLLka1hW3Xk+l1lJtSaNqwXRi80fRTvk95/dI3r4k8BXgY2gHrAUWcD5qYjtY4VjvBEIIbGAx/4/X8NLBz2pPuQsdxaOo8XY52ouand+rmdirks7frmEwi95fA+xBhUCDCgCA93KJM9s0jq+kn3Oc8g9ClH6Wci9HuYJ0meUpYydIb/kwmb43owvzU73STFHKybkSu4CvoitSoyXO8VvgZcAzptu4CkTQ9c7pJ3cpRiCEwCM8yBvP3cYbtsN9fbDLLUd+HnApOqr/CW3tuUxeKo07x7vlBgV96LehU4lGLDaTR6qklbMJXaDZULBdiNNKFxurvnaUBJ0r9nOsp7z6KRKhJXYWzUOjVV+ztrijxVTYDXwC7UzlbB5Z4NOoRrBkWq2rjNeamlM7Y91JY+joz/J6Xsph+1YOL4WhZ+MEZqCC71Q01qPYd+saAMdQgeA6ceyj4QVAIUKEjXyUP/Oh8S2lOoUfbrqC0fLCZbCAc5PN7Pjkv9fPNdgz0xEC96DGKS9Gz6fQEetlU7yGV5ajhjD/CEQM7gra2X7gW1zJzfTNh6FNqNZzFjrYxVFD38oib3Ytvz3o6O/ep4XMbA7IGcLC4rU8j+j4vwhRrKI/M+WcI8AahKiJEDGB6FJlSJEz5nnlTNRQ5ZVf4afhbiKd+L0CEYg7Jgjv441cMr+FLa9A3YtiTHTCSqDJIYql2G8HFpNbKRhFhf1SGjHlW0UyMzraxiAqHjTQ8qOkMDnfQH2wUceBqZT3XszU5uE9NFKplUCMkwcZ4pvXod9zCypEM87rJDrSH0Mf6nxSzk8rKs6Ooqs3neS0gqBrp9NgjFHPbimVPn6cCJEKY0Elq74N/JZhXltGTQ2OEEii1uJeVLX2gput1CtjaCf214Cn+B9CHwghAHDNN2GBRPifc4XugQxXHIL/uhL9HnvR5T6LnJblGgIj5PSZJmfbcVRzyKKOXaNMTfAHGJsMn+afuJhn0oTKysIxx5V9p2LxNB8evUoCx8bml9zNa3lR1deaGdx8FF7pQkcgr+vzJ9BOOJUphFcO4XdgTGCEwG/F4qT5A195eCNfZRWPRE5ywf2wdS2MF67pBLajzkRt5KKJW1DfAdcI6FaeHkYFyCwRAEKEKFGS/I6Xesh79SpexS3cUtU1DVGG9s2jnD4QIcJLeE5V15lZEmhVoVWVDsxjNTk1tRIWfhvvcgzihOf4RiBsAgDDl2f5sHUlhixDsVG++grYehGa7+EKtL5DKzry70KXZO8D/ixwr6gH5xbgQaAfFd4wq9yGr+AT/IIhYrx+xq6ZoJsD/EVZHwEBOhvK+DKM1hiYCueT61SVeBGapLAWuB75/hEMIRAHcwZ8/nJDH9tZz+W6InAOuiQYQYVfBn3Ae1ChHG2BJ86Ah9fAKFiH4rQd7KR7Zzfdd3XT/UQ30VRglJ2qEITf8SGuJMqH+XfP7/PDkz9De9nzCMUXboLLdGokd+It62gUdUuuRb+zUZfaqYXjVCIYQsAAA2BSWe6V+1gllxDZD7ExkDTaygTq+deG5nmXKBxfD5l3A6+klXk8nQ/wI/bQl/fvuTy3fp+rJgjdHuf5brGccnhzG65MJT0ggrdHaGaYqk0A1FHlFCqbSZ+HLm/Vih7wubBrMIbJNHAHZDps/rLz/2NDVzfP2LWCRZvhkdMy7DrvqPpq/CH/TecCHwXOpp0n+Adewgd5Zj1aP+eJYDVg7NxUdaQWNJPeLkr77QvwHGqzKlA7giEEACJgJeDVA6v44sCn6UJzme57fB9rn1yHOa1QfWsCFhJlMX/D6XxwxhscAjpNWc5iz4ttwSCNzisrB0ZNZAPqwXac4sa5ZajtoLEIjhDogHP3n8u3+DKaulAxRDCJxfDoId1QkCNgEaVtvCdIk5xNlsEp4qeLRMRZmTAYbDLj46iFxUIWkBVDz2KDhcWKo2VPFQCG0eQUY1S24mec4zLoXPSNaPTgPlSQuLES7WgI9TzfW1trvCQVWYWmG3cTNn/NGPMFEZkP3IxOlvYBrzbG9DkZiL+AmkhHgeuMMVsrXWfTvk3cz/0TtmWBUTLo2igqAJrJfe/oLGF3wbmOkOAYQ3yLh3mkeBnaOUAcocs3B51uulnJKsYY4wS9JEliYxMlygIWMtYE775+gI50jO/+fbtPV60lXqIJnUQVzfdApA/Gng7Zi4FPogaq29GoQhvNu3M1jeid5kUTyAAfMMZsFZEOYIuI3AFcB9xpjPmkiFwPXA98EHghmlZsHZrA7Qbnd0maaZskAADGsPkxR3IbVCrkMXlet58hPsUf+C6/ZYAfoQEdswfvRdfWgAfnHa9dNlefwNBOO1mypEmTIMExjmGw6c08iDXYRV7ayQnXCY7a6VLJLjAM1o1wzh/VaXDrHdD7z2AuRqcGp6NGOtcvoDFzDXrJLHQYHXAxxgyJyOOoff4aVPwBfAvNAvBBZ/u3jTEG2Cwi3SKyzDlPUU7hzKLbBxji//KFUi2jmCT/CvdzI58gxX2VPlqDkRtdvVnZ55HmFEbZ6cTMrS96lNeU47300k8fmQKruiCkSEI2gznyZ+yul5S4TtCEgBdNIAGtGU1msxIYHYT7fw+Js1C1P45Kh5nEr0xQOaa0ROiUGTsfddNZkvdgHyEX5LuCXDIwUFPqinLnLaY8psmylUPojMMbhxjgKTaTmmWjP8BVXDfFd1gMM8wxbgb+e9rXtcgSYxgLM0kATMBksfsehqd+Nu1rBY9eaB3JBT2cCbTeqdsbNKloMTwLZxFpR/MHvs8YM6hTf8UYY6aaNjy/7sBqVk/aP8wYf8c3PJ1rGNjMbm7hZ/yeWyB/CjELaOd0biqpEZWmhXnM4xqqSb0cZZRFLVvoHdtEirbSXV8iyPyNMNhIc+JKXTYFEoODEa1P2AKYk2iPmz3RaZ6EgIjEUAHwPWOMW1/mqKvmi8gyGLfA9TDRYL+SvOrwLvl1BzbJpgl3I0WG73AnO/ispw/xGKP8AzdxJ//MrMsiAryLj03D5mzRTAtdnF7VtSMkWLjgcfoObqBs9ZtInMgZb4O9D1V1vZkjiVr3y6nzG+DotXB0BbSdhAUnYWQhOhWYHQIAvK0OCPB14HFjzOfydt0KvAk1lb4J+Gne9neLyE2oQXCgnD2gkDQ2X+RP/B0f9foW9tNPjAPMRgEA8C/85YR5u7c5nLeCraXzEuWoNHPOkgWThYEnMMnit9q7QXOmGEYXtU4rc0wraly9AkZOwkgP6jW4lPoJAf+v60UTuAx4A7BdRB5ytn0YffhvEZG3oib4Vzv7bkO/ud2oLf/NXhtjk+XT3MFH+AzwUKXD87AIWherJd7mcN4SanoxM1VysLXJqmHw4D3Y2eKt8yJsZpYU3nLRu5b/DlQA1Jv82Hl/8LI6cA+l798kx3xnVeBdU2nEKIb/5iA7uJFPcBtFqpaVJUqMWIA802uNn13Av3MZrMwoWLWIofeJCJp2LgYcdEsRTdVrsN4sQadl/pXYCsSn7+Eg/4cPsZfvTuv9a2ijg84pB4eG+EfEjnLxg5fRFAmoRtYEbEI1eRtIDcOxnWgCkMUETU8pjoVOX7qZep7E0gRCCAxyjEF+jq67OumTWpZpnUGTRFXbQl/tnLp7Bs0spIXZtDgVKDzk0YjbEd625ZmIzwkvfKMJDU13k9FcZMPDj8OBm9BqRKcQkMehDILmNPBX6w3Ep7Yi8PJr4hy3h2nC0J8U1s/vojWzlGR2lJ7Uk6TT7WxLDtN3ErU09KXguFaAacFNIu1WHJndzGT6jhStnOw/mxStZb/ZCMJZyRh1qzO0nNwDXslNJI46AHX0wq7/gd07YfAidHa7nIA8FiVYiN8VjgLxaQ3wxFn9DGcyRCMwloFs00FitGBj059JkbUNCTeT8CiwMw3H832IXYPJLMklVgILw4WeVMEsXobwSsW4bFo4kTwfu8rwWIsa5xO4BB0kM6hr2r0Vjo+hQX9dA7BqMzy1C/Y8DEPPRwPYghr/0IW2zb8BLxhCIAvbkintt84zvH9siAnJE0xKe1IburpzrLC6nbtCMLuFgJZqebeH42y81F6rbLUX0j48EDXNNmyhnilub25hshAwTJaJgvantcCSk7Dij7DtSTiYBV5AQB6PAuJoWM5DaDLN6gmGFce9QTZq9Ew4f+f/uGl1j6MFXnpiTFSLgrcIVQsyZPk8t3s40tvKvIUmBal15aCa+gm4tSddis1IDLkM1IUIOrieloGl+1Gv96DWDYjizGV8PWMwcPv1IHpTik17bFQ56Afswm41F4RABsNz2Zun7URQl0zXNWg+Oqv9I0NoZ95Y9oz6rdX+e1NhU0O2o1m94miUbyGjwJ3oQsBidMrQzkSBkcRRnvzP7e8fglo5/ROpwRECe5zfbv8u9hm9OcHNKt7Cl51XBsM+dnIvf8rb34QmYx5Cu203+hWOjNdrrx89y+EHr4T3fVH/rqmo2Qw87Fyk2BJ6Fi1OcxLta01M9Ls5FfUgPgpwBxqK/RxmZQmrAoIjBAqn8lMWxBGCUuPGP7r5R94w/pdhGX9kCYPkUvc0AS9sgu64ysfDNvzHKNjEqLdxa7QFdqzT1zWvQOQajCvhTi/zjxXUf2AVaizsPQgPfB4O7QGuQpcPWwiOphlndmoC0ybFLsbo4nKewfvo5zhR4uTfMI1lnxw3H3WSZhVubSrytQhCE02kybCL7TxZZVEPL3ycH7KU9vFkHll+yQ30FrQL2gUWOE/YmHE1XP/chitRyhQrWYgmcu0MpIhuRWtanEdOZnZmobsXtt8MO34BY+eh4THrCcanWEHZYK4pMguEwCi/5Vb+yCO8l5fxNzyNwm5dqpMXnwuX2pr7P02KNJ+edou9Mp9lgD7Of8TwOt7O4YJHbhD4P0m4yPHROQ5sHW/rzNh9S83QTtkPn/jHGWnC9Iig04CLmKj1W+jU4LJRuHAUdt8FW47A0AfRpAL1tqfPR4WAP8uEs0AIGJI8QZIb+Bw/5nbezAW8kr9mDRsd66L7NfmrzNWqwkyOB4CXcyNJfkCKbfQXKTqRAR42mjYT9LOmxl/V17gVs2GeP6tYtaEbXW0rNu0XVOuOA+elwXoU7vsujLwPtSzWkw7gUrSKUl/VZwuMEHgecA/TDQaOAoYRnmIr/8I2Psv3iTiqrmAcMVD4Ghj/O39/qX3utmL7C89R7BrF/i48X/5rG2GIBPpYl/Z/yFIsAfZ0CmwUI0MzwyRpxwSnu/hDB95KJ0WA9Vl49Dcw8r/QHNf1tA9o4Te/vDMDc1c3M1UBkEHNwDHU3bMd6CeLRRbLU9lI/fjFbmapr6XS1+XOfKcS2txcog2FHAI+ximo/eoehFxyiyZ0YXARmgj0CIuw8KOTNDPMKm6lhxcxxsJgLs6chS6PjOE9rqYNnQp4rZoSBySD5s5ZR93co8fxTwgFRghMvbDSEdRr6kJ0nrYe/9YPg2IFzmcQYRdr+W9O0gF8Hqcgo7Pf9ZMYxObXJBggQ5wxytfDqWQYtIHYklE4Wk6sFnesSQK9FqzI1tBP4LWoZR/09j8O/NrD+zpRnxuvRNBuNrAZxjaic4nZQWCEwNQ5gdYlO42JXumFjh4dqBdJJyo49pBzp52qV1il412/Zy9VAMuds9gDl6aVP3EF8BOSqO4UnfT+Do5yAT8jzUqG+T3H2VAkg2OO/YwxUmbaYGDaMnH/afCp98ONU8ouUYJz0Vu4Gx2MXXnfzkSl62z0lv+owvniTP05Xg9s/wGMXYzGF9RbG/CHBhYCTgFDflPhOAv4D7Qcwg+BL+GXz/VMM4bhzCZYn0wh3Dhpf5vAs1rg1Fa4KXWA9OBdJPjrsuf8LfdygENF7RvVYgSyzvC/mzF+wPHpnchCA4Q6gQuAr1A6NkqA1cDLgR+XOCaG9yrj+bQBzzJwxy9g+Hzq7YzlFw0sBFwqdV4bHS0NufpzgZzZViQL/GMSHil1gAU7u2HDCvjDPni2h0d7FUtpo4VBT6m2pkYsBQudlIMHeYLfTiF9/ASyqOLXjAqEi9H4ETdQspgytRBV33cU2ZdGjepJSs9R3G5TqND1ApkMGsVWr8cnhX4h/sQ3iGYDqy8ichwt8NZb6dgAs5DGbj80/mdo9PZDbT/DKcaYRYUbAyEEAETkAWPMpnq3Y7o0evuh8T9Do7cf6vMZ6u36FBISUmdCIRASMscJkhD4Wr0bUCWN3n5o/M/Q6O2HOnyGwNgEQkJC6kOQNIGQkJA6UHchICJXi8hOEdktItfXuz1eEZF9IrJdRB4SkQecbfNF5A4RecL5HShvEhH5hogcE5FH8rYVbbMoX3Tuy8MickH9Wj7e1mLt/5iI9Dj34SEReVHevg857d8pIi+oT6tziMgqEfmdiDwmIo+KyHud7fW9B8aYuv2grhp7UC/uOLAN2FDPNk2h7fuAhQXbPgVc77y+HvjXerezoH3PQn3uHqnUZrSe5O2oD96lwH0Bbf/HgL8tcuwGpz81obnC9gCROrd/GXCB87oD2OW0s673oN6awMXAbmPMXmNMCrgJuKbObaqGa4BvOa+/Bbysfk2ZjDHmDzApKUGpNl8DfNsom4FupwR93SjR/lJcA9xkjEkaY55Eow4urlnjPGCMOWyM2eq8HkLDnVZQ53tQbyGwAk2J63LQ2dYIGODXIrJFRN7ubFticmXYj6DVI4NOqTY30r15t6MufyNvChbo9ovIqcD5wH3U+R7UWwg0Ms80xlyARia9S0Selb/TqD7XUEsvjdhm4AY0lHQjWmb4s3VtjQdEpB2Nc3yfMWZC0EY97kG9hUAPmiPDZaWzLfAYY3qc38eAn6Cq5lFXXXN+H6tfCz1Tqs0NcW+MMUeNMbYxJgvcSE7lD2T7RSSGCoDvGWPcOMe63oN6C4H7gXUiskZE4miKiFvr3KaKiEibiHS4r4Hno8F9t6JpaXF+/7Q+LZwSpdp8K/BGx0J9KTCQp7IGhoI58rXkgixvBV4rIk0isgZNB/TnmW5fPiIiwNeBx40xn8vbVd97UE9raZ4FdBdqvf1Ivdvjsc1rUcvzNuBRt91ojto70QyQvwHm17utBe3+Pqoyp9H55VtLtRm1SH/ZuS/bgU0Bbf93nPY97Dw0y/KO/4jT/p3ACwPQ/meiqv7DaFqsh5z+X9d7EHoMhoTMceo9HQgJCakzoRAICZnjhEIgJGSOEwqBkJA5TigEQkLmOKEQCAmZ44RCICRkjhMKgZCQOc7/D8kyrT9QwVfYAAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "print(len(dataset))\n",
    "img, label = dataset[212]\n",
    "plt.imshow(np.array(img).transpose(1, 2, 0))\n",
    "print(label)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda:0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using cache found in C:\\Users\\IX/.cache\\torch\\hub\\pytorch_vision_v0.10.0\n",
      "d:\\Workspace\\Autopilot\\venv\\lib\\site-packages\\torchvision\\models\\_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.\n",
      "  warnings.warn(\n",
      "d:\\Workspace\\Autopilot\\venv\\lib\\site-packages\\torchvision\\models\\_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNeXt50_32X4D_Weights.IMAGENET1K_V1`. You can also use `weights=ResNeXt50_32X4D_Weights.DEFAULT` to get the most up-to-date weights.\n",
      "  warnings.warn(msg)\n"
     ]
    }
   ],
   "source": [
    "from torch import nn\n",
    "import torch.optim as optim\n",
    "\n",
    "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
    "print(device)\n",
    "\n",
    "class BehaviorPredictionModel(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.backbone = torch.hub.load('pytorch/vision:v0.10.0', 'resnext50_32x4d', pretrained=True)\n",
    "        self.linear = nn.Linear(1000, len(actions))\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.backbone(x)\n",
    "        x = self.linear(x)\n",
    "        return x\n",
    "\n",
    "model = BehaviorPredictionModel()\n",
    "model.to(device)\n",
    "\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0)\n",
    "\n",
    "save_dir = '../tmp/BehaviorPrediction/'\n",
    "os.makedirs(save_dir, exist_ok=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "    model.load_state_dict(torch.load(os.path.join(save_dir, 'model.pth')))\n",
    "    optimizer.load_state_dict(torch.load(os.path.join(save_dir, 'optimizer.pth')))\n",
    "except Exception as e:\n",
    "    print(e)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0 training loss: 1.1275527605733262 accuracy: 0.7796068787574768\n",
      "Epoch 0 testing loss: 0.26713911456162814 accuracy: 0.9148662686347961\n",
      "Epoch 1 training loss: 0.19734056659083715 accuracy: 0.9360598921775818\n",
      "Epoch 1 testing loss: 0.17803489180699072 accuracy: 0.9439300298690796\n",
      "Epoch 2 training loss: 0.13539895854139156 accuracy: 0.9568945169448853\n",
      "Epoch 2 testing loss: 0.13783496369918188 accuracy: 0.9553326368331909\n",
      "Epoch 3 training loss: 0.10317823439178572 accuracy: 0.9662401080131531\n",
      "Epoch 3 testing loss: 0.13175108337194705 accuracy: 0.9619341492652893\n",
      "Epoch 4 training loss: 0.09654968286850407 accuracy: 0.9685765504837036\n",
      "Epoch 4 testing loss: 0.11733692055160963 accuracy: 0.966906726360321\n",
      "Epoch 5 training loss: 0.06958486934975608 accuracy: 0.9768932461738586\n",
      "Epoch 5 testing loss: 0.14087544557704626 accuracy: 0.9603052139282227\n",
      "Epoch 6 training loss: 0.06030145941172191 accuracy: 0.9801942110061646\n",
      "Epoch 6 testing loss: 0.10435182989010068 accuracy: 0.9703360795974731\n",
      "Epoch 7 training loss: 0.05049743943314984 accuracy: 0.9832808375358582\n",
      "Epoch 7 testing loss: 0.12204952689544221 accuracy: 0.965877890586853\n",
      "Epoch 8 training loss: 0.05035491841790919 accuracy: 0.9834951758384705\n",
      "Epoch 8 testing loss: 0.10858839685020999 accuracy: 0.9702503681182861\n",
      "Epoch 9 training loss: 0.043587001611396226 accuracy: 0.9858101606369019\n",
      "Epoch 9 testing loss: 0.1310888304034136 accuracy: 0.9671639204025269\n",
      "Epoch 10 training loss: 0.04168992170448248 accuracy: 0.985917329788208\n",
      "Epoch 10 testing loss: 0.12311954002145453 accuracy: 0.966821014881134\n",
      "Epoch 11 training loss: 0.033135084892102214 accuracy: 0.9882965683937073\n",
      "Epoch 11 testing loss: 0.1047798677047338 accuracy: 0.9737654328346252\n",
      "Epoch 12 training loss: 0.032569226836672294 accuracy: 0.9892611503601074\n",
      "Epoch 12 testing loss: 0.11879688798821444 accuracy: 0.9704217910766602\n",
      "Epoch 13 training loss: 0.028555356944023776 accuracy: 0.9906544089317322\n",
      "Epoch 13 testing loss: 0.1067112349118475 accuracy: 0.9719650149345398\n",
      "Epoch 14 training loss: 0.030376502292039732 accuracy: 0.9905686974525452\n",
      "Epoch 14 testing loss: 0.10276189308977417 accuracy: 0.974022626876831\n",
      "Epoch 15 training loss: 0.03094094119199829 accuracy: 0.9900327920913696\n",
      "Epoch 15 testing loss: 0.12100944843186483 accuracy: 0.9677640795707703\n",
      "Epoch 16 training loss: 0.025050855693983045 accuracy: 0.9924120903015137\n",
      "Epoch 16 testing loss: 0.11378846955237544 accuracy: 0.9726508855819702\n",
      "Epoch 17 training loss: 0.019308931344186434 accuracy: 0.9937838912010193\n",
      "Epoch 17 testing loss: 0.12710691963165333 accuracy: 0.9738511443138123\n",
      "Epoch 18 training loss: 0.022636536291647694 accuracy: 0.9925192594528198\n",
      "Epoch 18 testing loss: 0.11035199841017494 accuracy: 0.9761659502983093\n",
      "Epoch 19 training loss: 0.023995953809338753 accuracy: 0.9924978017807007\n",
      "Epoch 19 testing loss: 0.13032327135999386 accuracy: 0.9743655920028687\n"
     ]
    }
   ],
   "source": [
    "num_epochs = 20\n",
    "for epoch in range(num_epochs):\n",
    "    train_loss = 0.0\n",
    "    train_correct = 0.0\n",
    "    for i, data in enumerate(train_loader):\n",
    "        x, y = data\n",
    "        x = x.to(device)\n",
    "        y = y.to(device)\n",
    "        optimizer.zero_grad()\n",
    "        y_pred = model(x)\n",
    "        loss = criterion(y_pred, y)\n",
    "\n",
    "        with torch.no_grad():\n",
    "            _, predictions = torch.max(y_pred, 1)\n",
    "            train_correct = train_correct + (predictions == y).float().sum()\n",
    "\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        train_loss += loss.item()\n",
    "    train_loss = train_loss / len(train_loader)\n",
    "    train_acc = train_correct / len(train_dataset)\n",
    "    print(f'Epoch {epoch} training loss: {train_loss} accuracy: {train_acc}')\n",
    "\n",
    "    with torch.no_grad():\n",
    "        test_loss = 0.0\n",
    "        test_correct = 0.0\n",
    "        for i, data in enumerate(test_loader):\n",
    "            x, y = data\n",
    "            x = x.to(device)\n",
    "            y = y.to(device)\n",
    "            y_pred = model(x)\n",
    "            loss = criterion(y_pred, y)\n",
    "            _, predictions = torch.max(y_pred, 1)\n",
    "            test_correct = test_correct + (predictions == y).float().sum()\n",
    "            test_loss += loss.item()\n",
    "        test_loss = test_loss / len(test_loader)\n",
    "        test_acc = test_correct / len(test_dataset)\n",
    "        print(f'Epoch {epoch} testing loss: {test_loss} accuracy: {test_acc}')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(model.state_dict(), os.path.join(save_dir, 'model.pth'))\n",
    "torch.save(optimizer.state_dict(), os.path.join(save_dir, 'optimizer.pth'))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average loss: 1.2253678088221702 accuracy: 0.838726282119751\n"
     ]
    }
   ],
   "source": [
    "def write_pred():\n",
    "    dataset = BehaviorPredictionDataset(dataset_paths, return_metadata=True, transform=preprocess)\n",
    "    dataloader = DataLoader(\n",
    "        dataset,\n",
    "        batch_size=8,\n",
    "        shuffle=False,\n",
    "    )\n",
    "\n",
    "    save_objs = {}\n",
    "\n",
    "    total_loss = 0.0\n",
    "    total_samples = 0.0\n",
    "    total_correct = 0.0\n",
    "    for i, data in enumerate(dataloader):\n",
    "        x, y, meta = data\n",
    "        x = x.to(device)\n",
    "        y = y.to(device)\n",
    "\n",
    "        y_pred = model(x)\n",
    "        loss = criterion(y_pred, y)\n",
    "\n",
    "        with torch.no_grad():\n",
    "            _, predictions = torch.max(y_pred, 1)\n",
    "            total_samples += y.shape[0]\n",
    "            total_correct += (predictions == y).float().sum()\n",
    "\n",
    "        for i in range(y.shape[0]):\n",
    "            output_path = meta['output_path'][i]\n",
    "            if output_path not in save_objs:\n",
    "                save_objs[output_path] = {}\n",
    "            frame = int(meta['frame'][i])\n",
    "            if frame not in save_objs[output_path]:\n",
    "                save_objs[output_path][frame] = {}\n",
    "            save_objs[output_path][frame][meta['vehicle_id'][i]] = {\n",
    "                'id': meta['vehicle_id'][i],\n",
    "                'frame': int(meta['frame'][i]),\n",
    "                'timestamp': float(meta['timestamp'][i]),\n",
    "                'location': meta['vehicle_location'][i].tolist(),\n",
    "                'rotation': meta['vehicle_rotation'][i].tolist(),\n",
    "                'current_action': actions[predictions[i]],\n",
    "                'current_action_gt': meta['current_action'][i],\n",
    "            }\n",
    "\n",
    "        total_loss += loss.item()\n",
    "    total_loss = total_loss / len(dataloader)\n",
    "    total_acc = total_correct / total_samples\n",
    "    print(f'Average loss: {total_loss} accuracy: {total_acc}')\n",
    "\n",
    "    for filepath, frames in save_objs.items():\n",
    "        os.makedirs(os.path.dirname(filepath), exist_ok=True)\n",
    "        with open(filepath, 'w') as file:\n",
    "            for frame in sorted(list(frames.keys())):\n",
    "                line = json.dumps(save_objs[filepath][frame])\n",
    "                file.write(line+\"\\n\")\n",
    "\n",
    "write_pred()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.9.13 ('venv': venv)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "da7e20478b5a775c0656e92accb7702d4567ea1fb2ed254df86c89875ad00a36"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
