{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "from os import path as osp\n",
    "import numpy as np\n",
    "from av2.datasets.sensor.av2_sensor_dataloader import AV2SensorDataLoader\n",
    "from av2.map.map_api import ArgoverseStaticMap\n",
    "from shapely.geometry import Polygon, MultiPolygon\n",
    "from copy import deepcopy\n",
    "from av2_to_wgs_conversion import convert_city_coords_to_wgs84\n",
    "import requests"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "overpass_useragent = 'NeurIPS' # use a more unique useragent here for yourself\n",
    "overpass_headers = {\n",
    "    'Connection': 'keep-alive',\n",
    "    'sec-ch-ua': '\"Google Chrome 80\"',\n",
    "    'Accept': '*/*',\n",
    "    'Sec-Fetch-Dest': 'empty',\n",
    "    'User-Agent': overpass_useragent,\n",
    "    'Content-Type': 'application/x-www-form-urlencoded; charset=UTF-8',\n",
    "    'Origin': 'https://overpass-turbo.eu',\n",
    "    'Sec-Fetch-Site': 'cross-site',\n",
    "    'Sec-Fetch-Mode': 'cors',\n",
    "    'Referer': 'https://overpass-turbo.eu/',\n",
    "    'Accept-Language': '',\n",
    "    'dnt': '1',\n",
    "}\n",
    "\n",
    "FAIL_LOGS = [\n",
    "    # official\n",
    "    '75e8adad-50a6-3245-8726-5e612db3d165',\n",
    "    '54bc6dbc-ebfb-3fba-b5b3-57f88b4b79ca',\n",
    "    'af170aac-8465-3d7b-82c5-64147e94af7d',\n",
    "    '6e106cf8-f6dd-38f6-89c8-9be7a71e7275',\n",
    "    # observed\n",
    "    '01bb304d-7bd8-35f8-bbef-7086b688e35e',\n",
    "    '453e5558-6363-38e3-bf9b-42b5ba0a6f1d',\n",
    "    # observed ll2_custom\n",
    "    '8940f5f1-13e0-3094-99ba-da2d17639774', \n",
    "    'c08279c0-10b4-3d21-b13f-a1c1a0b87f8b', \n",
    "    'c96a09c8-46ed-391f-8a66-c46fa8b76029'\n",
    "]\n",
    "\n",
    "av2_root_path = \"/tmp/argoverse2/sensor\"\n",
    "av2_geo_split_root_path = \"geographical-splits/near_extrapolation_splits/argoverse2/argo_geo_txts\"\n",
    "splits = ['train', 'val', 'test']\n",
    "\n",
    "out_path = \"out\"\n",
    "\n",
    "def get_log_ids(geo_split_root_path, split):\n",
    "    log_ids_file_path = osp.join(geo_split_root_path, split + '.txt')\n",
    "    log_ids = []\n",
    "    with open(log_ids_file_path, 'r') as f:\n",
    "        for line in f:\n",
    "            log_ids.append(line.strip())\n",
    "    for l in FAIL_LOGS:\n",
    "        if l in log_ids:\n",
    "            log_ids.remove(l)\n",
    "    return log_ids\n",
    "\n",
    "already_processed_scenario_maps = {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "log_ids_train = get_log_ids(av2_geo_split_root_path, 'train')\n",
    "log_ids_val = get_log_ids(av2_geo_split_root_path, 'val')\n",
    "log_ids_test = get_log_ids(av2_geo_split_root_path, 'test')\n",
    "\n",
    "root_path_train = osp.join(av2_root_path, 'train')\n",
    "root_path_val = osp.join(av2_root_path, 'val')\n",
    "root_path_test = osp.join(av2_root_path, 'test')\n",
    "\n",
    "loader_train = AV2SensorDataLoader(Path(root_path_train), Path(root_path_train))\n",
    "loader_val = AV2SensorDataLoader(Path(root_path_val), Path(root_path_val))\n",
    "loader_test = AV2SensorDataLoader(Path(root_path_test), Path(root_path_test))\n",
    "\n",
    "log_ids_train_og = list(loader_train.get_log_ids())\n",
    "log_ids_val_og = list(loader_val.get_log_ids())\n",
    "log_ids_test_og = list(loader_test.get_log_ids())\n",
    "# Initialize the dictionaries\n",
    "loader = {}\n",
    "root_path = {}\n",
    "\n",
    "for log_id in log_ids_train:\n",
    "            if log_id in log_ids_train_og:\n",
    "                loader[log_id] = loader_train\n",
    "                root_path[log_id] = root_path_train\n",
    "            elif log_id in log_ids_val_og:\n",
    "                loader[log_id] = loader_val\n",
    "                root_path[log_id] = root_path_val\n",
    "            elif log_id in log_ids_test_og:\n",
    "                loader[log_id] = loader_test\n",
    "                root_path[log_id] = root_path_test\n",
    "\n",
    "for log_id in log_ids_val:\n",
    "            if log_id in log_ids_train_og:\n",
    "                loader[log_id] = loader_train\n",
    "                root_path[log_id] = root_path_train\n",
    "            elif log_id in log_ids_val_og:\n",
    "                loader[log_id] = loader_val\n",
    "                root_path[log_id] = root_path_val\n",
    "            elif log_id in log_ids_test_og:\n",
    "                loader[log_id] = loader_test\n",
    "                root_path[log_id] = root_path_test\n",
    "\n",
    "for log_id in log_ids_test:\n",
    "            if log_id in log_ids_train_og:\n",
    "                loader[log_id] = loader_train\n",
    "                root_path[log_id] = root_path_train\n",
    "            elif log_id in log_ids_val_og:\n",
    "                loader[log_id] = loader_val\n",
    "                root_path[log_id] = root_path_val\n",
    "            elif log_id in log_ids_test_og:\n",
    "                loader[log_id] = loader_test\n",
    "                root_path[log_id] = root_path_test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_osm_map_for_logids(log_ids, root_path):\n",
    "    for log_id in log_ids:\n",
    "        data_root = root_path[log_id]\n",
    "        log_map_dirpath = Path(osp.join(data_root, log_id, \"map\"))\n",
    "        vector_data_fnames = sorted(log_map_dirpath.glob(\"log_map_archive_*.json\"))\n",
    "        if not len(vector_data_fnames) == 1:\n",
    "            raise RuntimeError(f\"JSON file containing vector map data is missing (searched in {log_map_dirpath})\")\n",
    "        vector_data_fname = vector_data_fnames[0]\n",
    "        vector_data_json_path = vector_data_fname\n",
    "\n",
    "        avm = ArgoverseStaticMap.from_json(vector_data_json_path)\n",
    "        av2_city_name = avm.log_id[-14:-11]\n",
    "        scene_ls_list = avm.get_scenario_lane_segments()\n",
    "        scene_ls_polygon_list = []\n",
    "        for ls in scene_ls_list:\n",
    "           scene_ls_polygon_list.append(Polygon(ls.polygon_boundary))\n",
    "\n",
    "        if avm.log_id not in already_processed_scenario_maps:\n",
    "            print(\"AV2 Map log_id: \" + avm.log_id)\n",
    "            bounds = MultiPolygon(scene_ls_polygon_list).bounds\n",
    "            print(\"Scenario Map Bounds: \" + str(bounds))\n",
    "            bounds_with_buffer = list(deepcopy(bounds))\n",
    "            bounds_with_buffer[0] -= 100\n",
    "            bounds_with_buffer[1] -= 100\n",
    "            bounds_with_buffer[2] += 100\n",
    "            bounds_with_buffer[3] += 100\n",
    "            print(\"Scenario Map Bounds with Buffer: \" + str(bounds_with_buffer))\n",
    "\n",
    "            bounds_with_buffer_latlon = convert_city_coords_to_wgs84(np.array([[bounds_with_buffer[0], bounds_with_buffer[1]], [bounds_with_buffer[2], bounds_with_buffer[3]]]), av2_city_name)\n",
    "\n",
    "            print(\"Scenario Map Bounds with Buffer in Lat/Lon: \" + str(bounds_with_buffer_latlon))\n",
    "\n",
    "            query_str = \"(node({0},{1},{2},{3}); rel(bn)->.x; way({0},{1},{2},{3}); node(w)->.x; rel(bw);); out meta;\".format(bounds_with_buffer_latlon[0, 0], bounds_with_buffer_latlon[0, 1], bounds_with_buffer_latlon[1, 0], bounds_with_buffer_latlon[1, 1])\n",
    "            data = {\n",
    "              'data': query_str\n",
    "            }\n",
    "            response = requests.post('https://overpass-api.de/api/interpreter', headers=overpass_headers, data=data)\n",
    "            Path(out_path).mkdir(parents=True, exist_ok=True)\n",
    "            with open(osp.join(out_path, avm.log_id + '.osm'), 'w') as f:\n",
    "                f.write(response.text)\n",
    "            already_processed_scenario_maps[avm.log_id] = True\n",
    "        else:\n",
    "            #print(\"AV2 Map \" + avm.log_id + \" already processed, skipping...\")\n",
    "            pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "get_osm_map_for_logids(log_ids_train, root_path)\n",
    "get_osm_map_for_logids(log_ids_val, root_path)\n",
    "get_osm_map_for_logids(log_ids_test, root_path)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "sd_map_encoding",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.20"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
