{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import json\n",
    "import os\n",
    "import os.path as osp\n",
    "import numpy as np\n",
    "from collections import OrderedDict\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Global vars for tracking and labeling data at load time.\n",
    "DIV_LINE_WIDTH = 50\n",
    "ROW_ORDER = [\"vanilla\", \"uniform\", \"kl\", \"klmc\", \"klmr\", \"mc\", \"mcv\", \"mr\", \"mrv\"]\n",
    "COL_ORDER = [\"NaturalReward\", \"NaturalCost\",\n",
    "             \"AdvUniformReward\", \"AdvUniformCost\",\n",
    "             \"AdvMadReward\", \"AdvMadCost\", \n",
    "             \"AdvAmadReward\", \"AdvAmadCost\", \n",
    "             \"AdvMaxCostReward\", \"AdvMaxCostCost\", \n",
    "             \"AdvMaxRewardReward\", \"AdvMaxRewardCost\", \n",
    "             \"AverageReward\", \"AverageCost\"]\n",
    "ENV_NAMES = {\"SafetyCarCircle-v0\": \"Car-Circle\",\n",
    "             \"SafetyAntRun-v0\": \"Ant-Run\",\n",
    "             \"SafetyAntCircle-v0\": \"Ant-Circle\",\n",
    "             \"SafetyCarRun-v0\": \"Car-Run\",\n",
    "             \"SafetyDroneCircle-v0\": \"Drone-Circle\",\n",
    "             \"SafetyDroneRun-v0\": \"Drone-Run\"}\n",
    "\n",
    "# change this\n",
    "NAME = \"ppo\"\n",
    "ENV = \"DroneRun\"\n",
    "TABLE_NAME = \"eval_optimal\"\n",
    "\n",
    "def get_datasets(logdir, data):\n",
    "    \"\"\"\n",
    "    Recursively look through logdir for output files produced by\n",
    "    spinup.logx.Logger. \n",
    "\n",
    "    Assumes that any file \"progress.txt\" is a valid hit. \n",
    "    \"\"\"\n",
    "    for root, _, files in os.walk(logdir):\n",
    "        if 'progress.txt' in files:\n",
    "            exp_name = None\n",
    "            env = None\n",
    "            try:\n",
    "                config_path = open(os.path.join(root, 'config.json'))\n",
    "                config = json.load(config_path)\n",
    "                if TABLE_NAME not in config[\"data_dir\"]:\n",
    "                    continue\n",
    "                if NAME not in config[\"data_dir\"]:\n",
    "                    continue\n",
    "                if ENV not in config[\"data_dir\"]:\n",
    "                    continue\n",
    "                env = config[\"env_cfg\"][\"env_name\"]\n",
    "                exp_name = config[\"exp_name\"]\n",
    "                exp_name = exp_name.split('_')[-1]\n",
    "                if exp_name not in ROW_ORDER:\n",
    "                    continue\n",
    "                if env not in list(data.keys()):\n",
    "                    data[env] = OrderedDict()\n",
    "                if exp_name not in list(data[env].keys()):\n",
    "                    data[env][exp_name] = OrderedDict()\n",
    "                print(root)\n",
    "            except:\n",
    "                print('No file named config.json')\n",
    "            try:\n",
    "                exp_data = pd.read_table(os.path.join(root, 'progress.txt'))\n",
    "                exp_data = exp_data.rename(columns=lambda x: x.split(\"/\")[-1])\n",
    "            except:\n",
    "                print('Could not read from %s' % os.path.join(root, 'progress.txt'))\n",
    "                continue\n",
    "            # Score, NoiseScale, Time\n",
    "            for (column_name, column_data) in exp_data.items():\n",
    "                if \"Reward\" in column_name or \"Cost\" in column_name:\n",
    "                    if column_name not in data[env][exp_name].keys():\n",
    "                        data[env][exp_name][column_name] = []\n",
    "                    data[env][exp_name][column_name] += list(column_data.values)\n",
    "            data[env][exp_name] = OrderedDict(sorted(data[env][exp_name].items(), key=lambda i:COL_ORDER.index(i[0])))\n",
    "    for env in data.keys():\n",
    "        for exp in data[env].keys():\n",
    "            for r in data[env][exp].keys():\n",
    "                mean = round(np.mean(data[env][exp][r]), 2)\n",
    "                std = round(np.std(data[env][exp][r]), 2)\n",
    "                data[env][exp][r] = str(mean) + \"$\\pm$\" + str(std)\n",
    "    return data\n",
    "\n",
    "\n",
    "def get_all_datasets(all_logdirs, exp_data):\n",
    "    \"\"\"\n",
    "    For every entry in all_logdirs,\n",
    "        1) check if the entry is a real directory and if it is, \n",
    "           pull data from it; \n",
    "\n",
    "        2) if not, check to see if the entry is a prefix for a \n",
    "           real directory, and pull data from that.\n",
    "    \"\"\"\n",
    "    logdirs = []\n",
    "    for logdir in all_logdirs:\n",
    "        if osp.isdir(logdir) and logdir[-1] == os.sep:\n",
    "            logdirs += [logdir]\n",
    "        else:\n",
    "            basedir = osp.dirname(logdir)\n",
    "            fulldir = lambda x: osp.join(basedir, x)\n",
    "            prefix = logdir.split(os.sep)[-1]\n",
    "            listdir = os.listdir(basedir)\n",
    "            logdirs += sorted([fulldir(x) for x in listdir if prefix in x])\n",
    "\n",
    "    # Verify logdirs\n",
    "    print('Plotting from...\\n' + '=' * DIV_LINE_WIDTH + '\\n')\n",
    "    for logdir in logdirs:\n",
    "        print(logdir)\n",
    "    print('\\n' + '=' * DIV_LINE_WIDTH)\n",
    "\n",
    "    # Load data from logdirs\n",
    "    for log in logdirs:\n",
    "        exp_data = get_datasets(log, exp_data)\n",
    "    return exp_data\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "logdir = [\n",
    "    \"/home/zijian/code/robust-safe-rl/data/\"\n",
    "    ]\n",
    "\n",
    "print(\"=\"*DIV_LINE_WIDTH)\n",
    "print(\"processing csv file\")\n",
    "exp_data = OrderedDict()\n",
    "get_all_datasets(logdir, exp_data)\n",
    "\n",
    "env_names = exp_data.keys()\n",
    "for env in env_names:\n",
    "    save_dir = osp.join(logdir[0], str(env)+\"_\"+TABLE_NAME+\".csv\")\n",
    "    exp_data[env] = OrderedDict(sorted(exp_data[env].items(), key=lambda i:ROW_ORDER.index(i[0])))\n",
    "    row_names = list(exp_data[env].keys())\n",
    "    column_names = list(exp_data[env][row_names[0]].keys())\n",
    "    exp_data_pd = pd.DataFrame.from_dict(exp_data[env], orient='index', columns=column_names)\n",
    "    exp_data_pd.to_csv(save_dir)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Ant-Circle: optimal, mcv, mrv\n",
    "\n",
    "Ant-Run: optimal, mc, mr\n",
    "\n",
    "Car-Circle: optimal, mc, mr\n",
    "\n",
    "Car-Run: last, mcv, mrv\n",
    "\n",
    "Drone-Circle: optimal, mc, mr\n",
    "\n",
    "Drone-Run: optimal, mc, mr\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "rsrl",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.16"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "577644e077fb164ce2c5cd7fa96fdac8a28753748e1f22ed5e00e36b1cbe17de"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
