{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gymnasium as gym\n",
    "import numpy as np\n",
    "from stable_baselines3 import DQN\n",
    "import matplotlib.pyplot as plt\n",
    "from data import *\n",
    "from stable_baselines3.common.monitor import Monitor\n",
    "from stable_baselines3.common.results_plotter import load_results, ts2xy\n",
    "import os\n",
    "from matplotlib.lines import Line2D\n",
    "from matplotlib.patches import Patch\n",
    "from time import time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ClassifGym(gym.Env):\n",
    "    def __init__(self, x, y, rmax=0, rmin=-1, p=1):\n",
    "        super(ClassifGym, self).__init__()\n",
    "        self.all_p = np.linspace(0, 1, p + 2)[1:-1]\n",
    "        self.nb_features = x.shape[1]\n",
    "        unique_x, idx_unique_x = np.unique(x, axis=0, return_index=True)\n",
    "        self.states = unique_x\n",
    "        self.nb_states = self.states.shape[0]\n",
    "        self.x_mins = self.states.min(axis=0)\n",
    "        self.x_maxs = self.states.max(axis=0)\n",
    "        self.classes = y[idx_unique_x]\n",
    "        self.action_space = gym.spaces.Discrete(len(np.unique(self.classes)))\n",
    "        self.observation_space = gym.spaces.Box(\n",
    "            low=self.x_mins, high=self.x_maxs, dtype=np.float32\n",
    "        )\n",
    "        self.rmax = rmax\n",
    "        self.rmin = rmin\n",
    "        self.state = None\n",
    "        # self.states=np.array([list(s) for s in self.states])\n",
    "        # print(self.states)\n",
    "\n",
    "    def reset(\n",
    "        self,\n",
    "        **kwargs\n",
    "    ):\n",
    "        self.state_idx = np.random.choice(self.nb_states)\n",
    "        self.state = self.states[self.state_idx]\n",
    "        return self.state, {}\n",
    "\n",
    "    def step(self, action):\n",
    "        if action == self.classes[self.state_idx]:\n",
    "            r = self.rmax\n",
    "        else:\n",
    "            r = self.rmin\n",
    "        next_state_idx = np.random.choice(self.nb_states)\n",
    "        next_state = self.states[next_state_idx]\n",
    "        self.state_idx, self.state = next_state_idx, next_state\n",
    "        return self.state, r, True, False, {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def cartesian_prod(x, y):\n",
    "    return np.transpose([np.tile(x, len(y)), np.repeat(y, len(x))])\n",
    "\n",
    "\n",
    "class IBMDPGym(gym.Env):\n",
    "    def __init__(\n",
    "        self,\n",
    "        sup_dt_mdp,\n",
    "        p=4,\n",
    "        zeta=0,\n",
    "        gamma=1,\n",
    "        read_obs_path=None,\n",
    "        max_tree_depth=None,\n",
    "        compute_obs=True,\n",
    "    ):\n",
    "        super(IBMDPGym, self).__init__()\n",
    "        # self.obs_depth = obs_depth\n",
    "        self.gamma = gamma\n",
    "        self.p = p\n",
    "        self.max_tree_depth = max_tree_depth\n",
    "\n",
    "        self.base_mdp = sup_dt_mdp\n",
    "        assert isinstance(\n",
    "            self.base_mdp.action_space, gym.spaces.Discrete\n",
    "        ), \"NEEDS MDP WITH DISCRETE ACTION SPACE\"\n",
    "        self.lower_bounds = self.base_mdp.observation_space.low\n",
    "        self.upper_bounds = self.base_mdp.observation_space.high\n",
    "        self.base_states = self.base_mdp.observation_space\n",
    "        self.nb_base_features = self.base_mdp.observation_space.shape[0]\n",
    "\n",
    "        self.initial_observation = np.append(self.lower_bounds, self.upper_bounds)\n",
    "        # self.observation_space = gym.spaces.Box(\n",
    "        #     np.hstack((self.lower_bounds, self.lower_bounds, self.lower_bounds)),\n",
    "        #     np.hstack((self.upper_bounds, self.upper_bounds, self.upper_bounds)),\n",
    "        #     dtype=np.float32,\n",
    "        # )\n",
    "        self.observation_space = gym.spaces.Box(\n",
    "            np.hstack((self.lower_bounds, self.lower_bounds)),\n",
    "            np.hstack((self.upper_bounds, self.upper_bounds)),\n",
    "            dtype=np.float32,\n",
    "        )\n",
    "\n",
    "        self.all_p = np.linspace(0, 1, self.p + 2)[1:-1]\n",
    "        self.info_actions = cartesian_prod(\n",
    "            list(range(self.nb_base_features)), self.all_p\n",
    "        )\n",
    "        self.nb_info_actions = len(self.info_actions)\n",
    "        self.action_space = gym.spaces.Discrete(\n",
    "            self.nb_info_actions + self.base_mdp.action_space.n\n",
    "        )\n",
    "\n",
    "        self.zeta = zeta\n",
    "        # self.final_s = [-1] * self.observation_space.shape[0]\n",
    "        self.depth = 0\n",
    "\n",
    "    def step(self, action):\n",
    "        if action < self.base_mdp.action_space.n:\n",
    "            _, r, _, _, _ = self.base_mdp.step(action)\n",
    "            self.state = self.initial_observation.copy()\n",
    "            return self.state.copy(), r, True, False, {}\n",
    "        else:\n",
    "            modif = self.state.copy()\n",
    "            if self.depth == 3:\n",
    "                # self.base_s, r, term, trunc, infos = self.base_mdp.step(action)\n",
    "                self.state = self.initial_observation.copy()\n",
    "                return self.state.copy(), -1, True, False, {}\n",
    "            else:\n",
    "                self.depth +=1\n",
    "                info = action - self.base_mdp.action_space.n\n",
    "                feat = int(self.info_actions[info][0])\n",
    "                f_lower = modif[feat]\n",
    "                f_upper = modif[self.nb_base_features + feat]\n",
    "                value_p = (\n",
    "                    self.info_actions[info][1] * (f_upper - f_lower) + f_lower\n",
    "                )  # unnormalized value\n",
    "                # deterministic transition\n",
    "                if self.base_s[feat] <= value_p:\n",
    "                    modif[self.nb_base_features + feat] = min(f_upper, value_p)\n",
    "                else:\n",
    "                    modif[feat] = max(f_lower, value_p)\n",
    "                self.state = modif\n",
    "                return self.state.copy(), self.zeta, False, False, {}\n",
    "\n",
    "    def reset(self, **kwargs):\n",
    "        self.depth = 0\n",
    "        self.base_s, _ = self.base_mdp.reset()\n",
    "        self.state = self.initial_observation.copy()\n",
    "        # print(self.state)\n",
    "        return self.state.copy(), {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "functions_ = [\n",
    "    get_avila_data,\n",
    "    get_bank_data,\n",
    "    get_bean_data,\n",
    "    get_bidding_data,\n",
    "    get_eeg_data,\n",
    "    get_fault_data,\n",
    "    get_htru_data,\n",
    "    get_magic_data,\n",
    "    get_occupancy_data,\n",
    "    get_page_data,\n",
    "    get_raisin_data,\n",
    "    get_rice_data,\n",
    "    get_room_data,\n",
    "    get_segment_data,\n",
    "    get_skin_data,\n",
    "    get_wilt_data,\n",
    "]\n",
    "names_ = [\n",
    "    \"avila\",\n",
    "    \"bank\",\n",
    "    \"bean\",\n",
    "    \"bidding\",\n",
    "    \"eeg\",\n",
    "    \"fault\",\n",
    "    \"htru\",\n",
    "    \"magic\",\n",
    "    \"occupancy\",\n",
    "    \"page\",\n",
    "    \"raisin\",\n",
    "    \"rice\",\n",
    "    \"room\",\n",
    "    \"segment\",\n",
    "    \"skin\",\n",
    "    \"wilt\",\n",
    "]\n",
    "\n",
    "for f, dataset in enumerate(names_):\n",
    "    ##################################################\n",
    "    X, y = functions_[f]()\n",
    "    avg_perf = []\n",
    "    avg_perf_ = []\n",
    "\n",
    "    for run in range(5):\n",
    "        agent = DQN.load(\"dqn_{}_run{}.zip\".format(dataset,run))\n",
    "        env = ClassifGym(X,y)\n",
    "        env = IBMDPGym(env)\n",
    "        tot = 0\n",
    "        for x in X:\n",
    "            s, _  = env.reset()\n",
    "            env.base_s = x.copy()\n",
    "            rs = 0\n",
    "            done = False\n",
    "            while not done:\n",
    "                s, r, term, trunc, infos = env.step(int(agent.predict(s, deterministic=True)[0]))\n",
    "                rs+=r\n",
    "                done = term or trunc\n",
    "            tot += rs+1\n",
    "        avg_perf.append(tot/len(X))\n",
    "\n",
    "\n",
    "        agent = DQN.load(\"dqnp3_{}_run{}.zip\".format(dataset,run))\n",
    "        env = ClassifGym(X,y, p=2)\n",
    "        env = IBMDPGym(env)\n",
    "        tot_ = 0\n",
    "        for x in X:\n",
    "            s, _  = env.reset()\n",
    "            env.base_s = x.copy()\n",
    "            rs = 0\n",
    "            done = False\n",
    "            while not done:\n",
    "                s, r, term, trunc, infos = env.step(int(agent.predict(s, deterministic=True)[0]))\n",
    "                rs+=r\n",
    "                done = term or trunc\n",
    "            tot_ += rs+1\n",
    "        avg_perf_.append(tot_/len(X))\n",
    "\n",
    "    print(dataset, \"&${} \\pm {}$&${} \\pm {}$\".format(np.round(np.mean(avg_perf) * 100, 1), np.round(np.std(avg_perf)*100, 1), np.round(np.mean(avg_perf_) * 100, 1), np.round(np.std(avg_perf_)*100, 1)))\n",
    "    # print(dataset, np.round(np.mean(avg_perf) * 100, 1), np.round(np.std(avg_perf)*100, 1))\n",
    "\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def moving_average(values, window):\n",
    "    \"\"\"\n",
    "    Smooth values by doing a moving average\n",
    "    :param values: (numpy array)\n",
    "    :param window: (int)\n",
    "    :return: (numpy array)\n",
    "    \"\"\"\n",
    "    weights = np.repeat(1.0, window) / window\n",
    "    return np.convolve(values, weights, \"valid\")\n",
    "\n",
    "\n",
    "def plot_results(log_folder, label=\"\", c = \"orange\", alpha=0.6):\n",
    "    \"\"\"\n",
    "    plot the results\n",
    "\n",
    "    :param log_folder: (str) the save location of the results to plot\n",
    "    :param title: (str) the title of the task to plot\n",
    "    \"\"\"\n",
    "    x, y = ts2xy(load_results(log_folder), x_axis=\"walltime_hrs\")\n",
    "    y = moving_average(y, window=5000)\n",
    "    # Truncate x\n",
    "    x = x[len(x) - len(y) :]\n",
    "    x = x * 3600\n",
    "\n",
    "    return x, y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "colors = [\"forestgreen\", \"limegreen\", \"slategrey\", \"burlywood\", \"orange\", \"gold\", \"brown\", \"red\", \"sienna\", \"tomato\", \"fuchsia\", \"green\", \"blue\", \"cyan\", \"magenta\", \"purple\"]\n",
    "legend_elements = []\n",
    "for f, dataset in enumerate(names_):\n",
    "    legend_elements.append(Line2D([0], [0], color=colors[f], label='DQN-{}'.format(dataset) ,linewidth=2))\n",
    "    for r in range(5):\n",
    "        log_dir = \"./dqn_{}_run{}_md/\".format(dataset, r)\n",
    "        x,y = plot_results(log_dir)\n",
    "        plt.plot(x,y, alpha=0.7, c = colors[f])\n",
    "# for r in range(10):\n",
    "#     log_dir = \"./A2C_depth1_{}_run{}_md/\".format(z,r)\n",
    "#     x, y = plot_results(log_dir, c=\"red\", alpha=0.7)\n",
    "#     lines.append(axs[i].plot(x,y, alpha=0.7, c = \"red\", label=\"A2C\")[0])\n",
    "# lines.append(axs[i].axhline(2 * z,  color=\"black\", label=\"Infinite Depth Tree\", linestyle=\"dotted\"))\n",
    "# lines.append(axs[i].axhline(z - 1/4, color=\"blue\", label=\"Best Depth 1 Tree\"))\n",
    "# lines.append(axs[i].axhline(z - 1/4, color=\"blue\", label=\"Depth 1 Tree\"))\n",
    "# lines.append(axs[i].axhline(-1/2, color=\"black\", label=\"Depth 0 Tree\", linestyle=\"dashed\"))\n",
    "# axs[i].set_title(r\"$\\alpha = {}$\".format(z))\n",
    "plt.xlabel(\"walltime in sec.\")\n",
    "plt.ylabel(\"DQN-5 cumulative MDP rewards\")\n",
    "plt.xlim(0,1000)\n",
    "    \n",
    "# Create a legend below the subplots\n",
    "\n",
    "                #    Line2D([0], [0], color='red', label='A2C'),\n",
    "                #    Line2D([0], [0], color=\"black\", label=\"Infinite Depth Tree\", linestyle=\"dotted\"),\n",
    "                #    Line2D([0], [0], color=\"blue\", label=\"Sub-Optimal Depth 1 Tree\"),\n",
    "                #    Line2D([0], [0], color=\"black\", label=\"Depth 0 Tree\", linestyle=\"dashed\")]\n",
    "plt.legend(handles = legend_elements, labels = [\"{}\".format(n) for n in names_], loc='best', prop={\"size\": 10})\n",
    "# Adjust spacing between subplots and legend\n",
    "# plt.tight_layout(pad=2.0)\n",
    "\n",
    "# Show the plot\n",
    "plt.savefig(\"dqn.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "colors = [\"forestgreen\", \"limegreen\", \"slategrey\", \"burlywood\", \"orange\", \"gold\", \"brown\", \"red\", \"sienna\", \"tomato\", \"fuchsia\", \"green\", \"blue\", \"cyan\", \"magenta\", \"purple\"]\n",
    "legend_elements = []\n",
    "for f, dataset in enumerate(names_):\n",
    "    legend_elements.append(Line2D([0], [0], color=colors[f], label='DQN-{}'.format(dataset) ,linewidth=2))\n",
    "    for r in range(5):\n",
    "        log_dir = \"./dqnp3_{}_run{}_md/\".format(dataset, r)\n",
    "        x,y = plot_results(log_dir)\n",
    "        plt.plot(x,y, alpha=0.7, c = colors[f])\n",
    "# for r in range(10):\n",
    "#     log_dir = \"./A2C_depth1_{}_run{}_md/\".format(z,r)\n",
    "#     x, y = plot_results(log_dir, c=\"red\", alpha=0.7)\n",
    "#     lines.append(axs[i].plot(x,y, alpha=0.7, c = \"red\", label=\"A2C\")[0])\n",
    "# lines.append(axs[i].axhline(2 * z,  color=\"black\", label=\"Infinite Depth Tree\", linestyle=\"dotted\"))\n",
    "# lines.append(axs[i].axhline(z - 1/4, color=\"blue\", label=\"Best Depth 1 Tree\"))\n",
    "# lines.append(axs[i].axhline(z - 1/4, color=\"blue\", label=\"Depth 1 Tree\"))\n",
    "# lines.append(axs[i].axhline(-1/2, color=\"black\", label=\"Depth 0 Tree\", linestyle=\"dashed\"))\n",
    "# axs[i].set_title(r\"$\\alpha = {}$\".format(z))\n",
    "plt.xlabel(\"walltime in sec.\")\n",
    "plt.ylabel(\"cumulative MDP rewards\")\n",
    "plt.xlim(0,1000)\n",
    "    \n",
    "# Create a legend below the subplots\n",
    "\n",
    "                #    Line2D([0], [0], color='red', label='A2C'),\n",
    "                #    Line2D([0], [0], color=\"black\", label=\"Infinite Depth Tree\", linestyle=\"dotted\"),\n",
    "                #    Line2D([0], [0], color=\"blue\", label=\"Sub-Optimal Depth 1 Tree\"),\n",
    "                #    Line2D([0], [0], color=\"black\", label=\"Depth 0 Tree\", linestyle=\"dashed\")]\n",
    "plt.legend(handles = legend_elements, labels = [\"DQN-3-{}\".format(n) for n in names_], loc='best', prop={\"size\": 8})\n",
    "# Adjust spacing between subplots and legend\n",
    "# plt.tight_layout(pad=2.0)\n",
    "\n",
    "# Show the plot\n",
    "plt.savefig(\"dqn3.pdf\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
