{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train agents"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "path_base = \"runs/cliff_walking\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import subprocess\n",
    "\n",
    "shared_hyperparameters = [\"--env-id\", \"CliffWalkingEnv-v0\", \"--save-model\", \"--gamma\", \"0.95\", \"--n-quantile\", \"50\", \"--dir\", path_base]\n",
    "\n",
    "base_script1 = [\"python\", \"qrdqn.py\"] + shared_hyperparameters\n",
    "base_script2 = [\"python\", \"qrsrm.py\", \"--alpha\", \"0.1\"] + shared_hyperparameters\n",
    "base_script3 = [\"python\", \"qrsrm.py\", \"--alpha\", \"0.3\"] + shared_hyperparameters\n",
    "base_script4 = [\"python\", \"qrsrm.py\", \"--alpha\", \"0.5\"] + shared_hyperparameters\n",
    "base_script5 = [\"python\", \"qrsrm.py\", \"--alpha\", \"0.7\"] + shared_hyperparameters\n",
    "base_script6 = [\"python\", \"qrsrm.py\", \"--risk-measure\", \"WSCVaR\", \"--alphas\", \"0.1,1.0\", \"--weights\", \"0.8,0.2\"] + shared_hyperparameters\n",
    "\n",
    "scripts = [base_script1] + [base_script2] + [base_script3] + [base_script4] + [base_script5] + [base_script6]\n",
    "\n",
    "seeds = [\"1\", \"2\", \"3\", \"4\", \"5\"]\n",
    "scripts = [script + [\"--seed\", seed] for script in scripts for seed in seeds]\n",
    "# Run each script\n",
    "for script in scripts:\n",
    "    print(f\"Running script {script}...\")\n",
    "    try:\n",
    "        subprocess.run(script, check=True)\n",
    "    except subprocess.CalledProcessError:\n",
    "        print(f\"Script {script} failed!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Run the simulations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import run_simulation_from_dir, add_columns, make_agent_hue_kws\n",
    "from utils import load_data_from_dir, smooth_dataframe\n",
    "from qrsrm import weighted_sum_of_cvar\n",
    "from utils import AGENT_NAME_MAP\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import os\n",
    "os.environ[\"KMP_DUPLICATE_LIB_OK\"]=\"TRUE\"\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore', category=UserWarning)\n",
    "warnings.filterwarnings('ignore', category=DeprecationWarning)\n",
    "warnings.filterwarnings('ignore', category=RuntimeWarning)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path_base = \"runs/cliff_walking\"\n",
    "Nsimulations = 10000\n",
    "sim_seed = 6\n",
    "pkl_path = path_base+\"/df_exp.pkl\"\n",
    "if \"df_exp.pkl\" not in os.listdir(path_base):\n",
    "    df_exp = run_simulation_from_dir(path_base + \"/\", Nsimulations=Nsimulations, sim_seed=sim_seed)\n",
    "    df_exp.to_pickle(pkl_path)\n",
    "\n",
    "df_exp = pd.read_pickle(pkl_path)\n",
    "df_exp = df_exp.pipe(add_columns)\n",
    "df_exp = df_exp.sort_values(by=['agent', 'risk_measure', 'alpha', 'n_quantile', 'environment_name', 'agent_seed']).reset_index(drop=True)\n",
    "df_exp.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Plot the figures"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Performance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython import display\n",
    "%matplotlib inline\n",
    "\n",
    "from matplotlib import ticker\n",
    "from matplotlib import rcParams\n",
    "\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "backend_format = \"retina\"  # @param [\"retina\", \"\"]\n",
    "%config InlineBackend.figure_format = backend_format\n",
    "\n",
    "sns.set_context(\"notebook\")\n",
    "sns.set_style(\"ticks\")\n",
    "\n",
    "rcParams['ytick.right'] = True\n",
    "rcParams['axes.autolimit_mode'] = 'round_numbers'\n",
    "rcParams['axes.xmargin'] = 0\n",
    "rcParams['axes.ymargin'] = 0\n",
    "\n",
    "rcParams['figure.figsize'] = [8, 5]\n",
    "rcParams['figure.dpi'] = 150\n",
    "\n",
    "rcParams['pdf.fonttype'] = 42\n",
    "rcParams['ps.fonttype'] = 42\n",
    "\n",
    "colors = sns.color_palette(n_colors=10)\n",
    "fig_size = (8, 5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "index_columns=['agent', 'risk_measure', 'alpha', 'weight', 'sim_seed', 'Model', 'agent_seed', 'agent_id']\n",
    "columns=['rewards']\n",
    "\n",
    "def cvar1(group):\n",
    "    # Perform some operation on the group\n",
    "    r_values = group.values\n",
    "    result = np.mean(r_values[r_values < np.quantile(r_values, 0.1)])\n",
    "    return result\n",
    "\n",
    "def cvar3(group):\n",
    "    # Perform some operation on the group\n",
    "    r_values = group.values\n",
    "    result = np.mean(r_values[r_values < np.quantile(r_values, 0.3)])\n",
    "    return result\n",
    "\n",
    "def cvar5(group):\n",
    "    # Perform some operation on the group\n",
    "    r_values = group.values\n",
    "    result = np.mean(r_values[r_values < np.quantile(r_values, 0.5)])\n",
    "    return result\n",
    "\n",
    "def cvar7(group):\n",
    "    # Perform some operation on the group\n",
    "    r_values = group.values\n",
    "    result = np.mean(r_values[r_values < np.quantile(r_values, 0.7)])\n",
    "    return result\n",
    "\n",
    "def srm1(group):\n",
    "    # Perform some operation on the group\n",
    "    r_values = group.values\n",
    "    nq = 10001\n",
    "    taus = np.linspace(0.0, 1.0, nq)\n",
    "    taus_middle = (taus[:-1] + taus[1:]) / 2\n",
    "    phi_values2 = weighted_sum_of_cvar(taus_middle, alphas=[0.1, 1.0], weights=[0.8, 0.2])\n",
    "    quantiles = np.quantile(r_values, taus_middle)\n",
    "    result = np.matmul(quantiles, phi_values2) / nq\n",
    "    return result\n",
    "\n",
    "\n",
    "def mean_value(group):\n",
    "    # Perform some operation on the group\n",
    "    result = group.mean()\n",
    "    return result\n",
    "\n",
    "df_grouped = df_exp.groupby(index_columns)[columns].agg([\n",
    "    (r\"$\\mathbb{E}$\", mean_value),\n",
    "    (r\"CVaR$_{0.1}$\", cvar1),\n",
    "    (r\"CVaR$_{0.3}$\", cvar3),\n",
    "    (r\"CVaR$_{0.5}$\", cvar5),\n",
    "    (r\"CVaR$_{0.7}$\", cvar7),\n",
    "    (r\"SRM$_{0.1,1.0}$\", srm1),\n",
    "])\n",
    "\n",
    "\n",
    "# Drop 'rewards' level from columns\n",
    "df_grouped.columns = df_grouped.columns.droplevel(0)\n",
    "df_grouped.index = df_grouped.index.droplevel([0,1,2,3,4,5])\n",
    "df_grouped = df_grouped.reset_index()\n",
    "df_grouped.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Calculate mean and std separately\n",
    "mean_df = df_grouped.groupby(['agent_id']).mean()\n",
    "std_df = df_grouped.groupby(['agent_id']).std()\n",
    "mean_df.drop(columns=['agent_seed'], inplace=True)\n",
    "std_df.drop(columns=['agent_seed'], inplace=True)\n",
    "# Create a new DataFrame with the same structure as mean_df\n",
    "result_df = mean_df.copy()\n",
    "\n",
    "# For each column in the DataFrame, format the values as \"mean ± std\"\n",
    "for col in result_df.columns:\n",
    "    result_df[col] = mean_df[col].map(\"{:.2f}\".format) + \"±\" + std_df[col].map(\"{:.2f}\".format)\n",
    "\n",
    "result_df.reset_index(inplace=True)\n",
    "result_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "result_df.to_latex(path_base + \"/result_df.tex\", column_format='|l|ccccccc|', escape=False, index=False)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "CleanRL",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
