{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gymnasium as gym\n",
    "import numpy as np\n",
    "from stable_baselines3 import DQN\n",
    "import matplotlib.pyplot as plt\n",
    "from data import *\n",
    "from stable_baselines3.common.monitor import Monitor\n",
    "from stable_baselines3.common.results_plotter import load_results, ts2xy\n",
    "import os\n",
    "from matplotlib.lines import Line2D\n",
    "from matplotlib.patches import Patch\n",
    "from time import time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ClassifGym(gym.Env):\n",
    "    def __init__(self, x, y, rmax=0, rmin=-1, p=1):\n",
    "        super(ClassifGym, self).__init__()\n",
    "        self.all_p = np.linspace(0, 1, p + 2)[1:-1]\n",
    "        self.nb_features = x.shape[1]\n",
    "        unique_x, idx_unique_x = np.unique(x, axis=0, return_index=True)\n",
    "        self.states = unique_x\n",
    "        self.nb_states = self.states.shape[0]\n",
    "        self.x_mins = self.states.min(axis=0)\n",
    "        self.x_maxs = self.states.max(axis=0)\n",
    "        self.classes = y[idx_unique_x]\n",
    "        self.action_space = gym.spaces.Discrete(len(np.unique(self.classes)))\n",
    "        self.observation_space = gym.spaces.Box(\n",
    "            low=self.x_mins, high=self.x_maxs, dtype=np.float32\n",
    "        )\n",
    "        self.rmax = rmax\n",
    "        self.rmin = rmin\n",
    "        self.state = None\n",
    "        # self.states=np.array([list(s) for s in self.states])\n",
    "        # print(self.states)\n",
    "\n",
    "    def reset(\n",
    "        self,\n",
    "        **kwargs\n",
    "    ):\n",
    "        self.state_idx = np.random.choice(self.nb_states)\n",
    "        self.state = self.states[self.state_idx]\n",
    "        return self.state, {}\n",
    "\n",
    "    def step(self, action):\n",
    "        if action == self.classes[self.state_idx]:\n",
    "            r = self.rmax\n",
    "        else:\n",
    "            r = self.rmin\n",
    "        next_state_idx = np.random.choice(self.nb_states)\n",
    "        next_state = self.states[next_state_idx]\n",
    "        self.state_idx, self.state = next_state_idx, next_state\n",
    "        return self.state, r, True, False, {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def cartesian_prod(x, y):\n",
    "    return np.transpose([np.tile(x, len(y)), np.repeat(y, len(x))])\n",
    "\n",
    "\n",
    "class IBMDPGym(gym.Env):\n",
    "    def __init__(\n",
    "        self,\n",
    "        sup_dt_mdp,\n",
    "        p=4,\n",
    "        zeta=0,\n",
    "        gamma=1,\n",
    "        read_obs_path=None,\n",
    "        max_tree_depth=None,\n",
    "        compute_obs=True,\n",
    "    ):\n",
    "        super(IBMDPGym, self).__init__()\n",
    "        # self.obs_depth = obs_depth\n",
    "        self.gamma = gamma\n",
    "        self.p = p\n",
    "        self.max_tree_depth = max_tree_depth\n",
    "\n",
    "        self.base_mdp = sup_dt_mdp\n",
    "        assert isinstance(\n",
    "            self.base_mdp.action_space, gym.spaces.Discrete\n",
    "        ), \"NEEDS MDP WITH DISCRETE ACTION SPACE\"\n",
    "        self.lower_bounds = self.base_mdp.observation_space.low\n",
    "        self.upper_bounds = self.base_mdp.observation_space.high\n",
    "        self.base_states = self.base_mdp.observation_space\n",
    "        self.nb_base_features = self.base_mdp.observation_space.shape[0]\n",
    "\n",
    "        self.initial_observation = np.append(self.lower_bounds, self.upper_bounds)\n",
    "        # self.observation_space = gym.spaces.Box(\n",
    "        #     np.hstack((self.lower_bounds, self.lower_bounds, self.lower_bounds)),\n",
    "        #     np.hstack((self.upper_bounds, self.upper_bounds, self.upper_bounds)),\n",
    "        #     dtype=np.float32,\n",
    "        # )\n",
    "        self.observation_space = gym.spaces.Box(\n",
    "            np.hstack((self.lower_bounds, self.lower_bounds)),\n",
    "            np.hstack((self.upper_bounds, self.upper_bounds)),\n",
    "            dtype=np.float32,\n",
    "        )\n",
    "\n",
    "        self.all_p = np.linspace(0, 1, self.p + 2)[1:-1]\n",
    "        self.info_actions = cartesian_prod(\n",
    "            list(range(self.nb_base_features)), self.all_p\n",
    "        )\n",
    "        self.nb_info_actions = len(self.info_actions)\n",
    "        self.action_space = gym.spaces.Discrete(\n",
    "            self.nb_info_actions + self.base_mdp.action_space.n\n",
    "        )\n",
    "\n",
    "        self.zeta = zeta\n",
    "        # self.final_s = [-1] * self.observation_space.shape[0]\n",
    "        self.depth = 0\n",
    "\n",
    "    def step(self, action):\n",
    "        if action < self.base_mdp.action_space.n:\n",
    "            _, r, _, _, _ = self.base_mdp.step(action)\n",
    "            self.state = self.initial_observation.copy()\n",
    "            return self.state.copy(), r, True, False, {}\n",
    "        else:\n",
    "            modif = self.state.copy()\n",
    "            if self.depth == 3:\n",
    "                # self.base_s, r, term, trunc, infos = self.base_mdp.step(action)\n",
    "                self.state = self.initial_observation.copy()\n",
    "                return self.state.copy(), -1, True, False, {}\n",
    "            else:\n",
    "                self.depth +=1\n",
    "                info = action - self.base_mdp.action_space.n\n",
    "                feat = int(self.info_actions[info][0])\n",
    "                f_lower = modif[feat]\n",
    "                f_upper = modif[self.nb_base_features + feat]\n",
    "                value_p = (\n",
    "                    self.info_actions[info][1] * (f_upper - f_lower) + f_lower\n",
    "                )  # unnormalized value\n",
    "                # deterministic transition\n",
    "                if self.base_s[feat] <= value_p:\n",
    "                    modif[self.nb_base_features + feat] = min(f_upper, value_p)\n",
    "                else:\n",
    "                    modif[feat] = max(f_lower, value_p)\n",
    "                self.state = modif\n",
    "                return self.state.copy(), self.zeta, False, False, {}\n",
    "\n",
    "    def reset(self, **kwargs):\n",
    "        self.depth = 0\n",
    "        self.base_s, _ = self.base_mdp.reset()\n",
    "        self.state = self.initial_observation.copy()\n",
    "        # print(self.state)\n",
    "        return self.state.copy(), {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "functions_ = [\n",
    "    get_avila_data,\n",
    "    get_bank_data,\n",
    "    get_bean_data,\n",
    "    get_bidding_data,\n",
    "    get_eeg_data,\n",
    "    get_fault_data,\n",
    "    get_htru_data,\n",
    "    get_magic_data,\n",
    "    get_occupancy_data,\n",
    "    get_page_data,\n",
    "    get_raisin_data,\n",
    "    get_rice_data,\n",
    "    get_room_data,\n",
    "    get_segment_data,\n",
    "    get_skin_data,\n",
    "    get_wilt_data,\n",
    "]\n",
    "names_ = [\n",
    "    \"avila\",\n",
    "    \"bank\",\n",
    "    \"bean\",\n",
    "    \"bidding\",\n",
    "    \"eeg\",\n",
    "    \"fault\",\n",
    "    \"htru\",\n",
    "    \"magic\",\n",
    "    \"occupancy\",\n",
    "    \"page\",\n",
    "    \"raisin\",\n",
    "    \"rice\",\n",
    "    \"room\",\n",
    "    \"segment\",\n",
    "    \"skin\",\n",
    "    \"wilt\",\n",
    "]\n",
    "\n",
    "for f, dataset in enumerate(names_):\n",
    "    ##################################################\n",
    "    X, y = functions_[f]()\n",
    "    t = 0\n",
    "    for r in range(5):\n",
    "\n",
    "        log_dir = \"./runs/dqn_{}_run{}_md/\".format(dataset,r)\n",
    "        os.makedirs(log_dir, exist_ok=True)\n",
    "        # X, y = get_segment_data()\n",
    "\n",
    "        env = ClassifGym(X,y)\n",
    "        env = IBMDPGym(env)\n",
    "        env = Monitor(env, log_dir)\n",
    "        ts = time()\n",
    "        ppo = DQN(\"MlpPolicy\", env,  learning_starts=10_000, buffer_size=5000, target_update_interval=1000, exploration_final_eps=0)\n",
    "        ppo.learn(500_000)\n",
    "        t += time()-ts\n",
    "        ppo.save(\"./deeprl_policies/dqn_{}_run{}\".format(dataset,r))\n",
    "    print(dataset, t/5)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
