{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Combining Histograms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-06-07 11:22:01.140993: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA\n",
      "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
      "2023-06-07 11:22:01.676636: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda-11.4/lib64:\n",
      "2023-06-07 11:22:01.676700: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda-11.4/lib64:\n",
      "2023-06-07 11:22:01.676704: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n"
     ]
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from matplotlib import ticker\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "import pandas as pd\n",
    "import pickle\n",
    "from pathlib import Path\n",
    "from collections import defaultdict, Counter\n",
    "from trajdata import AgentType\n",
    "from trajdata_analysis.analysis.radian_formatter import Multiple"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "AV_hue_order = [\n",
    "    \"nusc_trainval\",\n",
    "    # \"nusc_mini\",\n",
    "    \"waymo_train\",\n",
    "    \"interaction_single\",\n",
    "    \"interaction_multi\",\n",
    "    # \"lyft_sample\",\n",
    "    # \"lyft_train\",\n",
    "    \"lyft_train_full\",\n",
    "    \"nuplan_mini\",\n",
    "]\n",
    "\n",
    "peds_hue_order = [\n",
    "    \"sdd\",\n",
    "    \"eupeds_ETH\",\n",
    "    \"eupeds_UCY\",\n",
    "]\n",
    "\n",
    "combined_hue_order = AV_hue_order + peds_hue_order\n",
    "\n",
    "# Set this to set the hue order for all following plots.\n",
    "all_hue_order = combined_hue_order\n",
    "\n",
    "label_map = {\n",
    "    \"nusc_trainval\": \"nuSc\",\n",
    "    # \"nusc_mini\": \"nuScenes Mini\",\n",
    "    \"waymo_train\": \"WOMD\",\n",
    "    \"interaction_single\": \"INT S\",\n",
    "    \"interaction_multi\": \"INT M\",\n",
    "    # \"lyft_sample\": \"Lyft L5 (s)\",\n",
    "    # \"lyft_train\": \"Lyft\",\n",
    "    \"lyft_train_full\": \"Lyft\",\n",
    "    \"nuplan_mini\": \"nuPlan\",\n",
    "\n",
    "    \"sdd\": \"SDD\",\n",
    "    \"eupeds_ETH\": \"ETH\",\n",
    "    \"eupeds_UCY\": \"UCY\",\n",
    "}\n",
    "\n",
    "def fix_legend_labels(ax):\n",
    "    if isinstance(ax, plt.Axes):\n",
    "        for text in ax.get_legend().get_texts():\n",
    "            # the first result will be all handles, i.e. the dots in the legend\n",
    "            # the second result will be all legend text\n",
    "            text.set_text(label_map[text.get_text()])\n",
    "\n",
    "        ax.get_legend().set_title(\"Dataset\")\n",
    "    else:\n",
    "        for text in ax.legend.get_texts():\n",
    "            # the first result will be all handles, i.e. the dots in the legend\n",
    "            # the second result will be all legend text\n",
    "            text.set_text(label_map[text.get_text()])\n",
    "\n",
    "        ax.legend.set_title(\"Dataset\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "const_bins = dict()\n",
    "const_bins[\"speed\"] = np.linspace(0, 60, 101)\n",
    "const_bins[\"acc\"] = np.linspace(0, 100, 101)\n",
    "const_bins[\"jerk\"] = np.linspace(-100, 100, 101)\n",
    "n_headings = 57\n",
    "const_bins[\"heading\"] = np.linspace(-np.pi - np.pi / (n_headings - 1), np.pi - np.pi / (n_headings - 1), n_headings)\n",
    "const_bins[\"ae_dist\"] = np.linspace(0, 200, 101)\n",
    "const_bins[\"length\"] = np.arange(0, 20.5, 0.5)\n",
    "const_bins[\"max_dh\"] = np.linspace(0, 4 * np.pi, 101)\n",
    "const_bins[\"rel_dh\"] = np.linspace(-4 * np.pi, 4 * np.pi, 100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "Path(\"../plots/combined\").mkdir(parents=True, exist_ok=True)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Pointwise Analyses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "results_dir = Path(\"../plots\")\n",
    "\n",
    "histograms = defaultdict(dict)\n",
    "bins = defaultdict(dict)\n",
    "\n",
    "for hist_file in results_dir.glob(\"**/*.npz\"):\n",
    "    env_name = hist_file.parent.stem\n",
    "    \n",
    "    if env_name in {\"lyft_train\", \"nusc_mini\", \"lyft_sample\"}:\n",
    "        continue\n",
    "    \n",
    "    hist_bins_data = np.load(hist_file)\n",
    "    \n",
    "    for name in [x for x in hist_bins_data.files if x.endswith(\"_bins\")]:\n",
    "        short_key = name[:-len(\"_bins\")]\n",
    "        if short_key not in bins[env_name]:\n",
    "            bins[env_name][short_key] = hist_bins_data[name]\n",
    "\n",
    "    for name in [x for x in hist_bins_data.files if x.endswith(\"_hist\")]:\n",
    "        short_key = name[:-len(\"_hist\")]\n",
    "        if short_key not in histograms[env_name]:\n",
    "            histograms[env_name][short_key] = hist_bins_data[name]\n",
    "        else:\n",
    "            histograms[env_name][short_key] += hist_bins_data[name]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def associate_closest(arr: np.ndarray, reference_arr: np.ndarray) -> np.ndarray:\n",
    "    indices = np.searchsorted(reference_arr, arr, side='right') - 1\n",
    "    return reference_arr[indices]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def combine_hists(histograms, bins, value_name: str, combine_bins: bool = False) -> pd.DataFrame:\n",
    "    env_dfs = list()\n",
    "    for env_name in histograms:\n",
    "        if value_name in histograms[env_name]:\n",
    "            weights = histograms[env_name][value_name]\n",
    "            if combine_bins:\n",
    "                values = associate_closest(bins[env_name][value_name][:-1], const_bins[\"_\".join(value_name.split(\"_\")[:-1])])\n",
    "            else:\n",
    "                values = bins[env_name][value_name][:-1]\n",
    "            \n",
    "            env_df = pd.DataFrame(\n",
    "                data={\n",
    "                    \"value\": values,\n",
    "                    \"weight\": np.stack(weights, axis=0).sum(axis=0) if isinstance(weights, list) else weights\n",
    "                }\n",
    "            )\n",
    "            \n",
    "            if combine_bins:\n",
    "                env_df = env_df.groupby(\"value\")[\"weight\"].sum().reset_index()\n",
    "            \n",
    "            env_df[\"env_name\"] = env_name\n",
    "            env_dfs.append(env_df)\n",
    "\n",
    "    if len(env_dfs) == 0:\n",
    "        return None\n",
    "\n",
    "    return pd.concat(env_dfs, ignore_index=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "normal_keys_labels_dict = {\n",
    "    \"speed\": (r\"Speed $(m/s)$\", True),\n",
    "    \"acc\": (r\"Acceleration $(m/s^2)$\", True),\n",
    "    \"jerk\": (r\"Jerk $(m/s^3)$\", True),\n",
    "    \"ae_dist\": (\"Agent-Ego Distance (m)\", False),\n",
    "    \"length\": (\"Agent Observation Length (s)\", True)\n",
    "}\n",
    "\n",
    "heading_keys_labels_dict = {\n",
    "    \"heading\": (\"Heading (radians)\", True),\n",
    "    \"max_dh\": (\"Max. Heading Change (radians)\", True),\n",
    "    \"rel_dh\": (\"Rel. Heading Change (radians)\", True)\n",
    "}"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Trying Heatmaps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.ticker as ticker\n",
    "from matplotlib.colors import LogNorm\n",
    "import copy\n",
    "from trajdata_analysis.analysis.radian_formatter import multiple_formatter\n",
    "\n",
    "def vis_heatmap(key: str, label:str, norm_data_df: pd.DataFrame, ax: plt.Axes, hide_ylabel: bool = False, hide_cbar_label: bool = False, radianx: bool = False) -> None:\n",
    "    my_cmap = copy.deepcopy(sns.color_palette('rocket', as_cmap=True))\n",
    "    my_cmap.set_bad(my_cmap.colors[0])\n",
    "    \n",
    "    if key == \"acc\":\n",
    "        # Matching to nearest multiples of g = 9.81\n",
    "        xticks = xticklabels = [0, 20, 39, 59, 78, 98]\n",
    "        num_ticks = len(xticks)\n",
    "    elif radianx:\n",
    "        num_ticks = 5\n",
    "        # the index of the position of yticks\n",
    "        xticks = np.linspace(0, len(const_bins[key]) - 1, num_ticks, dtype=int)\n",
    "        # the content of labels of these yticks\n",
    "        xticklabels = [const_bins[key][idx] for idx in xticks]\n",
    "    else:\n",
    "        num_ticks = 5\n",
    "        # the index of the position of yticks\n",
    "        xticks = np.linspace(0, len(const_bins[key]) - 1, num_ticks, dtype=int)\n",
    "        # the content of labels of these yticks\n",
    "        xticklabels = [const_bins[key][idx] for idx in xticks]\n",
    "    \n",
    "    sns.heatmap(\n",
    "        norm_data_df,\n",
    "        ax=ax,\n",
    "        norm=LogNorm(),\n",
    "        cmap=my_cmap,\n",
    "        # xticklabels=xticklabels,\n",
    "        # vmin=0.0,\n",
    "        # vmax=1.0,\n",
    "        cbar_kws={'label': 'Proportion'} if not hide_cbar_label else None,\n",
    "    )\n",
    "    \n",
    "    ax.hlines(range(1, norm_data_df.shape[0]), *ax.get_xlim(), colors=\"w\")\n",
    "    \n",
    "    ax.set_xlabel(label)\n",
    "    if not hide_ylabel:\n",
    "        ax.set_ylabel(\"Dataset\")\n",
    "    else:\n",
    "        ax.set_ylabel(None)\n",
    "    \n",
    "    ax.set_yticklabels([label_map[x.get_text()] for x in ax.get_yticklabels()])\n",
    "    if key == \"acc\":\n",
    "        ax.set_xticks(xticks, labels=[f\"{int(np.round(x/9.81, 0))}g\" for x in xticklabels])\n",
    "    elif radianx:\n",
    "        format_fn = multiple_formatter(denominator=4)\n",
    "        ax.set_xticks(xticks, labels=[format_fn(x, -1) for x in xticklabels])\n",
    "    else:\n",
    "        ax.set_xticks(xticks, labels=[int(x) for x in xticklabels])\n",
    "        \n",
    "    ax.tick_params(axis='x', rotation=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "for key, (label, logscale) in normal_keys_labels_dict.items():\n",
    "    agent_type_hists = dict()\n",
    "    agent_types_list = [x.name for x in AgentType]\n",
    "    for agent_type in agent_types_list:\n",
    "        agent_type_hists[agent_type] = combine_hists(histograms, bins, f\"{key}_{agent_type}\", combine_bins=key==\"length\")\n",
    "        \n",
    "        if agent_type_hists[agent_type] is None:\n",
    "            del agent_type_hists[agent_type]\n",
    "            continue\n",
    "    \n",
    "    if key == \"ae_dist\":\n",
    "        fig, axes = plt.subplots(ncols=len(agent_type_hists), figsize=(len(agent_types_list)*3, 2))\n",
    "    else:\n",
    "        fig, axes = plt.subplots(ncols=len(agent_type_hists), figsize=(len(agent_types_list)*2.65, 2.75))\n",
    "    agent_type_hist: pd.DataFrame\n",
    "    for i, (agent_type, agent_type_hist) in enumerate(agent_type_hists.items()):        \n",
    "        agent_data_df = agent_type_hist.pivot(index=\"env_name\", columns=\"value\", values=\"weight\")\n",
    "        norm_data_df = agent_data_df.div(agent_data_df.sum(axis=1), axis=0)\n",
    "        \n",
    "        norm_data_df.index = pd.CategoricalIndex(norm_data_df.index, categories=combined_hue_order)\n",
    "        norm_data_df.sort_index(inplace=True)\n",
    "        \n",
    "        vis_heatmap(key, label, norm_data_df, ax=axes[i], hide_ylabel=i > 0, hide_cbar_label=i < len(agent_types_list) - 1)\n",
    "        axes[i].set_title(agent_type.capitalize() + \"s\")\n",
    "    \n",
    "    # fig.subplots_adjust(wspace=0.2, hspace=0)\n",
    "    fig.tight_layout()\n",
    "    fig.savefig(f\"../plots/combined/{key}_heatmap.pdf\", bbox_inches=\"tight\")\n",
    "    plt.close(fig)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "for key, (label, logscale) in heading_keys_labels_dict.items():\n",
    "    agent_type_hists = dict()\n",
    "    agent_types_list = [x.name for x in AgentType]\n",
    "    for agent_type in agent_types_list:\n",
    "        agent_type_hists[agent_type] = combine_hists(histograms, bins, f\"{key}_{agent_type}\")\n",
    "        \n",
    "        if agent_type_hists[agent_type] is None:\n",
    "            del agent_type_hists[agent_type]\n",
    "            continue\n",
    "    \n",
    "    fig, axes = plt.subplots(ncols=len(agent_type_hists), figsize=(len(agent_types_list)*2.65, 2.75))\n",
    "    agent_type_hist: pd.DataFrame\n",
    "    for i, (agent_type, agent_type_hist) in enumerate(agent_type_hists.items()):\n",
    "        agent_data_df = agent_type_hist.pivot(index=\"env_name\", columns=\"value\", values=\"weight\")\n",
    "        norm_data_df = agent_data_df.div(agent_data_df.sum(axis=1), axis=0)\n",
    "        \n",
    "        norm_data_df.index = pd.CategoricalIndex(norm_data_df.index, categories=combined_hue_order)\n",
    "        norm_data_df.sort_index(inplace=True)\n",
    "        \n",
    "        vis_heatmap(key, label, norm_data_df, ax=axes[i], hide_ylabel=i > 0, hide_cbar_label=i < len(agent_types_list) - 1, radianx=True)\n",
    "        axes[i].set_title(agent_type.capitalize() + \"s\")\n",
    "    \n",
    "    # fig.subplots_adjust(wspace=0.2, hspace=0)\n",
    "    fig.tight_layout()\n",
    "    fig.savefig(f\"../plots/combined/{key}_heatmap.pdf\", bbox_inches=\"tight\")\n",
    "    plt.close(fig)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_905158/3580281287.py:12: UserWarning: The palette list has more values (10) than needed (9), which may not be intended.\n",
      "  sns.histplot(\n",
      "/home/bivanovic/anaconda3/envs/nuplan/lib/python3.9/site-packages/seaborn/distributions.py:407: FutureWarning: elementwise comparison failed; returning scalar instead, but in the future will perform elementwise comparison\n",
      "  and estimate_kws[\"bins\"] == \"auto\"\n",
      "/tmp/ipykernel_905158/3580281287.py:12: UserWarning: The palette list has more values (10) than needed (9), which may not be intended.\n",
      "  sns.histplot(\n",
      "/home/bivanovic/anaconda3/envs/nuplan/lib/python3.9/site-packages/seaborn/distributions.py:407: FutureWarning: elementwise comparison failed; returning scalar instead, but in the future will perform elementwise comparison\n",
      "  and estimate_kws[\"bins\"] == \"auto\"\n",
      "/tmp/ipykernel_905158/3580281287.py:12: UserWarning: The palette list has more values (10) than needed (9), which may not be intended.\n",
      "  sns.histplot(\n",
      "/home/bivanovic/anaconda3/envs/nuplan/lib/python3.9/site-packages/seaborn/distributions.py:407: FutureWarning: elementwise comparison failed; returning scalar instead, but in the future will perform elementwise comparison\n",
      "  and estimate_kws[\"bins\"] == \"auto\"\n",
      "/tmp/ipykernel_905158/3580281287.py:12: UserWarning: The palette list has more values (10) than needed (9), which may not be intended.\n",
      "  sns.histplot(\n",
      "/home/bivanovic/anaconda3/envs/nuplan/lib/python3.9/site-packages/seaborn/distributions.py:407: FutureWarning: elementwise comparison failed; returning scalar instead, but in the future will perform elementwise comparison\n",
      "  and estimate_kws[\"bins\"] == \"auto\"\n",
      "/tmp/ipykernel_905158/3580281287.py:12: UserWarning: The palette list has more values (10) than needed (9), which may not be intended.\n",
      "  sns.histplot(\n",
      "/home/bivanovic/anaconda3/envs/nuplan/lib/python3.9/site-packages/seaborn/distributions.py:407: FutureWarning: elementwise comparison failed; returning scalar instead, but in the future will perform elementwise comparison\n",
      "  and estimate_kws[\"bins\"] == \"auto\"\n",
      "/tmp/ipykernel_905158/3580281287.py:12: UserWarning: The palette list has more values (10) than needed (9), which may not be intended.\n",
      "  sns.histplot(\n",
      "/home/bivanovic/anaconda3/envs/nuplan/lib/python3.9/site-packages/seaborn/distributions.py:407: FutureWarning: elementwise comparison failed; returning scalar instead, but in the future will perform elementwise comparison\n",
      "  and estimate_kws[\"bins\"] == \"auto\"\n",
      "/tmp/ipykernel_905158/3580281287.py:12: UserWarning: The palette list has more values (10) than needed (9), which may not be intended.\n",
      "  sns.histplot(\n",
      "/home/bivanovic/anaconda3/envs/nuplan/lib/python3.9/site-packages/seaborn/distributions.py:407: FutureWarning: elementwise comparison failed; returning scalar instead, but in the future will perform elementwise comparison\n",
      "  and estimate_kws[\"bins\"] == \"auto\"\n",
      "/tmp/ipykernel_905158/3580281287.py:12: UserWarning: The palette list has more values (10) than needed (9), which may not be intended.\n",
      "  sns.histplot(\n",
      "/home/bivanovic/anaconda3/envs/nuplan/lib/python3.9/site-packages/seaborn/distributions.py:407: FutureWarning: elementwise comparison failed; returning scalar instead, but in the future will perform elementwise comparison\n",
      "  and estimate_kws[\"bins\"] == \"auto\"\n",
      "/tmp/ipykernel_905158/3580281287.py:12: UserWarning: The palette list has more values (10) than needed (9), which may not be intended.\n",
      "  sns.histplot(\n",
      "/home/bivanovic/anaconda3/envs/nuplan/lib/python3.9/site-packages/seaborn/distributions.py:407: FutureWarning: elementwise comparison failed; returning scalar instead, but in the future will perform elementwise comparison\n",
      "  and estimate_kws[\"bins\"] == \"auto\"\n",
      "/tmp/ipykernel_905158/3580281287.py:12: UserWarning: The palette list has more values (10) than needed (9), which may not be intended.\n",
      "  sns.histplot(\n",
      "/home/bivanovic/anaconda3/envs/nuplan/lib/python3.9/site-packages/seaborn/distributions.py:407: FutureWarning: elementwise comparison failed; returning scalar instead, but in the future will perform elementwise comparison\n",
      "  and estimate_kws[\"bins\"] == \"auto\"\n",
      "/tmp/ipykernel_905158/3580281287.py:12: UserWarning: The palette list has more values (10) than needed (9), which may not be intended.\n",
      "  sns.histplot(\n",
      "/home/bivanovic/anaconda3/envs/nuplan/lib/python3.9/site-packages/seaborn/distributions.py:407: FutureWarning: elementwise comparison failed; returning scalar instead, but in the future will perform elementwise comparison\n",
      "  and estimate_kws[\"bins\"] == \"auto\"\n",
      "/tmp/ipykernel_905158/3580281287.py:12: UserWarning: The palette list has more values (10) than needed (9), which may not be intended.\n",
      "  sns.histplot(\n",
      "/home/bivanovic/anaconda3/envs/nuplan/lib/python3.9/site-packages/seaborn/distributions.py:407: FutureWarning: elementwise comparison failed; returning scalar instead, but in the future will perform elementwise comparison\n",
      "  and estimate_kws[\"bins\"] == \"auto\"\n",
      "/tmp/ipykernel_905158/3580281287.py:12: UserWarning: The palette list has more values (10) than needed (9), which may not be intended.\n",
      "  sns.histplot(\n",
      "/home/bivanovic/anaconda3/envs/nuplan/lib/python3.9/site-packages/seaborn/distributions.py:407: FutureWarning: elementwise comparison failed; returning scalar instead, but in the future will perform elementwise comparison\n",
      "  and estimate_kws[\"bins\"] == \"auto\"\n",
      "/tmp/ipykernel_905158/3580281287.py:12: UserWarning: The palette list has more values (10) than needed (9), which may not be intended.\n",
      "  sns.histplot(\n",
      "/home/bivanovic/anaconda3/envs/nuplan/lib/python3.9/site-packages/seaborn/distributions.py:407: FutureWarning: elementwise comparison failed; returning scalar instead, but in the future will perform elementwise comparison\n",
      "  and estimate_kws[\"bins\"] == \"auto\"\n",
      "/tmp/ipykernel_905158/3580281287.py:12: UserWarning: The palette list has more values (10) than needed (9), which may not be intended.\n",
      "  sns.histplot(\n",
      "/home/bivanovic/anaconda3/envs/nuplan/lib/python3.9/site-packages/seaborn/distributions.py:407: FutureWarning: elementwise comparison failed; returning scalar instead, but in the future will perform elementwise comparison\n",
      "  and estimate_kws[\"bins\"] == \"auto\"\n"
     ]
    }
   ],
   "source": [
    "for key, (label, logscale) in heading_keys_labels_dict.items():\n",
    "    agent_type_hists = dict()\n",
    "    for agent_type in [x.name for x in AgentType]:\n",
    "        agent_type_hists[agent_type] = combine_hists(histograms, bins, f\"{key}_{agent_type}\")\n",
    "        \n",
    "        if agent_type_hists[agent_type] is None:\n",
    "            del agent_type_hists[agent_type]\n",
    "            continue\n",
    "    \n",
    "    fig, axes = plt.subplots(ncols=len(agent_type_hists), sharey=True, figsize=(15, 3))\n",
    "    for i, (agent_type, agent_type_hist) in enumerate(agent_type_hists.items()):\n",
    "        sns.histplot(\n",
    "            data=agent_type_hist,\n",
    "            x=\"value\",\n",
    "            weights=\"weight\",\n",
    "            hue=\"env_name\",\n",
    "            hue_order=all_hue_order,\n",
    "            element=\"step\",\n",
    "            fill=False,\n",
    "            ax=axes[i],\n",
    "            bins=const_bins[key],\n",
    "            stat=\"proportion\",\n",
    "            palette=sns.color_palette(),\n",
    "            common_norm=False\n",
    "        )\n",
    "        if i < len(agent_type_hists) - 1:\n",
    "            axes[i].get_legend().remove()\n",
    "        else:\n",
    "            fix_legend_labels(axes[i])\n",
    "            sns.move_legend(axes[i], \"upper left\", bbox_to_anchor=(1, 1))\n",
    "            \n",
    "        axes[i].set_xlabel(label)\n",
    "        if logscale:\n",
    "            axes[i].set_yscale(\"log\")\n",
    "        else:\n",
    "            axes[i].yaxis.set_major_formatter(ticker.PercentFormatter(1.0, decimals=0))\n",
    "            \n",
    "        major = Multiple(denominator=2)\n",
    "        axes[i].xaxis.set_major_formatter(major.formatter())\n",
    "        axes[i].set_title(f\"{agent_type.capitalize()}s\")\n",
    "        \n",
    "    fig.savefig(f\"../plots/combined/{key}.pdf\", bbox_inches=\"tight\")\n",
    "    plt.close(fig)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_905158/2552665191.py:12: UserWarning: The palette list has more values (10) than needed (9), which may not be intended.\n",
      "  sns.histplot(\n",
      "/home/bivanovic/anaconda3/envs/nuplan/lib/python3.9/site-packages/seaborn/distributions.py:407: FutureWarning: elementwise comparison failed; returning scalar instead, but in the future will perform elementwise comparison\n",
      "  and estimate_kws[\"bins\"] == \"auto\"\n",
      "/tmp/ipykernel_905158/2552665191.py:12: UserWarning: The palette list has more values (10) than needed (9), which may not be intended.\n",
      "  sns.histplot(\n",
      "/home/bivanovic/anaconda3/envs/nuplan/lib/python3.9/site-packages/seaborn/distributions.py:407: FutureWarning: elementwise comparison failed; returning scalar instead, but in the future will perform elementwise comparison\n",
      "  and estimate_kws[\"bins\"] == \"auto\"\n",
      "/tmp/ipykernel_905158/2552665191.py:12: UserWarning: The palette list has more values (10) than needed (9), which may not be intended.\n",
      "  sns.histplot(\n",
      "/home/bivanovic/anaconda3/envs/nuplan/lib/python3.9/site-packages/seaborn/distributions.py:407: FutureWarning: elementwise comparison failed; returning scalar instead, but in the future will perform elementwise comparison\n",
      "  and estimate_kws[\"bins\"] == \"auto\"\n",
      "/tmp/ipykernel_905158/2552665191.py:12: UserWarning: The palette list has more values (10) than needed (9), which may not be intended.\n",
      "  sns.histplot(\n",
      "/home/bivanovic/anaconda3/envs/nuplan/lib/python3.9/site-packages/seaborn/distributions.py:407: FutureWarning: elementwise comparison failed; returning scalar instead, but in the future will perform elementwise comparison\n",
      "  and estimate_kws[\"bins\"] == \"auto\"\n",
      "/tmp/ipykernel_905158/2552665191.py:12: UserWarning: The palette list has more values (10) than needed (9), which may not be intended.\n",
      "  sns.histplot(\n",
      "/home/bivanovic/anaconda3/envs/nuplan/lib/python3.9/site-packages/seaborn/distributions.py:407: FutureWarning: elementwise comparison failed; returning scalar instead, but in the future will perform elementwise comparison\n",
      "  and estimate_kws[\"bins\"] == \"auto\"\n",
      "/tmp/ipykernel_905158/2552665191.py:12: UserWarning: The palette list has more values (10) than needed (9), which may not be intended.\n",
      "  sns.histplot(\n",
      "/home/bivanovic/anaconda3/envs/nuplan/lib/python3.9/site-packages/seaborn/distributions.py:407: FutureWarning: elementwise comparison failed; returning scalar instead, but in the future will perform elementwise comparison\n",
      "  and estimate_kws[\"bins\"] == \"auto\"\n",
      "/tmp/ipykernel_905158/2552665191.py:12: UserWarning: The palette list has more values (10) than needed (9), which may not be intended.\n",
      "  sns.histplot(\n",
      "/home/bivanovic/anaconda3/envs/nuplan/lib/python3.9/site-packages/seaborn/distributions.py:407: FutureWarning: elementwise comparison failed; returning scalar instead, but in the future will perform elementwise comparison\n",
      "  and estimate_kws[\"bins\"] == \"auto\"\n",
      "/tmp/ipykernel_905158/2552665191.py:12: UserWarning: The palette list has more values (10) than needed (9), which may not be intended.\n",
      "  sns.histplot(\n",
      "/home/bivanovic/anaconda3/envs/nuplan/lib/python3.9/site-packages/seaborn/distributions.py:407: FutureWarning: elementwise comparison failed; returning scalar instead, but in the future will perform elementwise comparison\n",
      "  and estimate_kws[\"bins\"] == \"auto\"\n",
      "/tmp/ipykernel_905158/2552665191.py:12: UserWarning: The palette list has more values (10) than needed (9), which may not be intended.\n",
      "  sns.histplot(\n",
      "/home/bivanovic/anaconda3/envs/nuplan/lib/python3.9/site-packages/seaborn/distributions.py:407: FutureWarning: elementwise comparison failed; returning scalar instead, but in the future will perform elementwise comparison\n",
      "  and estimate_kws[\"bins\"] == \"auto\"\n",
      "/tmp/ipykernel_905158/2552665191.py:12: UserWarning: The palette list has more values (10) than needed (9), which may not be intended.\n",
      "  sns.histplot(\n",
      "/home/bivanovic/anaconda3/envs/nuplan/lib/python3.9/site-packages/seaborn/distributions.py:407: FutureWarning: elementwise comparison failed; returning scalar instead, but in the future will perform elementwise comparison\n",
      "  and estimate_kws[\"bins\"] == \"auto\"\n",
      "/tmp/ipykernel_905158/2552665191.py:12: UserWarning: The palette list has more values (10) than needed (9), which may not be intended.\n",
      "  sns.histplot(\n",
      "/home/bivanovic/anaconda3/envs/nuplan/lib/python3.9/site-packages/seaborn/distributions.py:407: FutureWarning: elementwise comparison failed; returning scalar instead, but in the future will perform elementwise comparison\n",
      "  and estimate_kws[\"bins\"] == \"auto\"\n",
      "/tmp/ipykernel_905158/2552665191.py:12: UserWarning: The palette list has more values (10) than needed (9), which may not be intended.\n",
      "  sns.histplot(\n",
      "/home/bivanovic/anaconda3/envs/nuplan/lib/python3.9/site-packages/seaborn/distributions.py:407: FutureWarning: elementwise comparison failed; returning scalar instead, but in the future will perform elementwise comparison\n",
      "  and estimate_kws[\"bins\"] == \"auto\"\n",
      "/tmp/ipykernel_905158/2552665191.py:12: UserWarning: The palette list has more values (10) than needed (9), which may not be intended.\n",
      "  sns.histplot(\n",
      "/home/bivanovic/anaconda3/envs/nuplan/lib/python3.9/site-packages/seaborn/distributions.py:407: FutureWarning: elementwise comparison failed; returning scalar instead, but in the future will perform elementwise comparison\n",
      "  and estimate_kws[\"bins\"] == \"auto\"\n",
      "/tmp/ipykernel_905158/2552665191.py:12: UserWarning: The palette list has more values (10) than needed (9), which may not be intended.\n",
      "  sns.histplot(\n",
      "/home/bivanovic/anaconda3/envs/nuplan/lib/python3.9/site-packages/seaborn/distributions.py:407: FutureWarning: elementwise comparison failed; returning scalar instead, but in the future will perform elementwise comparison\n",
      "  and estimate_kws[\"bins\"] == \"auto\"\n",
      "/tmp/ipykernel_905158/2552665191.py:12: UserWarning: The palette list has more values (10) than needed (9), which may not be intended.\n",
      "  sns.histplot(\n",
      "/home/bivanovic/anaconda3/envs/nuplan/lib/python3.9/site-packages/seaborn/distributions.py:407: FutureWarning: elementwise comparison failed; returning scalar instead, but in the future will perform elementwise comparison\n",
      "  and estimate_kws[\"bins\"] == \"auto\"\n",
      "/tmp/ipykernel_905158/2552665191.py:12: UserWarning: The palette list has more values (10) than needed (9), which may not be intended.\n",
      "  sns.histplot(\n",
      "/home/bivanovic/anaconda3/envs/nuplan/lib/python3.9/site-packages/seaborn/distributions.py:407: FutureWarning: elementwise comparison failed; returning scalar instead, but in the future will perform elementwise comparison\n",
      "  and estimate_kws[\"bins\"] == \"auto\"\n",
      "/tmp/ipykernel_905158/2552665191.py:12: UserWarning: The palette list has more values (10) than needed (9), which may not be intended.\n",
      "  sns.histplot(\n",
      "/home/bivanovic/anaconda3/envs/nuplan/lib/python3.9/site-packages/seaborn/distributions.py:407: FutureWarning: elementwise comparison failed; returning scalar instead, but in the future will perform elementwise comparison\n",
      "  and estimate_kws[\"bins\"] == \"auto\"\n",
      "/tmp/ipykernel_905158/2552665191.py:12: UserWarning: The palette list has more values (10) than needed (9), which may not be intended.\n",
      "  sns.histplot(\n",
      "/home/bivanovic/anaconda3/envs/nuplan/lib/python3.9/site-packages/seaborn/distributions.py:407: FutureWarning: elementwise comparison failed; returning scalar instead, but in the future will perform elementwise comparison\n",
      "  and estimate_kws[\"bins\"] == \"auto\"\n",
      "/tmp/ipykernel_905158/2552665191.py:12: UserWarning: The palette list has more values (10) than needed (9), which may not be intended.\n",
      "  sns.histplot(\n",
      "/home/bivanovic/anaconda3/envs/nuplan/lib/python3.9/site-packages/seaborn/distributions.py:407: FutureWarning: elementwise comparison failed; returning scalar instead, but in the future will perform elementwise comparison\n",
      "  and estimate_kws[\"bins\"] == \"auto\"\n",
      "/tmp/ipykernel_905158/2552665191.py:12: UserWarning: The palette list has more values (10) than needed (9), which may not be intended.\n",
      "  sns.histplot(\n",
      "/home/bivanovic/anaconda3/envs/nuplan/lib/python3.9/site-packages/seaborn/distributions.py:407: FutureWarning: elementwise comparison failed; returning scalar instead, but in the future will perform elementwise comparison\n",
      "  and estimate_kws[\"bins\"] == \"auto\"\n",
      "/tmp/ipykernel_905158/2552665191.py:12: UserWarning: The palette list has more values (10) than needed (9), which may not be intended.\n",
      "  sns.histplot(\n",
      "/home/bivanovic/anaconda3/envs/nuplan/lib/python3.9/site-packages/seaborn/distributions.py:407: FutureWarning: elementwise comparison failed; returning scalar instead, but in the future will perform elementwise comparison\n",
      "  and estimate_kws[\"bins\"] == \"auto\"\n",
      "/tmp/ipykernel_905158/2552665191.py:12: UserWarning: The palette list has more values (10) than needed (9), which may not be intended.\n",
      "  sns.histplot(\n",
      "/home/bivanovic/anaconda3/envs/nuplan/lib/python3.9/site-packages/seaborn/distributions.py:407: FutureWarning: elementwise comparison failed; returning scalar instead, but in the future will perform elementwise comparison\n",
      "  and estimate_kws[\"bins\"] == \"auto\"\n",
      "/tmp/ipykernel_905158/2552665191.py:12: UserWarning: The palette list has more values (10) than needed (9), which may not be intended.\n",
      "  sns.histplot(\n",
      "/home/bivanovic/anaconda3/envs/nuplan/lib/python3.9/site-packages/seaborn/distributions.py:407: FutureWarning: elementwise comparison failed; returning scalar instead, but in the future will perform elementwise comparison\n",
      "  and estimate_kws[\"bins\"] == \"auto\"\n",
      "/tmp/ipykernel_905158/2552665191.py:12: UserWarning: The palette list has more values (10) than needed (9), which may not be intended.\n",
      "  sns.histplot(\n",
      "/home/bivanovic/anaconda3/envs/nuplan/lib/python3.9/site-packages/seaborn/distributions.py:407: FutureWarning: elementwise comparison failed; returning scalar instead, but in the future will perform elementwise comparison\n",
      "  and estimate_kws[\"bins\"] == \"auto\"\n",
      "/tmp/ipykernel_905158/2552665191.py:12: UserWarning: The palette list has more values (10) than needed (9), which may not be intended.\n",
      "  sns.histplot(\n",
      "/home/bivanovic/anaconda3/envs/nuplan/lib/python3.9/site-packages/seaborn/distributions.py:407: FutureWarning: elementwise comparison failed; returning scalar instead, but in the future will perform elementwise comparison\n",
      "  and estimate_kws[\"bins\"] == \"auto\"\n"
     ]
    }
   ],
   "source": [
    "for key, (label, logscale) in normal_keys_labels_dict.items():\n",
    "    agent_type_hists = dict()\n",
    "    for agent_type in [x.name for x in AgentType]:\n",
    "        agent_type_hists[agent_type] = combine_hists(histograms, bins, f\"{key}_{agent_type}\")\n",
    "        \n",
    "        if agent_type_hists[agent_type] is None:\n",
    "            del agent_type_hists[agent_type]\n",
    "            continue\n",
    "    \n",
    "    fig, axes = plt.subplots(ncols=len(agent_type_hists), sharey=True, figsize=(15, 3))\n",
    "    for i, (agent_type, agent_type_hist) in enumerate(agent_type_hists.items()):\n",
    "        sns.histplot(\n",
    "            data=agent_type_hist,\n",
    "            x=\"value\",\n",
    "            weights=\"weight\",\n",
    "            hue=\"env_name\",\n",
    "            hue_order=all_hue_order,\n",
    "            element=\"step\",\n",
    "            fill=False,\n",
    "            ax=axes[i],\n",
    "            bins=const_bins[key],\n",
    "            stat=\"proportion\",\n",
    "            palette=sns.color_palette(),\n",
    "            common_norm=False\n",
    "        )\n",
    "        if i < len(agent_type_hists) - 1:\n",
    "            axes[i].get_legend().remove()\n",
    "        else:\n",
    "            fix_legend_labels(axes[i])\n",
    "            sns.move_legend(axes[i], \"upper left\", bbox_to_anchor=(1, 1))\n",
    "        axes[i].set_xlabel(label)\n",
    "        if logscale:\n",
    "            axes[i].set_yscale(\"log\")\n",
    "        else:\n",
    "            axes[i].yaxis.set_major_formatter(ticker.PercentFormatter(1.0, decimals=0))\n",
    "            \n",
    "        axes[i].set_title(f\"{agent_type.capitalize()}s\")\n",
    "        \n",
    "    fig.savefig(f\"../plots/combined/{key}.pdf\", bbox_inches=\"tight\")\n",
    "    plt.close(fig)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Scenewise Analyses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numbers\n",
    "\n",
    "def aggregate_histograms(raw_data) -> np.ndarray:\n",
    "    if isinstance(raw_data, np.ndarray) or isinstance(raw_data, Counter):\n",
    "        return raw_data\n",
    "    \n",
    "    if isinstance(raw_data, list) and isinstance(raw_data[0], numbers.Number):\n",
    "        return sum(raw_data)\n",
    "    \n",
    "    if isinstance(raw_data, list) and isinstance(raw_data[0], np.ndarray):\n",
    "        return sum(raw_data)\n",
    "    \n",
    "    if isinstance(raw_data, list) and isinstance(raw_data[0], Counter):\n",
    "        return sum(raw_data, Counter())\n",
    "    \n",
    "    if isinstance(raw_data, list) and isinstance(raw_data[0], dict) and (\n",
    "        isinstance(list(raw_data[0].values())[0], np.ndarray) or isinstance(list(raw_data[0].values())[0], numbers.Number)\n",
    "    ):\n",
    "        accumulator = dict()\n",
    "        for d in raw_data:\n",
    "            for key, arr in d.items():\n",
    "                if key in accumulator:\n",
    "                    accumulator[key] += arr\n",
    "                else:\n",
    "                    accumulator[key] = arr\n",
    "        \n",
    "        return accumulator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "results_dir = Path(\"../plots\")\n",
    "\n",
    "scene_histograms = defaultdict(dict)\n",
    "scene_bins = defaultdict(dict)\n",
    "\n",
    "for hist_file in results_dir.glob(\"**/*.pkl\"):\n",
    "    env_name = hist_file.parent.stem\n",
    "\n",
    "    if env_name in {\"lyft_train\", \"lyft_sample\", \"nusc_mini\"}:\n",
    "        continue\n",
    "\n",
    "    with open(hist_file, \"rb\") as f:\n",
    "        hist_bins_data = pickle.load(f)\n",
    "    \n",
    "    for name in [x for x in hist_bins_data.keys() if x.endswith(\"_bins\")]:\n",
    "        scene_bins[env_name][name[:-len(\"_bins\")]] = hist_bins_data[name]\n",
    "    \n",
    "    for name in [x for x in hist_bins_data.keys() if x.endswith(\"_hist\")]:\n",
    "        scene_histograms[env_name][name[:-len(\"_hist\")]] = aggregate_histograms(hist_bins_data[name])\n",
    "    \n",
    "    for name in [x for x in hist_bins_data.keys() if not (x.endswith(\"_hist\") or x.endswith(\"_bins\"))]:\n",
    "        scene_histograms[env_name][name] = aggregate_histograms(hist_bins_data[name])\n",
    "        \n",
    "    scene_histograms[\"waymo_train\"][\"num_positions\"] = 1614728198"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_key_df(scene_bins, scene_histograms, key):\n",
    "    hist_data_dict = defaultdict(list)\n",
    "    for env_name in scene_bins:\n",
    "        num_entries = len(scene_bins[env_name][key])-1\n",
    "        if isinstance(scene_histograms[env_name][key], dict):\n",
    "            for subkey, values in scene_histograms[env_name][key].items():\n",
    "                hist_data_dict[\"env_name\"] += [env_name]*num_entries\n",
    "                hist_data_dict[\"value\"] += scene_bins[env_name][key][:-1].tolist()\n",
    "                hist_data_dict[\"weight\"] += values.tolist()\n",
    "                hist_data_dict[\"agent_type\"] += [subkey]*num_entries\n",
    "\n",
    "        else:\n",
    "            hist_data_dict[\"env_name\"] += [env_name]*num_entries\n",
    "            hist_data_dict[\"value\"] += scene_bins[env_name][key][:-1].tolist()\n",
    "            hist_data_dict[\"weight\"] += scene_histograms[env_name][key].tolist()\n",
    "        \n",
    "    return pd.DataFrame(hist_data_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "scene_keys_labels_dict = {\n",
    "    \"sim_agents\": (\"Simultaneous Agents\", False),\n",
    "    \"max_sim_agents\": (\"Max. Simultaneous Agents\", False),\n",
    "    \"path_efficiency\": (\"Path Efficiency (%)\", True),\n",
    "    \"agent_density\": (r\"Agent Density (agent/$m^2$)\", True),\n",
    "    \"max_acc\": (r\"Max. Acceleration $(m/s^2)$\", True)\n",
    "}\n",
    "\n",
    "scene_const_bins = dict()\n",
    "scene_const_bins[\"sim_agents\"] = np.linspace(0, 250, 51).astype(int)\n",
    "scene_const_bins[\"max_sim_agents\"] = np.linspace(0, 250, 51).astype(int)\n",
    "scene_const_bins[\"path_efficiency\"] = np.linspace(0, 1.01, 102)\n",
    "scene_const_bins[\"agent_density\"] = np.logspace(-4, 0, 21)\n",
    "scene_const_bins[\"max_acc\"] = np.linspace(0, 9.81, 21)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Trying Heatmaps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.ticker as ticker\n",
    "from matplotlib.colors import LogNorm\n",
    "import copy\n",
    "\n",
    "def vis_heatmap(key: str, label: str, norm_data_df: pd.DataFrame, ax: plt.Axes, hide_ylabel: bool = False, hide_cbar_label: bool = False) -> None:\n",
    "    my_cmap = copy.deepcopy(sns.color_palette('rocket', as_cmap=True))\n",
    "    my_cmap.set_bad(my_cmap.colors[0])\n",
    "    \n",
    "    if key == \"path_efficiency\":\n",
    "        num_ticks = 5\n",
    "        # the index of the position of yticks\n",
    "        xticks = np.linspace(0, len(scene_const_bins[key]) - 1, num_ticks, dtype=int)\n",
    "        # the content of labels of these yticks\n",
    "        xticklabels = np.linspace(0, 1, num_ticks)\n",
    "    elif key == \"agent_density\":\n",
    "        num_ticks = 5\n",
    "        # the index of the position of yticks\n",
    "        xticks = np.linspace(0, len(scene_const_bins[key]) - 1, num_ticks, dtype=int)\n",
    "        # the content of labels of these yticks\n",
    "        xticklabels = np.logspace(-4, 0, num_ticks)\n",
    "    elif key == \"max_acc\":\n",
    "        num_ticks = 6\n",
    "        # the index of the position of yticks\n",
    "        xticks = np.linspace(0, len(scene_const_bins[key]) - 1, num_ticks, dtype=int)\n",
    "        # the content of labels of these yticks\n",
    "        xticklabels = [scene_const_bins[key][idx] for idx in xticks]\n",
    "    else:\n",
    "        num_ticks = 6\n",
    "        # the index of the position of yticks\n",
    "        xticks = np.linspace(0, len(scene_const_bins[key]) - 1, num_ticks, dtype=int)\n",
    "        # the content of labels of these yticks\n",
    "        xticklabels = [scene_const_bins[key][idx] for idx in xticks]\n",
    "    \n",
    "    sns.heatmap(\n",
    "        norm_data_df,\n",
    "        ax=ax,\n",
    "        norm=LogNorm(),\n",
    "        cmap=my_cmap,\n",
    "        # xticklabels=xticklabels,\n",
    "        # vmin=0.0,\n",
    "        # vmax=1.0,\n",
    "        cbar_kws={'label': 'Proportion'} if not hide_cbar_label else None,\n",
    "    )\n",
    "    \n",
    "    ax.hlines(range(1, norm_data_df.shape[0]), *ax.get_xlim(), colors=\"w\")\n",
    "    \n",
    "    ax.set_xlabel(label)\n",
    "    if not hide_ylabel:\n",
    "        ax.set_ylabel(\"Dataset\")\n",
    "    else:\n",
    "        ax.set_ylabel(None)\n",
    "    \n",
    "    ax.set_yticklabels([label_map[x.get_text()] for x in ax.get_yticklabels()])\n",
    "    if key == \"path_efficiency\":\n",
    "        ax.set_xticks(xticks, labels=[f\"{x*100:.0f}\" for x in xticklabels])\n",
    "    elif key == \"agent_density\":\n",
    "        tick = ticker.ScalarFormatter(useOffset=False, useMathText=True)\n",
    "        tick.set_powerlimits((0,0))\n",
    "        ax.set_xticks(xticks, labels=[u\"${}$\".format(tick.format_data(x)) for x in xticklabels])\n",
    "    elif key == \"max_acc\":\n",
    "        ax.set_xticks(xticks, labels=[f\"{np.round(x/9.81, 2):.1f}g\" for x in xticklabels])\n",
    "    else:\n",
    "        ax.set_xticks(xticks, labels=[np.round(x, 2) for x in xticklabels])\n",
    "        \n",
    "    ax.tick_params(axis='x', rotation=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "for key, (label, logscale) in scene_keys_labels_dict.items():\n",
    "    data_df = get_key_df(scene_bins, scene_histograms, key)\n",
    "    \n",
    "    if \"agent_type\" in data_df.columns:\n",
    "        agent_types_list = [x.name for x in AgentType]\n",
    "        fig, axes = plt.subplots(ncols=len(agent_types_list), figsize=(len(agent_types_list)*3, 2.5))\n",
    "        for i, agent_type in enumerate(agent_types_list):\n",
    "            agent_data_df = data_df[data_df[\"agent_type\"] == agent_type].pivot(index=\"env_name\", columns=\"value\", values=\"weight\")\n",
    "            norm_data_df = agent_data_df.div(agent_data_df.sum(axis=1), axis=0)\n",
    "            \n",
    "            norm_data_df.index = pd.CategoricalIndex(norm_data_df.index, categories=combined_hue_order)\n",
    "            norm_data_df.sort_index(inplace=True)\n",
    "            \n",
    "            vis_heatmap(key, label, norm_data_df, ax=axes[i], hide_ylabel=i > 0, hide_cbar_label=i < len(agent_types_list) - 1)\n",
    "            axes[i].set_title(agent_type.capitalize() + \"s\")\n",
    "            \n",
    "        fig.tight_layout()\n",
    "\n",
    "    else:\n",
    "        fig, ax = plt.subplots(figsize=(4, 2))\n",
    "        \n",
    "        if key == \"max_acc\":\n",
    "            data_df = data_df[~((data_df[\"env_name\"] == \"eupeds_ETH\") | (data_df[\"env_name\"] == \"eupeds_UCY\"))]\n",
    "        \n",
    "        data_df = data_df.pivot(index=\"env_name\", columns=\"value\", values=\"weight\")\n",
    "        norm_data_df = data_df.div(data_df.sum(axis=1), axis=0)\n",
    "        \n",
    "        norm_data_df.index = pd.CategoricalIndex(norm_data_df.index, categories=combined_hue_order)\n",
    "        norm_data_df.sort_index(inplace=True)\n",
    "        \n",
    "        vis_heatmap(key, label, norm_data_df, ax=ax)\n",
    "\n",
    "    fig.savefig(f\"../plots/combined/{key}_heatmap.pdf\", bbox_inches=\"tight\")\n",
    "    plt.close(fig)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_905158/1080391044.py:4: UserWarning: The palette list has more values (10) than needed (9), which may not be intended.\n",
      "  g = sns.displot(\n",
      "/home/bivanovic/anaconda3/envs/nuplan/lib/python3.9/site-packages/seaborn/distributions.py:407: FutureWarning: elementwise comparison failed; returning scalar instead, but in the future will perform elementwise comparison\n",
      "  and estimate_kws[\"bins\"] == \"auto\"\n",
      "/tmp/ipykernel_905158/1080391044.py:4: UserWarning: The palette list has more values (10) than needed (9), which may not be intended.\n",
      "  g = sns.displot(\n",
      "/home/bivanovic/anaconda3/envs/nuplan/lib/python3.9/site-packages/seaborn/distributions.py:407: FutureWarning: elementwise comparison failed; returning scalar instead, but in the future will perform elementwise comparison\n",
      "  and estimate_kws[\"bins\"] == \"auto\"\n",
      "/tmp/ipykernel_905158/1080391044.py:4: UserWarning: The palette list has more values (10) than needed (9), which may not be intended.\n",
      "  g = sns.displot(\n",
      "/home/bivanovic/anaconda3/envs/nuplan/lib/python3.9/site-packages/seaborn/distributions.py:407: FutureWarning: elementwise comparison failed; returning scalar instead, but in the future will perform elementwise comparison\n",
      "  and estimate_kws[\"bins\"] == \"auto\"\n",
      "/tmp/ipykernel_905158/1080391044.py:4: UserWarning: The palette list has more values (10) than needed (9), which may not be intended.\n",
      "  g = sns.displot(\n",
      "/home/bivanovic/anaconda3/envs/nuplan/lib/python3.9/site-packages/seaborn/distributions.py:407: FutureWarning: elementwise comparison failed; returning scalar instead, but in the future will perform elementwise comparison\n",
      "  and estimate_kws[\"bins\"] == \"auto\"\n",
      "/tmp/ipykernel_905158/1080391044.py:4: UserWarning: The palette list has more values (10) than needed (9), which may not be intended.\n",
      "  g = sns.displot(\n",
      "/home/bivanovic/anaconda3/envs/nuplan/lib/python3.9/site-packages/seaborn/distributions.py:407: FutureWarning: elementwise comparison failed; returning scalar instead, but in the future will perform elementwise comparison\n",
      "  and estimate_kws[\"bins\"] == \"auto\"\n"
     ]
    }
   ],
   "source": [
    "for key, (label, logscale) in scene_keys_labels_dict.items():\n",
    "    data_df = get_key_df(scene_bins, scene_histograms, key)\n",
    "    \n",
    "    g = sns.displot(\n",
    "        data=data_df,\n",
    "        x=\"value\",\n",
    "        weights=\"weight\",\n",
    "        hue=\"env_name\",\n",
    "        hue_order=all_hue_order,\n",
    "        col=\"agent_type\" if \"agent_type\" in data_df.columns else None,\n",
    "        col_order=[x.name for x in AgentType] if \"agent_type\" in data_df.columns else None,\n",
    "        element=\"step\",\n",
    "        fill=False,\n",
    "        # ax=ax,\n",
    "        bins=scene_const_bins[key],\n",
    "        stat=\"proportion\",\n",
    "        palette=sns.color_palette(),\n",
    "        common_norm=False\n",
    "    )\n",
    "    \n",
    "    if \"agent_type\" in data_df.columns:\n",
    "        fix_legend_labels(g)\n",
    "        agent_types_list = [x.name for x in AgentType]\n",
    "        for i, ax in enumerate(g.axes[0]):\n",
    "            ax.set_xlabel(label)\n",
    "            if logscale:\n",
    "                ax.set_yscale(\"log\")\n",
    "                if key == \"agent_density\":\n",
    "                    ax.set_xscale(\"log\")\n",
    "            else:\n",
    "                ax.yaxis.set_major_formatter(ticker.PercentFormatter(1.0, decimals=0))\n",
    "                \n",
    "            ax.set_title(f\"{agent_types_list[i].capitalize()}s\")\n",
    "            \n",
    "            if key == \"path_efficiency\":\n",
    "                ax.xaxis.set_major_formatter(ticker.PercentFormatter(1.0, decimals=0))\n",
    "\n",
    "    else:\n",
    "        ax = g.ax\n",
    "        fix_legend_labels(g)\n",
    "        ax.set_xlabel(label)\n",
    "        if logscale:\n",
    "            ax.set_yscale(\"log\")\n",
    "        else:\n",
    "            ax.yaxis.set_major_formatter(ticker.PercentFormatter(1.0, decimals=0))\n",
    "        \n",
    "    g.fig.savefig(f\"../plots/combined/{key}.pdf\", bbox_inches=\"tight\")\n",
    "    plt.close(g.fig)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "discrete_scene_keys_labels_dict = {\n",
    "    \"agent_counts\": (\"Proportion\", False, \"num_total_agents\"),\n",
    "    \"num_collisions\": (\"Collision Rate\", True, \"num_total_agents\"),\n",
    "    \"num_total_agents\": (\"Number of Agents\", True, None),\n",
    "    \"num_stationary\": (\"Proportion Stationary\", True, \"num_total_agents\"),\n",
    "    \"collided_agent_classes\": (\"Collision Rate\", True, \"agent_counts\"),\n",
    "    \"offroad_agent_types\": (\"Offroad Rate\", True, \"agent_counts\"),\n",
    "    # \"num_offroad_positions\": (\"Offroad Fraction\", True, \"num_positions\"),\n",
    "    \"num_veh_acc_triggers\": (\"Harsh Acceleration Rate\", True, \"num_positions\")\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_discrete_scene_df(key: str, norm_var: str) -> pd.DataFrame:\n",
    "    hist_data_dict = defaultdict(list)\n",
    "    \n",
    "    for env_name in scene_histograms:\n",
    "        if key in scene_histograms[env_name] and not (norm_var is not None and scene_histograms[env_name][norm_var] is None):            \n",
    "            env_data = scene_histograms[env_name][key]\n",
    "            \n",
    "            if isinstance(env_data, dict):\n",
    "                for k, v in env_data.items():\n",
    "                    hist_data_dict[\"env_name\"].append(env_name)\n",
    "                    hist_data_dict[\"agent_type\"].append(k)\n",
    "                    if norm_var is None:\n",
    "                        hist_data_dict[\"value\"].append(v)\n",
    "                    elif norm_var in {\"num_total_agents\", \"num_positions\"}:\n",
    "                        hist_data_dict[\"value\"].append(v / scene_histograms[env_name][norm_var])\n",
    "                    elif norm_var == \"agent_counts\":\n",
    "                        hist_data_dict[\"value\"].append(v / scene_histograms[env_name][norm_var][k])\n",
    "                    else:\n",
    "                        raise ValueError()\n",
    "            \n",
    "            elif isinstance(env_data, numbers.Number):\n",
    "                hist_data_dict[\"env_name\"].append(env_name)\n",
    "                hist_data_dict[\"dummy\"].append(\"x\")\n",
    "                if norm_var is None:\n",
    "                    hist_data_dict[\"value\"].append(env_data)\n",
    "                elif norm_var in {\"num_total_agents\", \"num_positions\"}:\n",
    "                    hist_data_dict[\"value\"].append(env_data / scene_histograms[env_name][norm_var])\n",
    "                else:\n",
    "                    raise ValueError()\n",
    "        \n",
    "    return pd.DataFrame(hist_data_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              env_name  agent_type     value\n",
      "0           eupeds_ETH  PEDESTRIAN  1.000000\n",
      "1          waymo_train     VEHICLE  0.882812\n",
      "2          waymo_train  PEDESTRIAN  0.109464\n",
      "3          waymo_train     BICYCLE  0.007724\n",
      "4      lyft_train_full     VEHICLE  0.272329\n",
      "5      lyft_train_full     UNKNOWN  0.689803\n",
      "6      lyft_train_full  PEDESTRIAN  0.033842\n",
      "7      lyft_train_full     BICYCLE  0.004025\n",
      "8           eupeds_UCY  PEDESTRIAN  1.000000\n",
      "9   interaction_single  PEDESTRIAN  0.042136\n",
      "10  interaction_single     VEHICLE  0.957864\n",
      "11         nuplan_mini     VEHICLE  0.451529\n",
      "12         nuplan_mini  PEDESTRIAN  0.537104\n",
      "13         nuplan_mini     BICYCLE  0.011368\n",
      "14                 sdd     BICYCLE  0.408738\n",
      "15                 sdd  PEDESTRIAN  0.507961\n",
      "16                 sdd     UNKNOWN  0.028350\n",
      "17                 sdd     VEHICLE  0.054951\n",
      "18   interaction_multi  PEDESTRIAN  0.048389\n",
      "19   interaction_multi     VEHICLE  0.951611\n",
      "20       nusc_trainval     VEHICLE  0.726854\n",
      "21       nusc_trainval  PEDESTRIAN  0.242120\n",
      "22       nusc_trainval  MOTORCYCLE  0.015675\n",
      "23       nusc_trainval     BICYCLE  0.015352\n",
      "186223274\n",
      "             env_name dummy     value\n",
      "0          eupeds_ETH     x  0.040053\n",
      "1         waymo_train     x  0.535906\n",
      "2     lyft_train_full     x  0.000883\n",
      "3          eupeds_UCY     x  0.000000\n",
      "4  interaction_single     x  0.052455\n",
      "5         nuplan_mini     x  0.000507\n",
      "6                 sdd     x  0.051068\n",
      "7   interaction_multi     x  0.044840\n",
      "8       nusc_trainval     x  0.174640\n"
     ]
    }
   ],
   "source": [
    "for key, (label, logscale, norm_var) in discrete_scene_keys_labels_dict.items():\n",
    "    data_df = get_discrete_scene_df(key, norm_var)\n",
    "    \n",
    "    if key == \"num_total_agents\":\n",
    "        print(data_df[\"value\"].sum(axis=0))\n",
    "    elif key == \"num_stationary\":\n",
    "        print(data_df)\n",
    "    elif key == \"agent_counts\":\n",
    "        print(data_df)\n",
    "    \n",
    "    fig, ax = plt.subplots(figsize=(5, 3))\n",
    "    if \"agent_type\" in data_df.columns:\n",
    "        agent_types_list = [x.name for x in AgentType]\n",
    "\n",
    "        sns.barplot(\n",
    "            data=data_df,\n",
    "            x=\"agent_type\",\n",
    "            order=agent_types_list,\n",
    "            y=\"value\",\n",
    "            hue=\"env_name\",\n",
    "            hue_order=all_hue_order if key != \"offroad_agent_types\" else [x for x in AV_hue_order if x != \"waymo_train\"],\n",
    "            palette=sns.color_palette(),\n",
    "            ax=ax\n",
    "        )\n",
    "    \n",
    "        fix_legend_labels(ax)\n",
    "        sns.move_legend(ax, \"upper left\", bbox_to_anchor=(1, 1))\n",
    "        ax.set_ylabel(label)\n",
    "        ax.set_xlabel(\"Agent Type\")\n",
    "        \n",
    "        ax.tick_params(axis='x', labelrotation=15)\n",
    "        \n",
    "        if logscale:\n",
    "            ax.set_yscale(\"log\")\n",
    "        else:\n",
    "            ax.yaxis.set_major_formatter(ticker.PercentFormatter(1.0, decimals=0))\n",
    "    \n",
    "    else:\n",
    "        sns.barplot(\n",
    "            data=data_df,\n",
    "            x=\"dummy\",\n",
    "            y=\"value\",\n",
    "            hue=\"env_name\",\n",
    "            hue_order=all_hue_order if key != \"offroad_agent_types\" else [x for x in AV_hue_order if x != \"waymo_train\"],\n",
    "            palette=sns.color_palette(),\n",
    "            ax=ax\n",
    "        )\n",
    "    \n",
    "        fix_legend_labels(ax)\n",
    "        sns.move_legend(ax, \"upper left\", bbox_to_anchor=(1, 1))\n",
    "        ax.set_ylabel(label)\n",
    "        ax.set_xticklabels([])\n",
    "        ax.set_xlabel(None)\n",
    "        if logscale:\n",
    "            ax.set_yscale(\"log\")\n",
    "        else:\n",
    "            ax.yaxis.set_major_formatter(ticker.PercentFormatter(1.0, decimals=0))\n",
    "\n",
    "    fig.savefig(f\"../plots/combined/{key}.pdf\", bbox_inches=\"tight\")\n",
    "    plt.close(fig)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "nuplan",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
