{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train agents"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "path_base = \"runs/spectrums\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import subprocess\n",
    "\n",
    "# Create each script by appending the unique part to the base script\n",
    "base_script = [\"python\", \"qrsrm.py\", \"--env-id\", \"TradingEnv-v0\", \"--save-model\", \"--n-quantiles\", \"200\", \"--total-timesteps\", \"1000000\", \"--gamma\", \"0.99\", \"--dir\", path_base]\n",
    "\n",
    "scripts1 = [base_script + [\"--risk-measure\", \"Exp\", \"--alpha\", \"12\"]]\n",
    "scripts2 = [base_script + [\"--risk-measure\", \"Dual\", \"--alpha\", \"4\"]]\n",
    "scripts3 = [base_script + [\"--risk-measure\", \"WSCVaR\", \"--alphas\", \"0.1,0.6,1.0\", \"--weights\", \"0.2,0.3,0.5\"]]\n",
    "\n",
    "scripts = scripts1 + scripts2 + scripts3\n",
    "\n",
    "seeds = [\"1\"]\n",
    "scripts = [script + [\"--seed\", seed] for script in scripts for seed in seeds]\n",
    "# Run each script\n",
    "for script in scripts:\n",
    "    print(f\"Running script {script}...\")\n",
    "    try:\n",
    "        subprocess.run(script, check=True)\n",
    "    except subprocess.CalledProcessError:\n",
    "        print(f\"Script {script} failed!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Run the simulations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import run_simulation_from_dir, add_columns, make_agent_hue_kws\n",
    "from utils import AGENT_NAME_MAP, AGG_RISK_VALUES\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import os\n",
    "os.environ[\"KMP_DUPLICATE_LIB_OK\"]=\"TRUE\"\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore', category=UserWarning)\n",
    "warnings.filterwarnings('ignore', category=DeprecationWarning)\n",
    "warnings.filterwarnings('ignore', category=RuntimeWarning)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython import display\n",
    "%matplotlib inline\n",
    "\n",
    "from matplotlib import ticker\n",
    "from matplotlib import rcParams\n",
    "\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "backend_format = \"retina\"  # @param [\"retina\", \"\"]\n",
    "%config InlineBackend.figure_format = backend_format\n",
    "\n",
    "sns.set_context(\"notebook\")\n",
    "sns.set_style(\"ticks\")\n",
    "\n",
    "rcParams['ytick.right'] = True\n",
    "rcParams['axes.autolimit_mode'] = 'round_numbers'\n",
    "rcParams['axes.xmargin'] = 0\n",
    "rcParams['axes.ymargin'] = 0\n",
    "\n",
    "rcParams['figure.figsize'] = [8, 5]\n",
    "rcParams['figure.dpi'] = 150\n",
    "\n",
    "rcParams['pdf.fonttype'] = 42\n",
    "rcParams['ps.fonttype'] = 42\n",
    "\n",
    "colors = sns.color_palette(n_colors=10)\n",
    "fig_size = (8, 5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Nsimulations = 10000\n",
    "sim_seed = 6\n",
    "pkl_path = path_base+\"/df_exp.pkl\"\n",
    "if \"df_exp.pkl\" not in os.listdir(path_base):\n",
    "    df_exp = run_simulation_from_dir(path_base + \"/\", Nsimulations=Nsimulations, sim_seed=sim_seed)\n",
    "    df_exp.to_pickle(pkl_path)\n",
    "\n",
    "df_exp = pd.read_pickle(pkl_path)\n",
    "df_exp = df_exp.pipe(add_columns)\n",
    "df_exp = df_exp.sort_values(by=['agent', 'risk_measure', 'alpha', 'n_quantile', 'environment_name', 'agent_seed']).reset_index(drop=True)\n",
    "print(df_exp.shape)\n",
    "df_exp.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "index_columns=['agent', 'risk_measure', 'alpha', 'sim_seed', 'Model', 'agent_seed']\n",
    "columns=['rewards']\n",
    "\n",
    "df_grouped = df_exp.groupby(index_columns)[columns].agg(AGG_RISK_VALUES)\n",
    "\n",
    "\n",
    "# Drop 'rewards' level from columns\n",
    "df_grouped.columns = df_grouped.columns.droplevel(0)\n",
    "df_grouped.index = df_grouped.index.droplevel([0,1,2,3])\n",
    "df_grouped = df_grouped.reset_index()\n",
    "df_grouped.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Calculate mean and std separately\n",
    "mean_df = df_grouped.groupby(['Model']).mean()\n",
    "std_df = df_grouped.groupby(['Model']).std()\n",
    "mean_df.drop(columns=['agent_seed'], inplace=True)\n",
    "std_df.drop(columns=['agent_seed'], inplace=True)\n",
    "# Create a new DataFrame with the same structure as mean_df\n",
    "result_df = mean_df.copy()\n",
    "\n",
    "# For each column in the DataFrame, format the values as \"mean ± std\"\n",
    "for col in result_df.columns:\n",
    "    result_df[col] = mean_df[col].map(\"{:.2f}\".format) + \"±\" + std_df[col].map(\"{:.2f}\".format)\n",
    "\n",
    "result_df.reset_index(inplace=True)\n",
    "result_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import exponential_risk_measures, dual_power, weighted_sum_of_cvar, CVaR\n",
    "fig, ax = plt.subplots(figsize=fig_size)\n",
    "\n",
    "x = np.linspace(0.0, 1.0, 1001)\n",
    "\n",
    "plt.plot(x, exponential_risk_measures(x, alpha=12), label=r\"$\\operatorname{ERM}$\", color=colors[0])\n",
    "plt.plot(x, dual_power(x,alpha=4), label=r\"$\\operatorname{DPRM}$\", color=colors[1])\n",
    "plt.plot(x, weighted_sum_of_cvar(x, alphas=[0.1,0.6,1.0], weights=[0.2,0.3,0.5]), label=r\"$\\operatorname{WSCVaR}$\", color=colors[2])\n",
    "\n",
    "\n",
    "ax.set_title('')\n",
    "ax.spines['right'].set_visible(True)\n",
    "ax.spines['top'].set_visible(True)\n",
    "ax.spines['left'].set_visible(True)\n",
    "ax.spines['bottom'].set_visible(True)\n",
    "ax.set_xlabel(r\"$\\alpha$\")\n",
    "ax.set_ylabel(r\"$\\phi(\\alpha)$\")\n",
    "\n",
    "    \n",
    "legend = ax.legend(loc='upper right', title='', frameon=False)\n",
    "plt.setp(legend.get_lines(), linewidth=3, alpha=0.5)\n",
    "\n",
    "fig.set_facecolor('white')\n",
    "fig.tight_layout()\n",
    "\n",
    "\n",
    "plt.savefig(path_base + '/risk_specrums.pdf', transparent=True)\n",
    "plt.savefig(path_base + '/risk_specrums.eps', format='eps', dpi=1200)\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "experiments_ordered_by_performance = [dict(agent_id=\"qrsrm_12.0\", agent_name=r'QR-SRM($\\lambda$=12.0)', color=colors[0]),\n",
    "                                      dict(agent_id=\"qrsrm_4.0\", agent_name=r'QR-SRM($\\nu$=4.0)', color=colors[1]),\n",
    "                                      dict(agent_id=\"qrsrm_0.1,0.6,1.0\", agent_name=r'QR-SRM($\\alpha$=0.1,0.6,1.0)', color=colors[2])]\n",
    "\n",
    "agent_names, hue_kws = make_agent_hue_kws(experiments_ordered_by_performance)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(figsize=(8, 5))\n",
    "\n",
    "sns.histplot(data=df_exp, \n",
    "             x=\"rewards\", \n",
    "             hue=\"Model\", \n",
    "             stat=\"density\", \n",
    "             common_norm=False, \n",
    "             alpha=0.5,\n",
    "             bins=200, \n",
    "             kde=True,\n",
    "             hue_order=agent_names,\n",
    "             palette=colors,\n",
    "             legend='brief',\n",
    "             ax=ax,\n",
    "    )\n",
    "    \n",
    "\n",
    "mean_df2 = mean_df[mean_df.index.isin(df_exp['Model'])][['$\\operatorname{WSCVaR}_{0.1,0.6,1.0}$', '$\\operatorname{ERM}_{12.0}$', '$\\operatorname{DPRM}_{4.0}$']]\n",
    "\n",
    "line_styles = ['dotted', 'dashed', 'solid']\n",
    "# Add vertical lines\n",
    "for agent, value in mean_df2.iterrows():\n",
    "    for i, v in enumerate(value):\n",
    "        ax.vlines(x=v, \n",
    "                  ymin=0, \n",
    "                  ymax=1.2, \n",
    "                  colors=hue_kws['color'][agent_names.index(agent)], \n",
    "                  linestyles=line_styles[2-i],\n",
    "                  alpha=1,\n",
    "                  linewidth=1,\n",
    "                  )\n",
    "    \n",
    "    \n",
    "    \n",
    "ax.set_title('')\n",
    "ax.spines['right'].set_visible(True)\n",
    "ax.spines['top'].set_visible(True)\n",
    "ax.spines['left'].set_visible(True)\n",
    "ax.spines['bottom'].set_visible(True)\n",
    "ax.set_xlabel('Discounted Future Rewards')\n",
    "ax.set_ylabel('Density')\n",
    "ax.set_xlim([-3, 4])\n",
    "\n",
    "\n",
    "    \n",
    "# Get the existing legend\n",
    "legend = ax.get_legend()\n",
    "# change legend names\n",
    "\n",
    "# Modify the legend\n",
    "legend.set_title('')\n",
    "legend.set_frame_on(False)\n",
    "\n",
    "plt.setp(legend.get_lines(), linewidth=3, alpha=0.5)\n",
    "\n",
    "fig.set_facecolor('white')\n",
    "fig.tight_layout()\n",
    "\n",
    "plt.savefig(path_base + '/comparison_srm_final.pdf', transparent=True)\n",
    "plt.savefig(path_base + '/comparison_srm_final.eps', format='eps', dpi=1200)\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "CleanRL",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
