{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "earlier-essay",
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "import os\n",
    "import pathlib\n",
    "\n",
    "from IPython.display import display, Math, Latex\n",
    "import numpy as np\n",
    "from tabulate import tabulate\n",
    "\n",
    "from offline_rl.rewards.evaluation.distance_matrix import DistanceMatrix\n",
    "from offline_rl.utils.file_utils import load_json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "based-supplement",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Taken from https://github.com/HumanCompatibleAI/evaluating-rewards/blob/7b99ec9b415d805bd77041f2f7807d112dec9802/src/evaluating_rewards/scripts/pipeline/combined_distances.py#L362\n",
    "def _fixed_width_format(x: float, figs: int = 3, return_zero_if_unrepr:bool=True) -> str:\n",
    "    \"\"\"Format x as a number targeting `figs+1` characters.\n",
    "    This is intended for cramming as much information as possible in a fixed-width\n",
    "    format. If `x >= 10 ** figs`, then we format it as an integer. Note this will\n",
    "    use more than `figs` characters if `x >= 10 ** (figs+1)`. Otherwise, we format\n",
    "    it as a float with as many significant figures as we can fit into the space.\n",
    "    If `x < 10 ** (-figs + 1)`, then we represent it as \"<\"+str(10 ** (-figs + 1)),\n",
    "    unless `x == 0` in which case we format `x` as \"0\" exactly.\n",
    "    Args:\n",
    "        x: The number to format. The code assumes this is non-negative; the return\n",
    "            value may exceed the character target if it is negative.\n",
    "        figs: The number of digits to target.\n",
    "    Returns:\n",
    "        The number formatted as described above.\n",
    "    \"\"\"\n",
    "    smallest_representable = 10 ** (-figs + 1)\n",
    "    if 0 < x < 10 ** (-figs + 1):\n",
    "        if return_zero_if_unrepr:\n",
    "            return \"0.00\"\n",
    "        else:\n",
    "            return \"<\" + str(smallest_representable)\n",
    "\n",
    "    raw_repr = str(x).replace(\".\", \"\")\n",
    "    num_leading_zeros = 0\n",
    "    for digit in raw_repr:\n",
    "        if digit == \"0\":\n",
    "            num_leading_zeros += 1\n",
    "        else:\n",
    "            break\n",
    "    if x >= 10 ** figs:\n",
    "        # No decimal point gives us an extra character to use\n",
    "        figs += 1\n",
    "    fstr = \"{:.\" + str(max(0, figs - num_leading_zeros)) + \"g}\"\n",
    "    res = fstr.format(x)\n",
    "\n",
    "    delta = (figs + 1) - len(res)\n",
    "    # g drops trailing zeros, add them back\n",
    "    if delta > 0 and \".\" in res:\n",
    "        res += \"0\" * delta\n",
    "    if delta > 1 and \".\" not in res:\n",
    "        res += \".\" + \"0\" * (delta - 1)\n",
    "\n",
    "    return res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "unavailable-throw",
   "metadata": {},
   "outputs": [],
   "source": [
    "basedir = \"<path-to-store-results>\" # Insert path where you saved the experiment outputs.\n",
    "seeddirs = sorted([d for d in pathlib.Path(basedir).iterdir() if d.is_dir()])\n",
    "[print(d) for d in seeddirs];"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "detected-playing",
   "metadata": {},
   "outputs": [],
   "source": [
    "# The environment for which to visualize results.\n",
    "# Options: BouncingBallsEnv-v0, CustomReacherEnv-v0, PointMazeLeftVel-v0\n",
    "env_name_to_visualize = \"BouncingBallsEnv-v0\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "integrated-advocacy",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Common constants\n",
    "ORDERED_METRIC_LABELS = [\n",
    "    \"TRRE_random\",\n",
    "    \"TRRE_expert\",\n",
    "    \"EPIC_random\",\n",
    "    \"EPIC_expert\",\n",
    "    \"CORR_random\",\n",
    "    \"CORR_expert\",\n",
    "]\n",
    "\n",
    "ORDERED_REWARD_LABELS = [\n",
    "    \"GT\",\n",
    "    \"SHAPING\",\n",
    "    \"FEASIBILITY\",\n",
    "    \"REGRESS OOD\",\n",
    "    \"REGRESS\",\n",
    "    \"PREF\",\n",
    "]\n",
    "REFERENCE_PRETTY_REWARD_LABEL = \"GT\"\n",
    "\n",
    "METRIC_LABEL_TO_DIRNAME = {\n",
    "    \"CORR_random\": \"no_canonicalization_random_dataset\",\n",
    "    \"CORR_expert\": \"no_canonicalization_expert_dataset\",\n",
    "    \"EPIC_random\": f\"EPIC_random_dataset\",\n",
    "    \"EPIC_expert\": f\"EPIC_expert_dataset\",\n",
    "    \"TRRE_random\": f\"EPIC_random_dataset_out_of_dist\",\n",
    "    \"TRRE_expert\": f\"EPIC_expert_dataset_out_of_dist\",\n",
    "}\n",
    "\n",
    "# Environment-specific constants.\n",
    "class BouncingBallsEnvVisualizationConstants:\n",
    "    env_name = \"BouncingBallsEnv-v0\"\n",
    "    \n",
    "    reward_label_to_pretty_reward_label = {\n",
    "             'goal_+1':\"GT\",\n",
    "             'goal_+1_shaping_+1':\"SHAPING\",\n",
    "             'feas_random_goal_+1_shaping_+1': \"FEASIBILITY\",\n",
    "             f'regression_random_dataset': \"REGRESS\",\n",
    "             f'regression_random_dataset_ood': \"REGRESS OOD\",\n",
    "             f'preference_random_dataset': \"PREF\",\n",
    "    }\n",
    "    \n",
    "class PointMazeLeftVelVisualizationConstants:\n",
    "    env_name = \"PointMazeLeftVel-v0\"\n",
    "    \n",
    "    reward_label_to_pretty_reward_label = {\n",
    "            'ground_truth':\"GT\",\n",
    "            f'regression_random_dataset': \"REGRESS\",\n",
    "    }\n",
    "    \n",
    "class CustomReacherEnvVisualizationConstants:\n",
    "    env_name = \"CustomReacherEnv-v0\"\n",
    "    \n",
    "    reward_label_to_pretty_reward_label = {\n",
    "             'ground_truth':\"GT\",\n",
    "             'ground_truth_shaping':\"SHAPING\",\n",
    "             f'regression_random_dataset': \"REGRESS\",\n",
    "             f'regression_random_dataset_ood': \"REGRESS OOD\",\n",
    "             f'preference_random_dataset': \"PREF\",\n",
    "    }\n",
    "    \n",
    "envs_constants = [\n",
    "    BouncingBallsEnvVisualizationConstants(),\n",
    "    CustomReacherEnvVisualizationConstants(),\n",
    "    PointMazeLeftVelVisualizationConstants(),\n",
    "]\n",
    "\n",
    "verbose = False\n",
    "for env_constants in envs_constants:\n",
    "    if env_constants.env_name != env_name_to_visualize:\n",
    "        continue\n",
    "    print(f\"Generating table for {env_constants.env_name}\")\n",
    "    \n",
    "    pretty_reward_label_to_reward_label = {v:k for k,v in env_constants.reward_label_to_pretty_reward_label.items()}\n",
    "    reference_reward_label = pretty_reward_label_to_reward_label[REFERENCE_PRETTY_REWARD_LABEL]\n",
    "    \n",
    "    tables = []\n",
    "    for seeddir in seeddirs:\n",
    "        if verbose:\n",
    "            print(f\"Collecting results from seed dir: {seeddir}\")\n",
    "        env_reward_evaluation_dir = os.path.join(seeddir, env_constants.env_name, \"reward_evaluation\")\n",
    "\n",
    "        # Collect the distance matrices.\n",
    "        metric_label_to_dist_mat = dict()\n",
    "        for label, metric_dirname in METRIC_LABEL_TO_DIRNAME.items():\n",
    "            filepath = os.path.join(env_reward_evaluation_dir, metric_dirname, \"distances.pkl\") \n",
    "            metric_label_to_dist_mat[label] = DistanceMatrix.load(filepath)\n",
    "        \n",
    "        # Collect results for the different rewards.\n",
    "        table = []\n",
    "        for pretty_reward_label in ORDERED_REWARD_LABELS:\n",
    "            if verbose:\n",
    "                print(f\"\\t{pretty_reward_label}\")\n",
    "            if pretty_reward_label not in pretty_reward_label_to_reward_label:\n",
    "                    continue\n",
    "\n",
    "            row = [pretty_reward_label]\n",
    "            reward_label = pretty_reward_label_to_reward_label[pretty_reward_label]\n",
    "            \n",
    "            # Collect reward evaluation results.\n",
    "            for metric_label in ORDERED_METRIC_LABELS:\n",
    "                if verbose:\n",
    "                    print(f\"\\t\\t{metric_label}\")\n",
    "                dist_mat = metric_label_to_dist_mat[metric_label]\n",
    "                dist = dist_mat.distance_between(reference_reward_label, reward_label)\n",
    "                row.append(dist * 1000)\n",
    "                \n",
    "            # Collect policy evaluation results.\n",
    "            policy_evaluation_output_filepath = os.path.join(\n",
    "                seeddir, \n",
    "                env_constants.env_name, \n",
    "                \"policy_evaluation\",\n",
    "                reward_label,\n",
    "                \"returns.json\",\n",
    "            )\n",
    "            policy_evaluation_output = load_json(policy_evaluation_output_filepath)\n",
    "            policy_evaluation_output = policy_evaluation_output[reference_reward_label]\n",
    "            policy_return = float(policy_evaluation_output[\"mean\"])\n",
    "            \n",
    "            row.append(policy_return)\n",
    "            \n",
    "            # Add the row to table.\n",
    "            table.append(row)\n",
    "            \n",
    "        tables.append(table)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "recovered-projection",
   "metadata": {},
   "outputs": [],
   "source": [
    "tables = np.array(tables)\n",
    "labels = tables[0, :, 0, None]\n",
    "num_seeds = len(tables)\n",
    "print(f\"num_seeds: {num_seeds}\")\n",
    "\n",
    "distance_values = tables[...,1:-1]\n",
    "mean_distance_values = distance_values.astype(float).mean(axis=0)\n",
    "std_distance_values = distance_values.astype(float).std(axis=0)\n",
    "formatted_mean_distance_values = np.empty_like(mean_distance_values).astype(str)\n",
    "formatted_std_err_distance_values = np.empty_like(std_distance_values).astype(str)\n",
    "for i in range(len(mean_distance_values)):\n",
    "    for j in range(len(mean_distance_values[i])):\n",
    "        formatted_mean_distance_values[i,j] = _fixed_width_format(mean_distance_values[i,j])\n",
    "        formatted_std_err_distance_values[i,j] = _fixed_width_format(std_distance_values[i,j] / np.sqrt(num_seeds))\n",
    "\n",
    "include_std_err_in_return_mean = True\n",
    "return_values = tables[...,-1].astype(float)\n",
    "mean_return_values = return_values.mean(axis=0)\n",
    "std_return_values = return_values.std(axis=0)\n",
    "formatted_mean_return_values = np.empty_like(mean_return_values).astype(str)\n",
    "formatted_std_err_return_values = np.empty_like(std_return_values).astype(str)\n",
    "for i in range(len(mean_return_values)):\n",
    "    if include_std_err_in_return_mean:\n",
    "        formatted_mean_return_values[i] = r\"{} $\\pm$ {}\".format(\n",
    "            _fixed_width_format(mean_return_values[i]),\n",
    "            _fixed_width_format(std_return_values[i] / np.sqrt(num_seeds)),\n",
    "        )\n",
    "    else:\n",
    "        formatted_mean_return_values[i] = f\"{_fixed_width_format(mean_return_values[i])}\"\n",
    "\n",
    "    formatted_std_err_return_values[i] = f\"{_fixed_width_format(std_return_values[i] / np.sqrt(num_seeds))}\"\n",
    "\n",
    "formatted_mean_return_values = formatted_mean_return_values.reshape(-1, 1)\n",
    "formatted_std_err_return_values = formatted_std_err_return_values.reshape(-1, 1)\n",
    "\n",
    "\n",
    "mean_table = np.concatenate((labels, formatted_mean_distance_values, formatted_mean_return_values), axis=-1)\n",
    "std_err_table = np.concatenate((labels, formatted_std_err_distance_values, formatted_std_err_return_values), axis=-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "creative-gossip",
   "metadata": {},
   "outputs": [],
   "source": [
    "def make_results_for_table(rows):\n",
    "    s = \"\"\n",
    "    for i, row in enumerate(rows):\n",
    "        for j, col in enumerate(row):\n",
    "            if j == 0:\n",
    "                s += f\"{col:15s}\"\n",
    "            else:\n",
    "                s += f\"{col:5}\"\n",
    "            \n",
    "            if j < len(row) - 1:\n",
    "                s += \" & \"        \n",
    "            else:\n",
    "                s += r\" \\\\\" + \"\\n\"\n",
    "    return s"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "duplicate-denmark",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"-------------MEAN---------------\")\n",
    "print(make_results_for_table(mean_table))\n",
    "print(\"-----------STD ERR--------------\")\n",
    "print(make_results_for_table(std_err_table))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "adopted-activity",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
