{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "be9599db-297a-4778-a34e-b4a50fb05bb4",
   "metadata": {},
   "source": [
    "# Hybrid RL Coverage Experiments for Tabular MDPs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "f0e89e29-3a5b-488d-b66c-ec51e7454d66",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.ticker as mticker\n",
    "import mdptoolbox.util as util\n",
    "import mdptoolbox.example as example\n",
    "import mdptoolbox.mdp as mdpt\n",
    "from tqdm import tqdm\n",
    "import sys\n",
    "sys.path.append('core')\n",
    "from core import core_mdp, core_env, core_ope\n",
    "from helpers import *\n",
    "plt.style.use('matplotlibrc')\n",
    "\n",
    "np.set_printoptions(suppress=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "e6cd10b4-1a13-4097-9d3f-b231f16486ac",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0, 0, 0, 0)"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "P, R = example.forest(S = 4, r1 = 4, r2 = 2, p=0.1)#.5\n",
    "#P, R, x_dist = core_env.orig_gridworld_ope_tools()\n",
    "#R = R.mean(-1).T\n",
    "\n",
    "# this is as close to 1 as possible\n",
    "# while still suppressing warnings from mdptoolbox\n",
    "#gamma = 0.9999999999999999\n",
    "gamma = 0.9999\n",
    "\n",
    "# idea: do you wait for a forest to get old to support wildlife,\n",
    "#       or cut it down for wood each year?\n",
    "# S is how old the forest can get\n",
    "# r1 is reward for waiting in oldest state\n",
    "# r2 is reward for cutting in oldest state\n",
    "# p is probability forest burns down each year\n",
    "H = 20\n",
    "forest = mdpt.ValueIteration(P, R, gamma, H)\n",
    "forest.run()\n",
    "nActions = P.shape[0]\n",
    "nStates = P.shape[1]\n",
    "\n",
    "# this is the optimal policy!\n",
    "forest.policy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "40e560fc-fdac-4f2c-8036-ce872f9dcbbf",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(1)\n",
    "forest_mdp = core_mdp.MDP(P, R, np.ones(nStates)/nStates, 1)\n",
    "Noff = 100\n",
    "pi_star = translate_policy(forest.policy, nActions)\n",
    "pi_unif = np.ones((nStates, nActions))/nActions\n",
    "pi_bad = (0.6*(1 - pi_star)/(nActions-1) + 0.4*pi_unif)\n",
    "dataset_optimal = collect_sample(Noff, forest_mdp, \n",
    "                         pi_star, H, stationary=True)\n",
    "dataset_covered = collect_sample(Noff, forest_mdp, \n",
    "                         pi_unif, H, stationary=True)\n",
    "dataset_bad = collect_sample(Noff, forest_mdp, \n",
    "                         pi_bad, H, stationary=True)\n",
    "occ_star = getOcc(dataset_optimal, nStates, nActions)\n",
    "occ_covered = getOcc(dataset_covered, nStates, nActions)\n",
    "occ_bad = getOcc(dataset_bad, nStates, nActions)\n",
    "occ_all = np.ones(occ_star.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "7adc330a-f423-4bff-91f4-d9af947f7245",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[0.102 , 0.    ],\n",
       "       [0.0995, 0.    ],\n",
       "       [0.1045, 0.    ],\n",
       "       [0.694 , 0.    ]])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "occ_star"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "546a61b2-0b9c-4ab8-8146-38fd66b17d12",
   "metadata": {},
   "outputs": [],
   "source": [
    "thresh = 1/nStates/nActions\n",
    "part_on_cov = (occ_covered<thresh)\n",
    "part_on_opt = (occ_star<thresh)\n",
    "part_on_bad = (occ_bad<thresh)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "037cafa4-b382-4d92-8fcb-962966d8ff3f",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/30 [00:00<?, ?it/s]"
     ]
    }
   ],
   "source": [
    "def ucbvi(mdp, P, R, dataset, nStates, nActions, H, T, delta=0.01):\n",
    "    # this is as close to 1 as possible\n",
    "    # while still suppressing warnings from mdptoolbox\n",
    "    gamma = 0.9999999999999999\n",
    "    gamma = 0.999#9999999999999\n",
    "\n",
    "    pis = []\n",
    "    occupancies = []\n",
    "    vals = []\n",
    "    visits = []\n",
    "    n = len(dataset)\n",
    "    for t in range(T):\n",
    "        # getting transition probabilities\n",
    "        # filling nonstochastic rows with 1/S\n",
    "        phat = getPhat(dataset, nStates, nActions)\n",
    "        phat[phat.sum(-1) != 1] = 1/nStates\n",
    "        \n",
    "        # getting bonus term and rewards\n",
    "        n_sa = getN_sa(dataset, nStates, nActions)\n",
    "        bonus = 2*H*np.sqrt(np.log(nStates*nActions*H*T/delta)/n_sa) #/ 10\n",
    "        bonus[np.isinf(bonus)] = 0\n",
    "        bonus[np.isnan(bonus)] = 0\n",
    "        r = (getR_sa(dataset, nStates, nActions) \n",
    "             + bonus)\n",
    "        \n",
    "        # value iteration\n",
    "        vi = mdpt.ValueIteration(phat, r, gamma, H)\n",
    "        vi.run()\n",
    "        \n",
    "        # playing policy, simulating trajectory\n",
    "        pi = translate_policy(vi.policy, nActions)\n",
    "        traj = mdp.generate_trajectory(pi, H, stationary=True)\n",
    "        dataset = np.vstack([dataset, traj[None,...]])\n",
    "        \n",
    "        pis.append(pi)\n",
    "        #dataset_pi = collect_sample(100, mdp, pi, H)\n",
    "        #occupancies.append(getN_sa(dataset_pi, nStates, nActions)/H/len(dataset_pi))\n",
    "        \n",
    "        occupancies.append(getN_sa(dataset[n:], nStates, nActions)/H/(t+1))\n",
    "        visits.append(getN_sa(dataset[-1:], nStates, nActions))\n",
    "        #occupancies.append(n_sa/H/len(dataset))\n",
    "        #vals.append(eval_pi(pi, P, R, nStates, H))\n",
    "    \n",
    "        \n",
    "    return pi, dataset, np.array(pis), np.array(occupancies), np.array(visits)\n",
    "\n",
    "n_trials = 30\n",
    "occs_ucb_bad_trials = []\n",
    "occs_ucb_cov_trials = []\n",
    "occs_ucb_opt_trials = []\n",
    "occs_ucb_on_trials = []\n",
    "visits_ucb_bad_trials = []\n",
    "visits_ucb_cov_trials = []\n",
    "visits_ucb_opt_trials = []\n",
    "visits_ucb_on_trials = []\n",
    "dataset_ucb_bad_trials = []\n",
    "dataset_ucb_cov_trials = []\n",
    "dataset_ucb_opt_trials = []\n",
    "dataset_ucb_on_trials = []\n",
    "for i in tqdm(range(n_trials)):  \n",
    "    pi_ucb_bad, dataset_ucb_bad, pis_ucb_bad, occs_ucb_bad, visits_ucb_bad = ucbvi(\n",
    "                            forest_mdp, P, R,\n",
    "                            dataset_bad, nStates, nActions, \n",
    "                            H, 200, delta=0.01)\n",
    "\n",
    "    pi_ucb_cov, dataset_ucb_cov, pis_ucb_cov, occs_ucb_cov, visits_ucb_cov = ucbvi(\n",
    "                            forest_mdp, P, R,\n",
    "                            dataset_covered, nStates, nActions, \n",
    "                            H, 200, delta=0.01)\n",
    "\n",
    "    pi_ucb_opt, dataset_ucb_opt, pis_ucb_opt, occs_ucb_opt, visits_ucb_opt = ucbvi(\n",
    "                            forest_mdp, P, R,\n",
    "                            dataset_optimal, nStates, nActions, \n",
    "                            H, 200, delta=0.01)\n",
    "\n",
    "    pi_ucb_on, dataset_ucb_on, pis_ucb_on, occs_ucb_on, visits_ucb_on = ucbvi(\n",
    "                            forest_mdp, P, R,\n",
    "                            dataset_covered[:1], nStates, nActions, \n",
    "                            H, 200, delta=0.01)\n",
    "    \n",
    "    occs_ucb_bad_trials.append(occs_ucb_bad)\n",
    "    occs_ucb_cov_trials.append(occs_ucb_cov)\n",
    "    occs_ucb_opt_trials.append(occs_ucb_opt)\n",
    "    occs_ucb_on_trials.append(occs_ucb_on)\n",
    "    visits_ucb_bad_trials.append(visits_ucb_bad)\n",
    "    visits_ucb_cov_trials.append(visits_ucb_cov)\n",
    "    visits_ucb_opt_trials.append(visits_ucb_opt)\n",
    "    visits_ucb_on_trials.append(visits_ucb_on)\n",
    "    dataset_ucb_bad_trials.append(dataset_ucb_bad)\n",
    "    dataset_ucb_cov_trials.append(dataset_ucb_cov)\n",
    "    dataset_ucb_opt_trials.append(dataset_ucb_opt)\n",
    "    dataset_ucb_on_trials.append(dataset_ucb_on)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3584a551-1e59-4002-90f4-d7483e017343",
   "metadata": {},
   "outputs": [],
   "source": [
    "titles = ['Adversarial Behavior Policy', \n",
    "          'Uniform Behavior Policy', \n",
    "          'Optimal Behavior Policy',\n",
    "          'Online Only']\n",
    "datasets = [np.array(dataset_ucb_bad_trials)[:,Noff:,:,-1].mean(-1),\n",
    "            np.array(dataset_ucb_cov_trials)[:,Noff:,:,-1].mean(-1),\n",
    "            np.array(dataset_ucb_opt_trials)[:,Noff:,:,-1].mean(-1),\n",
    "            np.array(dataset_ucb_on_trials)[:,1:,:,-1].mean(-1)]\n",
    "plt.figure(figsize=(12,4))\n",
    "for i, title in enumerate(titles):\n",
    "    plt.plot(pd.Series(datasets[i].mean(0)).rolling(5).mean(), label=title)\n",
    "    n = len(datasets[i].mean(0))\n",
    "    plt.fill_between(np.arange(n),\n",
    "                     datasets[i].mean(0) - 1*datasets[i].std(0),\n",
    "                     datasets[i].mean(0) + 1*datasets[i].std(0),\n",
    "                     alpha=0.2)\n",
    "plt.legend()\n",
    "plt.xlabel('Online Episodes')\n",
    "plt.ylabel('Average Reward')\n",
    "plt.title('Average Per-Episode Reward Over Online Timesteps')\n",
    "plt.tight_layout()\n",
    "plt.axhline(dataset_optimal[:,:,-1].mean(-1).mean(), linestyle='dashed',\n",
    "           label='Optimal Policy Average Reward', color='black')\n",
    "plt.savefig('figs/reward_tabular.png', dpi=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "706ed806-7cc4-408c-b6c6-351e356de5a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1,3, figsize=(12,4))\n",
    "parts = [part_on_bad, part_on_cov, part_on_opt]\n",
    "occs = [np.array(occs_ucb_bad_trials), np.array(occs_ucb_cov_trials), np.array(occs_ucb_opt_trials)]\n",
    "titles = ['Adversarial Behavior Policy', \n",
    "          'Uniform Behavior Policy', \n",
    "          'Optimal Behavior Policy']\n",
    "\n",
    "for i in range(3):\n",
    "    cstars_hy = np.array([[c_star(occ_star[parts[i]], occ[parts[i]]) \n",
    "                   for occ in occ_t] for occ_t in occs[i]])\n",
    "    cstars_on = np.array([[c_star(occ_star[parts[i]], occ[parts[i]]) \n",
    "                   for occ in occ_t] for occ_t in occs_ucb_on_trials])\n",
    "    ax[i].plot(cstars_hy.mean(0), linestyle='solid',\n",
    "         label='Hybrid')\n",
    "    ax[i].plot(cstars_on.mean(0), linestyle='dashed',\n",
    "         label='Online Only')\n",
    "    T = np.array(occs_ucb_bad_trials).shape[1]\n",
    "    ax[i].fill_between(np.arange(T),\n",
    "                       cstars_hy.mean(0) - 1.96*cstars_hy.std(0),\n",
    "                       cstars_hy.mean(0) + 1.96*cstars_hy.std(0),\n",
    "                      alpha=0.3)\n",
    "    ax[i].fill_between(np.arange(T),\n",
    "                       cstars_on.mean(0) - 1.96*cstars_on.std(0),\n",
    "                       cstars_on.mean(0) + 1.96*cstars_on.std(0),\n",
    "                      alpha=0.3)\n",
    "                       \n",
    "    ax[i].set_title(titles[i], fontsize=10)\n",
    "    \n",
    "    ax[i].set_xlabel('Online Episodes', fontsize=10, y=0.05)\n",
    "    \n",
    "    #ax[i].set_yscale('symlog')\n",
    "    ax[i].yaxis.set_major_formatter(mticker.FormatStrFormatter('%.1f'))\n",
    "    ax[i].legend()\n",
    "    ax[i].set_ylim(0,8)\n",
    "    \n",
    "\n",
    "ax[0].set_ylabel('Concentrability Coefficient', fontsize=10)\n",
    "\n",
    "plt.suptitle('Coverage Over Online Partition, Forest Tabular MDP', fontsize=16, y=0.99)\n",
    "plt.savefig('figs/cov_tabular_onpart.png', dpi=300)\n",
    "\n",
    "# UCB parameter 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25fed7b3-db26-4383-9a21-1db7aaf233e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1,3, figsize=(12,4))\n",
    "parts = [part_on_bad, part_on_cov, part_on_opt]\n",
    "occs = [np.array(occs_ucb_bad_trials), np.array(occs_ucb_cov_trials), np.array(occs_ucb_opt_trials)]\n",
    "titles = ['Adversarial Behavior Policy', \n",
    "          'Uniform Behavior Policy', \n",
    "          'Optimal Behavior Policy']\n",
    "for i in range(3):\n",
    "    cstars_hy = np.array([[c_star(occ_star[~parts[i]], occ[~parts[i]]) \n",
    "                   for occ in occ_t] for occ_t in occs[i]])\n",
    "    cstars_on = np.array([[c_star(occ_star[~parts[i]], occ[~parts[i]]) \n",
    "                   for occ in occ_t] for occ_t in occs_ucb_on_trials])\n",
    "    ax[i].plot(cstars_hy.mean(0), linestyle='solid',\n",
    "         label='Hybrid')\n",
    "    ax[i].plot(cstars_on.mean(0), linestyle='dashed',\n",
    "         label='Online Only')\n",
    "    T = np.array(occs_ucb_bad_trials).shape[1]\n",
    "    ax[i].fill_between(np.arange(T),\n",
    "                       cstars_hy.mean(0) - 1.96*cstars_hy.std(0),\n",
    "                       cstars_hy.mean(0) + 1.96*cstars_hy.std(0),\n",
    "                      alpha=0.3)\n",
    "    ax[i].fill_between(np.arange(T),\n",
    "                       cstars_on.mean(0) - 1.96*cstars_on.std(0),\n",
    "                       cstars_on.mean(0) + 1.96*cstars_on.std(0),\n",
    "                      alpha=0.3)\n",
    "    ax[i].set_title(titles[i], fontsize=10)\n",
    "    ax[i].set_xlabel('Online Episodes', fontsize=10, y=0.05)\n",
    "    #ax[i].set_ylim(0,10)\n",
    "    ax[i].set_ylim(0,8)\n",
    "    #ax[i].set_yscale('symlog')\n",
    "    ax[i].yaxis.set_major_formatter(mticker.FormatStrFormatter('%.1f'))\n",
    "    ax[i].legend()\n",
    "ax[0].set_ylabel('Concentrability Coefficient', fontsize=10)\n",
    "\n",
    "plt.suptitle('Coverage Over Offline Partition, Forest Tabular MDP', fontsize=16, y=0.99)\n",
    "plt.savefig('figs/cov_tabular_offpart.png', dpi=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3a92501-0484-47d9-a3a8-bd702f12995f",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1,3, figsize=(12,4))\n",
    "parts = [part_on_bad, part_on_cov, part_on_opt]\n",
    "visits = [np.array(visits_ucb_bad_trials), np.array(visits_ucb_cov_trials),\n",
    "          np.array(visits_ucb_opt_trials)]\n",
    "titles = ['Adversarial Behavior Policy', \n",
    "          'Uniform Behavior Policy', \n",
    "          'Optimal Behavior Policy']\n",
    "\n",
    "for i in range(3):\n",
    "    visit_stat = visits[i][:, :, parts[i]].sum(-1).cumsum(1)\n",
    "    visit_stat_on = np.array(visits_ucb_on_trials)[:, :, parts[i]].sum(-1).cumsum(1)\n",
    "    ax[i].plot(visit_stat.mean(0), linestyle='solid',\n",
    "         label='Hybrid')\n",
    "    ax[i].plot(visit_stat_on.mean(0), linestyle='dashed',\n",
    "         label='Online Only')\n",
    "    ax[i].fill_between(np.arange(len(visit_stat.mean(0))),\n",
    "                       visit_stat.mean(0)-1.96*visit_stat.std(0),\n",
    "                       visit_stat.mean(0)+1.96*visit_stat.std(0), alpha=0.3)\n",
    "    ax[i].fill_between(np.arange(len(visit_stat_on.mean(0))),\n",
    "                       visit_stat_on.mean(0)-1.96*visit_stat_on.std(0),\n",
    "                       visit_stat_on.mean(0)+1.96*visit_stat_on.std(0), alpha=0.3)\n",
    "    ax[i].set_title(titles[i], fontsize=10)\n",
    "    ax[i].set_xlabel('Online Episodes', fontsize=10, y=0.05)\n",
    "    ax[i].set_ylim(0,3400)\n",
    "    ax[i].legend()\n",
    "ax[0].set_ylabel('Cumulative Visits', fontsize=10)\n",
    "\n",
    "plt.suptitle('Cumulative Visits to Online Partition, Forest Tabular MDP', fontsize=16, y=0.99)\n",
    "plt.savefig('figs/visits_tabular_onpart.png', dpi=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa08727f-80ab-4daa-9c86-615089ac3660",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1,3, figsize=(12,4))\n",
    "parts = [part_on_bad, part_on_cov, part_on_opt]\n",
    "visits = [np.array(visits_ucb_bad_trials), np.array(visits_ucb_cov_trials),\n",
    "          np.array(visits_ucb_opt_trials)]\n",
    "titles = ['Adversarial Behavior Policy', \n",
    "          'Uniform Behavior Policy', \n",
    "          'Optimal Behavior Policy']\n",
    "\n",
    "for i in range(3):\n",
    "    visit_stat = visits[i][:, :, ~parts[i]].sum(-1).cumsum(1)\n",
    "    visit_stat_on = np.array(visits_ucb_on_trials)[:, :, ~parts[i]].sum(-1).cumsum(1)\n",
    "    ax[i].plot(visit_stat.mean(0), linestyle='solid',\n",
    "         label='Hybrid')\n",
    "    ax[i].plot(visit_stat_on.mean(0), linestyle='dashed',\n",
    "         label='Online Only')\n",
    "    ax[i].fill_between(np.arange(len(visit_stat.mean(0))),\n",
    "                       visit_stat.mean(0)-1.96*visit_stat.std(0),\n",
    "                       visit_stat.mean(0)+1.96*visit_stat.std(0), alpha=0.3)\n",
    "    ax[i].fill_between(np.arange(len(visit_stat_on.mean(0))),\n",
    "                       visit_stat_on.mean(0)-1.96*visit_stat_on.std(0),\n",
    "                       visit_stat_on.mean(0)+1.96*visit_stat_on.std(0), alpha=0.3)\n",
    "    ax[i].set_title(titles[i], fontsize=10)\n",
    "    ax[i].set_xlabel('Online Episodes', fontsize=10, y=0.05)\n",
    "    ax[i].set_ylim(0,3400)\n",
    "    #ax[i].set_ylim(0,10)\n",
    "    ax[i].legend()\n",
    "ax[0].set_ylabel('Cumulative Visits', fontsize=10)\n",
    "\n",
    "plt.suptitle('Cumulative Visits to Offline Partition, Forest Tabular MDP', fontsize=16, y=0.99)\n",
    "plt.savefig('figs/visits_tabular_offpart.png', dpi=300)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "311",
   "language": "python",
   "name": "311"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
