{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import networkx as nx\n",
    "import osmnx as ox\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from matplotlib.colors import to_hex\n",
    "from places import generate_places\n",
    "from publib import set_style, fix_style\n",
    "\n",
    "set_style(['article'])  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "network_type = \"all\"\n",
    "place_name1 = \"huangpu\" # huangpu \n",
    "place_name2 = \"pasadena\" # pasadena \n",
    "k=16\n",
    "nchunk=10\n",
    "\n",
    "data_dir1 = f\"../data/{place_name1}/\"\n",
    "data_dir2 = f\"../data/{place_name2}/\"\n",
    "places1 = generate_places(place_name1)\n",
    "places2 = generate_places(place_name2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !!!! Note that you need to recreate the dataset, \n",
    "# !!!! since the saved city network is not up-to-date!\n",
    "G1 = ox.graph_from_place(places1, network_type=network_type, retain_all=False)\n",
    "G2 = ox.graph_from_place(places2, network_type=network_type, retain_all=False)\n",
    "len(list(G1.nodes)), len(list(G2.nodes))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "labels1 = torch.load(f'{data_dir1}{nchunk}-chunk_{k}-hop_node_labels.pt').numpy()\n",
    "labels2 = torch.load(f'{data_dir2}{nchunk}-chunk_{k}-hop_node_labels.pt').numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_colors(labels):\n",
    "    cmap = sns.color_palette(\"rocket\", as_cmap=True)\n",
    "    palette = cmap(np.linspace(0.1, 0.85, 10))\n",
    "    palette_hex = [to_hex(color) for color in palette]\n",
    "    color_array = [palette_hex[index] for index in labels]\n",
    "    return color_array"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "my_font = 23\n",
    "node_size = 20\n",
    "fig, ax = plt.subplots(1,2, figsize=(24, 7), gridspec_kw={'width_ratios': [1, 1]})\n",
    "ox.plot_graph(\n",
    "    G1, ax=ax[0], node_size=node_size+10, node_color=get_colors(labels1), node_alpha=0.5,\n",
    "    edge_linewidth=1, edge_color='black', edge_alpha=0.5, show=False, \n",
    ")\n",
    "ox.plot_graph(\n",
    "    G2, ax=ax[1], node_size=node_size, node_color=get_colors(labels2), node_alpha=0.5,\n",
    "    edge_linewidth=1, edge_color='black', edge_alpha=0.5, show=False, \n",
    ")\n",
    "\n",
    "ax[0].set_title(f\"(a) Huangpu, Shanghai\", fontsize=my_font)\n",
    "ax[1].set_title(f\"(b) Pasadena, Los Angeles\", fontsize=my_font)\n",
    "\n",
    "\n",
    "plt.tight_layout(pad=10.0)\n",
    "fix_style('article')\n",
    "# plt.savefig(f\"../Figures/label_visual_map_{place_name1}_and_{place_name2}_{k}-hop_{nchunk}-chunk.jpg\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "d_labels = {\n",
    "    \"paris\":\"Paris\", \"shanghai\":\"Shanghai\", \"la\":\"L.A.\", \"london\":\"London\",\n",
    "}\n",
    "\n",
    "d_color = {\n",
    "    \"paris\":\"#c5373e\", # Red, \"#9c251c\", rgb(197, 55, 62)\n",
    "    \"shanghai\":   \"#006eae\", # Blue, \"#00498d\", rgb(0, 110, 174)\n",
    "    \"la\":  \"#439130\", # Green, \"#1c6e2b\",   rgb(67, 145, 48)\n",
    "    \"london\":   \"#6e788e\", # Grey # \"#43536a\",   rgb(110, 120, 142)\n",
    "}\n",
    "\n",
    "d_marker = {\"paris\":'x', \"shanghai\":'o', \"la\":'^', \"london\":\"s\"}\n",
    "\n",
    "\n",
    "eccentricities = {}\n",
    "for dataset_name in [\"paris\", \"shanghai\", \"la\", \"london\"]:\n",
    "    dataset_path = f'../data/{dataset_name}'\n",
    "    eccentricities[dataset_name] = torch.load(f'{dataset_path}/{k}-hop_eccentricities.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_bins = 100\n",
    "fontsize = 26\n",
    "\n",
    "# Create the figure and axis\n",
    "fig, ax = plt.subplots(1, 1, figsize=(11, 8))\n",
    "\n",
    "# Plot histograms for each dataset on the same subplot\n",
    "for dataset_name in [\"paris\", \"shanghai\", \"la\", \"london\"]:\n",
    "    ax.hist(\n",
    "        x=eccentricities[dataset_name]/1000,\n",
    "        bins=num_bins,\n",
    "        range=(0, 10),  # Limit the x-axis to focus on the main distribution\n",
    "\n",
    "        density=True,\n",
    "        alpha=0.7,\n",
    "        label=d_labels[dataset_name],\n",
    "        color=d_color[dataset_name],\n",
    "        edgecolor='black'\n",
    "    )\n",
    "\n",
    "\n",
    "# Set labels and title\n",
    "ax.set_xlabel(\"{k}-hop eccentricity (km)\", fontsize=fontsize)\n",
    "ax.set_ylabel(\"Density\", fontsize=fontsize)\n",
    "ax.set_title(\"Distributions of node eccentricity estimations\", fontsize=fontsize + 2)\n",
    "ax.tick_params(axis='both', labelsize=fontsize - 3)\n",
    "# ax.xaxis.set_major_formatter(ticker.FormatStrFormatter('%.2f'))\n",
    "\n",
    "# Add gridlines for readability\n",
    "ax.grid(visible=True, linestyle='--', alpha=0.6)\n",
    "\n",
    "# Add legend\n",
    "ax.legend(fontsize=fontsize - 5, frameon=True)\n",
    "\n",
    "# Apply custom styling and adjust layout\n",
    "plt.tight_layout()\n",
    "fix_style('article')\n",
    "plt.savefig(f\"../Figures/label_{k}-hop_eccentricity_distributions.jpg\", bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
