{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1753b0d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import warnings\n",
    "\n",
    "from src.rl_agent import RLAgent\n",
    "from src.rl_experiments import RLExperiments\n",
    "from src.state_representation import StateRepresentation\n",
    "\n",
    "from src.environments.env_continual_redpillbluepill import EnvironmentContinualRedPillBluePill\n",
    "\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "# pytorch_device = 'cuda'\n",
    "# pytorch_device = 'mps'\n",
    "pytorch_device = 'cpu'"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9e773e23",
   "metadata": {},
   "source": [
    "## tau-RPBP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "680d8232",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# define environment\n",
    "env = EnvironmentContinualRedPillBluePill(render_mode=None) \n",
    "\n",
    "# define agent\n",
    "actions = list(env.action_dict.keys())\n",
    "states = list(env.state_dict.values())\n",
    " \n",
    "policy = None\n",
    "\n",
    "agent = RLAgent(agent_type='q_learning',\n",
    "                states=states,\n",
    "                actions=actions,\n",
    "                policy=policy,\n",
    "                avg_reward_method='differential',\n",
    "                initial_avg_reward=0.0,\n",
    "                action_type='discrete',\n",
    "                action_selection_rule='epsilon_greedy',\n",
    "                policy_type='tabular',\n",
    "                value_type='tabular',\n",
    "                pytorch_device=pytorch_device,\n",
    "                use_cvar=True, \n",
    "                var_quantile=0.9, \n",
    "                initial_var_reward=0.0,\n",
    "               )\n",
    "\n",
    "# run experiment\n",
    "rl_experiments = RLExperiments()\n",
    "\n",
    "step_sizes = {\n",
    "    'value': 0.02,\n",
    "    'avg_reward': 0.1,\n",
    "    'var': 0.1,\n",
    "}\n",
    "\n",
    "df_tau_rpbp = rl_experiments.run_experiment_continuing(experiment='tau_rpbp',\n",
    "                                                       agent=agent, \n",
    "                                                       env=env,\n",
    "                                                       num_runs=50,\n",
    "                                                       max_steps=110000,\n",
    "                                                       discount=1.0,\n",
    "                                                       epsilon=0.1,\n",
    "                                                       step_size=step_sizes,\n",
    "                                                       tau_change_step=50000,\n",
    "                                                       tau_change_value=0.1,\n",
    "                                                       )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46e85879",
   "metadata": {},
   "outputs": [],
   "source": [
    "# tau-RPBP plot\n",
    "df_dict = {\n",
    "    'RED CVaR': {\n",
    "        'df': df_tau_rpbp,\n",
    "        'color_percent': '#FEB780',\n",
    "    },\n",
    "}\n",
    "\n",
    "rl_experiments = RLExperiments()\n",
    "rl_experiments.get_tau_results_figure(experiment='tau_rpbp',\n",
    "                                      df_dict=df_dict, \n",
    "                                      n_runs=50, \n",
    "                                      tau_change_step=50000,\n",
    "                                      rolling_average_amount=500,\n",
    "                                      x_max=100000,\n",
    "                                     )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1ed8e527",
   "metadata": {},
   "source": [
    "## s-RPBP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1fad4db5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# define environment\n",
    "env = EnvironmentContinualRedPillBluePill(render_mode=None, change_dist=True) \n",
    "\n",
    "# define agent\n",
    "actions = list(env.action_dict.keys())\n",
    "states = list(env.state_dict.values())\n",
    " \n",
    "policy = None\n",
    "\n",
    "agent = RLAgent(agent_type='q_learning',\n",
    "                states=states,\n",
    "                actions=actions,\n",
    "                policy=policy,\n",
    "                avg_reward_method='differential',\n",
    "                initial_avg_reward=0.0,\n",
    "                action_type='discrete',\n",
    "                action_selection_rule='epsilon_greedy',\n",
    "                policy_type='tabular',\n",
    "                value_type='tabular',\n",
    "                pytorch_device=pytorch_device,\n",
    "                use_cvar=True, \n",
    "                var_quantile=0.25, \n",
    "                initial_var_reward=0.0,\n",
    "               )\n",
    "\n",
    "# run experiment\n",
    "rl_experiments = RLExperiments()\n",
    "\n",
    "step_sizes = {\n",
    "    'value': 0.02,\n",
    "    'avg_reward': 0.1,\n",
    "    'var': 0.1,\n",
    "}\n",
    "\n",
    "df_s_rpbp = rl_experiments.run_experiment_continuing(experiment='s_rpbp',\n",
    "                                                     agent=agent, \n",
    "                                                     env=env,\n",
    "                                                     num_runs=10,\n",
    "                                                     max_steps=120000,\n",
    "                                                     discount=1.0,\n",
    "                                                     epsilon=0.1,\n",
    "                                                     step_size=step_sizes,\n",
    "                                                     )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f82aef4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# s-RPBP plot\n",
    "df_dict = {\n",
    "    'RED CVaR': {\n",
    "        'df': df_s_rpbp,\n",
    "        'color_cvar': '#FEB780',\n",
    "    },\n",
    "}\n",
    "\n",
    "rl_experiments = RLExperiments()\n",
    "rl_experiments.get_s_results_figure(experiment='s_rpbp',\n",
    "                                    df_dict=df_dict, \n",
    "                                    rolling_average_amount=1000,\n",
    "                                    x_max=119900,\n",
    "                                    quantile=0.25,\n",
    "                                    epsilon=0.1,\n",
    "                                    env=env,\n",
    "                                    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b2a942dc",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
