{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "Ablation study Final Off-Policy MDP 4 Layer Balanced Tree and 2 arms.ipynb",
      "provenance": [],
      "collapsed_sections": [],
      "toc_visible": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qZMEFqwgtp52"
      },
      "source": [
        "# Build Environment"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "XQaUBAhftz5A",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "7aa5ccd5-56c8-4fe5-b8a2-8c5f0b0eff9e"
      },
      "source": [
        "import numpy as np\n",
        "from tqdm import tqdm\n",
        "\n",
        "class Environment(object):\n",
        "  \n",
        "  # Declare environment variables\n",
        "  def __init__(self, ):\n",
        "    \n",
        "    # Set the number of states, actions, and depth. The code is for balanced trees. The horizon is equal to L+1\n",
        "    self.num_states =  15 # 31 # 15\n",
        "    self.num_actions = 2\n",
        "    self.L = 4 # 5 # 4\n",
        "    self.horizon = self.L + 1\n",
        "\n",
        "    self.episodes = 5000\n",
        "    self.num_trials = 20\n",
        "    self.discount_factor = 1.0\n",
        "\n",
        "    self.start_state = 0\n",
        "\n",
        "    self.S = np.array([i for i in range(self.num_states)])\n",
        "    self.pi_e = np.zeros((self.num_states, self.num_actions))\n",
        "\n",
        "    for s in range(self.num_states):\n",
        "      self.pi_e[s] = [0.95, 0.05]\n",
        "      # self.pi_e[s] = [0.9, 0.1]\n",
        "        \n",
        "    \n",
        "\n",
        "    self.R = np.zeros((self.num_states, self.num_actions))\n",
        "\n",
        "    for s in range(self.num_states):\n",
        "      for a in range(self.num_actions):\n",
        "        self.R[s][a] = 1\n",
        "    \n",
        "    \n",
        "\n",
        "    self.var = np.zeros((self.num_states, self.num_actions))\n",
        "    \n",
        "\n",
        "    for s in range(self.num_states):\n",
        "      self.var[s][0] = 0.001\n",
        "      self.var[s][1] = 20 #10 works good\n",
        "\n",
        "    max_var = 3*np.sqrt(np.max(self.var)) + np.max(self.R)\n",
        "    min_var = -3*np.sqrt(np.min(self.var)) + np.min(self.R)\n",
        "    # self.ucb_constant = (max_var - min_var)**2/2.0\n",
        "    self.ucb_constant = 4\n",
        "    # print(max_var, min_var, self.ucb_constant)\n",
        "\n",
        "    # Calculate the set of all leaves\n",
        "\n",
        "    leaf_starting_index = int(self.num_states - np.floor(self.num_states/2))\n",
        "    self.leaf_set = [i for i in range(leaf_starting_index-1, self.num_states)]\n",
        "    \n",
        "\n",
        "    self.special_state = np.array([-1,-1])\n",
        "    self.goal_state = np.copy(self.leaf_set)\n",
        "\n",
        "    # self.ucb_constant = 1.0\n",
        "\n",
        "    \n",
        "    self.level = []\n",
        "    self.calculate_level()\n",
        "    self.reset()\n",
        "  \n",
        "  def calculate_level(self, ):\n",
        "\n",
        "    self.level.append((0,0))\n",
        "    curr_state = 1\n",
        "    for ell in range(1,self.L+1):\n",
        "      for i in range(curr_state, curr_state + (self.num_actions**ell)):\n",
        "        self.level.append((i,ell))\n",
        "      curr_state = curr_state+(self.num_actions**ell)\n",
        "  \n",
        "  def find_leaves(self, s):\n",
        "\n",
        "    next_set = []\n",
        "    next_action = []\n",
        "    if s in self.leaf_set:\n",
        "      next_set.append(s)\n",
        "      return next_set, [i for i in range(self.num_actions)]\n",
        "    for action in range(self.num_actions):\n",
        "      next_set.append(2*s+(action+1)) # 2s + (a + 1)\n",
        "      next_action.append(action)\n",
        "\n",
        "    return next_set, next_action\n",
        "\n",
        "  def get_level(self, s):\n",
        "    for x in self.level:\n",
        "      if x[0] == s:\n",
        "        # print(x[1])\n",
        "        return x[1]\n",
        "  \n",
        "  def reset(self, ):\n",
        "\n",
        "    self.curr_state = 0\n",
        "    self.next_state = 0\n",
        "  \n",
        "  \n",
        "  def step(self, action):\n",
        "\n",
        "    # noise = np.sqrt(self.var[0][action])*np.random.randn() # The sqrt has to be there\n",
        "    \n",
        "    #\n",
        "    if self.curr_state not in self.leaf_set: # if you are in leaf state, the next state is self state\n",
        "      if action == 0:\n",
        "        self.next_state = 2*self.curr_state + 1 \n",
        "        reward = self.R[self.curr_state][0]\n",
        "      if action == 1:\n",
        "        self.next_state = 2*self.curr_state + 2 \n",
        "        reward = self.R[self.curr_state][1]\n",
        "    else:\n",
        "      if action == 0:\n",
        "        self.next_state = self.curr_state \n",
        "        reward = self.R[self.curr_state][0]\n",
        "      if action == 1:\n",
        "        self.next_state = self.curr_state\n",
        "        reward = self.R[self.curr_state][1]\n",
        "    \n",
        "    self.curr_state = self.next_state\n",
        "    return self.next_state, reward + np.sqrt(self.var[self.curr_state][action])*np.random.randn() # The sqrt has to be there\n",
        "\n",
        "env = Environment()  \n",
        "# print(env.pi_e, env.R, env.var, env.leaf_set, env.level) \n",
        "print(\"leaf set:\", env.leaf_set) \n",
        "print(\"env level:\", env.level)\n",
        "# for ep in range(5):\n",
        "#   env.reset()\n",
        "#   for t in range(env.horizon):\n",
        "#     a = np.random.randint(2)\n",
        "#     print(env.curr_state,a)\n",
        "#     next_state, reward =env.step(a)\n",
        "#     print(next_state, reward)\n"
      ],
      "execution_count": 1,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "leaf set: [7, 8, 9, 10, 11, 12, 13, 14]\n",
            "env level: [(0, 0), (1, 1), (2, 1), (3, 2), (4, 2), (5, 2), (6, 2), (7, 3), (8, 3), (9, 3), (10, 3), (11, 3), (12, 3), (13, 3), (14, 3), (15, 4), (16, 4), (17, 4), (18, 4), (19, 4), (20, 4), (21, 4), (22, 4), (23, 4), (24, 4), (25, 4), (26, 4), (27, 4), (28, 4), (29, 4), (30, 4)]\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Debug"
      ],
      "metadata": {
        "id": "6UiEgTt-VvS8"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "b = np.zeros((env.L, env.num_states, env.num_actions))\n",
        "B = np.zeros((env.L, env.num_states))\n",
        "\n",
        "for t in range(env.L-1, -1, -1):\n",
        "  for s in env.leaf_set:\n",
        "    sum_ = 0.0\n",
        "    for a in range(env.num_actions):\n",
        "      b[t][s][a] = np.sqrt(env.pi_e[s][a]**2 * env.var[s][a]**2)\n",
        "      sum_ += b[t][s][a]\n",
        "    B[t][s] = sum_ \n",
        "    \n",
        "\n",
        "  for s in range(env.num_states-1,-1,-1):\n",
        "    # print(\"s: \",s)\n",
        "    if s not in env.leaf_set:\n",
        "      for a in range(env.num_actions):\n",
        "        B_ = 0.0\n",
        "        child_set = [2*s + 1, 2*s + 2]\n",
        "        # print(child_set)\n",
        "        for s1 in child_set:\n",
        "          B_ += B[t][s1]**2\n",
        "        # print(\"B_: \", B_)\n",
        "        b[t][s][a] = np.sqrt(env.pi_e[s][a]**2 * (env.var[s][a]**2 + env.discount_factor**2 * B_ ) )\n",
        "      B[t][s] = np.sum(b[t][s])\n",
        "      # print(b[t][s])\n",
        "\n",
        "print(b[0])\n",
        "print(B[0])\n"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "B7Ie5A0rkKbh",
        "outputId": "b4568518-a91b-4928-88be-ae039d44ee92"
      },
      "execution_count": 2,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "[[5.59872579e+00 1.04251152e+00]\n",
            " [3.15357917e+00 1.01368074e+00]\n",
            " [3.15357917e+00 1.01368074e+00]\n",
            " [1.34477955e+00 1.00250162e+00]\n",
            " [1.34477955e+00 1.00250162e+00]\n",
            " [1.34477955e+00 1.00250162e+00]\n",
            " [1.34477955e+00 1.00250162e+00]\n",
            " [9.50000000e-04 1.00000000e+00]\n",
            " [9.50000000e-04 1.00000000e+00]\n",
            " [9.50000000e-04 1.00000000e+00]\n",
            " [9.50000000e-04 1.00000000e+00]\n",
            " [9.50000000e-04 1.00000000e+00]\n",
            " [9.50000000e-04 1.00000000e+00]\n",
            " [9.50000000e-04 1.00000000e+00]\n",
            " [9.50000000e-04 1.00000000e+00]]\n",
            "[6.64123731 4.16725991 4.16725991 2.34728117 2.34728117 2.34728117\n",
            " 2.34728117 1.00095    1.00095    1.00095    1.00095    1.00095\n",
            " 1.00095    1.00095    1.00095   ]\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Value Iteration"
      ],
      "metadata": {
        "id": "okLWVTAndPDo"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "class value_iteration(object):\n",
        "\n",
        "  def __init__(self,):\n",
        "    ''''\n",
        "    Init\n",
        "    '''\n",
        "    self.env = Environment()\n",
        "\n",
        "  def run_value_iteration(self, b):\n",
        "    \n",
        "    # self.V = np.random.uniform(0,1, size = (limit, self.env.num_states, self.env.num_states))\n",
        "    self.V = np.zeros((self.env.horizon, self.env.num_states)) # Intialize to 0\n",
        "    # self.V[0][self.env.goal_state[0]][self.env.goal_state[1]] = 0\n",
        "\n",
        "\n",
        "    m = self.env.horizon - 1\n",
        "\n",
        "    while True:\n",
        "      for s in range(self.env.num_states-1, -1, -1): # Start from the end states\n",
        "        val = []\n",
        "        # print(\"s:\", s)\n",
        "        next_states, next_actions = env.find_leaves(s)\n",
        "        # print(next_states, next_actions)\n",
        "        for action in next_actions:\n",
        "            \n",
        "          val_next_states = 0.0 # Summing over the values of the next states\n",
        "          for s1 in next_states:\n",
        "              \n",
        "            # Not using P as this is deterministic\n",
        "            if m == self.env.horizon - 1:\n",
        "              # val_next_states += self.V[m][s1]\n",
        "              val_next_states = 0.0\n",
        "            else:\n",
        "              val_next_states += self.V[m+1][s1]\n",
        "          # print(b[curr_state[0]][curr_state[1]][action]) \n",
        "          val_ = env.pi_e[s][action]*( b[s][action] + self.env.discount_factor*val_next_states )\n",
        "          val.append(val_)\n",
        "            \n",
        "        self.V[m][s] = np.sum(val)\n",
        "          \n",
        "      # if m >= limit-1 and limit!= -1:\n",
        "      # if m >= self.env.horizon-1:\n",
        "      #   break\n",
        "      if m <= 0:\n",
        "        break\n",
        "      m -= 1\n",
        "      # print(V)\n",
        "    return self.V\n",
        "    \n",
        "env = Environment()\n",
        "# print(env.R)\n",
        "vi_pi_e = value_iteration()\n",
        "V = vi_pi_e.run_value_iteration(env.R)\n",
        "print(V[0][0])\n",
        "\n",
        "# x = env.R_ + np.random.uniform(0,0.001, size = (env.num_states, env.num_states, env.num_actions))\n",
        "# V = vi_pi_e.run_value_iteration(env.P, x, limit =  100, state_transition = True)\n",
        "# print(V[0][0][0])\n",
        "\n",
        "print(V[0])"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "gMYIiAPkdRfH",
        "outputId": "3df79195-b19b-48b4-845b-47ad04739116"
      },
      "execution_count": 3,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "22.999999999999996\n",
            "[23. 15. 15.  9.  9.  9.  9.  5.  5.  5.  5.  5.  5.  5.  5.]\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# On Policy Sampling"
      ],
      "metadata": {
        "id": "Lyzx1-KEnFxE"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "class OnPolicySampling(object):\n",
        "\n",
        "  def __init__(self, env):\n",
        "    self.num_actions = env.num_actions\n",
        "    self.num_states = env.num_states\n",
        "    self.discount_factor = 1.0\n",
        "    self.episodes = env.episodes\n",
        "    self.env = env\n",
        "    self.reset()\n",
        "  \n",
        "  def reset(self, ):\n",
        "    # self.v = np.zeros(self.num_states)\n",
        "    self.q = np.zeros((self.num_states, self.num_actions))\n",
        "    self.sum_reward = np.zeros((self.num_states, self.num_actions))\n",
        "    self.visit = np.zeros((self.num_states, self.num_actions))\n",
        "    self.pi_e = env.pi_e\n",
        "    self.value_iter = value_iteration()\n",
        "    \n",
        "  \n",
        "  def onpolicy(self, s, action_list, pi_e):\n",
        "\n",
        "    \n",
        "    return np.random.choice(action_list, p = pi_e[s])\n",
        "    \n",
        "\n",
        "  def run_OnPolicySampling(self, env):\n",
        "\n",
        "    self.reset()\n",
        "    tqdm._instances.clear()\n",
        "    self.v_pi = np.zeros((self.episodes,self.num_states))\n",
        "\n",
        "    action_list = [i for i in range(self.num_actions)]\n",
        "    \n",
        "    \n",
        "\n",
        "    for eps in tqdm(range(0,self.episodes), total=self.episodes):\n",
        "      env.reset()\n",
        "      for t in range(env.horizon-1): # At the leaf state, take an action and transition to the same state, No need to reset the environment as we are not running the policy\n",
        "      # for t in range(env.horizon):\n",
        "\n",
        "        curr_state = env.curr_state\n",
        "        action = self.onpolicy(curr_state, action_list, self.pi_e)\n",
        "        next_state, reward = env.step(action)\n",
        "\n",
        "        self.visit[curr_state][action] += 1\n",
        "        self.sum_reward[curr_state][action] += reward\n",
        "        self.q[curr_state][action] = self.sum_reward[curr_state][action]/self.visit[curr_state][action]\n",
        "      \n",
        "        \n",
        "        # self.v[0] += reward*self.discount_factor\n",
        "      \n",
        "      # self.v_pi[eps] = np.dot(self.q[0], env.pi_e[0])\n",
        "      # self.v_pi[eps] = np.sum(np.multiply(self.q,env.pi_e), axis = 1)\n",
        "      self.v_pi[eps] = self.value_iter.run_value_iteration(self.q)[0]\n",
        "\n",
        "    return self.v_pi"
      ],
      "metadata": {
        "id": "-UDZTq1rnIQs"
      },
      "execution_count": 4,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Confidence Bound Method with Variance\n",
        "\n",
        "Sample next action as $\\max_{a\\in[A]} CB_a$ where $CB_a = R\\pi^2_e\\left(\\sqrt{\\dfrac{2 \\widehat{\\sigma}_a^2 \\log t}{n_a}} + \\dfrac{7 \\log t}{3 n_a}\\right)$ where $R$ is the upper bound to the maximum possible reward."
      ],
      "metadata": {
        "id": "7qoQIiKPOez7"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import cvxpy as cp\n",
        "\n",
        "class CBVar_policy(object):\n",
        "\n",
        "  def __init__(self, env):\n",
        "    self.num_actions = env.num_actions\n",
        "    self.num_states = env.num_states\n",
        "    self.discount_factor = env.discount_factor\n",
        "    self.ucb = 100*np.ones(self.num_actions)\n",
        "    self.episodes = env.episodes\n",
        "    self.env = env\n",
        "    self.reset()\n",
        "\n",
        "  def reset(self, ):\n",
        "    # self.v = np.zeros(self.num_states)\n",
        "    self.q = np.zeros((self.num_states, self.num_actions))\n",
        "    self.var = np.zeros((self.num_states, self.num_actions))\n",
        "    self.sum_reward = np.zeros((self.num_states, self.num_actions))\n",
        "    self.sum_reward_sq = np.zeros((self.num_states, self.num_actions))\n",
        "    self.visit = np.zeros((self.num_states, self.num_actions))\n",
        "    self.pi_e = env.pi_e\n",
        "    self.value_iter = value_iteration()\n",
        "\n",
        "    self.ucb = np.zeros((self.num_states, self.num_actions))\n",
        "    \n",
        "  \n",
        "\n",
        "  def cbvar_policy(self, s, action_list, t):\n",
        "    \n",
        "    \n",
        "    # Explore each arm once in the state\n",
        "    level = self.env.get_level(s)\n",
        "    explore_arms = []\n",
        "    for i in range(self.num_actions): \n",
        "      if self.visit[s][i] <= 1:\n",
        "        explore_arms.append(i)\n",
        "\n",
        "    # print(s,explore_arms)\n",
        "    if len(explore_arms) == 0: # No more forced exploration\n",
        "\n",
        "      ## Update parameters\n",
        "\n",
        "      # print(np.argmax(self.ucb))\n",
        "      return np.argmax(self.ucb[s]) # return deterministically the best ucb\n",
        "\n",
        "    else:\n",
        "      return np.argmin(self.visit[s]) # return the arm sampled least in state s when unexplored set is non-empty\n",
        "    \n",
        "  \n",
        "  \n",
        "  def run_behavior_policy(self, env, action_list, eps):\n",
        "    \n",
        "    traj = np.zeros((env.horizon, 4))\n",
        "    # for eps in tqdm(range(0,self.episodes), total=self.episodes):\n",
        "    \n",
        "    env.reset()\n",
        "    for t in range(env.horizon-1): # At the leaf state, take an action and transition to the same state, No need to reset the environment as we are not running the policy\n",
        "    # for t in range(env.horizon):\n",
        "      curr_state = env.curr_state\n",
        "      action = self.cbvar_policy(curr_state, action_list, eps)\n",
        "      next_state, reward = env.step(action)\n",
        "      traj[t] = [curr_state, action, reward, next_state]\n",
        "\n",
        "      ## Update behavior parameters\n",
        "      self.visit[curr_state][action] += 1\n",
        "      self.sum_reward[curr_state][action] += reward\n",
        "      self.q[curr_state][action] =  self.sum_reward[curr_state][action]/self.visit[curr_state][action] \n",
        "\n",
        "          \n",
        "      self.sum_reward_sq[curr_state][action] += (reward - self.q[curr_state][action])**2 \n",
        "\n",
        "      # Add ucb to variance and calculate ucb for all action in current state\n",
        "      for a in range(self.num_actions):\n",
        "        self.var[curr_state][a] = self.sum_reward_sq[curr_state][a]/self.visit[curr_state][a]\n",
        "        # self.ucb[curr_state][a] = self.q[curr_state][action] + (self.pi_e[curr_state][a]**2) * self.env.ucb_constant * (  np.sqrt(2.0*self.var[curr_state][a]*np.log(eps+1)/self.visit[curr_state][a]) + ( (7.0/3.0)*np.log(eps+1)/self.visit[curr_state][a] ) )\n",
        "        ucb1 = (self.env.ucb_constant**2)* np.sqrt(2.0*self.var[curr_state][a]*np.log(self.num_actions*self.num_states*self.env.L*self.episodes + 1.0)/self.visit[curr_state][a])\n",
        "        ucb2 = (self.env.ucb_constant**2)*(7.0/3.0)*np.log(self.num_actions*self.num_states*self.env.L*self.episodes + 1.0)/self.visit[curr_state][a]\n",
        "        self.ucb[curr_state][a] = (self.pi_e[curr_state][a]**2) * ( ucb1 + ucb2 )\n",
        "        # self.ucb[curr_state][a] = (self.pi_e[curr_state][a]**2) * (self.env.ucb_constant**2)* np.sqrt(2.0*np.log(self.num_actions*self.num_states*self.env.L*self.episodes + 1.0)/self.visit[curr_state][a])\n",
        "        # self.ucb[curr_state][a] = ( ucb1 + ucb2 )\n",
        "        \n",
        "    return traj\n",
        "\n",
        "\n",
        "  def run_cbvar_policy(self, env):\n",
        "    \n",
        "    \n",
        "    self.reset()\n",
        "    tqdm._instances.clear()\n",
        "    self.v_pi = np.zeros((self.episodes,self.num_states))\n",
        "\n",
        "    action_list = [i for i in range(self.num_actions)]\n",
        "    self.traj = np.zeros((self.episodes, env.horizon, 4))\n",
        "\n",
        "    imps_ratio_sum = 0\n",
        "    for eps in tqdm(range(0,self.episodes), total=self.episodes):\n",
        "\n",
        "      self.traj[eps] = self.run_behavior_policy(env, action_list, eps)\n",
        "      for t in range(env.horizon-1): # At the leaf state, take an action and transition to the same state, No need to reset the environment as we are not running the policy\n",
        "        action = int(self.traj[eps][t][1])\n",
        "        \n",
        "        curr_state = int(self.traj[eps][t][0])\n",
        "        next_state, reward = int(self.traj[eps][t][3]), self.traj[eps][t][2]\n",
        "        \n",
        "      # self.v_pi[eps] = np.sum(np.multiply(self.q,env.pi_e), axis = 1)\n",
        "      self.v_pi[eps] = self.value_iter.run_value_iteration(self.q)[0]\n",
        "      \n",
        "    return self.v_pi\n"
      ],
      "metadata": {
        "id": "t4Oc1WclOhx1"
      },
      "execution_count": 5,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Plugin Policy (UCB)"
      ],
      "metadata": {
        "id": "b-IqtMviJ3_B"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import cvxpy as cp\n",
        "from multiprocessing import Process, Queue\n",
        "import time\n",
        "import pickle\n",
        "\n",
        "class ucb_exploration_policy(object):\n",
        "\n",
        "  def __init__(self, env, ucb_constant = 2):\n",
        "    self.num_actions = env.num_actions\n",
        "    self.num_states = env.num_states\n",
        "    self.discount_factor = env.discount_factor\n",
        "    self.ucb = 100*np.ones(self.num_actions)\n",
        "    self.episodes = env.episodes\n",
        "    self.env = env\n",
        "    self.ucb_constant = ucb_constant\n",
        "    self.reset()\n",
        "\n",
        "  def reset(self, ):\n",
        "    # self.v = np.zeros(self.num_states)\n",
        "    self.q = np.zeros((self.num_states, self.num_actions))\n",
        "    self.var = np.zeros((self.num_states, self.num_actions))\n",
        "    self.sum_reward = np.zeros((self.num_states, self.num_actions))\n",
        "    self.sum_reward_sq = np.zeros((self.num_states, self.num_actions))\n",
        "    self.visit = np.zeros((self.num_states, self.num_actions))\n",
        "    self.pi_e = env.pi_e\n",
        "    self.value_iter = value_iteration()\n",
        "\n",
        "  \n",
        "\n",
        "  def calculate_bsa(self, ):\n",
        "\n",
        "    b = np.zeros((self.env.L, self.env.num_states, self.env.num_actions))\n",
        "    B = np.zeros((self.env.L, self.env.num_states))\n",
        "\n",
        "    for t in range(self.env.L-1, -1, -1):\n",
        "      for s in self.env.leaf_set:\n",
        "        sum_ = 0.0\n",
        "        for a in range(self.env.num_actions):\n",
        "          b[t][s][a] = np.sqrt(self.env.pi_e[s][a]**2 * self.var[s][a]**2)\n",
        "          sum_ += b[t][s][a]\n",
        "        B[t][s] = sum_ \n",
        "        \n",
        "\n",
        "      for s in range(self.env.num_states-1,-1,-1):\n",
        "        # print(\"s: \",s)\n",
        "        if s not in self.env.leaf_set:\n",
        "          for a in range(self.env.num_actions):\n",
        "            B_ = 0.0\n",
        "            child_set = [2*s + 1, 2*s + 2]\n",
        "            # print(child_set)\n",
        "            for s1 in child_set:\n",
        "              B_ += B[t][s1]**2\n",
        "            # print(\"B_: \", B_)\n",
        "            b[t][s][a] = np.sqrt(self.env.pi_e[s][a]**2 * (self.var[s][a]**2 + self.env.discount_factor**2 * B_ ) )\n",
        "          B[t][s] = np.sum(b[t][s])\n",
        "          # print(b[t][s])\n",
        "\n",
        "    # print(b[0])\n",
        "    # print(B[0])\n",
        "    return b[0]\n",
        "\n",
        "  def opt_plugin_ucb_policy(self, s, action_list, t):\n",
        "    \n",
        "    \n",
        "    # Explore each arm once in the state\n",
        "    # level = self.env.get_level(s)\n",
        "    explore_arms = []\n",
        "    for i in range(self.num_actions): \n",
        "      if self.visit[s][i] <= 3:\n",
        "        explore_arms.append(i)\n",
        "\n",
        "    # print(s,explore_arms)\n",
        "    if len(explore_arms) == 0: # No more forced exploration\n",
        "\n",
        "      ## Update parameters\n",
        "\n",
        "      \n",
        "      self.b = self.calculate_bsa()\n",
        "      # prob = self.b[s]/np.sum(self.b[s])\n",
        "\n",
        "      return np.argmax(self.b[s]/(self.visit[s])) # return deterministically the best b\n",
        "\n",
        "    else:\n",
        "      return np.argmin(self.visit[s]) # return the arm sampled least in state s when unexplored set is non-empty\n",
        "    \n",
        "  \n",
        "  \n",
        "  def run_behavior_policy(self, env, action_list, eps):\n",
        "    \n",
        "    traj = np.zeros((env.horizon, 4))\n",
        "    # for eps in tqdm(range(0,self.episodes), total=self.episodes):\n",
        "    \n",
        "    env.reset()\n",
        "    for t in range(env.horizon-1): # At the leaf state, take an action and transition to the same state, No need to reset the environment as we are not running the policy\n",
        "    # for t in range(env.horizon):\n",
        "      curr_state = env.curr_state\n",
        "      action = self.opt_plugin_ucb_policy(curr_state, action_list, eps)\n",
        "      next_state, reward = env.step(action)\n",
        "      traj[t] = [curr_state, action, reward, next_state]\n",
        "\n",
        "      ## Update behavior parameters\n",
        "      self.visit[curr_state][action] += 1\n",
        "      self.sum_reward[curr_state][action] += reward\n",
        "      self.q[curr_state][action] =  self.sum_reward[curr_state][action]/(self.visit[curr_state][action] + 1)\n",
        "\n",
        "          \n",
        "      self.sum_reward_sq[curr_state][action] += (reward - self.q[curr_state][action])**2 \n",
        "\n",
        "      # Add ucb to variance and calculate ucb for all action in current state\n",
        "      for a in range(self.num_actions):\n",
        "        # ucb = (self.env.ucb_constant**2)*np.sqrt( np.log(self.num_actions*self.num_states*self.env.L*self.episodes + 1.0)/(self.visit[curr_state][a]) )\n",
        "        ucb = (self.ucb_constant**2) * np.sqrt( np.log(self.num_actions*self.num_states*self.env.L*self.episodes + 1.0)/(self.visit[curr_state][a] + 1) )\n",
        "        self.var[curr_state][a] = self.sum_reward_sq[curr_state][a]/(self.visit[curr_state][a] + 1) + ucb\n",
        "\n",
        "    \n",
        "    return traj\n",
        "\n",
        "\n",
        "  def run_plugin_ucb_policy(self, env, tr):\n",
        "    \n",
        "    \n",
        "    self.reset()\n",
        "    tqdm._instances.clear()\n",
        "    self.v_pi =np.zeros((self.episodes,self.num_states))\n",
        "\n",
        "    action_list = [i for i in range(self.num_actions)]\n",
        "    self.traj = np.zeros((self.episodes, env.horizon, 4))\n",
        "\n",
        "    # imps_ratio_sum = 0\n",
        "    for eps in tqdm(range(0,self.episodes), total=self.episodes):\n",
        "\n",
        "      self.traj[eps] = self.run_behavior_policy(env, action_list, eps)\n",
        "      for t in range(env.horizon-1): # At the leaf state, take an action and transition to the same state, No need to reset the environment as we are not running the policy\n",
        "      \n",
        "        action = int(self.traj[eps][t][1])\n",
        "        \n",
        "        curr_state = int(self.traj[eps][t][0])\n",
        "        next_state, reward = int(self.traj[eps][t][3]), self.traj[eps][t][2]\n",
        "        # print(\"action: \", action, reward)\n",
        "        \n",
        "        \n",
        "      self.v_pi[eps] = self.value_iter.run_value_iteration(self.q)[0]\n",
        "      \n",
        "    # return self.v_pi\n",
        "    file = open(\"/content/f_plugin_\"+str(self.ucb_constant)+\"_\"+str(tr), 'wb')\n",
        "    data = pickle.dump(self.v_pi,file)\n",
        "    file.close() \n",
        "  \n",
        "  def run_plugin_ucb(self, env):\n",
        "\n",
        "      \n",
        "    # self.error = np.zeros((self.num_trials, self.T))  \n",
        "    v_pi_ucb_exploration_policy = np.zeros((env.num_trials, env.episodes, env.num_states))\n",
        "    pool = []\n",
        "    # Q = Queue()\n",
        "\n",
        "    for tr in range(env.num_trials):\n",
        "          \n",
        "      p = Process(target = self.run_plugin_ucb_policy, args=(env, tr)) # takes in tuple\n",
        "      pool.append(p)\n",
        "          \n",
        "    for tr in range(env.num_trials):\n",
        "      pool[tr].start()\n",
        "\n",
        "    \n",
        "    for tr in range(env.num_trials):\n",
        "      # self.error[tr] = Q.get()\n",
        "      pool[tr].join()\n",
        "      \n",
        "    for tr in range(env.num_trials):\n",
        "      file = open(\"/content/f_plugin_\"+str(self.ucb_constant)+\"_\"+str(tr), 'rb')\n",
        "      v_pi_ucb_exploration_policy[tr] = pickle.load(file)\n",
        "      file.close()\n",
        "\n",
        "    for tr in range(env.num_trials):\n",
        "      pool[tr].close()\n",
        "\n",
        "    return v_pi_ucb_exploration_policy"
      ],
      "metadata": {
        "id": "94rX7iW9J7mw"
      },
      "execution_count": 6,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Oracle Policy OptSol\n",
        "\n",
        "Sample according to the solution of the optimization: $b_i = \\dfrac{\\pi^2_i \\sigma^2_i}{\\sum_{j=1}^K \\pi^2_j \\sigma^2_j}$ where $\\sigma^2_i$ is the true variance of arm $i$."
      ],
      "metadata": {
        "id": "qKWu0HX893zf"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import cvxpy as cp\n",
        "from multiprocessing import Process, Queue\n",
        "import time\n",
        "import pickle\n",
        "\n",
        "class oracle_policy(object):\n",
        "\n",
        "  def __init__(self, env):\n",
        "    self.num_actions = env.num_actions\n",
        "    self.num_states = env.num_states\n",
        "    self.discount_factor = env.discount_factor\n",
        "    self.ucb = 100*np.ones(self.num_actions)\n",
        "    self.episodes = env.episodes\n",
        "    self.env = env\n",
        "    self.reset()\n",
        "\n",
        "  def reset(self, ):\n",
        "    # self.v = np.zeros(self.num_states)\n",
        "    self.q = np.zeros((self.num_states, self.num_actions))\n",
        "    self.var = np.zeros((self.num_states, self.num_actions))\n",
        "    self.sum_reward = np.zeros((self.num_states, self.num_actions))\n",
        "    self.sum_reward_sq = np.zeros((self.num_states, self.num_actions))\n",
        "    self.visit = np.zeros((self.num_states, self.num_actions))\n",
        "    self.pi_e = env.pi_e\n",
        "    self.value_iter = value_iteration()\n",
        "    \n",
        "  \n",
        "\n",
        "  def calculate_bsa(self, ):\n",
        "\n",
        "    b = np.zeros((self.env.L, self.env.num_states, self.env.num_actions))\n",
        "    B = np.zeros((self.env.L, self.env.num_states))\n",
        "\n",
        "    for t in range(self.env.L-1, -1, -1):\n",
        "      for s in self.env.leaf_set:\n",
        "        sum_ = 0.0\n",
        "        for a in range(self.env.num_actions):\n",
        "          b[t][s][a] = np.sqrt(self.env.pi_e[s][a]**2 * self.env.var[s][a]**2)\n",
        "          sum_ += b[t][s][a]\n",
        "        B[t][s] = sum_ \n",
        "        \n",
        "\n",
        "      for s in range(self.env.num_states-1,-1,-1):\n",
        "        # print(\"s: \",s)\n",
        "        if s not in self.env.leaf_set:\n",
        "          for a in range(self.env.num_actions):\n",
        "            B_ = 0.0\n",
        "            child_set = [2*s + 1, 2*s + 2]\n",
        "            # print(child_set)\n",
        "            for s1 in child_set:\n",
        "              B_ += B[t][s1]**2\n",
        "            # print(\"B_: \", B_)\n",
        "            b[t][s][a] = np.sqrt(self.env.pi_e[s][a]**2 * (self.env.var[s][a]**2 + self.env.discount_factor**2 * B_ ) )\n",
        "          B[t][s] = np.sum(b[t][s])\n",
        "          # print(b[t][s])\n",
        "\n",
        "    # print(b[0])\n",
        "    # print(B[0])\n",
        "    return b[0]\n",
        "\n",
        "  def oracle_policy(self, s, action_list, t):\n",
        "    \n",
        "    \n",
        "    level = self.env.get_level(s)\n",
        "    explore_arms = []\n",
        "    for i in range(self.num_actions):\n",
        "      # if self.visit[s][i] <= ((self.num_actions)**(self.env.L - level + 1))*np.sqrt(t + 10):\n",
        "      if self.visit[s][i] <= 3:\n",
        "        explore_arms.append(i)\n",
        "\n",
        "    # print(s,explore_arms)\n",
        "    if len(explore_arms) == 0: # No more forced exploration\n",
        "\n",
        "      self.b = self.calculate_bsa()\n",
        "      # prob = self.b[s]/np.sum(self.b[s])\n",
        "      # return np.random.choice(action_list, p = prob)\n",
        "\n",
        "      return np.argmax(self.b[s]/(self.visit[s]))\n",
        "\n",
        "    else:\n",
        "      return np.argmin(self.visit[s]) # return the arm sampled least in state s when unexplored set is non-empty\n",
        "\n",
        "    \n",
        "\n",
        "    \n",
        "  \n",
        "  \n",
        "  def run_behavior_policy(self, env, action_list, eps):\n",
        "    \n",
        "    traj = np.zeros((env.horizon, 4))\n",
        "    \n",
        "    # for eps in tqdm(range(0,self.episodes), total=self.episodes):\n",
        "    env.reset()\n",
        "    for t in range(env.horizon-1): # At the leaf state, take an action and transition to the same state, No need to reset the environment as we are not running the policy\n",
        "    # for t in range(env.horizon):\n",
        "      curr_state = env.curr_state\n",
        "      action = self.oracle_policy(curr_state,action_list, eps)\n",
        "      next_state, reward = env.step(action)\n",
        "      traj[t] = [curr_state, action, reward, next_state]\n",
        "        \n",
        "    return traj\n",
        "\n",
        "\n",
        "  def run_oracle_policy(self, env, tr):\n",
        "    \n",
        "    \n",
        "    self.reset()\n",
        "    tqdm._instances.clear()\n",
        "    self.v_pi = np.zeros((self.episodes,self.num_states))\n",
        "\n",
        "    action_list = [i for i in range(self.num_actions)]\n",
        "    # explore_arms = np.ones((self.num_states, self.num_actions))\n",
        "\n",
        "    self.traj = np.zeros((self.episodes, env.horizon, 4))\n",
        "\n",
        "    imps_ratio_sum = 0\n",
        "    for eps in tqdm(range(0,self.episodes), total=self.episodes):\n",
        "\n",
        "      self.traj[eps] = self.run_behavior_policy(env,action_list, eps)\n",
        "      for t in range(env.horizon-1): # At the leaf state, take an action and transition to the same state, No need to reset the environment as we are not running the policy\n",
        "      \n",
        "        action = int(self.traj[eps][t][1])\n",
        "        curr_state = int(self.traj[eps][t][0])\n",
        "        next_state, reward = self.traj[eps][t][3], self.traj[eps][t][2]\n",
        "        \n",
        "        \n",
        "\n",
        "        ## Update parameters\n",
        "        self.visit[curr_state][action] += 1\n",
        "        self.sum_reward[curr_state][action] += reward\n",
        "        self.q[curr_state][action] =  self.sum_reward[curr_state][action]/self.visit[curr_state][action] \n",
        "\n",
        "        \n",
        "        self.sum_reward_sq[curr_state][action] += (reward - self.q[curr_state][action])**2 \n",
        "        self.var[curr_state][action] = self.sum_reward_sq[curr_state][action]/self.visit[curr_state][action]\n",
        "        \n",
        "      \n",
        "      \n",
        "      # self.v_pi[eps] = np.sum(np.multiply(self.q,env.pi_e), axis = 1)\n",
        "      self.v_pi[eps] = self.value_iter.run_value_iteration(self.q)[0]\n",
        "    \n",
        "    file = open(\"/content/f_oracle_\"+str(tr), 'wb')\n",
        "    data = pickle.dump(self.v_pi,file)\n",
        "    file.close()\n",
        "    # return self.v_pi\n",
        "  \n",
        "  def run_oracle(self, env):\n",
        "\n",
        "      \n",
        "    # self.error = np.zeros((self.num_trials, self.T))  \n",
        "    v_pi_oraclepolicy = np.zeros((env.num_trials, env.episodes, env.num_states))\n",
        "    pool = []\n",
        "    # Q = Queue()\n",
        "\n",
        "    for tr in range(env.num_trials):\n",
        "          \n",
        "      p = Process(target = self.run_oracle_policy, args=(env, tr)) # takes in tuple\n",
        "      pool.append(p)\n",
        "          \n",
        "    for tr in range(env.num_trials):\n",
        "      pool[tr].start()\n",
        "\n",
        "    \n",
        "    for tr in range(env.num_trials):\n",
        "      # self.error[tr] = Q.get()\n",
        "      pool[tr].join()\n",
        "      \n",
        "    for tr in range(env.num_trials):\n",
        "      file = open(\"/content/f_oracle_\"+str(tr), 'rb')\n",
        "      v_pi_oraclepolicy[tr] = pickle.load(file)\n",
        "      file.close()\n",
        "\n",
        "    for tr in range(env.num_trials):\n",
        "      pool[tr].close()\n",
        "\n",
        "    return v_pi_oraclepolicy"
      ],
      "metadata": {
        "id": "3CloNz3E93MA"
      },
      "execution_count": 7,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pBVljOFVWptj"
      },
      "source": [
        "# Run Main"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "v4odxFUtWoZi",
        "outputId": "a54963d6-0ba7-4f57-e5a2-89fb8f060346"
      },
      "source": [
        "\n",
        "\n",
        "if __name__ == '__main__':\n",
        "\n",
        "  env = Environment()\n",
        "\n",
        "  c = [0, 0.1, 0.01, 1, 10, 100]\n",
        "  \n",
        "  # agent_onpolicy = OnPolicySampling(env)\n",
        "  agent_ucb_exploration_policy1 = ucb_exploration_policy(env, c[0])\n",
        "  agent_ucb_exploration_policy2 = ucb_exploration_policy(env, c[1])\n",
        "  agent_ucb_exploration_policy3 = ucb_exploration_policy(env, c[2])\n",
        "  agent_ucb_exploration_policy4 = ucb_exploration_policy(env, c[3])\n",
        "  agent_ucb_exploration_policy5 = ucb_exploration_policy(env, c[4])\n",
        "  agent_ucb_exploration_policy6 = ucb_exploration_policy(env, c[5])\n",
        "  agent_oraclepolicy = oracle_policy(env)\n",
        "  # agent_cbvarpolicy = CBVar_policy(env)\n",
        "  \n",
        "  \n",
        "\n",
        "  \n",
        "  \n",
        "  # v_pi_onpolicy = np.zeros((env.num_trials, env.episodes, env.num_states))  \n",
        "  v_pi_oraclepolicy = np.zeros((env.num_trials, env.episodes, env.num_states))\n",
        "  v_pi_ucb_exploration_policy1 = np.zeros((env.num_trials, env.episodes, env.num_states))\n",
        "  v_pi_ucb_exploration_policy2 = np.zeros((env.num_trials, env.episodes, env.num_states))\n",
        "  v_pi_ucb_exploration_policy3 = np.zeros((env.num_trials, env.episodes, env.num_states))\n",
        "  v_pi_ucb_exploration_policy4 = np.zeros((env.num_trials, env.episodes, env.num_states))\n",
        "  v_pi_ucb_exploration_policy5 = np.zeros((env.num_trials, env.episodes, env.num_states))\n",
        "  v_pi_ucb_exploration_policy6 = np.zeros((env.num_trials, env.episodes, env.num_states))\n",
        "  # v_pi_cbvarpolicy = np.zeros((env.num_trials, env.episodes, env.num_states))\n",
        "  \n",
        "  \n",
        "  \n",
        "  # for tr in range(env.num_trials):\n",
        "    \n",
        "   \n",
        "  print(\"Run Plugin UCB policy\")\n",
        "  v_pi_ucb_exploration_policy1 = agent_ucb_exploration_policy1.run_plugin_ucb(env)\n",
        "  print(\"Run Plugin UCB policy\")\n",
        "  v_pi_ucb_exploration_policy2 = agent_ucb_exploration_policy2.run_plugin_ucb(env)\n",
        "  print(\"Run Plugin UCB policy\")\n",
        "  v_pi_ucb_exploration_policy3 = agent_ucb_exploration_policy3.run_plugin_ucb(env)\n",
        "  print(\"Run Plugin UCB policy\")\n",
        "  v_pi_ucb_exploration_policy4 = agent_ucb_exploration_policy4.run_plugin_ucb(env)\n",
        "  print(\"Run Plugin UCB policy\")\n",
        "  v_pi_ucb_exploration_policy5 = agent_ucb_exploration_policy5.run_plugin_ucb(env)\n",
        "  print(\"Run Plugin UCB policy\")\n",
        "  v_pi_ucb_exploration_policy6 = agent_ucb_exploration_policy6.run_plugin_ucb(env)\n",
        "  print(\"Run Oracle policy\")\n",
        "  v_pi_oraclepolicy = agent_oraclepolicy.run_oracle(env)\n",
        "    \n",
        "    \n",
        "    \n",
        "    \n"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Run Plugin UCB policy\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            " 99%|█████████▉| 4954/5000 [11:32<00:05,  7.93it/s]\n",
            "100%|█████████▉| 4978/5000 [11:32<00:03,  6.40it/s]\n",
            "100%|█████████▉| 4988/5000 [11:35<00:01,  7.90it/s]\n",
            "100%|█████████▉| 4982/5000 [11:34<00:02,  8.25it/s]\n",
            "100%|█████████▉| 4975/5000 [11:35<00:03,  7.60it/s]\n",
            "100%|█████████▉| 4991/5000 [11:35<00:01,  8.59it/s]\n",
            "100%|█████████▉| 4999/5000 [11:35<00:00, 10.65it/s]\n",
            "100%|█████████▉| 4985/5000 [11:34<00:01,  8.02it/s]\n",
            "100%|█████████▉| 4995/5000 [11:35<00:00, 10.75it/s]\n",
            "100%|█████████▉| 4987/5000 [11:36<00:01, 11.76it/s]\n",
            "100%|█████████▉| 4993/5000 [11:35<00:00, 13.65it/s]\n",
            " 99%|█████████▉| 4971/5000 [11:35<00:01, 14.58it/s]\n",
            "100%|█████████▉| 4995/5000 [11:36<00:00, 16.22it/s]\n",
            "100%|██████████| 5000/5000 [11:36<00:00,  7.18it/s]\n",
            "100%|█████████▉| 4991/5000 [11:36<00:00, 22.33it/s]\n",
            "100%|█████████▉| 4994/5000 [11:36<00:00, 23.81it/s]\n",
            "\n",
            "100%|█████████▉| 4986/5000 [11:35<00:00, 25.47it/s]\n",
            "100%|██████████| 5000/5000 [11:36<00:00,  7.17it/s]\n",
            "100%|██████████| 5000/5000 [11:36<00:00,  7.18it/s]\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Run Plugin UCB policy\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            " 99%|█████████▉| 4959/5000 [11:19<00:04,  8.44it/s]\n",
            "100%|█████████▉| 4978/5000 [11:21<00:03,  7.21it/s]\n",
            " 99%|█████████▉| 4971/5000 [11:22<00:03,  8.00it/s]\n",
            "100%|█████████▉| 4987/5000 [11:23<00:01,  9.05it/s]\n",
            "100%|█████████▉| 4979/5000 [11:23<00:02,  9.84it/s]\n",
            "100%|█████████▉| 4983/5000 [11:24<00:01, 11.04it/s]\n",
            "100%|█████████▉| 4990/5000 [11:23<00:00, 10.68it/s]\n",
            "100%|█████████▉| 4986/5000 [11:24<00:01, 10.77it/s]\n",
            " 99%|█████████▉| 4967/5000 [11:24<00:02, 12.28it/s]\n",
            "100%|█████████▉| 4997/5000 [11:24<00:00, 11.96it/s]\n",
            "100%|█████████▉| 4998/5000 [11:25<00:00, 13.42it/s]\n",
            "100%|█████████▉| 4992/5000 [11:25<00:00, 12.42it/s]\n",
            "100%|██████████| 5000/5000 [11:25<00:00,  7.30it/s]\n",
            "100%|██████████| 5000/5000 [11:25<00:00,  7.29it/s]\n",
            "100%|██████████| 5000/5000 [11:24<00:00,  7.30it/s]\n",
            "100%|█████████▉| 4991/5000 [11:25<00:00, 18.80it/s]\n",
            "100%|█████████▉| 4997/5000 [11:25<00:00, 21.53it/s]\n",
            "100%|██████████| 5000/5000 [11:25<00:00,  7.29it/s]\n",
            "100%|██████████| 5000/5000 [11:25<00:00,  7.29it/s]\n",
            "100%|██████████| 5000/5000 [11:25<00:00,  7.30it/s]\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Run Plugin UCB policy\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            " 99%|█████████▉| 4958/5000 [11:20<00:05,  7.27it/s]\n",
            " 99%|█████████▉| 4964/5000 [11:22<00:04,  7.41it/s]\n",
            " 99%|█████████▉| 4964/5000 [11:22<00:04,  8.08it/s]\n",
            "100%|█████████▉| 4994/5000 [11:22<00:00,  7.69it/s]\n",
            "100%|██████████| 5000/5000 [11:23<00:00,  7.32it/s]\n",
            "100%|█████████▉| 4980/5000 [11:23<00:01, 10.39it/s]\n",
            " 99%|█████████▊| 4929/5000 [11:23<00:06, 10.16it/s]\n",
            " 99%|█████████▉| 4948/5000 [11:24<00:04, 10.52it/s]\n",
            " 99%|█████████▉| 4950/5000 [11:24<00:04, 11.54it/s]\n",
            "100%|██████████| 5000/5000 [11:25<00:00,  7.30it/s]\n",
            " 99%|█████████▉| 4947/5000 [11:24<00:03, 15.62it/s]\n",
            "100%|█████████▉| 4998/5000 [11:25<00:00, 15.68it/s]\n",
            "100%|██████████| 5000/5000 [11:25<00:00,  7.30it/s]\n",
            "100%|██████████| 5000/5000 [11:25<00:00,  7.29it/s]\n",
            "\n",
            "100%|██████████| 5000/5000 [11:25<00:00,  7.29it/s]\n",
            "100%|██████████| 5000/5000 [11:26<00:00,  7.29it/s]\n",
            " 99%|█████████▉| 4971/5000 [11:25<00:00, 31.90it/s]\n",
            "100%|██████████| 5000/5000 [11:26<00:00,  7.29it/s]\n",
            "100%|██████████| 5000/5000 [11:26<00:00,  7.29it/s]\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Run Plugin UCB policy\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            " 99%|█████████▊| 4927/5000 [11:22<00:08,  8.21it/s]\n",
            " 99%|█████████▉| 4964/5000 [11:25<00:04,  8.75it/s]\n",
            "100%|█████████▉| 4990/5000 [11:24<00:01,  8.59it/s]\n",
            "100%|█████████▉| 4999/5000 [11:26<00:00,  9.03it/s]\n",
            "100%|█████████▉| 4998/5000 [11:26<00:00,  9.72it/s]\n",
            "100%|██████████| 5000/5000 [11:26<00:00,  7.29it/s]\n",
            "\n",
            "100%|█████████▉| 4982/5000 [11:27<00:01,  9.65it/s]\n",
            " 99%|█████████▉| 4971/5000 [11:28<00:02, 12.69it/s]\n",
            "100%|█████████▉| 4994/5000 [11:27<00:00, 13.32it/s]\n",
            " 99%|█████████▉| 4967/5000 [11:28<00:02, 16.19it/s]\n",
            " 99%|█████████▉| 4972/5000 [11:28<00:01, 16.99it/s]\n",
            "100%|██████████| 5000/5000 [11:28<00:00,  7.26it/s]\n",
            "100%|█████████▉| 4987/5000 [11:29<00:00, 18.31it/s]\n",
            "100%|██████████| 5000/5000 [11:29<00:00,  7.25it/s]\n",
            "100%|█████████▉| 4983/5000 [11:29<00:00, 23.52it/s]\n",
            "100%|██████████| 5000/5000 [11:29<00:00,  7.25it/s]\n",
            "100%|██████████| 5000/5000 [11:29<00:00,  7.25it/s]\n",
            "100%|██████████| 5000/5000 [11:29<00:00,  7.25it/s]\n",
            "100%|██████████| 5000/5000 [11:29<00:00,  7.25it/s]\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Run Plugin UCB policy\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            " 99%|█████████▉| 4954/5000 [11:20<00:06,  6.63it/s]\n",
            " 99%|█████████▉| 4965/5000 [11:21<00:04,  7.72it/s]\n",
            " 99%|█████████▉| 4952/5000 [11:22<00:06,  7.62it/s]\n",
            "100%|█████████▉| 4977/5000 [11:22<00:03,  7.37it/s]\n",
            "100%|█████████▉| 4993/5000 [11:22<00:00, 10.07it/s]\n",
            "100%|█████████▉| 4998/5000 [11:24<00:00, 12.48it/s]\n",
            " 99%|█████████▉| 4964/5000 [11:24<00:03, 11.36it/s]\n",
            "100%|█████████▉| 4995/5000 [11:24<00:00, 11.79it/s]\n",
            "100%|█████████▉| 4978/5000 [11:24<00:01, 11.85it/s]\n",
            "100%|██████████| 5000/5000 [11:24<00:00,  7.30it/s]\n",
            " 99%|█████████▉| 4970/5000 [11:24<00:02, 14.33it/s]\n",
            "100%|█████████▉| 4975/5000 [11:24<00:01, 16.95it/s]\n",
            "100%|██████████| 5000/5000 [11:23<00:00,  7.31it/s]\n",
            "100%|█████████▉| 4999/5000 [11:24<00:00, 16.17it/s]"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Run Plugin UCB policy\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "100%|█████████▉| 4977/5000 [11:24<00:01, 16.98it/s]\n",
            "100%|██████████| 5000/5000 [11:24<00:00,  7.30it/s]\n",
            "100%|█████████▉| 4980/5000 [11:25<00:01, 19.24it/s]\n",
            "100%|██████████| 5000/5000 [11:24<00:00,  7.30it/s]\n",
            "100%|██████████| 5000/5000 [11:25<00:00,  7.30it/s]\n",
            "100%|██████████| 5000/5000 [11:25<00:00,  7.30it/s]\n",
            "100%|██████████| 5000/5000 [11:25<00:00,  7.30it/s]\n",
            " 99%|█████████▉| 4959/5000 [11:21<00:05,  7.10it/s]\n",
            " 99%|█████████▉| 4941/5000 [11:22<00:07,  7.68it/s]\n",
            " 99%|█████████▉| 4949/5000 [11:23<00:05,  9.39it/s]\n",
            "100%|█████████▉| 4988/5000 [11:23<00:01,  8.59it/s]\n",
            " 99%|█████████▉| 4969/5000 [11:23<00:03,  9.47it/s]\n",
            "100%|██████████| 5000/5000 [11:25<00:00,  7.30it/s]\n",
            "100%|█████████▉| 4992/5000 [11:25<00:00, 10.78it/s]\n",
            " 99%|█████████▉| 4965/5000 [11:25<00:02, 11.69it/s]\n",
            " 99%|█████████▉| 4967/5000 [11:25<00:02, 11.98it/s]\n",
            "100%|█████████▉| 4976/5000 [11:26<00:01, 12.81it/s]\n",
            "100%|█████████▉| 4998/5000 [11:26<00:00, 13.61it/s]\n",
            "100%|█████████▉| 4980/5000 [11:26<00:01, 15.64it/s]\n",
            "100%|██████████| 5000/5000 [11:25<00:00,  7.30it/s]\n",
            "\n",
            "100%|██████████| 5000/5000 [11:26<00:00,  7.29it/s]\n",
            "100%|█████████▉| 4987/5000 [11:26<00:00, 18.72it/s]\n",
            "100%|█████████▉| 4983/5000 [11:26<00:00, 18.43it/s]\n",
            "100%|█████████▉| 4994/5000 [11:26<00:00, 33.46it/s]\n",
            "100%|██████████| 5000/5000 [11:26<00:00,  7.28it/s]\n",
            "100%|██████████| 5000/5000 [11:27<00:00,  7.28it/s]\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Run Oracle policy\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            " 20%|█▉        | 983/5000 [02:07<09:16,  7.22it/s]"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pUmXNOLfs0kX"
      },
      "source": [
        "# Plot error"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "6dAH8Aucl-11"
      },
      "source": [
        "from numpy.core.fromnumeric import shape\n",
        "import matplotlib\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "\n",
        "plt.style.use('ggplot')\n",
        "plt.figure(dpi=200)\n",
        "\n",
        "vi_true = value_iteration()\n",
        "v_true = vi_true.run_value_iteration(env.R)[0]\n",
        "print(v_true)\n",
        "\n",
        "\n",
        "\n",
        "error_ucb_exploration_policy1 = (v_pi_ucb_exploration_policy1 - v_true)**2\n",
        "error_ucb_exploration_policy2 = (v_pi_ucb_exploration_policy2 - v_true)**2\n",
        "error_ucb_exploration_policy3 = (v_pi_ucb_exploration_policy3 - v_true)**2\n",
        "error_ucb_exploration_policy4 = (v_pi_ucb_exploration_policy4 - v_true)**2\n",
        "error_ucb_exploration_policy5 = (v_pi_ucb_exploration_policy5 - v_true)**2\n",
        "error_ucb_exploration_policy6 = (v_pi_ucb_exploration_policy6 - v_true)**2\n",
        "error_oraclepolicy = (v_pi_oraclepolicy - v_true)**2\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "# error_onpolicy = error_onpolicy[:,:,0]\n",
        "error_ucb_exploration_policy1 = error_ucb_exploration_policy1[:,:,0]\n",
        "error_ucb_exploration_policy2 = error_ucb_exploration_policy2[:,:,0]\n",
        "error_ucb_exploration_policy3 = error_ucb_exploration_policy3[:,:,0]\n",
        "error_ucb_exploration_policy4 = error_ucb_exploration_policy4[:,:,0]\n",
        "error_ucb_exploration_policy5 = error_ucb_exploration_policy5[:,:,0]\n",
        "error_ucb_exploration_policy6 = error_ucb_exploration_policy6[:,:,0]\n",
        "error_oraclepolicy = error_oraclepolicy[:,:,0]\n",
        "# error_cbvarpolicy = error_cbvarpolicy[:,:,0]\n",
        "\n",
        "\n",
        "\n",
        "# error_onpolicy = np.reshape(error_onpolicy, (env.num_trials, env.episodes))\n",
        "error_ucb_exploration_policy1 = np.reshape(error_ucb_exploration_policy1, (env.num_trials, env.episodes))\n",
        "error_ucb_exploration_policy2 = np.reshape(error_ucb_exploration_policy2, (env.num_trials, env.episodes))\n",
        "error_ucb_exploration_policy3 = np.reshape(error_ucb_exploration_policy3, (env.num_trials, env.episodes))\n",
        "error_ucb_exploration_policy4 = np.reshape(error_ucb_exploration_policy4, (env.num_trials, env.episodes))\n",
        "error_ucb_exploration_policy5 = np.reshape(error_ucb_exploration_policy5, (env.num_trials, env.episodes))\n",
        "error_ucb_exploration_policy6 = np.reshape(error_ucb_exploration_policy6, (env.num_trials, env.episodes))\n",
        "error_oraclepolicy = np.reshape(error_oraclepolicy, (env.num_trials, env.episodes))\n",
        "# error_cbvarpolicy = np.reshape(error_cbvarpolicy, (env.num_trials, env.episodes))\n",
        "\n",
        "\n",
        "scale = np.arange(0, env.episodes, 50)\n",
        "\n",
        "# x = np.logspace(1, 1000, 20, base = 1.0076, endpoint = True)\n",
        "# scale = [int(i) for i in x]\n",
        "# print(scale)\n",
        "\n",
        "\n",
        "k = 0.5\n",
        "\n",
        "# #BBF890\n",
        "\n",
        "\n",
        "\n",
        "# plt.errorbar(scale, [np.average(error_onpolicy, axis=0)[i] for i in scale], [k*np.std(error_onpolicy, axis=0)[i] for i in scale], color = \"#009E73\", linewidth = 2, capsize = 3.0, capthick = 1.0, alpha = 4.0, label='Onpolicy', linestyle = \"-.\", marker = \"o\", markersize = 6.0)\n",
        "# plt.errorbar(scale, [np.average(error_cbvarpolicy , axis=0)[i] for i in scale], [k*np.std(error_cbvarpolicy, axis=0)[i] for i in scale], color = \"#00FFFF\", linewidth = 2, capsize = 3.0, capthick = 1.0, alpha = 4.0, label='CB-Var', linestyle = \"-.\", marker = \"o\", markersize = 6.0)\n",
        "\n",
        "plt.errorbar(scale, [np.average(error_oraclepolicy, axis=0)[i] for i in scale], [k*np.std(error_oraclepolicy, axis=0)[i] for i in scale], color = \"#D55E00\", linewidth = 2, capsize = 3.0, capthick = 1.0, alpha = 4.0, label='Oracle (Ours)', linestyle = \"-.\", marker = \"o\", markersize = 6.0)\n",
        "\n",
        "plt.errorbar(scale, [np.average(error_ucb_exploration_policy1, axis=0)[i] for i in scale], [k*np.std(error_ucb_exploration_policy1, axis=0)[i] for i in scale], color = \"#FF0000\", linewidth = 2, capsize = 3.0, capthick = 1.0, alpha = 4.0, label='Plugin-UCB (Ours), c = 0', linestyle = \"-.\", marker = \"o\", markersize = 6.0)\n",
        "plt.errorbar(scale, [np.average(error_ucb_exploration_policy2, axis=0)[i] for i in scale], [k*np.std(error_ucb_exploration_policy2, axis=0)[i] for i in scale], color = \"#009E73\", linewidth = 2, capsize = 3.0, capthick = 1.0, alpha = 4.0, label='Plugin-UCB (Ours), c = 1', linestyle = \"-.\", marker = \"o\", markersize = 6.0)\n",
        "plt.errorbar(scale, [np.average(error_ucb_exploration_policy3, axis=0)[i] for i in scale], [k*np.std(error_ucb_exploration_policy3, axis=0)[i] for i in scale], color = \"#0000FF\", linewidth = 2, capsize = 3.0, capthick = 1.0, alpha = 4.0, label='Plugin-UCB (Ours), c =10', linestyle = \"-.\", marker = \"o\", markersize = 6.0)\n",
        "plt.errorbar(scale, [np.average(error_ucb_exploration_policy4, axis=0)[i] for i in scale], [k*np.std(error_ucb_exploration_policy4, axis=0)[i] for i in scale], color = \"#00FFFF\", linewidth = 2, capsize = 3.0, capthick = 1.0, alpha = 4.0, label='Plugin-UCB (Ours), c =100', linestyle = \"-.\", marker = \"o\", markersize = 6.0)\n",
        "plt.errorbar(scale, [np.average(error_ucb_exploration_policy5, axis=0)[i] for i in scale], [k*np.std(error_ucb_exploration_policy5, axis=0)[i] for i in scale], color = \"#F0FF0F\", linewidth = 2, capsize = 3.0, capthick = 1.0, alpha = 4.0, label='Plugin-UCB (Ours), c =100', linestyle = \"-.\", marker = \"o\", markersize = 6.0)\n",
        "plt.errorbar(scale, [np.average(error_ucb_exploration_policy6, axis=0)[i] for i in scale], [k*np.std(error_ucb_exploration_policy6, axis=0)[i] for i in scale], color = \"#0FF00F\", linewidth = 2, capsize = 3.0, capthick = 1.0, alpha = 4.0, label='Plugin-UCB (Ours), c =100', linestyle = \"-.\", marker = \"o\", markersize = 6.0)\n",
        "\n",
        "\n",
        "plt.yscale('log')\n",
        "plt.xscale('log')\n",
        "\n",
        "\n",
        "plt.xlabel(\"Episodes\")\n",
        "plt.ylabel(\"MSE\")\n",
        "\n",
        "plt.xlim(1,5000)\n",
        "# plt.ylim(-10,10)\n",
        "\n",
        "plt.legend()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Load Drive"
      ],
      "metadata": {
        "id": "o2IvGYSqZ5uy"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from google.colab import drive\n",
        "drive.mount('/content/drive')"
      ],
      "metadata": {
        "id": "_gXVQWqBZ7bk"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Write file"
      ],
      "metadata": {
        "id": "dmJg-HQXW3rz"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import pickle\n",
        "\n",
        "f1 = v_pi_ucb_exploration_policy1\n",
        "f2 = v_pi_ucb_exploration_policy2\n",
        "f3 = v_pi_ucb_exploration_policy3\n",
        "f4 = v_pi_oraclepolicy\n",
        "\n",
        "\n",
        "\n",
        "with open('/content/drive/My Drive/Dataset/ReVar/tree_data_ablation.pickle', 'wb') as handle:\n",
        "    pickle.dump([f1, f2, f3, f4], handle, protocol=pickle.HIGHEST_PROTOCOL)\n"
      ],
      "metadata": {
        "id": "T_XrwT5TWkor"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Read Data"
      ],
      "metadata": {
        "id": "W1pksr6PZmvc"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import pickle\n",
        "\n",
        "with open('/content/drive/My Drive/Dataset/ReVar/tree_data_ablation.pickle', 'rb') as handle:\n",
        "    b = pickle.load(handle)\n",
        "\n",
        "v_pi_ucb_exploration_policy1 = b[0]\n",
        "v_pi_ucb_exploration_policy2 = b[1]\n",
        "v_pi_ucb_exploration_policy3 = b[2]\n",
        "v_pi_oraclepolicy = b[3]\n"
      ],
      "metadata": {
        "id": "r4cFw0mgXh3U"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Plot from Read File"
      ],
      "metadata": {
        "id": "QX8cERfabISI"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from numpy.core.fromnumeric import shape\n",
        "import matplotlib\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "\n",
        "plt.style.use('ggplot')\n",
        "plt.figure(dpi=200)\n",
        "\n",
        "vi_true = value_iteration()\n",
        "v_true = vi_true.run_value_iteration(env.R)[0]\n",
        "print(v_true)\n",
        "\n",
        "\n",
        "\n",
        "error_ucb_exploration_policy1 = (v_pi_ucb_exploration_policy1 - v_true)**2\n",
        "error_ucb_exploration_policy2 = (v_pi_ucb_exploration_policy2 - v_true)**2\n",
        "error_ucb_exploration_policy3 = (v_pi_ucb_exploration_policy3 - v_true)**2\n",
        "error_oraclepolicy = (v_pi_oraclepolicy - v_true)**2\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "# error_onpolicy = error_onpolicy[:,:,0]\n",
        "error_ucb_exploration_policy1 = error_ucb_exploration_policy1[:,:,0]\n",
        "error_ucb_exploration_policy2 = error_ucb_exploration_policy2[:,:,0]\n",
        "error_ucb_exploration_policy3 = error_ucb_exploration_policy3[:,:,0]\n",
        "error_oraclepolicy = error_oraclepolicy[:,:,0]\n",
        "# error_cbvarpolicy = error_cbvarpolicy[:,:,0]\n",
        "\n",
        "\n",
        "\n",
        "# error_onpolicy = np.reshape(error_onpolicy, (env.num_trials, env.episodes))\n",
        "error_ucb_exploration_policy1 = np.reshape(error_ucb_exploration_policy1, (env.num_trials, env.episodes))\n",
        "error_ucb_exploration_policy2 = np.reshape(error_ucb_exploration_policy2, (env.num_trials, env.episodes))\n",
        "error_ucb_exploration_policy3 = np.reshape(error_ucb_exploration_policy3, (env.num_trials, env.episodes))\n",
        "error_oraclepolicy = np.reshape(error_oraclepolicy, (env.num_trials, env.episodes))\n",
        "# error_cbvarpolicy = np.reshape(error_cbvarpolicy, (env.num_trials, env.episodes))\n",
        "\n",
        "\n",
        "# scale = np.arange(0, env.episodes, 50)\n",
        "\n",
        "x = np.logspace(1, 1000, 20, base = 1.0076, endpoint = True)\n",
        "scale = [int(i) for i in x]\n",
        "print(scale)\n",
        "\n",
        "\n",
        "k = 0.5\n",
        "\n",
        "# #BBF890\n",
        "\n",
        "\n",
        "\n",
        "# plt.errorbar(scale, [np.average(error_onpolicy, axis=0)[i] for i in scale], [k*np.std(error_onpolicy, axis=0)[i] for i in scale], color = \"#009E73\", linewidth = 2, capsize = 3.0, capthick = 1.0, alpha = 4.0, label='Onpolicy', linestyle = \"-.\", marker = \"o\", markersize = 6.0)\n",
        "# plt.errorbar(scale, [np.average(error_cbvarpolicy , axis=0)[i] for i in scale], [k*np.std(error_cbvarpolicy, axis=0)[i] for i in scale], color = \"#00FFFF\", linewidth = 2, capsize = 3.0, capthick = 1.0, alpha = 4.0, label='CB-Var', linestyle = \"-.\", marker = \"o\", markersize = 6.0)\n",
        "\n",
        "plt.errorbar(scale, [np.average(error_oraclepolicy, axis=0)[i] for i in scale], [k*np.std(error_oraclepolicy, axis=0)[i] for i in scale], color = \"#D55E00\", linewidth = 2, capsize = 3.0, capthick = 1.0, alpha = 4.0, label='Oracle (Ours)', linestyle = \"-.\", marker = \"o\", markersize = 6.0)\n",
        "\n",
        "plt.errorbar(scale, [np.average(error_ucb_exploration_policy1, axis=0)[i] for i in scale], [k*np.std(error_ucb_exploration_policy1, axis=0)[i] for i in scale], color = \"#FF0000\", linewidth = 2, capsize = 3.0, capthick = 1.0, alpha = 4.0, label='Plugin-UCB (Ours), c = 1', linestyle = \"-.\", marker = \"o\", markersize = 6.0)\n",
        "plt.errorbar(scale, [np.average(error_ucb_exploration_policy2, axis=0)[i] for i in scale], [k*np.std(error_ucb_exploration_policy2, axis=0)[i] for i in scale], color = \"#009E73\", linewidth = 2, capsize = 3.0, capthick = 1.0, alpha = 4.0, label='Plugin-UCB (Ours), c = 10', linestyle = \"-.\", marker = \"o\", markersize = 6.0)\n",
        "plt.errorbar(scale, [np.average(error_ucb_exploration_policy3, axis=0)[i] for i in scale], [k*np.std(error_ucb_exploration_policy3, axis=0)[i] for i in scale], color = \"#0000FF\", linewidth = 2, capsize = 3.0, capthick = 1.0, alpha = 4.0, label='Plugin-UCB (Ours), c =100', linestyle = \"-.\", marker = \"o\", markersize = 6.0)\n",
        "\n",
        "plt.title('Ablation Study in a $4$-Depth, $2$-Action Tree', size = 20, fontweight='bold')\n",
        "\n",
        "plt.yscale('log')\n",
        "plt.xscale('log')\n",
        "\n",
        "\n",
        "plt.xlabel(\"Episodes\")\n",
        "plt.ylabel(\"MSE\")\n",
        "\n",
        "plt.xlim(1,5000)\n",
        "# plt.ylim(-10,10)\n",
        "\n",
        "plt.legend()"
      ],
      "metadata": {
        "id": "_BfNnyqpYTbz"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}