{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import wandb\n",
    "import pandas as pd\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "api = wandb.Api(timeout=120)\n",
    "runs = api.runs(\"IPRO_experiments\")\n",
    "env_id = \"minecart-v0\""
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "run_hists = {'ppo': {'arg1': [], 'arg2': [], 'arg3': []},\n",
    "             'dqn': {'arg1': [], 'arg2': [], 'arg3': []},\n",
    "             'a2c': {'arg1': [], 'arg2': [], 'arg3': []}}\n",
    "\n",
    "for run in runs:\n",
    "    if run.config['env_id'] == env_id:\n",
    "        name = run.name\n",
    "        splitted = name.split('_')\n",
    "        alg = splitted[0]\n",
    "        arg = splitted[-1]\n",
    "        run_hists[alg][arg].append(run.history(keys=['outer/hypervolume', 'outer/coverage']))\n",
    "        print(f'Added run {name} to {alg} - {arg}')"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "def extract_iter_hist(hist):\n",
    "    hypervolumes = list(hist['outer/hypervolume'].values)\n",
    "    coverages = list(np.clip(list(hist['outer/coverage'].values), 0, 1))\n",
    "    return hypervolumes, coverages"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "run_data = {'ppo': {'arg1': [], 'arg2': [], 'arg3': []},\n",
    "            'dqn': {'arg1': [], 'arg2': [], 'arg3': []},\n",
    "            'a2c': {'arg1': [], 'arg2': [], 'arg3': []}}\n",
    "\n",
    "for alg in run_hists:\n",
    "    print(f\"Extracting data for {alg}\")\n",
    "    for arg in run_hists[alg]:\n",
    "        print(f\"Extracting data for {alg} - {arg}\")\n",
    "        for seed, hist in enumerate(run_hists[alg][arg]):\n",
    "            print(f'Run {seed}')\n",
    "            hypervolumes, coverages = extract_iter_hist(hist)\n",
    "            run_data[alg][arg].append((hypervolumes, coverages))"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "best_data = {'ppo': [],\n",
    "             'dqn': [],\n",
    "             'a2c': []}\n",
    "\n",
    "# Extract the best argument i.e. the argument with the largest mean final hypervolume\n",
    "max_iter = 0\n",
    "\n",
    "for alg in run_data:\n",
    "    best_arg = None\n",
    "    best_mean = -1\n",
    "    for arg in run_data[alg]:\n",
    "        hypervolumes = [tpl[0] for tpl in run_data[alg][arg]]\n",
    "        arg_mean = np.mean([hv[-1] for hv in hypervolumes])\n",
    "        print(f\"Mean final hypervolume for {alg} - {arg}: {arg_mean}\")\n",
    "\n",
    "        if arg_mean > best_mean and len(run_data[alg][arg]) == 5:\n",
    "            best_mean = arg_mean\n",
    "            best_arg = arg\n",
    "            max_iter = max(max_iter, max([len(hv) for hv in hypervolumes]))\n",
    "    print(f\"Best argument for {alg} is {best_arg} with mean final hypervolume {best_mean}\")\n",
    "    print('---')\n",
    "    best_data[alg] = run_data[alg][best_arg]"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "def fill_iterations(hypervolumes, coverages, max_iter):\n",
    "    while len(hypervolumes) < max_iter:\n",
    "        hypervolumes.append(hypervolumes[-1])\n",
    "        coverages.append(coverages[-1])"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "print(f\"Max iterations: {max_iter}\")\n",
    "for alg in best_data:\n",
    "    for seed, (hypervolumes, coverages) in enumerate(best_data[alg]):\n",
    "        fill_iterations(hypervolumes, coverages, max_iter)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# Make dictionaries with the data for all seeds.\n",
    "for alg in best_data:\n",
    "    hv_dict = {alg: [], 'Iteration': [], 'Seed': []}\n",
    "    cov_dict = {alg: [], 'Iteration': [], 'Seed': []}\n",
    "\n",
    "    for seed, (hypervolumes, coverages) in enumerate(best_data[alg]):\n",
    "        hv_dict[alg].extend(hypervolumes)\n",
    "        cov_dict[alg].extend(coverages)\n",
    "        hv_dict['Iteration'].extend(range(max_iter))\n",
    "        cov_dict['Iteration'].extend(range(max_iter))\n",
    "        hv_dict['Seed'].extend([seed] * max_iter)\n",
    "        cov_dict['Seed'].extend([seed] * max_iter)\n",
    "\n",
    "    hv_df = pd.DataFrame.from_dict(hv_dict)\n",
    "    cov_df = pd.DataFrame.from_dict(cov_dict)\n",
    "    hv_df.to_csv(f'results/{alg}_{env_id}_hv.csv', index=False)\n",
    "    cov_df.to_csv(f'results/{alg}_{env_id}_cov.csv', index=False)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [],
   "metadata": {
    "collapsed": false
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
