{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# This file is licensed under the MIT License.\n",
    "# See the LICENSE file in the project root for full license information.\n",
    "import numpy as np\n",
    "import pickle\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy.special import rel_entr\n",
    "from envs.forked_riverswim import ForkedRiverSwim\n",
    "from envs.riverswim import RiverSwim\n",
    "from typing import Dict, Tuple\n",
    "from itertools import product\n",
    "import lzma\n",
    "from utils.utils import  policy_iteration\n",
    "from scipy.stats import t\n",
    "from make_agent import AgentType\n",
    "from simulation_parameters import EnvType\n",
    "from run import SequencedResults\n",
    "import seaborn as sns\n",
    "import pandas as pd\n",
    "\n",
    "def TV(p,q):\n",
    "    return np.sum(np.abs(p-q), -1) * 0.5\n",
    "\n",
    "def CE(x, c=0.95):\n",
    "    N = x.shape[0]\n",
    "    alpha = c + (1-c)/2\n",
    "    c = t.ppf(alpha, N)\n",
    "    s = np.std(x, axis=0, ddof=1)\n",
    "    return x.mean(0), c * s/ np.sqrt(N)\n",
    "\n",
    "    \n",
    "compute_dist_omega = lambda x,y: TV(x,y)\n",
    "compute_dist_value = lambda V, mdp: np.linalg.norm(V -mdp.V_greedy[np.newaxis, np.newaxis], axis=-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import shutil\n",
    "import subprocess\n",
    "\n",
    "TITLE_SIZE = 22\n",
    "LEGEND_SIZE = 15\n",
    "TICK_SIZE = 14\n",
    "AXIS_TITLE = 22\n",
    "XAXIS_LABEL = 22\n",
    "AXIS_LABEL = 22\n",
    "FONT_SIZE = 14\n",
    "\n",
    "\n",
    "def _latex_has_package(pkg: str) -> bool:\n",
    "    kpse = shutil.which(\"kpsewhich\")\n",
    "    if not kpse:\n",
    "        return False\n",
    "    try:\n",
    "        result = subprocess.run([kpse, pkg], check=False, capture_output=True, text=True)\n",
    "        return result.stdout.strip() != \"\"\n",
    "    except Exception:\n",
    "        return False\n",
    "\n",
    "\n",
    "use_tex = (\n",
    "    shutil.which(\"latex\") is not None\n",
    "    and _latex_has_package(\"newtxtext.sty\")\n",
    "    and _latex_has_package(\"newtxmath.sty\")\n",
    ")\n",
    "\n",
    "rc_parameters = {\n",
    "    \"font.size\": FONT_SIZE,\n",
    "    \"axes.titlesize\": AXIS_TITLE,\n",
    "    \"axes.labelsize\": AXIS_LABEL,\n",
    "    \"xtick.labelsize\": TICK_SIZE,\n",
    "    \"ytick.labelsize\": TICK_SIZE,\n",
    "    \"legend.fontsize\": LEGEND_SIZE,\n",
    "    \"figure.titlesize\": TITLE_SIZE,\n",
    "    \"font.family\": \"serif\",  # use serif/main font for text elements\n",
    "    \"text.usetex\": use_tex,  # render via LaTeX when available\n",
    "    \"pdf.fonttype\": 42,\n",
    "    \"ps.fonttype\": 42,\n",
    "}\n",
    "if use_tex:\n",
    "    rc_parameters[\"text.latex.preamble\"] = r\"\"\"\n",
    "        \\usepackage{newtxtext}\n",
    "        \\usepackage{newtxmath}\n",
    "    \"\"\"\n",
    "\n",
    "plt.rcParams.update(rc_parameters)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lengths = [\n",
    "    (5, EnvType.RIVERSWIM),\n",
    "    (3, EnvType.FORKED_RIVERSWIM),\n",
    "    (10, EnvType.RIVERSWIM),\n",
    "    (5, EnvType.FORKED_RIVERSWIM),\n",
    "    (20, EnvType.RIVERSWIM),\n",
    "    (10, EnvType.FORKED_RIVERSWIM),\n",
    "    (30, EnvType.RIVERSWIM),\n",
    "    (15, EnvType.FORKED_RIVERSWIM),\n",
    "    (50, EnvType.RIVERSWIM),\n",
    "    (25, EnvType.FORKED_RIVERSWIM)\n",
    "]\n",
    "agents = [\n",
    "    AgentType.MDP_NAS, AgentType.PS_MDP_NAS, AgentType.O_BPI,\n",
    "    AgentType.Q_UCB, AgentType.PSRL, AgentType.BAYES_MFBPI,\n",
    "    AgentType.Q_LEARNING, AgentType.VarDE_Q_LEARNING,\n",
    "]\n",
    "\n",
    "data: Dict[Tuple[int, EnvType, AgentType],  SequencedResults] = {}\n",
    "for env_type, agent in product(lengths, agents):\n",
    "    length, env = env_type\n",
    "    data[(length, env, agent)] = None\n",
    "    \n",
    "    print(f\"> Opening {env.value}/{agent.value}/{length}\")\n",
    "    with lzma.open(f'data/{env.value}/{agent.value}_{length}.pkl.lzma', 'rb') as f:\n",
    "        res  = pickle.load(f)\n",
    "        data[(length, env, agent)] = res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "labels_agents = {\n",
    "    AgentType.PSRL: 'PSRL',\n",
    "    AgentType.Q_UCB: 'Q-UCB',\n",
    "    AgentType.MDP_NAS: 'MDP-NaS',\n",
    "    AgentType.BAYES_MFBPI: 'MF-BPI',\n",
    "    AgentType.O_BPI: 'O-BPI',\n",
    "    AgentType.PS_MDP_NAS: 'PS-MDP-NaS',\n",
    "    AgentType.Q_LEARNING: 'Q-Learning',\n",
    "    AgentType.VarDE_Q_LEARNING: 'VarDE Q-Learning',\n",
    "}\n",
    "\n",
    "def find_first_below_threshold(matrix, threshold=0.05):\n",
    "    # Create a boolean mask of elements less than the threshold\n",
    "    mask = matrix < threshold\n",
    "    \n",
    "    # For each row, find the first occurrence of a value below the threshold and return the index\n",
    "    indices = np.array([np.argmax(row) if np.any(row) else len(row) for row in mask])\n",
    "    \n",
    "    return indices\n",
    "\n",
    "df = pd.DataFrame({\n",
    "    'N': [],\n",
    "    'Environment': [],\n",
    "    'Agent': [],\n",
    "    't_hit': []\n",
    "})\n",
    "for length, env_type in lengths:\n",
    "    if env_type == EnvType.RIVERSWIM:\n",
    "        env = RiverSwim(length)\n",
    "    else:\n",
    "        env = ForkedRiverSwim(length)\n",
    "    Vgreedy, _, _ =policy_iteration(0.99, env.transitions, env.rewards)\n",
    "    for agent in agents:\n",
    "        x = data[(length, env_type, agent)].dist_value_infinity / np.max(Vgreedy)\n",
    "        idxs =  1 - x[:,-1] #find_first_below_threshold(x)* data[(length, env_type, agent)].simulation_parameters.frequency_evaluation\n",
    "        for i in idxs:\n",
    "            _length = length * 2 -1 if env_type == EnvType.FORKED_RIVERSWIM else length\n",
    "            df.loc[len(df)] = [_length, env_type.value, labels_agents[agent], i]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib.patches import Patch\n",
    "colors = [\"#6BAED6\",\"#1F77B4\",\"#0B3C5D\",\"#8BC34A\",\"#2CA02C\",\"#0B6623\",\"#FF8C00\",\"#D62728\"]\n",
    "filtered_df = df\n",
    "unique_agents = filtered_df['Agent'].unique()\n",
    "agents_colors = ['MDP-NaS', 'PS-MDP-NaS', 'O-BPI', 'Q-UCB', 'PSRL', 'MF-BPI', 'Q-Learning', 'VarDE Q-Learning']\n",
    "\n",
    "color_dict = {agent: color for agent, color in zip(agents_colors, colors)}\n",
    "\n",
    "fig, axes = plt.subplots(1,2, sharey=True, figsize=(15, 3),)\n",
    "for ax_id, environment in enumerate([EnvType.RIVERSWIM.value, EnvType.FORKED_RIVERSWIM.value]):\n",
    "    env_data = filtered_df[filtered_df['Environment'] == environment]\n",
    "    sns.barplot(x=\"N\", y=\"t_hit\", hue=\"Agent\", data=env_data, palette=color_dict, ax=axes[ax_id], hue_order=[labels_agents[x]for x in agents], legend=False)\n",
    "    axes[ax_id].set_xlabel('$|S|$')#if environment == EnvType.RIVERSWIM.value else '$N$')\n",
    "    axes[ax_id].set_title(r'\\rm{' + environment + '}')\n",
    "\n",
    "axes[1].set_ylabel('')\n",
    "axes[0].set_ylabel(r\"$1-\\frac{\\|V^* - V^{\\hat\\pi_T}\\|_\\infty}{\\|V^*\\|_\\infty}$\")\n",
    "plt.subplots_adjust(wspace=0.05)\n",
    "legend_labels = [labels_agents[x] for x in agents]\n",
    "handles = [Patch(facecolor=color_dict[label], label=label) for label in legend_labels]\n",
    "labels = legend_labels\n",
    "\n",
    "# Manually reorder the handles and labels\n",
    "# ordered_handles = [handles[1], handles[2], handles[5],handles[0], handles[3], handles[4], handles[6]]  # Change the order as desired\n",
    "# ordered_labels = [labels[1], labels[2], labels[5], labels[0], labels[3], labels[4], labels[6]]  # Change the order as desired\n",
    "\n",
    "fig.legend(handles, labels, loc='center right', bbox_to_anchor=(1.08, 0.5), title='Agent')\n",
    "\n",
    "plt.savefig('figures/bpi.pdf', bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(5, 2 ,figsize=(13,25))\n",
    "\n",
    "linestyle = {\n",
    "    AgentType.Q_UCB: \"solid\", \n",
    "    AgentType.Q_LEARNING: \"solid\", \n",
    "    AgentType.BAYES_MFBPI: \"solid\",\n",
    "    AgentType.FORCED_MFBPI: \"solid\",\n",
    "    AgentType.PSRL: \"solid\",\n",
    "    AgentType.MDP_NAS: \"solid\",\n",
    "    AgentType.O_BPI: 'solid',\n",
    "    AgentType.PS_MDP_NAS: 'solid',\n",
    "    AgentType.VarDE_Q_LEARNING: 'solid',\n",
    "    AgentType.VarDEMFBPI: 'solid',\n",
    "}\n",
    "\n",
    "id_l = 0\n",
    "for length, env_type in lengths:\n",
    "    if env_type == EnvType.RIVERSWIM:\n",
    "        env = RiverSwim(length)\n",
    "    else:\n",
    "        env = ForkedRiverSwim(length)\n",
    "    Vgreedy, _, _ =policy_iteration(0.99, env.transitions, env.rewards)\n",
    "    \n",
    "    id_e = 0 if env_type == EnvType.RIVERSWIM else 1\n",
    "\n",
    "    if env_type == EnvType.RIVERSWIM:\n",
    "        id_l = 0 if length == 5 else 1 if length == 10 else 2 if length == 20 else 3 if length == 30 else 4\n",
    "    else:\n",
    "        id_l = 0  if length == 3 else 1 if length == 5 else 2 if length == 10 else 3 if length == 15 else 4\n",
    "    for agent_id,  agent in enumerate(agents):\n",
    "        parameters=  data[(length, env_type, agent)].simulation_parameters\n",
    "        x = range(0, parameters.horizon + 1, parameters.frequency_evaluation)\n",
    "\n",
    "        df = pd.DataFrame(data[(length, env_type, agent)].dist_value / np.linalg.norm(Vgreedy))\n",
    "        df = df.stack().reset_index()\n",
    "        df.columns = ['Run', 't', 'Value']\n",
    "        alpha = 0.05  # Adjust this for the desired confidence level (e.g., 0.05 for 95% confidence intervals)\n",
    "        grouped_data = df.groupby('t')['Value'].agg(['mean', 'sem'])\n",
    "        grouped_data['lower'] = grouped_data['mean'] - t.ppf(1 - alpha/2, len(data)-1) * grouped_data['sem']\n",
    "        grouped_data['upper'] = grouped_data['mean'] + t.ppf(1 - alpha/2, len(data)-1) * grouped_data['sem']\n",
    "        grouped_data.reset_index(inplace=True)\n",
    "        ax[id_l, id_e].plot(grouped_data['t'] * parameters.frequency_evaluation, grouped_data['mean'],label=labels_agents[agent], linestyle=linestyle[agent], color=colors[agent_id])\n",
    "        ax[id_l, id_e].fill_between(grouped_data['t'] * parameters.frequency_evaluation, grouped_data['lower'], grouped_data['upper'], alpha=0.2,  color=colors[agent_id])\n",
    "\n",
    "        #mu, ce = CE(data[(length, env_type, agent)].dist_value / np.linalg.norm(Vgreedy))\n",
    "        #ax[id_l, id_e].plot(x, mu, label=labels[agent], linestyle=linestyle[agent], color=colors[agent_id])\n",
    "        #ax[id_l, id_e].fill_between(x, mu-ce, mu+ce, alpha=0.2)\n",
    "        #sns.lineplot(x='t', y='Value', data=df, errorbar=('ci', 95))\n",
    "    #ax[id_l, id_e].grid()\n",
    "    #if id_l == 0:\n",
    "    ax[id_l, id_e].set_title(r'\\textrm{' + f'{env_type.value}({length})' + '}')\n",
    "\n",
    "    ax[id_l, id_e].set_ylabel(r'$\\frac{\\|V^* - V^{\\hat\\pi_t}\\|_2}{\\|V^*\\|_2}$')\n",
    "\n",
    "    if id_l == 4:\n",
    "        ax[id_l, id_e].set_xlabel('$t$')\n",
    "\n",
    "# plt.subplots_adjust(wspace=0.05)\n",
    "handles, labels = ax[0,0].get_legend_handles_labels()\n",
    "fig.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, .922), frameon=True, ncols = 4)\n",
    "\n",
    "plt.savefig('figures/full.pdf', bbox_inches='tight')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dude",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.14"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
