{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import numpy as np\n",
    "from experiments.atari import EXPERIMENTED_GAME\n",
    "\n",
    "\n",
    "experiments = [\"ut30_uh6000\"]\n",
    "baselines = ['DQN (Nature)', 'DQN (Adam)', 'C51', 'REM'] # ['DQN (Nature)', 'Quantile (JAX)_dopamine', 'DQN (Adam)', 'C51', 'REM', 'Rainbow', 'IQN', 'M-IQN']\n",
    "baselines_performance_profile = [] # ['DQN (Nature)', 'DQN (Adam)', 'C51', 'REM']\n",
    "games = EXPERIMENTED_GAME\n",
    "ks = [5]\n",
    "seeds = [11, 21, 12, 22, 13]\n",
    "plot_std = True\n",
    "selected_epochs = np.arange(200) # np.array([1, 10, 25, 50, 75, 100, 125, 150, 175, 200]) - 1\n",
    "taus = np.linspace(0.0, 8.0, 81)\n",
    "add_dqn = False\n",
    "add_head_std = False\n",
    "add_approximation_error = False"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Extract data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from idqn.utils.baselines_scores import get_baselines_scores\n",
    "\n",
    "if add_dqn:\n",
    "    dqn_scores = {}\n",
    "    for experiment in experiments:\n",
    "        dqn_scores[experiment] = {}\n",
    "        for game in games:\n",
    "            dqn_scores[experiment][game] = np.zeros((200, len(seeds))) * np.nan\n",
    "            for idx_seed, seed in enumerate(seeds):\n",
    "                dqn_scores[experiment][game][:, idx_seed] = np.load(f\"figures/{experiment}/{game}/DQN/J_{seed}.npy\")\n",
    "\n",
    "\n",
    "idqn_scores = {}\n",
    "for experiment in experiments:\n",
    "    for k in ks:\n",
    "        idqn_scores[f\"{experiment}_{k}\"] = {}\n",
    "        for game in games:\n",
    "            idqn_scores[f\"{experiment}_{k}\"][game] = np.zeros((200, len(seeds))) * np.nan\n",
    "            for idx_seed, seed in enumerate(seeds):\n",
    "                idqn_scores[f\"{experiment}_{k}\"][game][:, idx_seed] = np.load(f\"figures/{experiment}/{game}/iDQN/{k}_J_{seed}.npy\")\n",
    "\n",
    "baselines_scores = get_baselines_scores(baselines, games)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### IQM vs iterations & performance profile"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from idqn.utils.process_scores import compute_iqm_and_confidence_interval\n",
    "from experiments.atari import COLORS, LABEL, ORDER\n",
    "\n",
    "\n",
    "plt.rc(\"font\", size=15)\n",
    "plt.rc(\"lines\", linewidth=3)\n",
    "fig = plt.figure(\"Main figure\")\n",
    "ax = fig.add_subplot(111)\n",
    "fig_legend = plt.figure(\"Legend figure\")\n",
    "lines = []\n",
    "\n",
    "for experiment in experiments:\n",
    "    for k in ks:\n",
    "        iqms, iqms_confidence_interval = compute_iqm_and_confidence_interval(idqn_scores[f\"{experiment}_{k}\"], selected_epochs)\n",
    "        lines.append(ax.plot(selected_epochs + 1, iqms, label=f\"iDQN {LABEL[f'{experiment}_{k}']}\", color=COLORS[f\"{experiment}_{k}\"], zorder=ORDER[f\"{experiment}_{k}\"])[0])\n",
    "        if plot_std:\n",
    "            ax.fill_between(selected_epochs + 1, iqms_confidence_interval[0, :], iqms_confidence_interval[1, :], color=COLORS[f\"{experiment}_{k}\"], alpha=0.3, zorder=ORDER[f\"{experiment}_{k}\"])\n",
    "\n",
    "if add_dqn:\n",
    "    for experiment in experiments:\n",
    "        iqms, iqms_confidence_interval = compute_iqm_and_confidence_interval(dqn_scores[experiment], selected_epochs)\n",
    "        lines.append(ax.plot(selected_epochs + 1, iqms, label=f\"DQN {LABEL[experiment]}\", color=COLORS[experiment], zorder=ORDER[experiment])[0])\n",
    "        if plot_std:\n",
    "            ax.fill_between(selected_epochs + 1, iqms_confidence_interval[0, :], iqms_confidence_interval[1, :], color=COLORS[experiment], zorder=ORDER[experiment], alpha=0.3)\n",
    "\n",
    "for baseline in baselines:\n",
    "    iqms, iqms_confidence_interval = compute_iqm_and_confidence_interval(baselines_scores[baseline], selected_epochs)\n",
    "    lines.append(ax.plot(selected_epochs + 1, iqms, label=LABEL[baseline], color=COLORS[baseline], zorder=ORDER[baseline])[0])\n",
    "    if plot_std:\n",
    "        ax.fill_between(selected_epochs + 1, iqms_confidence_interval[0, :], iqms_confidence_interval[1, :], color=COLORS[baseline], zorder=ORDER[baseline], alpha=0.3)\n",
    "\n",
    "ax.grid(zorder=0)\n",
    "ax.set_xlabel(\"Number of Frames (in millions)\")\n",
    "ax.set_ylabel(\"IQM Human Normalized Score\")\n",
    "# ax.set_xticklabels([])\n",
    "\n",
    "if len(lines) < 6:\n",
    "    fig_legend.legend(lines, [line.get_label() for line in lines], ncols=len(lines))\n",
    "else:\n",
    "    import itertools\n",
    "    ncols = int(np.ceil(len(lines) / 2))\n",
    "    def flip(items):\n",
    "        return itertools.chain(*[items[i::ncols] for i in range(ncols)])\n",
    "    fig_legend.legend(flip(lines), flip([line.get_label() for line in lines]), ncols=ncols)\n",
    "\n",
    "if len(games) == 1 and len(experiments) > 0:\n",
    "    ax.set_title(games[0])\n",
    "    fig.savefig(f\"figures/{experiments[0]}/{games[0]}/J.pdf\", bbox_inches='tight')\n",
    "    _ = fig_legend.savefig(f\"figures/{experiments[0]}/{games[0]}/J_legend.pdf\", bbox_inches='tight')\n",
    "elif len(experiments) > 0:\n",
    "    fig.savefig(f\"figures/{experiments[0]}/J.pdf\", bbox_inches='tight')\n",
    "    _ = fig_legend.savefig(f\"figures/{experiments[0]}/J_legend.pdf\", bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from idqn.utils.process_scores import compute_performance_profile_and_confidence_interval\n",
    "from experiments.atari import COLORS, LABEL, ORDER\n",
    "\n",
    "\n",
    "if len(games) > 1:\n",
    "    plt.rc(\"font\", size=15)\n",
    "    plt.rc(\"lines\", linewidth=3)\n",
    "    fig = plt.figure(\"Main figure\")\n",
    "    ax = fig.add_subplot(111)\n",
    "    fig_legend = plt.figure(\"Legend figure\")\n",
    "    lines = []\n",
    "\n",
    "    for experiment in experiments:\n",
    "        for k in ks:\n",
    "            performance_profile, performance_profile_confidence_interval = compute_performance_profile_and_confidence_interval(idqn_scores[f\"{experiment}_{k}\"], taus)\n",
    "            lines.append(ax.plot(taus, performance_profile, label=f\"iDQN {LABEL[f'{experiment}_{k}']}\", color=COLORS[f\"{experiment}_{k}\"], zorder=ORDER[f\"{experiment}_{k}\"])[0])\n",
    "            if plot_std:\n",
    "                ax.fill_between(taus, performance_profile_confidence_interval[0, :], performance_profile_confidence_interval[1, :], color=COLORS[f\"{experiment}_{k}\"], zorder=ORDER[f\"{experiment}_{k}\"], alpha=0.3)\n",
    "\n",
    "    for baseline in baselines_performance_profile:\n",
    "        performance_profile, performance_profile_confidence_interval = compute_performance_profile_and_confidence_interval(baselines_scores[baseline], taus)\n",
    "        lines.append(ax.plot(taus, performance_profile, label=LABEL.get(baseline), color=COLORS[baseline], zorder=ORDER[baseline])[0])\n",
    "        if plot_std:\n",
    "            ax.fill_between(taus, performance_profile_confidence_interval[0, :], performance_profile_confidence_interval[1, :], color=COLORS[baseline], zorder=ORDER[baseline], alpha=0.3)\n",
    "\n",
    "    ax.grid(zorder=0)\n",
    "    ax.set_xlabel(r\"Human Normalized Score $(\\tau)$\")\n",
    "    ax.set_ylabel(r\"Fraction of runs with score $> \\tau$\")\n",
    "    fig_legend.legend(lines, [line.get_label() for line in lines], ncols=len(lines))\n",
    "    if len(experiments) > 0:\n",
    "        fig.savefig(f\"figures/{experiments[0]}/P.pdf\", bbox_inches='tight')\n",
    "        _ = fig_legend.savefig(f\"figures/{experiments[0]}/P_legend.pdf\", bbox_inches='tight')"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Head std"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from idqn.utils.process_scores import compute_iqm_and_confidence_interval\n",
    "from experiments.atari import COLORS, LABEL\n",
    "\n",
    "\n",
    "if add_head_std:\n",
    "    plt.rc(\"font\", size=15)\n",
    "    plt.rc(\"lines\", linewidth=3)\n",
    "    fig = plt.figure(\"Main figure\")\n",
    "    ax = fig.add_subplot(111)\n",
    "    fig_legend = plt.figure(\"Legend figure\")\n",
    "    lines = []\n",
    "\n",
    "    head_stds = {}\n",
    "    for experiment in experiments:\n",
    "        for k in ks:\n",
    "            head_stds[f\"{experiment}_{k}\"] = {}\n",
    "            for game in games:\n",
    "                head_stds[f\"{experiment}_{k}\"][game] = np.zeros((200, len(seeds))) * np.nan\n",
    "                for idx_seed, seed in enumerate(seeds):\n",
    "                    head_stds[f\"{experiment}_{k}\"][game][:, idx_seed] = np.load(f\"figures/{experiment}/{game}/iDQN/{k}_S_{seed}.npy\")\n",
    "\n",
    "\n",
    "    for experiment in experiments:\n",
    "        for k in ks:\n",
    "            iqms, iqms_confidence_interval = compute_iqm_and_confidence_interval(head_stds[f\"{experiment}_{k}\"], selected_epochs, normalize=False)\n",
    "            lines.append(ax.plot(selected_epochs + 1, iqms, label=f\"iDQN {LABEL[f'{experiment}_{k}']}\", color=COLORS[f\"{experiment}_{k}\"])[0])\n",
    "            if plot_std:\n",
    "                ax.fill_between(selected_epochs + 1, iqms_confidence_interval[0, :], iqms_confidence_interval[1, :], color=COLORS[f\"{experiment}_{k}\"], alpha=0.3)\n",
    "\n",
    "\n",
    "    ax.grid(zorder=0)\n",
    "    ax.set_xlabel(\"Number of Frames (in millions)\")\n",
    "    ax.set_ylabel(\"IQM inter-head standard deviation\")\n",
    "    fig_legend.legend(lines, [line.get_label() for line in lines], ncols=len(lines))\n",
    "    if len(games) == 1 and len(experiments) > 0:\n",
    "        ax.set_title(games[0])\n",
    "        fig.savefig(f\"figures/{experiments[0]}/{games[0]}/S.pdf\", bbox_inches='tight')\n",
    "        _ = fig_legend.savefig(f\"figures/{experiments[0]}/{games[0]}/S_legend.pdf\", bbox_inches='tight')"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Approximation error"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from idqn.utils.process_scores import compute_iqm_and_confidence_interval\n",
    "from experiments.atari import COLORS, LABEL\n",
    "\n",
    "\n",
    "if add_approximation_error:\n",
    "    plt.rc(\"font\", size=15)\n",
    "    plt.rc(\"lines\", linewidth=3)\n",
    "    fig = plt.figure(\"Main figure\")\n",
    "    ax = fig.add_subplot(111)\n",
    "    fig_legend = plt.figure(\"Legend figure\")\n",
    "    lines = []\n",
    "\n",
    "    approximation_errors = {}\n",
    "    for experiment in experiments:\n",
    "        for k in ks:\n",
    "            approximation_errors[f\"{experiment}_{k}\"] = {}\n",
    "            for game in games:\n",
    "                approximation_errors[f\"{experiment}_{k}\"][game] = np.zeros((200, len(seeds))) * np.nan\n",
    "                for idx_seed, seed in enumerate(seeds):\n",
    "                    approximation_errors[f\"{experiment}_{k}\"][game][:, idx_seed] = np.load(f\"figures/{experiment}/{game}/iDQN/{k}_A_{seed}.npy\")\n",
    "\n",
    "\n",
    "    for experiment in experiments:\n",
    "        for k in ks:\n",
    "            iqms, iqms_confidence_interval = compute_iqm_and_confidence_interval(approximation_errors[f\"{experiment}_{k}\"], selected_epochs, normalize=False)\n",
    "            lines.append(ax.plot(selected_epochs + 1, iqms, label=f\"iDQN {LABEL[f'{experiment}_{k}']}\" if k > 1 else \"DQN (Adam)\", color=COLORS[f\"{experiment}_{k}\"])[0])\n",
    "            if plot_std:\n",
    "                ax.fill_between(selected_epochs + 1, iqms_confidence_interval[0, :], iqms_confidence_interval[1, :], color=COLORS[f\"{experiment}_{k}\"], alpha=0.3)\n",
    "\n",
    "\n",
    "    ax.grid(zorder=0)\n",
    "    ax.set_xlabel(\"Number of Frames (in millions)\")\n",
    "    ax.set_ylabel(\"IQM approximation error\")\n",
    "    fig_legend.legend(lines, [line.get_label() for line in lines], ncols=len(lines))\n",
    "    if len(games) == 1 and len(experiments) > 0:\n",
    "        ax.set_title(games[0])\n",
    "        fig.savefig(f\"figures/{experiments[0]}/{games[0]}/A.pdf\", bbox_inches='tight')\n",
    "        _ = fig_legend.savefig(f\"figures/{experiments[0]}/{games[0]}/A_legend.pdf\", bbox_inches='tight')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "env_cpu",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
