{
   "cells": [
      {
         "cell_type": "code",
         "execution_count": 49,
         "metadata": {},
         "outputs": [],
         "source": [
            "def tag_dict_to_str(tag_dict):\n",
            "    string = \"\"\n",
            "    for key, val in tag_dict.items():\n",
            "        string += key\n",
            "        string += ': '\n",
            "        string += str(val)\n",
            "        string += ', '\n",
            "    return string\n",
            "\n",
            "\n",
            "def has_not_relevant_key(string, not_relevant_keys):\n",
            "    split = [string]\n",
            "    if ':' in string:\n",
            "        split = string.split(':')\n",
            "    elif '_' in string:\n",
            "        split = string.split('_')\n",
            "    for str in split:\n",
            "        if str.lower() in not_relevant_keys:\n",
            "            return True\n",
            "    return False\n",
            "\n",
            "\n",
            "not_relevant_keys = {'addr', 'comment', 'contact', 'source', 'name', 'tiger',\n",
            "                     'ref', 'created_by', 'nysgissam', 'wikidata', 'operator',\n",
            "                     'lacounty', 'osak', 'source_ref', 'nhd', 'admin_level',\n",
            "                     'wikipedia', 'yh', 'gnis', 'at_bev', 'mml', 'postal_code',\n",
            "                     'raba', 'nycdoitt', 'maaamet', 'pmfsefin', 'old_name', 'official_name',\n",
            "                     'chicago', 'linz', 'it', 'destination', 'date', 'lojic'\n",
            "                     'geobase', 'mapillary', 'clc', 'ssr', 'unsigned_ref', 'naptan'\n",
            "                     'mvdgis', 'linz2osm', 'gns', 'note', 'metcouncil', 'url',\n",
            "                     'route_ref', 'gtfs', 'uic', 'attribution', 'date', 'ts',\n",
            "                     'id', 'survey', 'stif', 'network', 'naptan', 'location',\n",
            "                     'tmc', 'fixme', 'wabe', 'object', 'description', 'check_date',\n",
            "                     'tec', 'qroti', 'dcgis', 'website', 'short_name', 'image',\n",
            "                     'NaPTANAreaCode', 'vrs', 'cxx', 'in', 'code', 'massgis', 'original_osm_id', \n",
            "                     'bbr', 'shape', 'lnam', 'redwood_city_ca', 'email', 'KSJ2',\n",
            "                     'canvec', 'uuid', 'sorting_name', 'phone', 'inegi', 'ine', \n",
            "                     'brand', 'cesena', 'geobase', 'mobile', 'strazakosm', 'ipp',\n",
            "                     'fhrs', 'alt_name', 'old_street', 'ksj2', 'unocha',\n",
            "                     'wikimedia_commons', 'lojic', 'brn', 'fid', 'notas',\n",
            "                     'fax', 'sangis', 'okato', 'nhd-shp', 'surrey', 'statscan',\n",
            "                     'panoramax'}\n",
            "                                                                 \n",
            "\n",
            "repeat_no = 20"
         ]
      },
      {
         "cell_type": "code",
         "execution_count": null,
         "metadata": {},
         "outputs": [],
         "source": [
            "import pickle\n",
            "import gc\n",
            "\n",
            "gc.disable()\n",
            "\n",
            "with open('osm_planet_unique_relevant_tagsets.pkl', 'rb') as file:\n",
            "    unique_relevant_tagsets = pickle.load(file)\n",
            "    print(\"Loaded file \" + file.name)\n",
            "\n",
            "gc.enable()"
         ]
      },
      {
         "cell_type": "code",
         "execution_count": null,
         "metadata": {},
         "outputs": [],
         "source": [
            "keys = list(unique_relevant_tagsets.keys())\n",
            "print(\"Unique relevant tagsets: \" + str(len(unique_relevant_tagsets)))"
         ]
      },
      {
         "cell_type": "code",
         "execution_count": 50,
         "metadata": {},
         "outputs": [],
         "source": [
            "import random\n",
            "dataset_dict = {'anchor': [], 'positive': []}\n",
            "for _ in range(0, repeat_no):\n",
            "    for el in unique_relevant_tagsets.keys():\n",
            "        if len(unique_relevant_tagsets[el]) < _:\n",
            "            continue\n",
            "        idx_anchor = random.randrange(0, len(unique_relevant_tagsets[el]))\n",
            "        idx_positive = random.randrange(\n",
            "            0, len(unique_relevant_tagsets[el]))\n",
            "        if idx_anchor != idx_positive:\n",
            "            dataset_dict['anchor'].append(tag_dict_to_str(\n",
            "                unique_relevant_tagsets[el][idx_anchor]))\n",
            "            dataset_dict['positive'].append(tag_dict_to_str(\n",
            "                unique_relevant_tagsets[el][idx_positive]))\n",
            "        else:\n",
            "            relevant_tags = {key: val for key, val in unique_relevant_tagsets[el][idx_positive].items(\n",
            "            ) if not has_not_relevant_key(key, not_relevant_keys)}\n",
            "            dataset_dict['anchor'].append(tag_dict_to_str(\n",
            "                unique_relevant_tagsets[el][idx_anchor]))\n",
            "            dataset_dict['positive'].append(tag_dict_to_str(relevant_tags))\n",
            "\n"
         ]
      },
      {
         "cell_type": "code",
         "execution_count": null,
         "metadata": {},
         "outputs": [],
         "source": [
            "print(len(dataset_dict['anchor']))\n",
            "print(len(dataset_dict['positive']))\n",
            "\n",
            "gc.disable()\n",
            "\n",
            "with open('relevant_tags_pairs_dataset_osm_planet_20_rep.pkl', 'wb') as handle:\n",
            "    pickle.dump(dataset_dict, handle, protocol=pickle.HIGHEST_PROTOCOL) \n",
            "\n",
            "gc.enable()"
         ]
      },
      {
         "cell_type": "code",
         "execution_count": null,
         "metadata": {},
         "outputs": [],
         "source": [
            "from datasets import Dataset\n",
            "dataset = Dataset.from_dict(dataset_dict)"
         ]
      },
      {
         "cell_type": "code",
         "execution_count": null,
         "metadata": {},
         "outputs": [],
         "source": [
            "num_shards = 128\n",
            "output_path_template = \"relevant_tags_pairs_dataset_parquet_sharded_20_rep/relevant_tags_pairs_dataset_osm_planet_20_rep_{index:05d}.parquet\"\n",
            "for index in range(num_shards):\n",
            "    shard = dataset.shard(index=index, num_shards=num_shards, contiguous=True)\n",
            "    shard.to_parquet(output_path_template.format(index=index))"
         ]
      }
   ],
   "metadata": {
      "kernelspec": {
         "display_name": "sd_map_encoding",
         "language": "python",
         "name": "python3"
      },
      "language_info": {
         "codemirror_mode": {
            "name": "ipython",
            "version": 3
         },
         "file_extension": ".py",
         "mimetype": "text/x-python",
         "name": "python",
         "nbconvert_exporter": "python",
         "pygments_lexer": "ipython3",
         "version": "3.8.0"
      }
   },
   "nbformat": 4,
   "nbformat_minor": 2
}
