{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Setup to analyse an MDP Playground experiment\n",
    "from mdp_playground.analysis import MDPP_Analysis\n",
    "# Set dir_name to the location where the CSV files from running an experiment were saved\n",
    "dir_name = '/home/mdpp_8828049'\n",
    "# Set exp_name to the name that was given to the experiment when running it\n",
    "exp_name = 'dqn_space_invaders_p_noise'\n",
    "# Set the following to True to save PDFs of plots that you generate below\n",
    "save_fig = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Data loading\n",
    "# IMPORTANT: If evaluation_interval is set to None in experiment config file, set load_eval to False here - this would lead to errors in the cells below that plot evaluation metrics. Please ignore those.\n",
    "mdpp_analysis = MDPP_Analysis()\n",
    "train_stats, eval_stats, train_curves, eval_curves, train_aucs, eval_aucs = mdpp_analysis.load_data(dir_name, exp_name, load_eval=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 1-D: Plots showing reward after 20k timesteps when varying a single meta-feature\n",
    "# Plots across 10 runs: Training: with std dev across the runs\n",
    "mdpp_analysis.plot_1d_dimensions(train_stats, save_fig)\n",
    "mdpp_analysis.plot_1d_dimensions(train_aucs, save_fig)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plots across 10 runs: Evaluation: with std dev across the runs\n",
    "mdpp_analysis.plot_1d_dimensions(eval_stats, save_fig, train=False)\n",
    "mdpp_analysis.plot_1d_dimensions(eval_aucs, save_fig, train=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# This and the next cell do the same as the previous 2 cells but plot episode mean lengths instead of episode reward\n",
    "mdpp_analysis.plot_1d_dimensions(train_stats, save_fig, metric_num=-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mdpp_analysis.plot_1d_dimensions(eval_stats, save_fig, train=False, metric_num=-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 2-D heatmap plots across 10 runs: Training runs: with std dev across the runs\n",
    "# There seems to be a bug with matplotlib - x and y axes tick labels are not correctly set even though we pass them. Please feel free to look into the code and suggest a correction if you find it.\n",
    "mdpp_analysis.plot_2d_heatmap(train_stats, save_fig)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# 2-D heatmap plots across 10 runs: Evaluation runs: with std dev across the runs\n",
    "mdpp_analysis.plot_2d_heatmap(eval_stats, save_fig, train=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot learning curves: Training: Each curve corresponds to a different seed for the agent\n",
    "mdpp_analysis.plot_learning_curves(train_curves, save_fig)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot learning curves: Evaluation: Each curve corresponds to a different seed for the agent\n",
    "mdpp_analysis.plot_learning_curves(eval_curves, save_fig, train=False)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
