{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import numpy as np\n",
    "from experiments.atari import EXPERIMENTED_GAME\n",
    "\n",
    "\n",
    "experiments = [\"ut30_uh6000\"]\n",
    "baselines = ['Quantile (JAX)_dopamine', 'REM', 'Rainbow', 'IQN', 'M-IQN'] # ['DQN (Nature)', 'Quantile (JAX)_dopamine', 'DQN (Adam)', 'C51', 'REM', 'Rainbow', 'IQN', 'M-IQN']\n",
    "baselines_performance_profile = []\n",
    "idx_game += 1\n",
    "games = [EXPERIMENTED_GAME[idx_game]]\n",
    "ks_idqn = [5]\n",
    "ks_iiqn = []\n",
    "seeds = [11, 12, 21, 22, 13]\n",
    "selected_epochs = np.arange(200) # np.array([1, 10, 25, 50, 75, 100, 125, 150, 175, 200]) - 1\n",
    "taus = np.linspace(0.0, 8.0, 81)\n",
    "show = {\"dqn\": False, \"iqn\": False, \"head_std\": False, \"approximation_error\": False, \"std\": True}"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Extract data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from idqn.utils.baselines_scores import get_baselines_scores\n",
    "\n",
    "def collect_data(scores, algorithm, idqn_key_k=\"\", idqn_key_path=\"\"):\n",
    "    for experiment in experiments:\n",
    "        experiment_key = f\"{algorithm}_{experiment}{idqn_key_k}\"\n",
    "        scores[experiment_key] = {}\n",
    "        for game in games:\n",
    "            scores[experiment_key][game] = np.zeros((200, len(seeds))) * np.nan\n",
    "            for idx_seed, seed in enumerate(seeds):\n",
    "                scores[experiment_key][game][:, idx_seed] = np.load(f\"figures/{experiment}/{game}/{algorithm}/{idqn_key_path}J_{seed}.npy\")\n",
    "\n",
    "if show[\"dqn\"]:\n",
    "    dqn_scores = {}\n",
    "    collect_data(dqn_scores, \"DQN\")\n",
    "\n",
    "if show[\"iqn\"]:\n",
    "    iqn_scores = {}\n",
    "    collect_data(iqn_scores, \"IQN\")\n",
    "\n",
    "if len(ks_idqn) > 0:\n",
    "    idqn_scores = {}\n",
    "    for k in ks_idqn:\n",
    "        collect_data(idqn_scores, \"iDQN\", f\"_{k}\", f\"{k}_\")\n",
    "\n",
    "if len(ks_iiqn) > 0:\n",
    "    iiqn_scores = {}\n",
    "    for k in ks_iiqn:\n",
    "        collect_data(iiqn_scores, \"iIQN\", f\"_{k}\", f\"{k}_\")\n",
    "\n",
    "baselines_scores = get_baselines_scores(baselines, games)\n",
    "baselines_performance_profile_scores = get_baselines_scores(baselines_performance_profile, games)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### IQM vs iterations & performance profile"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from idqn.utils.process_scores import compute_iqm_and_confidence_interval\n",
    "from experiments.atari import COLORS, LABEL, ORDER\n",
    "\n",
    "\n",
    "plt.rc(\"font\", size=18)  # 21 for main paper, 18 for the table of figures and 15 big figures.\n",
    "plt.rc(\"lines\", linewidth=3)\n",
    "fig = plt.figure(\"Main figure\")\n",
    "ax = fig.add_subplot(111)\n",
    "fig_legend = plt.figure(\"Legend figure\")\n",
    "lines = []\n",
    "\n",
    "\n",
    "def plot_iqm(scores, normalize=True):\n",
    "    for experiment in scores.keys():\n",
    "        iqms, iqms_confidence_interval = compute_iqm_and_confidence_interval(scores[experiment], selected_epochs, normalize)\n",
    "        lines.append(ax.plot(selected_epochs + 1, iqms, label=LABEL[experiment], color=COLORS[experiment], zorder=ORDER[experiment])[0])\n",
    "        if show[\"std\"]:\n",
    "            ax.fill_between(selected_epochs + 1, iqms_confidence_interval[0, :], iqms_confidence_interval[1, :], color=COLORS[experiment], zorder=ORDER[experiment], alpha=0.3)\n",
    "\n",
    "\n",
    "if len(ks_iiqn) > 0:\n",
    "    plot_iqm(iiqn_scores)\n",
    "\n",
    "if len(ks_idqn) > 0:\n",
    "    plot_iqm(idqn_scores)\n",
    "\n",
    "if show[\"dqn\"]:\n",
    "    plot_iqm(dqn_scores)\n",
    "\n",
    "if show[\"iqn\"]:\n",
    "    plot_iqm(iqn_scores)\n",
    "\n",
    "plot_iqm(baselines_scores)\n",
    "\n",
    "print(games[0])\n",
    "ax.grid(zorder=0)\n",
    "# ax.set_xticklabels([])\n",
    "ax.set_xlabel(\"Number of Frames (in millions)\")\n",
    "# ax.set_ylabel(\"IQM Human Normalized Score\")\n",
    "\n",
    "\n",
    "if len(lines) < 6:\n",
    "    fig_legend.legend(lines, [line.get_label() for line in lines], ncols=len(lines))\n",
    "else:\n",
    "    import itertools\n",
    "    ncols = int(np.ceil(len(lines) / 2))\n",
    "    def flip(items):\n",
    "        return itertools.chain(*[items[i::ncols] for i in range(ncols)])\n",
    "    fig_legend.legend(flip(lines), flip([line.get_label() for line in lines]), ncols=ncols)\n",
    "\n",
    "if len(games) == 1 and len(experiments) > 0:\n",
    "    ax.set_title(games[0])\n",
    "    fig.savefig(f\"figures/{experiments[0]}/{games[0]}/J.pdf\", bbox_inches='tight')\n",
    "    _ = fig_legend.savefig(f\"figures/{experiments[0]}/{games[0]}/J_legend.pdf\", bbox_inches='tight')\n",
    "elif len(experiments) > 0:\n",
    "    fig.savefig(f\"figures/{experiments[0]}/J.pdf\", bbox_inches='tight')\n",
    "    _ = fig_legend.savefig(f\"figures/{experiments[0]}/J_legend.pdf\", bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if len(games) > 1:\n",
    "    from idqn.utils.process_scores import compute_performance_profile_and_confidence_interval\n",
    "\n",
    "    plt.rc(\"font\", size=15)\n",
    "    plt.rc(\"lines\", linewidth=3)\n",
    "    fig = plt.figure(\"Main figure\")\n",
    "    ax = fig.add_subplot(111)\n",
    "    fig_legend = plt.figure(\"Legend figure\")\n",
    "    lines = []\n",
    "\n",
    "\n",
    "    def plot_performance_profile(scores):\n",
    "        for experiment in scores.keys():\n",
    "            performance_profile, performance_profile_confidence_interval = compute_performance_profile_and_confidence_interval(scores[experiment], taus)\n",
    "            lines.append(ax.plot(taus, performance_profile, label=LABEL[experiment], color=COLORS[experiment], zorder=ORDER[experiment])[0])\n",
    "            if show[\"std\"]:\n",
    "                ax.fill_between(taus, performance_profile_confidence_interval[0, :], performance_profile_confidence_interval[1, :], color=COLORS[experiment], zorder=ORDER[experiment], alpha=0.3)\n",
    "\n",
    "\n",
    "    if show[\"dqn\"]:\n",
    "        plot_performance_profile(dqn_scores)\n",
    "\n",
    "    if show[\"iqn\"]:\n",
    "        plot_performance_profile(iqn_scores)\n",
    "\n",
    "    if len(ks_idqn) > 0:\n",
    "        plot_performance_profile(idqn_scores)\n",
    "\n",
    "    if len(ks_iiqn) > 0:\n",
    "        plot_performance_profile(iiqn_scores)\n",
    "\n",
    "    plot_performance_profile(baselines_performance_profile_scores)\n",
    "\n",
    "    ax.grid(zorder=0)\n",
    "    ax.set_xlabel(r\"Human Normalized Score $(\\tau)$\")\n",
    "    ax.set_ylabel(r\"Fraction of runs with score $> \\tau$\")\n",
    "    fig_legend.legend(lines, [line.get_label() for line in lines], ncols=len(lines))\n",
    "    if len(experiments) > 0:\n",
    "        fig.savefig(f\"figures/{experiments[0]}/P.pdf\", bbox_inches='tight')\n",
    "        _ = fig_legend.savefig(f\"figures/{experiments[0]}/P_legend.pdf\", bbox_inches='tight')"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Head std"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if show[\"head_std\"]:\n",
    "    plt.rc(\"font\", size=15)\n",
    "    plt.rc(\"lines\", linewidth=3)\n",
    "    fig = plt.figure(\"Main figure\")\n",
    "    ax = fig.add_subplot(111)\n",
    "    fig_legend = plt.figure(\"Legend figure\")\n",
    "    lines = []\n",
    "\n",
    "    head_stds = {}\n",
    "    for experiment in experiments:\n",
    "        for k in ks_idqn:\n",
    "            head_stds[f\"iDQN_{experiment}_{k}\"] = {}\n",
    "            for game in games:\n",
    "                head_stds[f\"iDQN_{experiment}_{k}\"][game] = np.zeros((200, len(seeds))) * np.nan\n",
    "                for idx_seed, seed in enumerate(seeds):\n",
    "                    head_stds[f\"iDQN_{experiment}_{k}\"][game][:, idx_seed] = np.load(f\"figures/{experiment}/{game}/iDQN/{k}_S_{seed}.npy\")\n",
    "\n",
    "\n",
    "    plot_iqm(head_stds, normalize=False)\n",
    "\n",
    "    ax.grid(zorder=0)\n",
    "    ax.set_xlabel(\"Number of Frames (in millions)\")\n",
    "    ax.set_ylabel(\"IQM inter-head standard deviation\")\n",
    "    fig_legend.legend(lines, [line.get_label() for line in lines], ncols=len(lines))\n",
    "    if len(games) == 1 and len(experiments) > 0:\n",
    "        ax.set_title(games[0])\n",
    "        fig.savefig(f\"figures/{experiments[0]}/{games[0]}/S.pdf\", bbox_inches='tight')\n",
    "        _ = fig_legend.savefig(f\"figures/{experiments[0]}/{games[0]}/S_legend.pdf\", bbox_inches='tight')"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Approximation error"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if show[\"approximation_error\"]:\n",
    "    plt.rc(\"font\", size=15)\n",
    "    plt.rc(\"lines\", linewidth=3)\n",
    "    fig = plt.figure(\"Main figure\")\n",
    "    ax = fig.add_subplot(111)\n",
    "    fig_legend = plt.figure(\"Legend figure\")\n",
    "    lines = []\n",
    "\n",
    "    approximation_errors = {}\n",
    "    for experiment in experiments:\n",
    "        for k in ks_idqn:\n",
    "            approximation_errors[f\"{experiment}_{k}\"] = {}\n",
    "            for game in games:\n",
    "                approximation_errors[f\"{experiment}_{k}\"][game] = np.zeros((200, len(seeds))) * np.nan\n",
    "                for idx_seed, seed in enumerate(seeds):\n",
    "                    approximation_errors[f\"{experiment}_{k}\"][game][:, idx_seed] = np.load(f\"figures/{experiment}/{game}/iDQN/{k}_A_{seed}.npy\")\n",
    "\n",
    "\n",
    "    plot_iqm(approximation_errors)\n",
    "\n",
    "    ax.grid(zorder=0)\n",
    "    ax.set_xlabel(\"Number of Frames (in millions)\")\n",
    "    ax.set_ylabel(\"IQM approximation error\")\n",
    "    fig_legend.legend(lines, [line.get_label() for line in lines], ncols=len(lines))\n",
    "    if len(games) == 1 and len(experiments) > 0:\n",
    "        ax.set_title(games[0])\n",
    "        fig.savefig(f\"figures/{experiments[0]}/{games[0]}/A.pdf\", bbox_inches='tight')\n",
    "        _ = fig_legend.savefig(f\"figures/{experiments[0]}/{games[0]}/A_legend.pdf\", bbox_inches='tight')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "env_cpu",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
