{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gymnasium as gym\n",
    "import numpy as np\n",
    "from stable_baselines3 import PPO,DQN, A2C\n",
    "from gymnasium.wrappers.time_limit import TimeLimit\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ClassifGym(gym.Env):\n",
    "    def __init__(self, x, y, rmax=0, rmin=-1, p=1):\n",
    "        super(ClassifGym, self).__init__()\n",
    "        self.all_p = np.linspace(0, 1, p + 2)[1:-1]\n",
    "        self.nb_features = x.shape[1]\n",
    "        unique_x, idx_unique_x = np.unique(x, axis=0, return_index=True)\n",
    "        self.states = unique_x\n",
    "        self.nb_states = self.states.shape[0]\n",
    "        self.x_mins = self.states.min(axis=0)\n",
    "        self.x_maxs = self.states.max(axis=0)\n",
    "        self.classes = y[idx_unique_x]\n",
    "        self.action_space = gym.spaces.Discrete(len(np.unique(self.classes)))\n",
    "        self.observation_space = gym.spaces.Box(\n",
    "            low=self.x_mins, high=self.x_maxs, dtype=np.float32\n",
    "        )\n",
    "        self.rmax = rmax\n",
    "        self.rmin = rmin\n",
    "        self.state = None\n",
    "        # self.states=np.array([list(s) for s in self.states])\n",
    "        # print(self.states)\n",
    "\n",
    "    def reset(\n",
    "        self,\n",
    "        **kwargs\n",
    "    ):\n",
    "        self.state_idx = np.random.choice(self.nb_states)\n",
    "        self.state = self.states[self.state_idx]\n",
    "        return self.state, {}\n",
    "\n",
    "    def step(self, action):\n",
    "        if action == self.classes[self.state_idx]:\n",
    "            r = self.rmax\n",
    "        else:\n",
    "            r = self.rmin\n",
    "        next_state_idx = np.random.choice(self.nb_states)\n",
    "        next_state = self.states[next_state_idx]\n",
    "        self.state_idx, self.state = next_state_idx, next_state\n",
    "        return self.state, r, True, False, {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def cartesian_prod(x, y):\n",
    "    return np.transpose([np.tile(x, len(y)), np.repeat(y, len(x))])\n",
    "\n",
    "## Main Class: builds a zeta-obs depth IBMDP associated to a given MDP ##\n",
    "class IBMDPGym(gym.Env):\n",
    "    def __init__(\n",
    "        self,\n",
    "        sup_dt_mdp,\n",
    "        p=1,\n",
    "        zeta=-0.2,\n",
    "        gamma=1,\n",
    "        read_obs_path=None,\n",
    "        max_tree_depth=None,\n",
    "        compute_obs=True,\n",
    "    ):\n",
    "        super(IBMDPGym, self).__init__()\n",
    "        # self.obs_depth = obs_depth\n",
    "        self.gamma = gamma\n",
    "        self.p = p\n",
    "        self.max_tree_depth = max_tree_depth\n",
    "\n",
    "        self.base_mdp = sup_dt_mdp\n",
    "        assert isinstance(\n",
    "            self.base_mdp.action_space, gym.spaces.Discrete\n",
    "        ), \"NEEDS MDP WITH DISCRETE ACTION SPACE\"\n",
    "        self.lower_bounds = self.base_mdp.observation_space.low\n",
    "        self.upper_bounds = self.base_mdp.observation_space.high\n",
    "        self.base_states = self.base_mdp.observation_space\n",
    "        self.nb_base_features = self.base_mdp.observation_space.shape[0]\n",
    "\n",
    "        self.initial_observation = np.append(self.lower_bounds, self.upper_bounds)\n",
    "        # self.observation_space = gym.spaces.Box(\n",
    "        #     np.hstack((self.lower_bounds, self.lower_bounds, self.lower_bounds)),\n",
    "        #     np.hstack((self.upper_bounds, self.upper_bounds, self.upper_bounds)),\n",
    "        #     dtype=np.float32,\n",
    "        # )\n",
    "        self.observation_space = gym.spaces.Box(\n",
    "            np.hstack((self.lower_bounds, self.lower_bounds)),\n",
    "            np.hstack((self.upper_bounds, self.upper_bounds)),\n",
    "            dtype=np.float32,\n",
    "        )\n",
    "\n",
    "        self.all_p = np.linspace(0, 1, self.p + 2)[1:-1]\n",
    "        self.info_actions = cartesian_prod(\n",
    "            list(range(self.nb_base_features)), self.all_p\n",
    "        )\n",
    "        self.nb_info_actions = len(self.info_actions)\n",
    "        self.action_space = gym.spaces.Discrete(\n",
    "            self.nb_info_actions + self.base_mdp.action_space.n\n",
    "        )\n",
    "\n",
    "        self.zeta = zeta\n",
    "        # self.final_s = [-1] * self.observation_space.shape[0]\n",
    "        self.depth = 0\n",
    "\n",
    "    def step(self, action):\n",
    "        if action < self.base_mdp.action_space.n:\n",
    "            _, r, _, _, _ = self.base_mdp.step(action)\n",
    "            self.state = self.initial_observation.copy()\n",
    "            return self.state, r, True, False, {}\n",
    "        else:\n",
    "            if self.depth == 1:\n",
    "                # self.base_s, r, term, trunc, infos = self.base_mdp.step(action)\n",
    "                self.state = self.initial_observation.copy()\n",
    "                return self.state, -1/2, True, False, {}\n",
    "            else:\n",
    "                self.depth +=1\n",
    "                info = action - self.base_mdp.action_space.n\n",
    "                feat = int(self.info_actions[info][0])\n",
    "                f_lower = self.state[feat]\n",
    "                f_upper = self.state[self.nb_base_features + feat]\n",
    "                value_p = (\n",
    "                    self.info_actions[info][1] * (f_upper - f_lower) + f_lower\n",
    "                )  # unnormalized value\n",
    "                # deterministic transition\n",
    "                if self.base_s[feat] <= value_p:\n",
    "                    self.state[self.nb_base_features + feat] = min(f_upper, value_p)\n",
    "                else:\n",
    "                    self.state[feat] = max(f_lower, value_p)\n",
    "                return self.state, self.zeta, False, False, {}\n",
    "\n",
    "    def reset(self, **kwargs):\n",
    "        self.depth = 0\n",
    "        self.base_s, _ = self.base_mdp.reset()\n",
    "        self.state = self.initial_observation.copy()\n",
    "        # print(self.state)\n",
    "        return self.state, {}\n",
    "    \n",
    "class IBMDPGymD2(gym.Env):\n",
    "    def __init__(\n",
    "        self,\n",
    "        sup_dt_mdp,\n",
    "        p=1,\n",
    "        zeta=-0.2,\n",
    "        gamma=1,\n",
    "        read_obs_path=None,\n",
    "        max_tree_depth=None,\n",
    "        compute_obs=True,\n",
    "    ):\n",
    "        super(IBMDPGymD2, self).__init__()\n",
    "        # self.obs_depth = obs_depth\n",
    "        self.gamma = gamma\n",
    "        self.p = p\n",
    "        self.max_tree_depth = max_tree_depth\n",
    "\n",
    "        self.base_mdp = sup_dt_mdp\n",
    "        assert isinstance(\n",
    "            self.base_mdp.action_space, gym.spaces.Discrete\n",
    "        ), \"NEEDS MDP WITH DISCRETE ACTION SPACE\"\n",
    "        self.lower_bounds = self.base_mdp.observation_space.low\n",
    "        self.upper_bounds = self.base_mdp.observation_space.high\n",
    "        self.base_states = self.base_mdp.observation_space\n",
    "        self.nb_base_features = self.base_mdp.observation_space.shape[0]\n",
    "\n",
    "        self.initial_observation = np.append(self.lower_bounds, self.upper_bounds)\n",
    "        # self.observation_space = gym.spaces.Box(\n",
    "        #     np.hstack((self.lower_bounds, self.lower_bounds, self.lower_bounds)),\n",
    "        #     np.hstack((self.upper_bounds, self.upper_bounds, self.upper_bounds)),\n",
    "        #     dtype=np.float32,\n",
    "        # )\n",
    "        self.observation_space = gym.spaces.Box(\n",
    "            np.hstack((self.lower_bounds, self.lower_bounds)),\n",
    "            np.hstack((self.upper_bounds, self.upper_bounds)),\n",
    "            dtype=np.float32,\n",
    "        )\n",
    "\n",
    "        self.all_p = np.linspace(0, 1, self.p + 2)[1:-1]\n",
    "        self.info_actions = cartesian_prod(\n",
    "            list(range(self.nb_base_features)), self.all_p\n",
    "        )\n",
    "        self.nb_info_actions = len(self.info_actions)\n",
    "        self.action_space = gym.spaces.Discrete(\n",
    "            self.nb_info_actions + self.base_mdp.action_space.n\n",
    "        )\n",
    "\n",
    "        self.zeta = zeta\n",
    "        # self.final_s = [-1] * self.observation_space.shape[0]\n",
    "        self.depth = 0\n",
    "\n",
    "    def step(self, action):\n",
    "        if action < self.base_mdp.action_space.n:\n",
    "            _, r, _, _, _ = self.base_mdp.step(action)\n",
    "            self.state = self.initial_observation.copy()\n",
    "            return self.state.copy(), r, True, False, {}\n",
    "        else:\n",
    "            modif = self.state.copy()\n",
    "            if self.depth == 2:\n",
    "                # self.base_s, r, term, trunc, infos = self.base_mdp.step(action)\n",
    "                self.state = self.initial_observation.copy()\n",
    "                return self.state.copy(), -1/2, True, False, {}\n",
    "            else:\n",
    "                self.depth +=1\n",
    "                info = action - self.base_mdp.action_space.n\n",
    "                feat = int(self.info_actions[info][0])\n",
    "                f_lower = modif[feat]\n",
    "                f_upper = modif[self.nb_base_features + feat]\n",
    "                value_p = (\n",
    "                    self.info_actions[info][1] * (f_upper - f_lower) + f_lower\n",
    "                )  # unnormalized value\n",
    "                # deterministic transition\n",
    "                if self.base_s[feat] <= value_p:\n",
    "                    modif[self.nb_base_features + feat] = min(f_upper, value_p)\n",
    "                else:\n",
    "                    modif[feat] = max(f_lower, value_p)\n",
    "                self.state = modif\n",
    "                return self.state.copy(), self.zeta, False, False, {}\n",
    "\n",
    "    def reset(self, **kwargs):\n",
    "        self.depth = 0\n",
    "        self.base_s, _ = self.base_mdp.reset()\n",
    "        self.state = self.initial_observation.copy()\n",
    "        # print(self.state)\n",
    "        return self.state.copy(), {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from stable_baselines3.common.monitor import Monitor\n",
    "from stable_baselines3.common.results_plotter import load_results, ts2xy\n",
    "import os\n",
    "from matplotlib.lines import Line2D\n",
    "from matplotlib.patches import Patch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def moving_average(values, window):\n",
    "    \"\"\"\n",
    "    Smooth values by doing a moving average\n",
    "    :param values: (numpy array)\n",
    "    :param window: (int)\n",
    "    :return: (numpy array)\n",
    "    \"\"\"\n",
    "    weights = np.repeat(1.0, window) / window\n",
    "    return np.convolve(values, weights, \"valid\")\n",
    "\n",
    "\n",
    "def plot_results(log_folder, label=\"\", c = \"orange\", alpha=0.6):\n",
    "    \"\"\"\n",
    "    plot the results\n",
    "\n",
    "    :param log_folder: (str) the save location of the results to plot\n",
    "    :param title: (str) the title of the task to plot\n",
    "    \"\"\"\n",
    "    x, y = ts2xy(load_results(log_folder), x_axis=\"walltime_hrs\")\n",
    "    y = moving_average(y, window=5000)\n",
    "    # Truncate x\n",
    "    x = x[len(x) - len(y) :]\n",
    "    x = x * 3600\n",
    "\n",
    "    return x, y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for z in [0]:\n",
    "    for r in range(5):\n",
    "        X = np.array([[1,2], [2,1], [3,4], [4,3]])\n",
    "        y = np.array([0,1,2,3])\n",
    "\n",
    "        log_dir = \"./runs/depth2_{}_run{}_md/\".format(z,r)\n",
    "        os.makedirs(log_dir, exist_ok=True)\n",
    "        # X, y = get_segment_data()\n",
    "\n",
    "        env = ClassifGym(X,y)\n",
    "        env = IBMDPGymD2(env, zeta=z)\n",
    "        env = Monitor(env, log_dir)\n",
    "\n",
    "        ppo = PPO(\"MlpPolicy\", env)\n",
    "        ppo.learn(500_000)\n",
    "\n",
    "    for r in range(5):\n",
    "        X = np.array([[1,2], [2,1], [3,4], [4,3]])\n",
    "        y = np.array([0,1,2,3])\n",
    "        log_dir = \"./runs/dqn_depth2_{}_run{}_md/\".format(z,r)\n",
    "        os.makedirs(log_dir, exist_ok=True)\n",
    "        # X, y = get_segment_data()\n",
    "\n",
    "        env = ClassifGym(X,y)\n",
    "        env = IBMDPGymD2(env, zeta=z)\n",
    "        env = Monitor(env, log_dir)\n",
    "\n",
    "        dqn = DQN(\"MlpPolicy\", env,  learning_starts=10_000, buffer_size=5000, target_update_interval=1000, exploration_final_eps=0)\n",
    "        dqn.learn(500_000)\n",
    "\n",
    "    # for r in range(10):\n",
    "    #     log_dir = \"./dqn_depth2_{}_run{}/\".format(z,r)\n",
    "    #     os.makedirs(log_dir, exist_ok=True)\n",
    "    #     # X, y = get_segment_data()\n",
    "\n",
    "    #     env = ClassifGym(X,y)\n",
    "    #     env = IBMDPGym(env, zeta=z)\n",
    "    #     env = TimeLimit(env,3)\n",
    "    #     env = Monitor(env, log_dir)\n",
    "\n",
    "    #     dqn = DQN(\"MlpPolicy\", env, learning_starts=0, buffer_size=2000)\n",
    "    #     dqn.learn(50_000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i, z in enumerate([0]):\n",
    "    for r in range(5):\n",
    "        log_dir = \"./runs/dqn_depth2_{}_run{}_md/\".format(z,r)\n",
    "        x,y = plot_results(log_dir)\n",
    "        plt.plot(x,y, alpha=0.7, c = \"orange\", linewidth=3)\n",
    "    for r in range(5):\n",
    "        log_dir = \"./runs/depth2_{}_run{}_md/\".format(z,r)\n",
    "        x, y = plot_results(log_dir, c=\"purple\")\n",
    "        plt.plot(x,y, alpha=0.7, c = \"purple\", linewidth=3)\n",
    "    # for r in range(10):\n",
    "    #     log_dir = \"./A2C_depth1_{}_run{}_md/\".format(z,r)\n",
    "    #     x, y = plot_results(log_dir, c=\"red\", alpha=0.7)\n",
    "    #     lines.append(axs[i].plot(x,y, alpha=0.7, c = \"red\", label=\"A2C\")[0])\n",
    "    plt.hlines(2*z,0,600, color=\"black\", linewidth=3)\n",
    "    # lines.append(axs[i].axhline(2 * z,  color=\"black\", label=\"Infinite Depth Tree\", linestyle=\"dotted\"))\n",
    "    # lines.append(axs[i].axhline(z - 1/4, color=\"blue\", label=\"Best Depth 1 Tree\"))\n",
    "    # lines.append(axs[i].axhline(z - 1/4, color=\"blue\", label=\"Depth 1 Tree\"))\n",
    "    # lines.append(axs[i].axhline(-1/2, color=\"black\", label=\"Depth 0 Tree\", linestyle=\"dashed\"))\n",
    "    # axs[i].set_title(r\"$\\alpha = {}$\".format(z))\n",
    "plt.xlabel(\"walltime in sec.\")\n",
    "plt.ylabel(\"cumulative MDP rewards\")\n",
    "\n",
    "    \n",
    "# Create a legend below the subplots\n",
    "legend_elements = [Line2D([0], [0], color='orange', label='DQN' ,linewidth=3),\n",
    "                #    Line2D([0], [0], color='red', label='A2C'),\n",
    "                   Line2D([0], [0], color='purple', label='PPO' ,linewidth=3),\n",
    "                   Line2D([0], [0], color=\"black\", label=\"Opt. Depth 2 Tree\",linewidth=3)]\n",
    "                #    Line2D([0], [0], color=\"black\", label=\"Infinite Depth Tree\", linestyle=\"dotted\"),\n",
    "                #    Line2D([0], [0], color=\"blue\", label=\"Sub-Optimal Depth 1 Tree\"),\n",
    "                #    Line2D([0], [0], color=\"black\", label=\"Depth 0 Tree\", linestyle=\"dashed\")]\n",
    "plt.legend(handles = legend_elements, labels = [\"DQN\", \"PPO\", \"Opt. Depth 2 Tree\"], loc='best', prop={\"size\": 14})\n",
    "plt.xlim(1,600)\n",
    "# Adjust spacing between subplots and legend\n",
    "# plt.tight_layout(pad=2.0)\n",
    "\n",
    "# Show the plot\n",
    "plt.savefig(\"depth2RL.pdf\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
