{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import ipywidgets as widgets\n",
    "\n",
    "# Avoid non-compliant Type 3 fonts\n",
    "import matplotlib\n",
    "matplotlib.rcParams['pdf.fonttype'] = 42  # pylint: disable=wrong-import-position\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from ipywidgets import interact\n",
    "from IPython.display import display\n",
    "from tqdm.notebook import tqdm\n",
    "\n",
    "import utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pd.set_option('display.max_rows', None)\n",
    "plt.rcParams['figure.figsize'] = (12, 8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "logs_dir = utils.get_logs_dir()\n",
    "eval_dir = utils.get_eval_dir()\n",
    "env_names = ['small_empty', 'small_divider', 'large_empty', 'large_doors', 'large_tunnels', 'large_rooms']\n",
    "step_size = 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load all runs\n",
    "cfgs = [utils.load_config(str(x / 'config.yml')) for x in tqdm(sorted(logs_dir.iterdir())) if x.is_dir()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def extend_curves(curves, min_len=None):\n",
    "    if len(curves) == 0:\n",
    "        return curves\n",
    "    max_length = max(len(curve) for curve in curves)\n",
    "    if min_len is not None:\n",
    "        max_length = max(max_length, min_len)\n",
    "    for i, curve in enumerate(curves):\n",
    "        curves[i] = np.pad(curve, (0, max_length - len(curve)), 'edge')\n",
    "    return curves"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_curve_for_run(cfg):\n",
    "    eval_path = eval_dir / '{}.npy'.format(cfg.run_name)\n",
    "    data = np.load(eval_path, allow_pickle=True)\n",
    "    curves = []\n",
    "    for episode in data:\n",
    "        cubes = np.asarray([step['cubes'] for step in episode])\n",
    "        simulation_steps = np.array([step['simulation_steps'] for step in episode])\n",
    "        x = np.arange(0, simulation_steps[-1] + step_size, step_size)\n",
    "        xp, fp = simulation_steps, cubes\n",
    "        curves.append(np.floor(np.interp(x, xp, fp, left=0)))\n",
    "    return np.mean(extend_curves(curves), axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_all_curves():\n",
    "    all_curves = {}\n",
    "    for cfg in tqdm(cfgs):\n",
    "        if cfg.experiment_name not in all_curves:\n",
    "            all_curves[cfg.experiment_name] = []\n",
    "        all_curves[cfg.experiment_name].append(get_curve_for_run(cfg))\n",
    "    return all_curves"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_curves = get_all_curves()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_all_cutoffs():\n",
    "    all_cutoffs = {}\n",
    "    for cfg in tqdm(cfgs):\n",
    "        robot_config_str = cfg.experiment_name.split('-')[0]\n",
    "        if robot_config_str not in all_cutoffs:\n",
    "            all_cutoffs[robot_config_str] = {}\n",
    "        if cfg.env_name not in all_cutoffs[robot_config_str]:\n",
    "            all_cutoffs[robot_config_str][cfg.env_name] = float('inf')\n",
    "\n",
    "        # Find the time at which the last cube was successfully foraged\n",
    "        y_mean = np.mean(extend_curves(all_curves[cfg.experiment_name]), axis=0)\n",
    "        this_len = np.searchsorted(y_mean > y_mean[-1] - 0.5, True)  # Subtract 0.5 since interpolated curves are continuous\n",
    "        all_cutoffs[robot_config_str][cfg.env_name] = min(all_cutoffs[robot_config_str][cfg.env_name], this_len)\n",
    "    return all_cutoffs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_cutoffs = get_all_cutoffs()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_all_results():\n",
    "    all_results = {}\n",
    "    for cfg in tqdm(cfgs):\n",
    "        robot_config_str = cfg.experiment_name.split('-')[0]\n",
    "        if robot_config_str not in all_results:\n",
    "            all_results[robot_config_str] = {}\n",
    "        cutoff = all_cutoffs[robot_config_str][cfg.env_name]\n",
    "        curves = extend_curves(all_curves[cfg.experiment_name], min_len=(cutoff + 1))\n",
    "        cubes = [curve[cutoff] for curve in curves]\n",
    "        all_results[robot_config_str][cfg.experiment_name] = '{:.2f} ± {:.2f}'.format(np.mean(cubes), np.std(cubes))\n",
    "    return all_results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_results = get_all_results()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def show_table():\n",
    "    def f(robot_config_str):\n",
    "        data = {'performance': all_results[robot_config_str]}\n",
    "        display(pd.DataFrame(data))\n",
    "\n",
    "    robot_config_str_dropdown = widgets.RadioButtons(options=sorted(all_results.keys()))\n",
    "    interact(f, robot_config_str=robot_config_str_dropdown)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# Note: These metrics measure relative performance, see show_curves() for plots of absolute performance\n",
    "show_table()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def show_curves():\n",
    "    def plot_curves(experiment_names, env_name, fontsize=20):\n",
    "        for experiment_name in experiment_names:\n",
    "            # Plot cutoff\n",
    "            robot_config_str = experiment_name.split('-')[0]\n",
    "            best_len = all_cutoffs[robot_config_str][env_name]\n",
    "            plt.axvline(best_len * step_size, linewidth=1, c='r')\n",
    "\n",
    "            # Plot curve\n",
    "            curves = extend_curves(all_curves[experiment_name])\n",
    "            x = np.arange(0, (len(curves[0]) - 0.5) * step_size, step_size)\n",
    "            y_mean, y_std = np.mean(curves, axis=0), np.std(curves, axis=0)\n",
    "            label = '{} ({})'.format(experiment_name, all_results[robot_config_str][experiment_name])\n",
    "            plt.plot(x, y_mean, label=label)\n",
    "            plt.fill_between(x, y_mean - y_std, y_mean + y_std, alpha=0.2)\n",
    "\n",
    "        num_cubes = 20 if env_name.startswith('large') else 10\n",
    "        plt.xlim(0)\n",
    "        plt.ylim(0, num_cubes)\n",
    "        plt.xticks(fontsize=fontsize - 2)\n",
    "        plt.yticks(range(0, num_cubes + 1, 2), fontsize=fontsize - 2)\n",
    "        plt.xlabel('Simulation Steps', fontsize=fontsize)\n",
    "        plt.ylabel('Num Objects', fontsize=fontsize)\n",
    "        plt.legend(fontsize=fontsize - 2)\n",
    "\n",
    "    def f(env_name, experiment_names):\n",
    "        if len(experiment_names) == 0:\n",
    "            return\n",
    "        plot_curves(experiment_names, env_name)\n",
    "        plt.show()\n",
    "\n",
    "    env_name_radio = widgets.RadioButtons(options=env_names)\n",
    "    experiment_names_select = widgets.SelectMultiple(layout=widgets.Layout(width='60%', height='150px'))\n",
    "    def update_experiment_names_options(*_):\n",
    "        matching_experiment_names = []\n",
    "        for experiment_name in sorted(all_curves):\n",
    "            if env_name_radio.value in experiment_name:\n",
    "                matching_experiment_names.append(experiment_name)\n",
    "        experiment_names_select.options = matching_experiment_names\n",
    "        experiment_names_select.rows = len(matching_experiment_names)\n",
    "        experiment_names_select.value = ()\n",
    "    env_name_radio.observe(update_experiment_names_options)\n",
    "    interact(f, env_name=env_name_radio, experiment_names=experiment_names_select)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Note: The vertical red line is used to measure relative performance, the curves show absolute performance\n",
    "show_curves()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
