{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "ab35b148-af7b-4b8b-b5ab-0b62e3c05c9a",
   "metadata": {},
   "source": [
    "## Evaluation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "87770b33-86f9-4c8c-aeb1-17c7732d3d85",
   "metadata": {},
   "source": [
    "Credits\n",
    "- [Stable Baselines3](https://github.com/DLR-RM/stable-baselines3)\n",
    "- [RL Baselines3 Zoo](https://github.com/DLR-RM/rl-baselines3-zoo)\n",
    "- [Gymnasium](https://gymnasium.farama.org/)\n",
    "- [rliable](https://github.com/google-research/rliable)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "956ce510-e4d1-48db-acfa-68edca31e2f5",
   "metadata": {},
   "source": [
    "### Import utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a87068ef-8aa6-4bec-a6d9-51edcaa248d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from scripts.evaluation_utils import *"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fc470a00-8323-48b2-95b2-da4c1c58a13e",
   "metadata": {},
   "source": [
    "### Define configs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "427e10ec-42fb-406d-85ba-57c30b38ede1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# global theme for plotting\n",
    "THEME = 'rliable' # recommended\n",
    "THEME_CHOICES = ['rliable', 'rl_zoo3', 'default']\n",
    "\n",
    "# text rendering in matplotlib (default is mathtext)\n",
    "USE_LATEX = True\n",
    "\n",
    "# evaluation npz log info\n",
    "# set 'BASE_LOG_PATH' to the '--log-folder' that you set during training \n",
    "# if you wish to regenerate the post-processed evaluation results, e.g., post_processed_results_56.pkl\n",
    "BASE_LOG_PATH = '' # to be set by user\n",
    "ALGO = 'ppo'\n",
    "WILDCARD = '**'\n",
    "EVALUATION_FILE = 'evaluations.npz'\n",
    "N_EVALUATION_FILE = 1 # per test case (i.e., per game, self-attention type, seed)\n",
    "N_EVALUATION_CHECKPOINT = 50 # total no. of eval checkpoints\n",
    "N_EVALUATION_EPISODE = 5 # no. of episodes per eval checkpoint\n",
    "EVALUATION_TIMESTEP_KEY = 'timesteps'\n",
    "EVALUATION_RESULT_KEY = 'results'\n",
    "EVALUATION_EP_LEN_KEY = 'ep_lengths'\n",
    "\n",
    "# experiment variables\n",
    "NUM_TIMESTEPS = 1e7 # this is the total timesteps set in the ppo.yml\n",
    "GAMES = [] # all games with 'NoFrameskip-v4' suffix\n",
    "GAMES_56 = [] # games focused in this paper\n",
    "GAMES_EXCLUDED_IN_AGENT57_NO_VERSION = ['Adventure', 'AirRaid', 'Carnival', 'ElevatorAction', 'JourneyEscape', 'Pooyan'] # games excluded by Agent57\n",
    "GAME_VERSION = 'NoFrameskip-v4'\n",
    "SELF_ATTN_TYPES = ['NA', 'SWA', 'CWRA', 'CWCA', 'CWRCA']\n",
    "SEEDS = ['0', '1', '10', '42', '1234']\n",
    "SEEDS_TO_RUNS = {'0':0, '1':1, '10':2, '42':3, '1234':4} # map seeds to column indices in 'mean_per_eval' array\n",
    "\n",
    "# post-processed evaluation results in pickle file format\n",
    "PICKLE_FILE_PATH_56 = 'data/post_processed_results_56.pkl'\n",
    "\n",
    "# smoothing factors\n",
    "SMOOTH_WEIGHT = 0.98  # weight used in Debiased Exponential Moving Average (DEMA)\n",
    "\n",
    "# figures\n",
    "FIGURE_PATH_EVALUATION = 'figures'\n",
    "LINEPLOT_PATH = 'lineplot' # learning curves\n",
    "RLIABLE_PATH = 'rliable' # overall performance with stratified bootstrap CIs\n",
    "DEMA_PATH = 'debiased_ema'\n",
    "UNSMOOTH_PATH = 'unsmoothed'\n",
    "FIGURE_EXT = '.pdf'\n",
    "DPI = 300\n",
    "\n",
    "# tables\n",
    "TABLE_PATH_EVALUATION = 'tables'\n",
    "TABLE_MARKDOWN_EXT = '.md' # for results table in markdown format\n",
    "TABLE_LATEX_EXT = '.txt' # for results table in latex format\n",
    "\n",
    "# winners\n",
    "WINNER_PATH_EVALUATION = 'winners'\n",
    "WINNER_EXT = '.json' # for winning games"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "57802ca6-9d91-4abf-8fea-d25d5978cf50",
   "metadata": {},
   "outputs": [],
   "source": [
    "# get all Atari games with version 'NoFrameskip-v4'\n",
    "atari_game_list=[]\n",
    "for key, value in gym.envs.registry.items():\n",
    "    if 'NoFrameskip-v4' in key and '-ram' not in key and 'AtariEnv' in value.entry_point: \n",
    "        atari_game_list.append(key)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b249bae-ad83-4c57-a67f-d15327c6ad64",
   "metadata": {},
   "outputs": [],
   "source": [
    "# set GAMES\n",
    "GAMES = atari_game_list.copy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a2d4cae1-0b94-4040-b0c4-436768037bc9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# add game version to each game in GAMES_EXCLUDED_IN_AGENT57_NO_VERSION\n",
    "GAMES_EXCLUDED_IN_AGENT57 = [game+GAME_VERSION for game in GAMES_EXCLUDED_IN_AGENT57_NO_VERSION]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "29cb71a2-41b7-4816-8fcf-d4733593a4ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "# set GAMES_56\n",
    "GAMES_56 = [game for game in GAMES if game not in GAMES_EXCLUDED_IN_AGENT57]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "170ea560-162e-4616-9fc9-2535c0cc4fae",
   "metadata": {},
   "source": [
    "### Set theme for all the plots"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "468550df-a32b-412f-bf13-93b75088519e",
   "metadata": {},
   "outputs": [],
   "source": [
    "set_theme(THEME)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7d73c256-9b73-4c4e-a006-6887942ff2af",
   "metadata": {},
   "source": [
    "### Re-generate post-processed results (optional)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8d765f94-954e-447d-8b6e-bca1836cf9a8",
   "metadata": {},
   "source": [
    "You may uncomment the following cell to re-generate the post-processed evaluation results. Do remember to set the `BASE_LOG_PATH` in the `Configs` section."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ab24c64-e831-45ca-bdee-e1cae92dc272",
   "metadata": {},
   "outputs": [],
   "source": [
    "# game_list = GAMES_56.copy()\n",
    "# sat_list = SELF_ATTN_TYPES.copy()\n",
    "# seed_list = SEEDS.copy()\n",
    "# learning_curves = False # set to True to see all learning curves\n",
    "# results_table = False # set to True to see the results table\n",
    "# verbose = False # Set to True to see more logs\n",
    "# save_file = True\n",
    "# output_fp = PICKLE_FILE_PATH_56\n",
    "\n",
    "# post_processed_results_56 = load_all_evaluation_files(game_list, sat_list, seed_list, \\\n",
    "#                                                       learning_curves=learning_curves, results_table=results_table, verbose=verbose, \\\n",
    "#                                                       save_file=save_file, output_fp=output_fp)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1bef37ff-2ce9-43c6-bef8-32ae46c010e8",
   "metadata": {},
   "source": [
    "### Create results table"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "10180d7b-bf52-4454-9bff-26cfaec328a6",
   "metadata": {},
   "source": [
    "Get results table using the **all evaluation scores**, i.e., `use_last=False` in **Markdown** format"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "836add67-0b5f-4072-91fb-880beb2ae312",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "results_table_df_56_all, results_table_markdown_56_all, winning_games_56_all = get_results_table(post_processed_results_56, use_last=False, \\\n",
    "                                                                                                 markdown=True, latex=False, save_table_as_md=True, \\\n",
    "                                                                                                 save_table_as_txt=False, save_winner_as_json=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6091d8df-9765-4ffc-918b-fc51b9d582fc",
   "metadata": {},
   "source": [
    "You may copy and paste the above cell output in this Markdown cell to better visualize the table."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6ddcc042-f4aa-403a-95ad-aae40a1efbe0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# check the winning games for each sat\n",
    "for sat in SELF_ATTN_TYPES:\n",
    "    print(f\"'\\033[1m{sat}\\033[0m': {winning_games_56_all[sat]}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7306b325-9f6c-4cd9-977a-3697b6886354",
   "metadata": {},
   "source": [
    "Get results table using the **all evaluation scores**, i.e., `use_last=False` in **LaTeX** format"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aaedf334-a454-4809-a992-8be8ba4d0fd7",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "results_table_df_56_all, results_table_markdown_56_all, winning_games_56_all = get_results_table(post_processed_results_56, use_last=False, \\\n",
    "                                                                                                 markdown=False, latex=True, save_table_as_md=False, \\\n",
    "                                                                                                 save_table_as_txt=True, save_winner_as_json=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1315b8c3-e0ff-4a40-8ee1-3accb458096d",
   "metadata": {},
   "outputs": [],
   "source": [
    "results_table_df_56_all"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f1323e3d-eeda-449b-bcb6-1348a87139d1",
   "metadata": {},
   "source": [
    "### Plot performance per game"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e725d1b7-1434-47e3-aea5-4cbf7bc74a76",
   "metadata": {},
   "source": [
    "Without smoothing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ac58f1bc-a335-4717-a863-253c19385b1b",
   "metadata": {},
   "outputs": [],
   "source": [
    "results = post_processed_results_56\n",
    "game_list = GAMES_56.copy()\n",
    "sat_list = SELF_ATTN_TYPES.copy()\n",
    "seed_list = SEEDS.copy()\n",
    "smoothing = False\n",
    "smooth_weight = 0.98\n",
    "no_million = False\n",
    "hue = 'a'\n",
    "n_boot = 10\n",
    "seed_boot = 42\n",
    "linewidth = 2\n",
    "suptitle = None\n",
    "fontsize_suptitle = 30\n",
    "position_suptitle = (0.5, 1.1)\n",
    "figsize = (40,60)\n",
    "fontsize_subtitle = 20\n",
    "ncols = 6\n",
    "fontsize_legend = 20\n",
    "legend_title = 'SAT'\n",
    "fontsize_legend_title = 22\n",
    "savefig = True\n",
    "\n",
    "plot_grouped_results_table(results, game_list, sat_list, seed_list, smoothing=smoothing, smooth_weight=smooth_weight, no_million=no_million, \\\n",
    "                           hue=hue, n_boot=n_boot, seed_boot=seed_boot, linewidth=linewidth, \\\n",
    "                           suptitle=suptitle, fontsize_suptitle=fontsize_suptitle, position_suptitle=position_suptitle, \\\n",
    "                           figsize=figsize, fontsize_subtitle=fontsize_subtitle, ncols=ncols, \\\n",
    "                           fontsize_legend=fontsize_legend, legend_title=legend_title, fontsize_legend_title=fontsize_legend_title, \\\n",
    "                           savefig=savefig)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0e9ef436-02f2-4c1f-b7b5-94310441eb96",
   "metadata": {},
   "source": [
    "With DEMA smoothing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a68019ed-20b6-4ae1-a154-88a1df167cb9",
   "metadata": {},
   "outputs": [],
   "source": [
    "results = post_processed_results_56\n",
    "game_list = GAMES_56.copy()\n",
    "sat_list = SELF_ATTN_TYPES.copy()\n",
    "seed_list = SEEDS.copy()\n",
    "smoothing = True\n",
    "smooth_weight = 0.98\n",
    "no_million = False\n",
    "hue = 'a'\n",
    "n_boot = 10\n",
    "seed_boot = 42\n",
    "linewidth = 2\n",
    "suptitle = None\n",
    "fontsize_suptitle = 30\n",
    "position_suptitle = (0.5, 1.1)\n",
    "figsize = (40,60)\n",
    "fontsize_subtitle = 20\n",
    "ncols = 6\n",
    "fontsize_legend = 20\n",
    "legend_title = 'SAT'\n",
    "fontsize_legend_title = 22\n",
    "savefig = True\n",
    "\n",
    "plot_grouped_results_table(results, game_list, sat_list, seed_list, smoothing=smoothing, smooth_weight=smooth_weight, no_million=no_million, \\\n",
    "                           hue=hue, n_boot=n_boot, seed_boot=seed_boot, linewidth=linewidth, \\\n",
    "                           suptitle=suptitle, fontsize_suptitle=fontsize_suptitle, position_suptitle=position_suptitle, \\\n",
    "                           figsize=figsize, fontsize_subtitle=fontsize_subtitle, ncols=ncols, \\\n",
    "                           fontsize_legend=fontsize_legend, legend_title=legend_title, fontsize_legend_title=fontsize_legend_title, \\\n",
    "                           savefig=savefig)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d873073e-cd7c-447c-b3c3-c0bf3534ce37",
   "metadata": {},
   "source": [
    "### rliable plots"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "76ecbe10-5464-4156-8407-5e1496e17f36",
   "metadata": {},
   "source": [
    "Copy the Human and Random scores from [deep_rl_precipice_colab.ipynb](https://github.com/google-research/rliable/blob/master/deep_rl_precipice_colab.ipynb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d6d4db82-f2d4-4040-9182-9e279633bb66",
   "metadata": {},
   "outputs": [],
   "source": [
    "score_str = \"\"\"alien 7127.70 227.80 297638.17 ± 37054.55 464232.43 ± 7988.66 741812.63\n",
    "amidar 1719.50 5.80 29660.08 ± 880.39 31331.37 ± 817.79 28634.39\n",
    "assault 742.00 222.40 67212.67 ± 6150.59 110100.04 ± 346.06 143972.03\n",
    "asterix 8503.30 210.00 991384.42 ± 9493.32 999354.03 ± 12.94 998425.00\n",
    "asteroids 47388.70 719.10 150854.61 ± 16116.72 431072.45 ± 1799.13 6785558.64\n",
    "atlantis 29028.10 12850.00 1528841.76 ± 28282.53 1660721.85 ± 14643.83 1674767.20\n",
    "bank_heist 753.10 14.20 23071.50 ± 15834.73 27117.85 ± 963.12 1278.98\n",
    "battle_zone 37187.50 2360.00 934134.88 ± 38916.03 992600.31 ± 1096.19 848623.00\n",
    "beam_rider 16926.50 363.90 300509.80 ± 13075.35 390603.06 ± 23304.09 4549993.53\n",
    "berzerk 2630.40 123.70 61507.83 ± 26539.54 77725.62 ± 4556.93 85932.60\n",
    "bowling 160.70 23.10 251.18 ± 13.22 161.77 ± 99.84 260.13\n",
    "boxing 12.10 0.10 100.00 ± 0.00 100.00 ± 0.00 100.00\n",
    "breakout 30.50 1.70 790.40 ± 60.05 863.92 ± 0.08 864.00\n",
    "centipede 12017.00 2090.90 412847.86 ± 26087.14 908137.24 ± 7330.99 1159049.27\n",
    "chopper_command 7387.80 811.00 999900.00 ± 0.00 999900.00 ± 0.00 991039.70\n",
    "crazy_climber 35829.40 10780.50 565909.85 ± 89183.85 729482.83 ± 87975.74 458315.40\n",
    "defender 18688.90 2874.50 677642.78 ± 16858.59 730714.53 ± 715.54 839642.95\n",
    "demon_attack 1971.00 152.10 143161.44 ± 220.32 143913.32 ± 92.93 143964.26\n",
    "double_dunk -16.40 -18.60 23.93 ± 0.06 24.00 ± 0.00 23.94\n",
    "enduro 860.50 0.00 2367.71 ± 8.69 2378.66 ± 3.66 2382.44\n",
    "fishing_derby -38.70 -91.70 86.97 ± 3.25 90.34 ± 2.66 91.16\n",
    "freeway 29.60 0.00 32.59 ± 0.71 34.00 ± 0.00 33.03\n",
    "frostbite 4334.70 65.20 541280.88 ± 17485.76 309077.30 ± 274879.03 631378.53\n",
    "gopher 2412.50 257.60 117777.08 ± 3108.06 129736.13 ± 653.03 130345.58\n",
    "gravitar 3351.40 173.00 19213.96 ± 348.25 21068.03 ± 497.25 6682.70\n",
    "hero 30826.40 1027.00 114736.26 ± 49116.60 49339.62 ± 4617.76 49244.11\n",
    "ice_hockey 0.90 -11.20 63.64 ± 6.48 86.59 ± 0.59 67.04\n",
    "jamesbond 302.80 29.00 135784.96 ± 9132.28 158142.36 ± 904.45 41063.25\n",
    "kangaroo 3035.00 52.00 24034.16 ± 12565.88 18284.99 ± 817.25 16763.60\n",
    "krull 2665.50 1598.00 251997.31 ± 20274.39 245315.44 ± 48249.07 269358.27\n",
    "kung_fu_master 22736.30 258.50 206845.82 ± 11112.10 267766.63 ± 2895.73 204824.00\n",
    "montezuma_revenge 4753.30 0.00 9352.01 ± 2939.78 3000.00 ± 0.00 0.00\n",
    "ms_pacman 6951.60 307.30 63994.44 ± 6652.16 62595.90 ± 1755.82 243401.10\n",
    "name_this_game 8049.00 2292.30 54386.77 ± 6148.50 138030.67 ± 5279.91 157177.85\n",
    "phoenix 7242.60 761.40 908264.15 ± 28978.92 990638.12 ± 6278.77 955137.84\n",
    "pitfall 6463.70 -229.40 18756.01 ± 9783.91 0.00 ± 0.00 0.00\n",
    "pong 14.60 -20.70 20.67 ± 0.47 21.00 ± 0.00 21.00\n",
    "private_eye 69571.30 24.90 79716.46 ± 29515.48 40700.00 ± 0.00 15299.98\n",
    "qbert 13455.00 163.90 580328.14 ± 151251.66 777071.30 ± 190653.94 72276.00\n",
    "riverraid 17118.00 1338.50 63318.67 ± 5659.55 93569.66 ± 13308.08 323417.18\n",
    "road_runner 7845.00 11.50 243025.80 ± 79555.98 593186.78 ± 88650.69 613411.80\n",
    "robotank 11.90 2.20 127.32 ± 12.50 144.00 ± 0.00 131.13\n",
    "seaquest 42054.70 68.40 999997.63 ± 1.42 999999.00 ± 0.00 999976.52\n",
    "skiing -4336.90 -17098.10 -4202.60 ± 607.85 -3851.44 ± 517.52 -29968.36\n",
    "solaris 12326.70 1236.30 44199.93 ± 8055.50 67306.29 ± 10378.22 56.62\n",
    "space_invaders 1668.70 148.00 48680.86 ± 5894.01 67898.71 ± 1744.74 74335.30\n",
    "star_gunner 10250.00 664.00 839573.53 ± 67132.17 998600.28 ± 218.66 549271.70\n",
    "surround 6.50 -10.00 9.50 ± 0.19 10.00 ± 0.00 9.99\n",
    "tennis -8.30 -23.80 23.84 ± 0.10 24.00 ± 0.00 0.00\n",
    "time_pilot 5229.20 3568.00 405425.31 ± 17044.45 460596.49 ± 3139.33 476763.90\n",
    "tutankham 167.60 11.40 2354.91 ± 3421.43 483.78 ± 37.90 491.48\n",
    "up_n_down 11693.20 533.40 623805.73 ± 23493.75 702700.36 ± 8937.59 715545.61\n",
    "venture 1187.50 0.00 2623.71 ± 442.13 2258.93 ± 29.90 0.40\n",
    "video_pinball 17667.90 0.00 992340.74 ± 12867.87 999645.92 ± 57.93 981791.88\n",
    "wizard_of_wor 4756.50 563.50 157306.41 ± 16000.00 183090.81 ± 6070.10 197126.00\n",
    "yars_revenge 54576.90 3092.90 998532.37 ± 375.82 999807.02 ± 54.85 553311.46\n",
    "zaxxon 9173.30 32.50 249808.90 ± 58261.59 370649.03 ± 19761.32 725853.90\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ac00edd0-98d4-4ffd-b401-bffed01db664",
   "metadata": {},
   "outputs": [],
   "source": [
    "ALL_HUMAN_SCORES, ALL_RANDOM_SCORES = get_reference_scores(score_str, games_to_skip=['Surround'])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0d83fbc0-987d-4ccd-a2f8-76f52f4415eb",
   "metadata": {},
   "source": [
    "Obtain Human Normalized Scores (HNS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e5b811b-44d9-4968-ab8f-086a7007b066",
   "metadata": {},
   "outputs": [],
   "source": [
    "results = post_processed_results_56\n",
    "sat_list = SELF_ATTN_TYPES.copy()\n",
    "game_list = GAMES_56.copy()\n",
    "n_games = len(game_list)\n",
    "n_runs = len(SEEDS)\n",
    "n_eval = N_EVALUATION_CHECKPOINT\n",
    "\n",
    "last_eval_hns_dict_56, mean_eval_hns_dict_56, all_eval_hns_dict_56 = get_hns_dict(results, sat_list, game_list)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e639c1a0-0e23-43c2-8de5-b87771491417",
   "metadata": {},
   "source": [
    "#### Aggregate performance"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fee42177-064b-48a5-99c4-ca0747b06d9c",
   "metadata": {},
   "source": [
    "Use `mean evals`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "429db2ba-86b4-4d33-a456-367cbb023a9c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# use mean evals\n",
    "hns_dicts = [last_eval_hns_dict_56, mean_eval_hns_dict_56]\n",
    "use_last = False\n",
    "task_bootstrap = False\n",
    "reps = 2000\n",
    "seed = None\n",
    "subfigure_width = 3.5\n",
    "row_height = 0.6\n",
    "xlabel_y_coord = -0.02\n",
    "interval_height = 0.6\n",
    "wspace = 0.11\n",
    "adjust_bottom = 0.2\n",
    "savefig = True\n",
    "figname = \"Aggregate performance (bootstrap over runs and use mean evals)\"\n",
    "plot_aggregate_performance(hns_dicts, use_last=use_last, task_bootstrap=task_bootstrap, reps=reps, seed=seed, \\\n",
    "                           subfigure_width=subfigure_width, row_height=row_height, xlabel_y_coord=xlabel_y_coord, \\\n",
    "                           interval_height=interval_height, wspace=wspace, adjust_bottom=adjust_bottom, \\\n",
    "                           savefig=savefig, figname=figname)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "59a0e0e9-1b5b-4edd-ace1-471356c9c0b5",
   "metadata": {},
   "source": [
    "#### Performance profile"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "55936323-2652-4a0b-b523-b83156671691",
   "metadata": {},
   "source": [
    "Use `mean evals`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69c9b27d-cc31-43d7-b5a7-cd530a9cc0f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# use mean evals, use score dist, with inset plot\n",
    "hns_dicts = [last_eval_hns_dict_56, mean_eval_hns_dict_56]\n",
    "use_last = False\n",
    "tau_start = 0\n",
    "tau_stop = 8\n",
    "tau_num = 81\n",
    "use_score_distribution = True\n",
    "task_bootstrap = False\n",
    "reps = 2000\n",
    "seed = None\n",
    "alpha = 0.15\n",
    "figsize = (7, 5)\n",
    "linestyles=None\n",
    "linewidth = 2.0\n",
    "inset = True\n",
    "inset_x_coord = 0.23\n",
    "inset_y_coord = 0.4\n",
    "inset_width = 0.55\n",
    "inset_height = 0.45\n",
    "inset_xlim_lower = 1\n",
    "inset_xlim_upper = 4\n",
    "inset_ylim_lower = 0.0\n",
    "inset_ylim_upper = 0.25\n",
    "inset_xticks = [1, 2, 3, 4]\n",
    "legend_loc = 'best'\n",
    "savefig = True # to save the figure, turn on plt.tight_layout()\n",
    "figname = \"Performance profiles with inset (bootstrap over runs and use mean evals, use score dist)\"\n",
    "\n",
    "plot_performance_profile(hns_dicts, use_last=use_last, tau_start=tau_start, tau_stop=tau_stop, tau_num=tau_num, \\\n",
    "                         use_score_distribution=use_score_distribution, task_bootstrap=task_bootstrap, \\\n",
    "                         reps=reps, seed=seed, figsize=figsize, alpha=alpha, linestyles=linestyles, linewidth=linewidth, \\\n",
    "                         inset=inset, inset_x_coord=inset_x_coord, inset_y_coord=inset_y_coord, inset_width=inset_width, inset_height=inset_height, \\\n",
    "                         inset_xlim_lower=inset_xlim_lower, inset_xlim_upper=inset_xlim_upper, \\\n",
    "                         inset_ylim_lower=inset_ylim_lower, inset_ylim_upper=inset_ylim_upper, inset_xticks=inset_xticks, \\\n",
    "                         legend_loc=legend_loc, savefig=savefig, figname=figname)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6f31b416-d6cf-4c41-97e0-32d6fcf7c4e8",
   "metadata": {},
   "source": [
    "#### Probability of improvement"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f9999a2d-bc4f-4da6-ba46-a46a22b439e3",
   "metadata": {},
   "source": [
    "Use `mean evals`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a80b5cf-bafb-4f60-b4f9-1afb782d6955",
   "metadata": {},
   "outputs": [],
   "source": [
    "# use mean evals\n",
    "hns_dicts = [last_eval_hns_dict_56, mean_eval_hns_dict_56]\n",
    "use_last = False\n",
    "task_bootstrap = False\n",
    "reps = 1000\n",
    "seed = None\n",
    "figsize = (8, 6)\n",
    "alpha = 0.75\n",
    "interval_height = 0.6\n",
    "wrect = 5\n",
    "ticklabelsize = 'x-large'\n",
    "labelsize = 'x-large'\n",
    "ylabel_x_coordinate = 0.08\n",
    "savefig = True\n",
    "figname = \"Probability of improvement (bootstrap over runs and use mean evals)\"\n",
    "\n",
    "plot_prob_of_improvement(hns_dicts, use_last=use_last, task_bootstrap=task_bootstrap, reps=reps, seed=seed, figsize=figsize, alpha=alpha, \\\n",
    "                         interval_height=interval_height, wrect=wrect, ticklabelsize=ticklabelsize, labelsize=labelsize, \\\n",
    "                         ylabel_x_coordinate=ylabel_x_coordinate, savefig=savefig, figname=figname)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f3397da1-e5df-4ed3-a12a-e153f57c222b",
   "metadata": {},
   "source": [
    "Use `mean evals` and set `algo_X`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e70c37ba-a01c-4a09-b765-661e10646801",
   "metadata": {},
   "outputs": [],
   "source": [
    "# use mean evals\n",
    "for sat in SELF_ATTN_TYPES.copy():\n",
    "    hns_dicts = [last_eval_hns_dict_56, mean_eval_hns_dict_56]\n",
    "    use_last = False\n",
    "    task_bootstrap = False\n",
    "    reps = 1000\n",
    "    seed = None\n",
    "    algo_X = sat\n",
    "    figsize = (4, 3)\n",
    "    alpha = 0.75\n",
    "    interval_height = 0.6\n",
    "    wrect = 5\n",
    "    ticklabelsize = 'x-large'\n",
    "    labelsize = 'x-large'\n",
    "    ylabel_x_coordinate = 0.15\n",
    "    savefig = True\n",
    "    figname = f\"Probability of improvement - {algo_X} vs rest (bootstrap over runs and use mean evals)\"\n",
    "    \n",
    "    plot_prob_of_improvement(hns_dicts, use_last=use_last, task_bootstrap=task_bootstrap, reps=reps, seed=seed, algo_X=algo_X, \\\n",
    "                             figsize=figsize, alpha=alpha, interval_height=interval_height, wrect=wrect, ticklabelsize=ticklabelsize, \\\n",
    "                             labelsize=labelsize, ylabel_x_coordinate=ylabel_x_coordinate, savefig=savefig, figname=figname)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8afe29aa-8857-4bcc-a5f6-fecd4c62fa94",
   "metadata": {},
   "source": [
    "#### Sample Efficiency Curves"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "72744cd1-6724-438c-b25d-3d1434875934",
   "metadata": {},
   "outputs": [],
   "source": [
    "hns_dict = all_eval_hns_dict_56\n",
    "results = post_processed_results_56\n",
    "downsample_factor = 5\n",
    "task_bootstrap = False\n",
    "reps = 2000\n",
    "seed = None\n",
    "figsize=(7,5)\n",
    "labelsize='xx-large'\n",
    "ticklabelsize='xx-large'\n",
    "marker = 'o'\n",
    "linewidth = 2\n",
    "savefig = True\n",
    "figname = \"Sample efficiency (bootstrap over runs)\"\n",
    "\n",
    "plot_sample_efficiency(hns_dict, results, downsample_factor=downsample_factor, task_bootstrap=task_bootstrap, \\\n",
    "                       reps=reps, seed=seed, figsize=figsize, ticklabelsize=ticklabelsize, labelsize=labelsize, \\\n",
    "                       marker=marker, linewidth=linewidth, savefig=savefig, figname=figname)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "state": {},
    "version_major": 2,
    "version_minor": 0
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
