{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# %load_ext blackcellmagic \n",
    "# %%black -l 120\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "from itertools import zip_longest\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from experiments.base.compute_iqm import get_iqm_and_conf_per_epoch\n",
    "\n",
    "env_name = \"lunar_lander\"\n",
    "experiment_folders = [\"25_25/dqn\", \"50_50/dqn\", \"100_100/dqn\", \"200_200/dqn\"]\n",
    "\n",
    "base_path = os.path.join(os.path.abspath(\"\"), \"..\", env_name, \"exp_output\")\n",
    "\n",
    "experiment_data = {experiment: {} for experiment in experiment_folders}\n",
    "\n",
    "for experiment in experiment_folders:\n",
    "    experiment_path = os.path.join(base_path, experiment, \"episode_returns_and_lengths\")\n",
    "\n",
    "    returns_experiment_ = []\n",
    "\n",
    "    for experiment_file in os.listdir(experiment_path):\n",
    "        list_episode_returns = json.load(open(os.path.join(experiment_path, experiment_file), \"rb\"))[\"episode_returns\"]\n",
    "\n",
    "        returns_experiment_.append([np.mean(episode_returns) for episode_returns in list_episode_returns])\n",
    "\n",
    "    returns_experiment = np.array(list(zip_longest(*returns_experiment_, fillvalue=np.nan))).T\n",
    "\n",
    "    p = json.load(open(os.path.join(experiment_path, \"../../parameters.json\"), \"rb\"))\n",
    "\n",
    "    print(f\"Plot {experiment} with {returns_experiment.shape[0]} seeds.\")\n",
    "    if returns_experiment.shape[1] < p[\"dqn\"][\"n_epochs\"]:\n",
    "        print(f\"!!! All the {returns_experiment.shape[0]} seeds are not complete !!!\")\n",
    "    elif np.isnan(returns_experiment).any():\n",
    "        seeds = np.array(list(map(lambda path: int(path.strip(\".json\")), os.listdir(experiment_path))))\n",
    "        print(f\"!!! Seeds {seeds[np.isnan(returns_experiment).any(axis=1)]} are not complete !!!\")\n",
    "\n",
    "    experiment_data[experiment][\"iqm\"], experiment_data[experiment][\"confidence\"] = get_iqm_and_conf_per_epoch(\n",
    "        returns_experiment\n",
    "    )\n",
    "    experiment_data[experiment][\"x_values\"] = (\n",
    "        np.arange(1, returns_experiment.shape[1] + 1) * p[\"dqn\"][\"n_training_steps_per_epoch\"]\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from experiments import DISPLAY_NAME\n",
    "from experiments.lunar_lander import COLORS, ORDERS\n",
    "\n",
    "\n",
    "plt.rc(\"font\", family=\"serif\", serif=\"Times New Roman\", size=18)\n",
    "plt.rc(\"lines\", linewidth=3)\n",
    "\n",
    "fig = plt.figure(figsize=(6, 3))\n",
    "ax = fig.add_subplot(111)\n",
    "\n",
    "for experiment in experiment_folders:\n",
    "    ax.plot(\n",
    "        experiment_data[experiment][\"x_values\"],\n",
    "        experiment_data[experiment][\"iqm\"],\n",
    "        label=DISPLAY_NAME[experiment.split(\"/\")[1]],\n",
    "        color=COLORS[experiment.split(\"/\")[1]],\n",
    "        zorder=ORDERS[experiment.split(\"/\")[1]],\n",
    "    )\n",
    "    ax.fill_between(\n",
    "        experiment_data[experiment][\"x_values\"],\n",
    "        experiment_data[experiment][\"confidence\"][0],\n",
    "        experiment_data[experiment][\"confidence\"][1],\n",
    "        color=COLORS[experiment.split(\"/\")[1]],\n",
    "        alpha=0.3,\n",
    "        zorder=ORDERS[experiment.split(\"/\")[1]],\n",
    "    )\n",
    "    ax.ticklabel_format(style=\"sci\", axis=\"x\", scilimits=(0, 0))\n",
    "\n",
    "ax.set_xlabel(\"Env Steps\")\n",
    "ax.set_ylabel(\"IQM Return\")\n",
    "\n",
    "ax.grid()\n",
    "ax.legend(ncols=1, frameon=False, loc=\"center\", bbox_to_anchor=(1.25, 0.5))\n",
    "ax.set_title(\"Lunar Lander\")\n",
    "fig.savefig(f\"../{env_name}/exp_output/performances.pdf\", bbox_inches=\"tight\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "env_cpu",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
