{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true,
    "ExecuteTime": {
     "end_time": "2023-09-28T20:27:10.470428102Z",
     "start_time": "2023-09-28T20:27:08.690373484Z"
    }
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "from dataclasses import dataclass\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "from tqdm.notebook import tnrange\n",
    "\n",
    "from frozen_lake import FrozenLakeEnv, frozen_lake_env_from_string, frozen_lake_policy_from_string\n",
    "from policies import TabularPolicy, LinearPolicy, POPLinearPolicy, EpsilonGreedyPolicy\n",
    "from q_models import LinearQModel, CQL, POPQ, LinearGModel"
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Part 1: Find a Problem\n",
    "\n",
    "We find an (environment, data policy, optimal policy) such that the distribution over states consistent with the (environment + data policy) is different than the (environment + optimal policy). Crucially, we insist that:\n",
    "\n",
    " 1. The *full coverage* assumption holds (i.e. the effective[1] support of data policy is a superset of the effective support of the optimal policy.)\n",
    " 2. The environment and/or optimal policy has at least a little stochasticity in it (enough that the learning process doesn't \"get stuck\" during training)\n",
    " 3. The F-matrix generated by the transition matrix from the data policy but the state distribution from the optimal policy is NOT PSD.\n",
    "\n",
    "The first two conditions mean that it is possible to learn the true optimal policy while following the data policy. The final condition means that if we draw transitions from the data policy while learning the optimal policy, the resultant TD updates are not guaranteed to be non-expansive. Whether the resultant TD updates diverge depends on the chosen basis, random luck, etc. It also depends on how our insights from MDPs transfer to the case where actions exist.\n",
    "\n",
    "[1] *effective* support of a policy = the support of the distribution over states consistent with that policy, for some fixed environment."
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "outputs": [],
   "source": [
    "def sample_trajectories(env, policy, n=1000, progress_bar=False):\n",
    "    \"\"\"Sample trajectories from the environment following the given policy.\"\"\"\n",
    "    states = np.zeros([n+1], dtype=np.int32)\n",
    "    actions = np.zeros([n], dtype=np.int32)\n",
    "    rewards = np.zeros([n], dtype=np.float32)\n",
    "    dones = np.zeros([n], dtype=bool)\n",
    "    states[0] = env.reset()\n",
    "    for i in tnrange(n, desc='Sampling Trajectories', disable=not progress_bar):\n",
    "        actions[i] = policy.act(states[i])\n",
    "        states[i+1], rewards[i], done, _ = env.step(actions[i])\n",
    "        if done:\n",
    "            dones[i] = True\n",
    "            states[i+1] = env.reset()\n",
    "    return states, actions, rewards, dones\n",
    "\n",
    "@dataclass(frozen=True)\n",
    "class FrozenLakeDataset:\n",
    "    s: np.ndarray\n",
    "    a: np.ndarray\n",
    "    sp: np.ndarray\n",
    "    r: np.ndarray\n",
    "    done: np.ndarray\n",
    "\n",
    "    def to_torch(self, device=None):\n",
    "        return FrozenLakeDataset(\n",
    "            s=torch.from_numpy(self.s).to(dtype=torch.int64, device=device),\n",
    "            a=torch.from_numpy(self.a).to(dtype=torch.int64, device=device),\n",
    "            sp=torch.from_numpy(self.sp).to(dtype=torch.int64, device=device),\n",
    "            r=torch.from_numpy(self.r).to(dtype=torch.float32, device=device),\n",
    "            done=torch.from_numpy(self.done).to(dtype=torch.bool, device=device),\n",
    "        )\n",
    "\n",
    "def generate_dataset(env, policy, n=1000, progress_bar=False):\n",
    "    \"\"\"Generate a dataset from the environment following the given policy.\"\"\"\n",
    "    states, actions, rewards, dones = sample_trajectories(env, policy, n, progress_bar=progress_bar)\n",
    "    return FrozenLakeDataset(\n",
    "        s=states[:-1],\n",
    "        a=actions,\n",
    "        sp=states[1:],\n",
    "        r=rewards,\n",
    "        done=dones,\n",
    "    )\n"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-09-28T20:27:11.238666295Z",
     "start_time": "2023-09-28T20:27:11.213135108Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Data Policy and Distribution:\n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n"
     ]
    }
   ],
   "source": [
    "frzmap = [\n",
    "    \"FFFS\",\n",
    "    \"FHFH\",\n",
    "    \"FFFF\",\n",
    "    \"FFHG\"\n",
    "]\n",
    "\n",
    "# loop = False\n",
    "loop = True\n",
    "env = frozen_lake_env_from_string(\n",
    "    frzmap, loop=loop, slippery=0.1\n",
    ")\n",
    "num_states = env.observation_space.n\n",
    "num_actions = env.action_space.n\n",
    "R = env.get_reward_matrix().reshape((num_states * num_actions, 1))\n",
    "\n",
    "# Generate Data distribution\n",
    "data_policy = frozen_lake_policy_from_string(\n",
    "    [\n",
    "        \"↓←←←\",\n",
    "        \"↓↑↑↑\",\n",
    "        \"↓→→↓\",\n",
    "        \"→↑↑↑\",\n",
    "    ],\n",
    "    epsilon=.5,\n",
    ")\n",
    "\n",
    "print(\"Data Policy and Distribution:\")\n",
    "print(env.render_policy(data_policy))"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-09-28T20:27:11.481103466Z",
     "start_time": "2023-09-28T20:27:11.434883503Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Eval Policy and Distribution:\n",
      "↓→↓S\n",
      "↓#↓#\n",
      "→→→↓\n",
      "→↑#G\n"
     ]
    }
   ],
   "source": [
    "opt_policy = env.shortest_path_policy()\n",
    "print(\"Eval Policy and Distribution:\")\n",
    "print(env.render_policy(opt_policy))"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-09-28T20:27:11.668771857Z",
     "start_time": "2023-09-28T20:27:11.655487540Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "outputs": [],
   "source": [
    "def get_P_policy(policy, P):\n",
    "    # Compute the (S x A) x (S x A) transition matrix for the given policy using the S x A x S transition matrix.\n",
    "    num_states, num_actions, _ = P.shape\n",
    "    n = num_states * num_actions\n",
    "    policy_table = policy.dist(np.arange(num_states))\n",
    "    P_policy = (P[:, :, :, None] * policy_table[None, None, :, :]).reshape([n, n])\n",
    "\n",
    "    return P_policy\n",
    "\n",
    "def compute_F_s(policy, P, Phi):\n",
    "    # Compute F(s) = E_{s' ~ P(s'|s)}[F(s, s')] for the given policy and transition matrix.\n",
    "\n",
    "    P_policy = get_P_policy(policy, P)\n",
    "\n",
    "    A = Phi[:, :, None] @ Phi[:, None, :]\n",
    "    B = Phi[:, :, None] @ (P_policy @ Phi)[:, None, :]\n",
    "\n",
    "    return np.concatenate([np.concatenate([A, B], axis=2), np.concatenate([B.transpose(0, 2, 1), A], axis=2)], axis=1)\n",
    "\n",
    "def compute_expected_F(mu, policy, P, Phi):\n",
    "    # Compute $$E_{s ~ \\mu}[F(s)]$$ for the given state-action distribution, $$\\mu$$, and policy.\n",
    "\n",
    "    return (mu[:, None, None] * compute_F_s(policy, P, Phi)).sum(axis=0)\n",
    "\n",
    "def get_policy_distribution(env, policy):\n",
    "    P = get_P_policy(policy, env.get_transition_matrix())\n",
    "    evals, evecs = np.linalg.eig(P.T)\n",
    "    mu = np.real(evecs[:, np.argmax(np.real(evals))])\n",
    "    return mu / mu.sum()"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-09-28T20:27:11.903981944Z",
     "start_time": "2023-09-28T20:27:11.879025039Z"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "We found an (environment, data policy, optimal policy) that satisfies our requirements. `cfg` specifies the environment, `policy_data` is the data generation policy and `policy_test` is the optimal policy.\n",
    "\n",
    "We now generate an entire fixed dataset and examine the number of times each state is visited. (Coverage over actions can be assumed because the data policy is $\\epsilon$-greedy with a generous $\\epsilon=0.5$)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "outputs": [
    {
     "data": {
      "text/plain": "Sampling Trajectories:   0%|          | 0/1000000 [00:00<?, ?it/s]",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "4398758cacf943dfaf2b0f1c97cf16b1"
      }
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Off-policy dataset distribution:\n",
      "[[103031 108204 134467 148649]\n",
      " [ 82247  35839  23124  26935]\n",
      " [ 75666  48100  34867  23738]\n",
      " [ 72171  56311  12711  13940]]\n"
     ]
    },
    {
     "data": {
      "text/plain": "Sampling Trajectories:   0%|          | 0/1000000 [00:00<?, ?it/s]",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "d02d59f239c94389b7536cd8331d3b5a"
      }
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "On-policy dataset distribution:\n",
      "[[  7919  34853 210961 242228]\n",
      " [  6732  27762 138184  62497]\n",
      " [  8491  22152 103317  70507]\n",
      " [  2639   5277  15085  41396]]\n"
     ]
    }
   ],
   "source": [
    "DATASET_SIZE = int(1e6)\n",
    "# DATASET_SIZE = int(1e4)\n",
    "\n",
    "def visit_map(env, dataset):\n",
    "    counts = np.bincount(dataset.s, minlength=env.observation_space.n)\n",
    "    return counts.reshape(env.map.shape)\n",
    "\n",
    "off_policy_dataset = generate_dataset(env, data_policy, n=DATASET_SIZE, progress_bar=True)\n",
    "mu = get_policy_distribution(env, data_policy)\n",
    "print('Off-policy dataset distribution:')\n",
    "print(visit_map(env, off_policy_dataset))\n",
    "print()\n",
    "\n",
    "eval_policy = EpsilonGreedyPolicy(opt_policy, env.action_space, epsilon=0.5)\n",
    "on_policy_dataset = generate_dataset(env, eval_policy, n=DATASET_SIZE, progress_bar=True)\n",
    "mu_on_policy = get_policy_distribution(env, eval_policy)\n",
    "print('On-policy dataset distribution:')\n",
    "print(visit_map(env, on_policy_dataset))"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-09-28T20:28:15.525830351Z",
     "start_time": "2023-09-28T20:27:12.440864635Z"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "We also care about the distribution of Q-values under the optimal policy. We pre-compute the state-action-distribution here:"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "outputs": [],
   "source": [
    "def get_q_table(q_model):\n",
    "    q_table = np.zeros((env.observation_space.n, env.action_space.n))\n",
    "    for s in range(env.observation_space.n):\n",
    "        q_table[s] = q_model(s)\n",
    "    return q_table\n",
    "\n",
    "# Compute the TD-error:\n",
    "def compute_TD_error(env, q_table, policy=None, gamma=0.99):\n",
    "    target_q = np.zeros((env.observation_space.n, env.action_space.n))\n",
    "    R = env.get_reward_matrix()\n",
    "    P = env.get_transition_matrix()\n",
    "    for s in range(env.observation_space.n):\n",
    "        for a in range(env.action_space.n):\n",
    "            target_q[s, a] = R[s, a]\n",
    "            for sp in range(env.observation_space.n):\n",
    "                if policy is None:\n",
    "                    target_q[s, a] += P[s, a, sp] * gamma * np.max(q_table[sp])\n",
    "                else:\n",
    "                    target_q[s, a] += P[s, a, sp] * gamma * np.sum(q_table[sp] * policy.dist(sp))\n",
    "    return np.mean((target_q - q_table) ** 2)\n",
    "\n",
    "# Compute the Q-error using the dataset:\n",
    "def compute_Q_error(env, dataset, reference_q_table, q_table):\n",
    "    idx = dataset.s * env.action_space.n + dataset.a\n",
    "    counts = np.bincount(idx, minlength=env.observation_space.n * env.action_space.n)\n",
    "    dist = counts / np.sum(counts)\n",
    "    q_errors = (reference_q_table - q_table) ** 2\n",
    "    return np.sum(q_errors.flatten() * dist)\n",
    "\n",
    "# Compute the performance of the policy:\n",
    "def compute_performance(env, policy):\n",
    "    start_s = env.reset()\n",
    "    \n",
    "    P = env.get_transition_matrix()\n",
    "    R = env.get_reward_matrix()\n",
    "    done = env.get_terminal_matrix()\n",
    "    q_table = np.zeros((env.observation_space.n, env.action_space.n))\n",
    "    q_table_prime = np.zeros_like(q_table)\n",
    "    \n",
    "    dist = policy.dist(np.arange(env.observation_space.n))\n",
    "\n",
    "    while True:\n",
    "        q_table_prime *= 0\n",
    "        for s in range(env.observation_space.n):\n",
    "            for a in range(env.action_space.n):\n",
    "                for sp in range(env.observation_space.n):\n",
    "                    v_next = (dist[sp] * q_table[sp]).sum()\n",
    "                    q_table_prime[s, a] += P[s, a, sp] * (R[s, a] + (1 - done[s, a]) * v_next)\n",
    "        if np.allclose(q_table, q_table_prime) or np.isnan(q_table_prime).any():\n",
    "            break\n",
    "        q_table = q_table_prime.copy()\n",
    "\n",
    "    if np.isnan(q_table).any():\n",
    "        return -np.inf\n",
    "    else:\n",
    "        return (q_table[start_s].flatten() * dist[start_s].flatten()).sum()"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-09-28T20:28:15.621528540Z",
     "start_time": "2023-09-28T20:28:15.615909346Z"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Part 2: Showing This Problem Happens\n",
    "\n",
    "We wish to show that the problem we identified persists in $Q$-functions, not just value functions. This is a necessary step in showing that our POP algorithm works in real RL algorithms. To do this, we compare the performance of a linear Q-function trained on- and off-policy.\n",
    "\n",
    "Two caveats:\n",
    "\n",
    "1. The ground-truth Q-values are obtained using a simple tabular method, which is guaranteed to converge in this case.\n",
    "2. The \"on-policy\" data is not strictly on-policy, instead it is the optimal policy with some small amount of dithering (necessary to ensure reasonable coverage)."
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "outputs": [],
   "source": [
    "## Construct feature matrices\n",
    "seed = 2\n",
    "np.random.seed(seed)\n",
    "\n",
    "# k = 63\n",
    "# k = 53\n",
    "# k = 47\n",
    "k = 60\n",
    "Phi = np.random.uniform(-1, 1, (env.observation_space.n * env.action_space.n, k))\n",
    "Phi /= np.linalg.norm(Phi, keepdims=True, axis=-1)\n",
    "\n",
    "explore_epsilon = 0.2"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-09-28T20:28:15.621726547Z",
     "start_time": "2023-09-28T20:28:15.616250627Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "outputs": [],
   "source": [
    "def dataset_occupancy(env, dataset):\n",
    "    \"\"\"Compute the occupancy of each state action pair in the dataset.\"\"\"\n",
    "    idx = dataset.s * env.action_space.n + dataset.a\n",
    "    counts = np.bincount(idx, minlength=env.observation_space.n * env.action_space.n)\n",
    "    return counts / np.sum(counts)\n",
    "\n",
    "default_train_device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n",
    "\n",
    "def train(\n",
    "    q_model,\n",
    "    policy,\n",
    "    dataset=None,\n",
    "    iters=1024,\n",
    "    batch_size=256,\n",
    "    gamma=0.99,\n",
    "    progress_bar=False,\n",
    "    log_freq=int(1e3),\n",
    "    dataset_update_freq=int(1e2),\n",
    "    policy_update_freq=1,\n",
    "    policy_update_iters=1,\n",
    "    secondary_update_freq=-1,\n",
    "    secondary_update_iters=1,\n",
    "    bc=False,\n",
    "    device=None,\n",
    "):\n",
    "    log = {\n",
    "        'step': [],\n",
    "        'TD-error': [],\n",
    "        'q_loss': [],\n",
    "        'Q-error': []\n",
    "    }\n",
    "    if hasattr(q_model, 'g_model'):\n",
    "        log['g_loss'] = []\n",
    "        log['pop_obj'] = []\n",
    "\n",
    "    if device is None:\n",
    "        device = default_train_device\n",
    "\n",
    "    online = dataset is None\n",
    "\n",
    "    def sample_batch(d):\n",
    "        # sample mini-batch\n",
    "        idx = torch.randint(len(d.s), (batch_size,))\n",
    "\n",
    "        s = d.s[idx,...]\n",
    "        a = d.a[idx,...]\n",
    "        sp = d.sp[idx,...]\n",
    "        r = d.r[idx,...]\n",
    "        done = d.done[idx,...]\n",
    "        ap = policy.act(sp)\n",
    "\n",
    "        return s, a, r, sp, ap, done\n",
    "\n",
    "    if not online:\n",
    "        data_dist = dataset_occupancy(env, dataset)\n",
    "        torch_dataset = dataset.to_torch(device=device)\n",
    "    elif bc:\n",
    "        raise ValueError(\"Behavior cloning not implemented for online training.\")\n",
    "\n",
    "    update_info = {}\n",
    "    secondary_update_info = {}\n",
    "    policy_update_info = {}\n",
    "    for step in tnrange(iters, desc=f'Training {\"online\" if online else \"offline\"}', disable=not progress_bar):\n",
    "        if online and step % dataset_update_freq == 0:\n",
    "            dataset = generate_dataset(env, policy, n=int(batch_size*dataset_update_freq), progress_bar=False)\n",
    "            data_dist = get_policy_distribution(env, policy)\n",
    "            torch_dataset = dataset.to_torch(device=device)\n",
    "\n",
    "        if secondary_update_freq > 0 and step % secondary_update_freq == 0:\n",
    "            for step2 in range(secondary_update_iters):\n",
    "                s, a, r, sp, ap, done = sample_batch(torch_dataset)\n",
    "                secondary_update_info = q_model.secondary_update(s, a, r, sp, ap, done, gamma=gamma)\n",
    "\n",
    "            q_model.update_reweight()\n",
    "            F_s = compute_F_s(policy, env.get_transition_matrix(), Phi)\n",
    "            F_mu = (data_dist[:, None, None] * F_s).sum(0)\n",
    "            q = q_model.reweight.detach().cpu().numpy() * data_dist\n",
    "            obj = q.sum()\n",
    "            q = q / q.sum()\n",
    "            F_q = (q[:, None, None] * F_s).sum(0)\n",
    "            secondary_update_info['F_mu_min_eig'] = np.linalg.eigvalsh(F_mu).min()\n",
    "            secondary_update_info['F_q_min_eig'] = np.linalg.eigvalsh(F_q).min()\n",
    "            secondary_update_info['dual_obj'] = obj\n",
    "\n",
    "        s, a, r, sp, ap, done = sample_batch(torch_dataset)\n",
    "        if not bc:\n",
    "            update_info = q_model.update(s, a, r, sp, ap, done, gamma=gamma)\n",
    "\n",
    "        if policy_update_freq == 1 or (step + 1) % policy_update_freq == 0:\n",
    "            for step2 in range(policy_update_iters):\n",
    "                policy_update_info = policy.update(q_model, s, a, r, sp, done, gamma=gamma, bc=bc)\n",
    "\n",
    "        if step % log_freq == 0:\n",
    "            log['step'].append(step)\n",
    "\n",
    "            info = {}\n",
    "            info.update(update_info)\n",
    "            info.update(secondary_update_info)\n",
    "            info.update(policy_update_info)\n",
    "            for k, v in info.items():\n",
    "                if k not in log:\n",
    "                    log[k] = []\n",
    "                log[k].append(v)\n",
    "\n",
    "            if not bc:\n",
    "                td_error = compute_TD_error(env, get_q_table(q_model), policy=policy, gamma=gamma)\n",
    "                log['TD-error'].append(td_error)\n",
    "                print(f\"Step {step}: TD-error = {log['TD-error'][-1]:.4f}\", end='')\n",
    "                ref_q_table = compute_opt_q(env, gamma=gamma, policy=policy)\n",
    "                q_error = compute_Q_error(env, dataset, ref_q_table, get_q_table(q_model))\n",
    "                log['Q-error'].append(q_error)\n",
    "                print(f\", Q-error = {log['Q-error'][-1]:.4f}\", end='')\n",
    "                \n",
    "                ## Test\n",
    "                # print(f'\\n Q Values at state 11 = {q_model(11)}', end=', ')\n",
    "                # print(f'\\n Policy dist at state 11 = {policy.dist(11)}', end=', ')\n",
    "                print(f'\\n Q Values at state 2 = {q_model(2)}', end=', ')\n",
    "                print(f'\\n Policy dist at state 2 = {policy.dist(2)}', end=', ')\n",
    "                print(f'\\n Reweight at state 2 = {q_model.reweight.detach().cpu().numpy().reshape((num_states, num_actions))[2]}', end=', ')\n",
    "\n",
    "            if policy_update_info is not None and len(policy_update_info) > 0:\n",
    "                print('')\n",
    "                print(f\"Policy update info: \", end='')\n",
    "                for k, v in policy_update_info.items():\n",
    "                    print(f\"{k} = {policy_update_info[k]:.6f}\", end=', ')\n",
    "\n",
    "            if secondary_update_info is not None and len(secondary_update_info) > 0:\n",
    "                print('')\n",
    "                print(f\"Secondary update info: \", end='')\n",
    "                for k, v in secondary_update_info.items():\n",
    "                    print(f\"{k} = {secondary_update_info[k]:.6f}\", end=', ')\n",
    "\n",
    "            print('')\n",
    "            print(env.render_policy(policy))\n",
    "\n",
    "    return log"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-09-28T20:28:15.724365089Z",
     "start_time": "2023-09-28T20:28:15.720707990Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "REFERENCE OPTIMAL POLICY\n",
      "max(Q(s,*))\n",
      "[[ 9.32  9.43  9.58  9.44]\n",
      " [ 9.45  9.32  9.76  9.32]\n",
      " [ 9.59  9.76  9.96 10.18]\n",
      " [ 9.46  9.59  9.32 10.32]]\n",
      "\n",
      "Effective Policy\n",
      "↓→↓S\n",
      "↓#↓#\n",
      "→→→↓\n",
      "↑↑#G\n",
      "\n",
      "F min eigenvalue\n",
      "-0.014972407336599346\n"
     ]
    }
   ],
   "source": [
    "def compute_opt_q(\n",
    "    env,\n",
    "    gamma=0.99,\n",
    "    epsilon=0.0,\n",
    "    atol=1e-6,\n",
    "    policy=None\n",
    "):\n",
    "    P = env.get_transition_matrix()\n",
    "    R = env.get_reward_matrix()\n",
    "    done = env.get_terminal_matrix()\n",
    "    q_table = np.zeros((env.observation_space.n, env.action_space.n))\n",
    "    q_table_prime = np.zeros_like(q_table)\n",
    "\n",
    "    if policy is not None:\n",
    "        policy_table = policy.dist(np.arange(num_states))\n",
    "\n",
    "    while True:\n",
    "        q_table_prime *= 0\n",
    "        for s in range(env.observation_space.n):\n",
    "            for a in range(env.action_space.n):\n",
    "                for sp in range(env.observation_space.n):\n",
    "                    if policy is None:\n",
    "                        v_next = ((1 - epsilon) * q_table[sp].max() + epsilon * q_table[sp].mean())\n",
    "                    else:\n",
    "                        v_next = (policy_table[sp] * q_table[sp]).sum()\n",
    "                    q_table_prime[s, a] += P[s, a, sp] * (R[s, a] + gamma * (1 - done[s, a]) * v_next)\n",
    "        if np.allclose(q_table, q_table_prime, atol=atol):\n",
    "            break\n",
    "        q_table = q_table_prime.copy()\n",
    "    return q_table\n",
    "\n",
    "opt_q_table = compute_opt_q(env, epsilon=explore_epsilon)\n",
    "\n",
    "with np.printoptions(precision=2, suppress=True):\n",
    "    print(\"REFERENCE OPTIMAL POLICY\")\n",
    "    print(\"max(Q(s,*))\")\n",
    "    print(opt_q_table.max(-1).reshape(env.map.shape))\n",
    "    print()\n",
    "\n",
    "    print(\"Effective Policy\")\n",
    "    policy = TabularPolicy(env.observation_space, env.action_space, np.argmax(opt_q_table, axis=1))\n",
    "    print(env.render_policy(policy))\n",
    "    print()\n",
    "\n",
    "    print('F min eigenvalue')\n",
    "    F = compute_expected_F(mu_on_policy, policy, env.get_transition_matrix(), Phi)\n",
    "    print(np.real(np.linalg.eigvals(F)).min())"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-09-28T20:28:23.680322351Z",
     "start_time": "2023-09-28T20:28:15.728119366Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "outputs": [
    {
     "data": {
      "text/plain": "Training online:   0%|          | 0/200000 [00:00<?, ?it/s]",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "3484859ca7b545d1935135ece3357f82"
      }
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Step 0: TD-error = 0.0678, Q-error = 0.0242\n",
      " Q Values at state 11 = [ 0.05148803 -0.00033141 -0.06965119 -0.03513265], \n",
      " Policy dist at state 11 = [[0.27545068 0.20878327 0.26607946 0.2496866 ]], \n",
      "Policy update info: policy_loss = -1.374757, alpha = 0.999900, alpha_loss = 0.000000, policy_entropy = 1.384201, \n",
      "←→←S\n",
      "→#↑#\n",
      "↓↓↓↑\n",
      "↑→#G\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
      "\u001B[0;31mKeyboardInterrupt\u001B[0m                         Traceback (most recent call last)",
      "Cell \u001B[0;32mIn[11], line 4\u001B[0m\n\u001B[1;32m      2\u001B[0m linear_q \u001B[38;5;241m=\u001B[39m LinearQModel(env\u001B[38;5;241m.\u001B[39mobservation_space, env\u001B[38;5;241m.\u001B[39maction_space, Phi_torch, learning_rate\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m2e-2\u001B[39m)\n\u001B[1;32m      3\u001B[0m policy \u001B[38;5;241m=\u001B[39m LinearPolicy(env\u001B[38;5;241m.\u001B[39mobservation_space, env\u001B[38;5;241m.\u001B[39maction_space, Phi_torch, lr\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m1e-2\u001B[39m, use_automatic_entropy_tuning\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mTrue\u001B[39;00m, target_entropy\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m0.25\u001B[39m, alpha_lr\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m1e-4\u001B[39m)\n\u001B[0;32m----> 4\u001B[0m log \u001B[38;5;241m=\u001B[39m \u001B[43mtrain\u001B[49m\u001B[43m(\u001B[49m\u001B[43mlinear_q\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mpolicy\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43miters\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;28;43mint\u001B[39;49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m2e5\u001B[39;49m\u001B[43m)\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mbatch_size\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;241;43m32\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mprogress_bar\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;28;43;01mTrue\u001B[39;49;00m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mlog_freq\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;28;43mint\u001B[39;49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m1e3\u001B[39;49m\u001B[43m)\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mdataset_update_freq\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;28;43mint\u001B[39;49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m1e2\u001B[39;49m\u001B[43m)\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mpolicy_update_freq\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;241;43m1\u001B[39;49m\u001B[43m)\u001B[49m\n\u001B[1;32m      6\u001B[0m ax \u001B[38;5;241m=\u001B[39m plt\u001B[38;5;241m.\u001B[39mgca()\n\u001B[1;32m      7\u001B[0m ax\u001B[38;5;241m.\u001B[39mset_yscale(\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mlog\u001B[39m\u001B[38;5;124m'\u001B[39m)\n",
      "Cell \u001B[0;32mIn[9], line 91\u001B[0m, in \u001B[0;36mtrain\u001B[0;34m(q_model, policy, dataset, iters, batch_size, gamma, progress_bar, log_freq, dataset_update_freq, policy_update_freq, policy_update_iters, secondary_update_freq, secondary_update_iters, bc, device)\u001B[0m\n\u001B[1;32m     89\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m policy_update_freq \u001B[38;5;241m==\u001B[39m \u001B[38;5;241m1\u001B[39m \u001B[38;5;129;01mor\u001B[39;00m (step \u001B[38;5;241m+\u001B[39m \u001B[38;5;241m1\u001B[39m) \u001B[38;5;241m%\u001B[39m policy_update_freq \u001B[38;5;241m==\u001B[39m \u001B[38;5;241m0\u001B[39m:\n\u001B[1;32m     90\u001B[0m     \u001B[38;5;28;01mfor\u001B[39;00m step2 \u001B[38;5;129;01min\u001B[39;00m \u001B[38;5;28mrange\u001B[39m(policy_update_iters):\n\u001B[0;32m---> 91\u001B[0m         policy_update_info \u001B[38;5;241m=\u001B[39m \u001B[43mpolicy\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mupdate\u001B[49m\u001B[43m(\u001B[49m\u001B[43mq_model\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43ms\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43ma\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mr\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43msp\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mdone\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mgamma\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mgamma\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mbc\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mbc\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m     93\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m step \u001B[38;5;241m%\u001B[39m log_freq \u001B[38;5;241m==\u001B[39m \u001B[38;5;241m0\u001B[39m:\n\u001B[1;32m     94\u001B[0m     log[\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mstep\u001B[39m\u001B[38;5;124m'\u001B[39m]\u001B[38;5;241m.\u001B[39mappend(step)\n",
      "File \u001B[0;32m~/projects/POP-QL/small_scale/policies.py:130\u001B[0m, in \u001B[0;36mLinearPolicy.update\u001B[0;34m(self, q_model, state, action, reward, state_prime, done, gamma, bc)\u001B[0m\n\u001B[1;32m    128\u001B[0m policy_loss \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_policy_loss(q_model, state, action, reward, state_prime, done, alpha\u001B[38;5;241m=\u001B[39malpha, gamma\u001B[38;5;241m=\u001B[39mgamma, bc\u001B[38;5;241m=\u001B[39mbc)\n\u001B[1;32m    129\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39moptim\u001B[38;5;241m.\u001B[39mzero_grad()\n\u001B[0;32m--> 130\u001B[0m \u001B[43mpolicy_loss\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mbackward\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m    131\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39moptim\u001B[38;5;241m.\u001B[39mstep()\n\u001B[1;32m    133\u001B[0m info \u001B[38;5;241m=\u001B[39m {\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mpolicy_loss\u001B[39m\u001B[38;5;124m'\u001B[39m: policy_loss\u001B[38;5;241m.\u001B[39mdetach()\u001B[38;5;241m.\u001B[39mcpu()\u001B[38;5;241m.\u001B[39mnumpy()}\n",
      "File \u001B[0;32m~/micromamba/envs/JaxCQL/lib/python3.9/site-packages/torch/_tensor.py:487\u001B[0m, in \u001B[0;36mTensor.backward\u001B[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001B[0m\n\u001B[1;32m    477\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m has_torch_function_unary(\u001B[38;5;28mself\u001B[39m):\n\u001B[1;32m    478\u001B[0m     \u001B[38;5;28;01mreturn\u001B[39;00m handle_torch_function(\n\u001B[1;32m    479\u001B[0m         Tensor\u001B[38;5;241m.\u001B[39mbackward,\n\u001B[1;32m    480\u001B[0m         (\u001B[38;5;28mself\u001B[39m,),\n\u001B[0;32m   (...)\u001B[0m\n\u001B[1;32m    485\u001B[0m         inputs\u001B[38;5;241m=\u001B[39minputs,\n\u001B[1;32m    486\u001B[0m     )\n\u001B[0;32m--> 487\u001B[0m \u001B[43mtorch\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mautograd\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mbackward\u001B[49m\u001B[43m(\u001B[49m\n\u001B[1;32m    488\u001B[0m \u001B[43m    \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mgradient\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mretain_graph\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mcreate_graph\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43minputs\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43minputs\u001B[49m\n\u001B[1;32m    489\u001B[0m \u001B[43m\u001B[49m\u001B[43m)\u001B[49m\n",
      "File \u001B[0;32m~/micromamba/envs/JaxCQL/lib/python3.9/site-packages/torch/autograd/__init__.py:200\u001B[0m, in \u001B[0;36mbackward\u001B[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001B[0m\n\u001B[1;32m    195\u001B[0m     retain_graph \u001B[38;5;241m=\u001B[39m create_graph\n\u001B[1;32m    197\u001B[0m \u001B[38;5;66;03m# The reason we repeat same the comment below is that\u001B[39;00m\n\u001B[1;32m    198\u001B[0m \u001B[38;5;66;03m# some Python versions print out the first line of a multi-line function\u001B[39;00m\n\u001B[1;32m    199\u001B[0m \u001B[38;5;66;03m# calls in the traceback and some print out the last line\u001B[39;00m\n\u001B[0;32m--> 200\u001B[0m \u001B[43mVariable\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_execution_engine\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mrun_backward\u001B[49m\u001B[43m(\u001B[49m\u001B[43m  \u001B[49m\u001B[38;5;66;43;03m# Calls into the C++ engine to run the backward pass\u001B[39;49;00m\n\u001B[1;32m    201\u001B[0m \u001B[43m    \u001B[49m\u001B[43mtensors\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mgrad_tensors_\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mretain_graph\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mcreate_graph\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43minputs\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m    202\u001B[0m \u001B[43m    \u001B[49m\u001B[43mallow_unreachable\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;28;43;01mTrue\u001B[39;49;00m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43maccumulate_grad\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;28;43;01mTrue\u001B[39;49;00m\u001B[43m)\u001B[49m\n",
      "\u001B[0;31mKeyboardInterrupt\u001B[0m: "
     ]
    }
   ],
   "source": [
    "Phi_torch = torch.tensor(Phi, device=default_train_device, dtype=torch.float32)\n",
    "linear_q = LinearQModel(env.observation_space, env.action_space, Phi_torch, learning_rate=2e-2)\n",
    "policy = LinearPolicy(env.observation_space, env.action_space, Phi_torch, lr=1e-2, use_automatic_entropy_tuning=True, target_entropy=0.25, alpha_lr=1e-4)\n",
    "log = train(linear_q, policy, iters=int(2e5), batch_size=32, progress_bar=True, log_freq=int(1e3), dataset_update_freq=int(1e2), policy_update_freq=1)\n",
    "\n",
    "ax = plt.gca()\n",
    "ax.set_yscale('log')\n",
    "plt.plot(log['step'], log['TD-error'], label='TD-error')\n",
    "plt.xlabel('step')\n",
    "plt.ylabel('TD-error')\n",
    "plt.show()\n",
    "\n",
    "with np.printoptions(precision=2, suppress=True):\n",
    "    print(\"LINEAR ON-POLICY\")\n",
    "    print(\"max(Q(s,*))\")\n",
    "    print(get_q_table(linear_q).max(-1).reshape(env.map.shape))\n",
    "    print()\n",
    "\n",
    "    print(\"Q-error\")\n",
    "    print(compute_Q_error(env, on_policy_dataset, opt_q_table, get_q_table(linear_q)))\n",
    "    print()\n",
    "\n",
    "    print('F min eigenvalue')\n",
    "    F = compute_expected_F(get_policy_distribution(env, policy), policy, env.get_transition_matrix(), Phi)\n",
    "    print(np.real(np.linalg.eigvals(F)).min())\n",
    "\n",
    "    print(\"Effective Policy\")\n",
    "    print(env.render_policy(policy))"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-09-28T14:21:06.721941345Z",
     "start_time": "2023-09-28T14:20:54.172082777Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "torch.manual_seed(seed + 123)\n",
    "Phi_torch = torch.tensor(Phi, device=default_train_device, dtype=torch.float32)\n",
    "linear_q = LinearQModel(env.observation_space, env.action_space, Phi_torch, learning_rate=1e-3)\n",
    "policy = LinearPolicy(env.observation_space, env.action_space, Phi_torch, lr=1e-3, use_automatic_entropy_tuning=True, target_entropy=0.25, alpha_lr=1e-3)\n",
    "log = train(linear_q, policy, off_policy_dataset, iters=int(2e5), batch_size=32, progress_bar=True, log_freq=int(1e3), policy_update_freq=1)\n",
    "\n",
    "ax = plt.gca()\n",
    "ax.set_yscale('log')\n",
    "plt.plot(log['step'], log['TD-error'], label='TD-error')\n",
    "plt.xlabel('step')\n",
    "plt.ylabel('TD-error')\n",
    "plt.show()\n",
    "\n",
    "with np.printoptions(precision=2, suppress=True):\n",
    "    print(\"LINEAR OFF-POLICY\")\n",
    "    print(\"max(Q(s,*))\")\n",
    "    print(get_q_table(linear_q).max(-1).reshape(env.map.shape))\n",
    "    print()\n",
    "\n",
    "    print(\"Q-error\")\n",
    "    print(compute_Q_error(env, off_policy_dataset, opt_q_table, get_q_table(linear_q)))\n",
    "    print()\n",
    "\n",
    "    print('F min eigenvalue')\n",
    "    F = compute_expected_F(mu, policy, env.get_transition_matrix(), Phi)\n",
    "    print(np.real(np.linalg.eigvals(F)).min())\n",
    "\n",
    "    print(\"Effective Policy\")\n",
    "    print(env.render_policy(policy))"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "start_time": "2023-09-26T12:55:45.662160310Z"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "Great! We've established that off-policy training greatly increases error. This means that we are ready for:\n",
    "\n",
    "## Part 3: POP-Q-Learning Solves This Problem\n",
    "\n",
    "We show that our algorithm solves this problem by implementing it and observing the error:"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "## Create heatmap of the reweighting\n",
    "\n",
    "def heatmap(reweight):\n",
    "    fig, ax = plt.subplots()\n",
    "    im = ax.imshow(reweight.reshape(env.map.shape))\n",
    "\n",
    "    # Hide ticks\n",
    "    ax.set_xticks([])\n",
    "    ax.set_yticks([])\n",
    "\n",
    "    # Label cells\n",
    "    for i in range(env.map.shape[0]):\n",
    "        for j in range(env.map.shape[1]):\n",
    "            text = ax.text(j, i, f\"{reweight[i * env.map.shape[1] + j]:.2f}\",\n",
    "                           ha=\"center\", va=\"center\", color=\"w\")\n",
    "\n",
    "    # Create colorbar\n",
    "    cbar = ax.figure.colorbar(im, ax=ax)\n",
    "\n",
    "    plt.show()"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "start_time": "2023-09-26T12:55:45.662442787Z"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "### Solve for importance sampling weights exactly\n",
    "\n",
    "By computing the importance sampling weights exactly, we can construct a lower bound on the error of the linear Q-function approximation with POP-Q."
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "outputs": [
    {
     "data": {
      "text/plain": "Training offline:   0%|          | 0/200000 [00:00<?, ?it/s]",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "5917b885b26d471fa98cfb9182624e8c"
      }
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Step 0: TD-error = 0.0568, Q-error = 0.0503\n",
      " Q Values at state 11 = [ 0.09889679 -0.09164629  0.0328095   0.00244008], \n",
      " Policy dist at state 11 = [[0.26086292 0.23120742 0.26387045 0.24405923]], \n",
      "Policy update info: policy_loss = -1.397457, alpha = 0.999900, alpha_loss = 0.000000, policy_entropy = 1.384659, \n",
      "Secondary update info: F_mu_min_eig = -0.003050, F_q_min_eig = 0.000004, \n",
      "↑←↓S\n",
      "↑#↑#\n",
      "↓↓↑↓\n",
      "↓↑#G\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
      "\u001B[0;31mKeyboardInterrupt\u001B[0m                         Traceback (most recent call last)",
      "Cell \u001B[0;32mIn[33], line 17\u001B[0m\n\u001B[1;32m     15\u001B[0m policy \u001B[38;5;241m=\u001B[39m LinearPolicy(env\u001B[38;5;241m.\u001B[39mobservation_space, env\u001B[38;5;241m.\u001B[39maction_space, Phi_torch, lr\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m1e-2\u001B[39m, use_automatic_entropy_tuning\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mTrue\u001B[39;00m, target_entropy\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m0.25\u001B[39m, alpha_lr\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m1e-4\u001B[39m)\n\u001B[1;32m     16\u001B[0m is_linear_q \u001B[38;5;241m=\u001B[39m LinearExactImportanceSamplingQModel(env\u001B[38;5;241m.\u001B[39mobservation_space, env\u001B[38;5;241m.\u001B[39maction_space, Phi_torch, policy, dataset_occupancy(env, off_policy_dataset), learning_rate\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m2e-2\u001B[39m)\n\u001B[0;32m---> 17\u001B[0m log \u001B[38;5;241m=\u001B[39m \u001B[43mtrain\u001B[49m\u001B[43m(\u001B[49m\u001B[43mis_linear_q\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mpolicy\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43moff_policy_dataset\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43miters\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;28;43mint\u001B[39;49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m2e5\u001B[39;49m\u001B[43m)\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mbatch_size\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;241;43m32\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mprogress_bar\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;28;43;01mTrue\u001B[39;49;00m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mlog_freq\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;28;43mint\u001B[39;49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m1e3\u001B[39;49m\u001B[43m)\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mpolicy_update_freq\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;241;43m1\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43msecondary_update_freq\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;28;43mint\u001B[39;49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m5e1\u001B[39;49m\u001B[43m)\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m     19\u001B[0m ax \u001B[38;5;241m=\u001B[39m plt\u001B[38;5;241m.\u001B[39mgca()\n\u001B[1;32m     20\u001B[0m ax\u001B[38;5;241m.\u001B[39mset_yscale(\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mlog\u001B[39m\u001B[38;5;124m'\u001B[39m)\n",
      "Cell \u001B[0;32mIn[32], line 89\u001B[0m, in \u001B[0;36mtrain\u001B[0;34m(q_model, policy, dataset, iters, batch_size, gamma, progress_bar, log_freq, dataset_update_freq, policy_update_freq, policy_update_iters, secondary_update_freq, secondary_update_iters, bc, device)\u001B[0m\n\u001B[1;32m     87\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m policy_update_freq \u001B[38;5;241m==\u001B[39m \u001B[38;5;241m1\u001B[39m \u001B[38;5;129;01mor\u001B[39;00m (step \u001B[38;5;241m+\u001B[39m \u001B[38;5;241m1\u001B[39m) \u001B[38;5;241m%\u001B[39m policy_update_freq \u001B[38;5;241m==\u001B[39m \u001B[38;5;241m0\u001B[39m:\n\u001B[1;32m     88\u001B[0m     \u001B[38;5;28;01mfor\u001B[39;00m step2 \u001B[38;5;129;01min\u001B[39;00m \u001B[38;5;28mrange\u001B[39m(policy_update_iters):\n\u001B[0;32m---> 89\u001B[0m         policy_update_info \u001B[38;5;241m=\u001B[39m \u001B[43mpolicy\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mupdate\u001B[49m\u001B[43m(\u001B[49m\u001B[43mq_model\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43ms\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43ma\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mr\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43msp\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mdone\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mgamma\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mgamma\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mbc\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mbc\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m     91\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m step \u001B[38;5;241m%\u001B[39m log_freq \u001B[38;5;241m==\u001B[39m \u001B[38;5;241m0\u001B[39m:\n\u001B[1;32m     92\u001B[0m     log[\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mstep\u001B[39m\u001B[38;5;124m'\u001B[39m]\u001B[38;5;241m.\u001B[39mappend(step)\n",
      "File \u001B[0;32m~/projects/POP-QL/small_scale/policies.py:130\u001B[0m, in \u001B[0;36mLinearPolicy.update\u001B[0;34m(self, q_model, state, action, reward, state_prime, done, gamma, bc)\u001B[0m\n\u001B[1;32m    128\u001B[0m policy_loss \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_policy_loss(q_model, state, action, reward, state_prime, done, alpha\u001B[38;5;241m=\u001B[39malpha, gamma\u001B[38;5;241m=\u001B[39mgamma, bc\u001B[38;5;241m=\u001B[39mbc)\n\u001B[1;32m    129\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39moptim\u001B[38;5;241m.\u001B[39mzero_grad()\n\u001B[0;32m--> 130\u001B[0m \u001B[43mpolicy_loss\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mbackward\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m    131\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39moptim\u001B[38;5;241m.\u001B[39mstep()\n\u001B[1;32m    133\u001B[0m info \u001B[38;5;241m=\u001B[39m {\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mpolicy_loss\u001B[39m\u001B[38;5;124m'\u001B[39m: policy_loss\u001B[38;5;241m.\u001B[39mdetach()\u001B[38;5;241m.\u001B[39mcpu()\u001B[38;5;241m.\u001B[39mnumpy()}\n",
      "File \u001B[0;32m~/micromamba/envs/JaxCQL/lib/python3.9/site-packages/torch/_tensor.py:487\u001B[0m, in \u001B[0;36mTensor.backward\u001B[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001B[0m\n\u001B[1;32m    477\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m has_torch_function_unary(\u001B[38;5;28mself\u001B[39m):\n\u001B[1;32m    478\u001B[0m     \u001B[38;5;28;01mreturn\u001B[39;00m handle_torch_function(\n\u001B[1;32m    479\u001B[0m         Tensor\u001B[38;5;241m.\u001B[39mbackward,\n\u001B[1;32m    480\u001B[0m         (\u001B[38;5;28mself\u001B[39m,),\n\u001B[0;32m   (...)\u001B[0m\n\u001B[1;32m    485\u001B[0m         inputs\u001B[38;5;241m=\u001B[39minputs,\n\u001B[1;32m    486\u001B[0m     )\n\u001B[0;32m--> 487\u001B[0m \u001B[43mtorch\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mautograd\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mbackward\u001B[49m\u001B[43m(\u001B[49m\n\u001B[1;32m    488\u001B[0m \u001B[43m    \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mgradient\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mretain_graph\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mcreate_graph\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43minputs\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43minputs\u001B[49m\n\u001B[1;32m    489\u001B[0m \u001B[43m\u001B[49m\u001B[43m)\u001B[49m\n",
      "File \u001B[0;32m~/micromamba/envs/JaxCQL/lib/python3.9/site-packages/torch/autograd/__init__.py:200\u001B[0m, in \u001B[0;36mbackward\u001B[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001B[0m\n\u001B[1;32m    195\u001B[0m     retain_graph \u001B[38;5;241m=\u001B[39m create_graph\n\u001B[1;32m    197\u001B[0m \u001B[38;5;66;03m# The reason we repeat same the comment below is that\u001B[39;00m\n\u001B[1;32m    198\u001B[0m \u001B[38;5;66;03m# some Python versions print out the first line of a multi-line function\u001B[39;00m\n\u001B[1;32m    199\u001B[0m \u001B[38;5;66;03m# calls in the traceback and some print out the last line\u001B[39;00m\n\u001B[0;32m--> 200\u001B[0m \u001B[43mVariable\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_execution_engine\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mrun_backward\u001B[49m\u001B[43m(\u001B[49m\u001B[43m  \u001B[49m\u001B[38;5;66;43;03m# Calls into the C++ engine to run the backward pass\u001B[39;49;00m\n\u001B[1;32m    201\u001B[0m \u001B[43m    \u001B[49m\u001B[43mtensors\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mgrad_tensors_\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mretain_graph\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mcreate_graph\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43minputs\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m    202\u001B[0m \u001B[43m    \u001B[49m\u001B[43mallow_unreachable\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;28;43;01mTrue\u001B[39;49;00m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43maccumulate_grad\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;28;43;01mTrue\u001B[39;49;00m\u001B[43m)\u001B[49m\n",
      "\u001B[0;31mKeyboardInterrupt\u001B[0m: "
     ]
    }
   ],
   "source": [
    "class LinearExactImportanceSamplingQModel(LinearQModel):\n",
    "    def __init__(self, observation_space, action_space, Phi, policy, sampling_dist, learning_rate=1e-1):\n",
    "        super().__init__(observation_space, action_space, Phi, learning_rate=learning_rate)\n",
    "        self.policy = policy\n",
    "        self.sampling_dist = sampling_dist\n",
    "        if isinstance(self.sampling_dist, np.ndarray):\n",
    "            self.sampling_dist = torch.tensor(self.sampling_dist, dtype=torch.float32, device=self.Phi.device)\n",
    "\n",
    "    def update_reweight(self):\n",
    "        policy_distribution = torch.tensor(get_policy_distribution(env, self.policy), dtype=torch.float32, device=self.Phi.device)\n",
    "        self.reweight = torch.clip(policy_distribution / (self.sampling_dist + 1e-6), 1e-2, 1e2)\n",
    "\n",
    "torch.manual_seed(0)\n",
    "Phi_torch = torch.tensor(Phi, device=default_train_device, dtype=torch.float32)\n",
    "policy = LinearPolicy(env.observation_space, env.action_space, Phi_torch, lr=1e-2, use_automatic_entropy_tuning=True, target_entropy=0.25, alpha_lr=1e-4)\n",
    "is_linear_q = LinearExactImportanceSamplingQModel(env.observation_space, env.action_space, Phi_torch, policy, dataset_occupancy(env, off_policy_dataset), learning_rate=2e-2)\n",
    "log = train(is_linear_q, policy, off_policy_dataset, iters=int(2e5), batch_size=32, progress_bar=True, log_freq=int(1e3), policy_update_freq=1, secondary_update_freq=int(5e1))\n",
    "\n",
    "ax = plt.gca()\n",
    "ax.set_yscale('log')\n",
    "plt.plot(log['step'], log['TD-error'], label='TD-error')\n",
    "plt.xlabel('step')\n",
    "plt.ylabel('TD-error')\n",
    "plt.show()\n",
    "\n",
    "with np.printoptions(precision=2, suppress=True):\n",
    "    print(\"LINEAR IMPORTANCE SAMPLING\")\n",
    "    print(\"max(Q(s,*))\")\n",
    "    print(get_q_table(is_linear_q).max(-1).reshape(env.map.shape))\n",
    "    print()\n",
    "\n",
    "    print(\"Q-error\")\n",
    "    print(compute_Q_error(env, off_policy_dataset, opt_q_table, get_q_table(is_linear_q)))\n",
    "    print()\n",
    "\n",
    "    print(\"Effective Policy\")\n",
    "    print(env.render_policy(policy))"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-09-26T13:17:20.741087852Z",
     "start_time": "2023-09-26T13:17:08.776479007Z"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "### Perform POP-Q learning\n",
    "\n",
    "Next, we perform POP-Q learning with a tabular g-model."
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "outputs": [
    {
     "data": {
      "text/plain": "Training offline:   0%|          | 0/200000 [00:00<?, ?it/s]",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "824bd5a8c2f9419b8968f355ab607e9d"
      }
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "ename": "RuntimeError",
     "evalue": "The size of tensor a (4) must match the size of tensor b (64) at non-singleton dimension 1",
     "output_type": "error",
     "traceback": [
      "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
      "\u001B[0;31mRuntimeError\u001B[0m                              Traceback (most recent call last)",
      "Cell \u001B[0;32mIn[13], line 10\u001B[0m\n\u001B[1;32m      8\u001B[0m g_model \u001B[38;5;241m=\u001B[39m LinearGModel(env\u001B[38;5;241m.\u001B[39mobservation_space, env\u001B[38;5;241m.\u001B[39maction_space, Phi_torch, lr\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m1e-3\u001B[39m)\n\u001B[1;32m      9\u001B[0m pop_q \u001B[38;5;241m=\u001B[39m POPQ(env\u001B[38;5;241m.\u001B[39mobservation_space, env\u001B[38;5;241m.\u001B[39maction_space, Phi_torch, g_model, env\u001B[38;5;241m.\u001B[39mget_terminal_matrix(), rank\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m6\u001B[39m, q_lr\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m1e-2\u001B[39m, dual_lr\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m1e-4\u001B[39m)\n\u001B[0;32m---> 10\u001B[0m log \u001B[38;5;241m=\u001B[39m \u001B[43mtrain\u001B[49m\u001B[43m(\u001B[49m\u001B[43mpop_q\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mpolicy\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43moff_policy_dataset\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43miters\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;28;43mint\u001B[39;49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m2e5\u001B[39;49m\u001B[43m)\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mbatch_size\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;241;43m32\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mprogress_bar\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;28;43;01mTrue\u001B[39;49;00m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mlog_freq\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;28;43mint\u001B[39;49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m2e2\u001B[39;49m\u001B[43m)\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mpolicy_update_freq\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;241;43m1\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43msecondary_update_freq\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;28;43mint\u001B[39;49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m5e0\u001B[39;49m\u001B[43m)\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43msecondary_update_iters\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;28;43mint\u001B[39;49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m1e1\u001B[39;49m\u001B[43m)\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m     12\u001B[0m ax \u001B[38;5;241m=\u001B[39m plt\u001B[38;5;241m.\u001B[39mgca()\n\u001B[1;32m     13\u001B[0m ax\u001B[38;5;241m.\u001B[39mset_yscale(\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mlog\u001B[39m\u001B[38;5;124m'\u001B[39m)\n",
      "Cell \u001B[0;32mIn[12], line 74\u001B[0m, in \u001B[0;36mtrain\u001B[0;34m(q_model, policy, dataset, iters, batch_size, gamma, progress_bar, log_freq, dataset_update_freq, policy_update_freq, policy_update_iters, secondary_update_freq, secondary_update_iters, bc, device)\u001B[0m\n\u001B[1;32m     71\u001B[0m     s, a, r, sp, ap, done \u001B[38;5;241m=\u001B[39m sample_batch(torch_dataset)\n\u001B[1;32m     72\u001B[0m     secondary_update_info \u001B[38;5;241m=\u001B[39m q_model\u001B[38;5;241m.\u001B[39msecondary_update(s, a, r, sp, ap, done, gamma\u001B[38;5;241m=\u001B[39mgamma)\n\u001B[0;32m---> 74\u001B[0m \u001B[43mq_model\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mupdate_reweight\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m     75\u001B[0m F_s \u001B[38;5;241m=\u001B[39m compute_F_s(policy, env\u001B[38;5;241m.\u001B[39mget_transition_matrix(), Phi)\n\u001B[1;32m     76\u001B[0m F_mu \u001B[38;5;241m=\u001B[39m (data_dist[:, \u001B[38;5;28;01mNone\u001B[39;00m, \u001B[38;5;28;01mNone\u001B[39;00m] \u001B[38;5;241m*\u001B[39m F_s)\u001B[38;5;241m.\u001B[39msum(\u001B[38;5;241m0\u001B[39m)\n",
      "File \u001B[0;32m~/projects/POP-QL/small_scale/q_models.py:247\u001B[0m, in \u001B[0;36mPOPQ.update_reweight\u001B[0;34m(self)\u001B[0m\n\u001B[1;32m    246\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mupdate_reweight\u001B[39m(\u001B[38;5;28mself\u001B[39m):\n\u001B[0;32m--> 247\u001B[0m     reweight \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_pop_values\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mstate_vec\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43maction_vec\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mdone_vec\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m    248\u001B[0m     reweight \u001B[38;5;241m/\u001B[39m\u001B[38;5;241m=\u001B[39m torch\u001B[38;5;241m.\u001B[39mmean(reweight)\n\u001B[1;32m    249\u001B[0m     reweight \u001B[38;5;241m=\u001B[39m torch\u001B[38;5;241m.\u001B[39mclip(reweight, \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mreweight_min, \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mreweight_max)\n",
      "File \u001B[0;32m~/projects/POP-QL/small_scale/q_models.py:205\u001B[0m, in \u001B[0;36mPOPQ._pop_values\u001B[0;34m(self, state, action, done, state_prime, action_prime)\u001B[0m\n\u001B[1;32m    201\u001B[0m m_b \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mb_mag \u001B[38;5;241m*\u001B[39m (\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mPhi \u001B[38;5;241m@\u001B[39m (\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mb \u001B[38;5;241m/\u001B[39m b_2_norm))[x]\n\u001B[1;32m    202\u001B[0m g \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mg_model(state, action)\n\u001B[1;32m    204\u001B[0m reweight \u001B[38;5;241m=\u001B[39m torch\u001B[38;5;241m.\u001B[39mexp(\n\u001B[0;32m--> 205\u001B[0m     (m_a \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39m \u001B[38;5;241m2\u001B[39m \u001B[38;5;241m+\u001B[39m m_b \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39m \u001B[38;5;241m2\u001B[39m)\u001B[38;5;241m.\u001B[39msum(\u001B[38;5;241m-\u001B[39m\u001B[38;5;241m1\u001B[39m) \u001B[38;5;241m+\u001B[39m \u001B[38;5;241;43m2\u001B[39;49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43m \u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m1\u001B[39;49m\u001B[43m \u001B[49m\u001B[38;5;241;43m-\u001B[39;49m\u001B[43m \u001B[49m\u001B[43mdone\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mfloat\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\u001B[43m)\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43ma_mag\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mb_mag\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43m \u001B[49m\u001B[43mg\u001B[49m\n\u001B[1;32m    206\u001B[0m     \u001B[38;5;241m-\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mpop_margin \u001B[38;5;241m*\u001B[39m (\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39ma \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39m \u001B[38;5;241m2\u001B[39m \u001B[38;5;241m+\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mb \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39m \u001B[38;5;241m2\u001B[39m)\u001B[38;5;241m.\u001B[39msum()\n\u001B[1;32m    207\u001B[0m )\n\u001B[1;32m    209\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m state_prime \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m \u001B[38;5;129;01mand\u001B[39;00m action_prime \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m:\n\u001B[1;32m    210\u001B[0m     \u001B[38;5;28;01mreturn\u001B[39;00m reweight\n",
      "\u001B[0;31mRuntimeError\u001B[0m: The size of tensor a (4) must match the size of tensor b (64) at non-singleton dimension 1"
     ]
    }
   ],
   "source": [
    "torch.manual_seed(seed + 123)\n",
    "Phi_torch = torch.tensor(Phi, device=default_train_device, dtype=torch.float32)\n",
    "policy = POPLinearPolicy(\n",
    "    env.observation_space, env.action_space, Phi_torch, lr=1e-4,\n",
    "    use_automatic_entropy_tuning=True, target_entropy=0.25, alpha_lr=1e-4,\n",
    "    use_automatic_kl_tuning=False, beta_multiplier=0,\n",
    ")\n",
    "g_model = LinearGModel(env.observation_space, env.action_space, Phi_torch, lr=1e-3)\n",
    "pop_q = POPQ(env.observation_space, env.action_space, Phi_torch, g_model, env.get_terminal_matrix(), rank=6, q_lr=1e-2, dual_lr=1e-4)\n",
    "log = train(pop_q, policy, off_policy_dataset, iters=int(2e5), batch_size=32, progress_bar=True, log_freq=int(2e2), policy_update_freq=1, secondary_update_freq=int(5e0), secondary_update_iters=int(1e1))\n",
    "\n",
    "ax = plt.gca()\n",
    "ax.set_yscale('log')\n",
    "plt.plot(log['step'], log['TD-error'], label='TD-error')\n",
    "plt.xlabel('step')\n",
    "plt.ylabel('TD-error')\n",
    "plt.show()\n",
    "\n",
    "with np.printoptions(precision=2, suppress=True):\n",
    "    print(\"POP-Q with LINEAR g-model\")\n",
    "    print(\"max(Q(s,*))\")\n",
    "    print(get_q_table(pop_q).max(-1).reshape(env.map.shape))\n",
    "    print()\n",
    "\n",
    "    print(\"Q-error\")\n",
    "    print(compute_Q_error(env, off_policy_dataset, opt_q_table, get_q_table(pop_q)))\n",
    "    print()\n",
    "\n",
    "    print(\"Effective Policy\")\n",
    "    print(env.render_policy(policy))"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-09-16T17:41:07.814131013Z",
     "start_time": "2023-09-16T17:41:04.488606819Z"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "### Why does Linear POP-Q fail?"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "↓←↓S\n",
      "↓#↓#\n",
      "↓←←←\n",
      "↓←#G\n"
     ]
    }
   ],
   "source": [
    "# Construct an adversarial policy\n",
    "\n",
    "adv_policy = frozen_lake_policy_from_string(\n",
    "    [\n",
    "        \"↓←↓←\",\n",
    "        \"↓↑↓↑\",\n",
    "        \"↓←←←\",\n",
    "        \"↓←↑↑\",\n",
    "    ],\n",
    "    epsilon=explore_epsilon,\n",
    ")\n",
    "print(env.render_policy(adv_policy))"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-09-15T02:08:15.310377181Z",
     "start_time": "2023-09-15T02:08:15.281466230Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[1.14361795e-03 7.09957409e-04 8.28867334e-03 9.07142646e-03]\n",
      " [6.03778691e-03 1.93033010e-03 7.03346111e-03 1.16387830e-03]\n",
      " [6.73097664e-02 1.37949388e-02 6.93441682e-03 5.21945352e-04]\n",
      " [6.70431849e-01 5.15009410e-02 4.09047505e-03 3.65361747e-05]]\n"
     ]
    },
    {
     "ename": "NameError",
     "evalue": "name 'heatmap' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
      "\u001B[0;31mNameError\u001B[0m                                 Traceback (most recent call last)",
      "Cell \u001B[0;32mIn[13], line 6\u001B[0m\n\u001B[1;32m      4\u001B[0m \u001B[38;5;28mprint\u001B[39m(adv_policy_distribution\u001B[38;5;241m.\u001B[39mreshape([num_states, num_actions])[np\u001B[38;5;241m.\u001B[39marange(num_states), adv_action_table]\u001B[38;5;241m.\u001B[39mreshape([\u001B[38;5;241m4\u001B[39m, \u001B[38;5;241m4\u001B[39m]))\n\u001B[1;32m      5\u001B[0m importance_weights \u001B[38;5;241m=\u001B[39m adv_policy_distribution \u001B[38;5;241m/\u001B[39m mu\n\u001B[0;32m----> 6\u001B[0m \u001B[43mheatmap\u001B[49m(importance_weights\u001B[38;5;241m.\u001B[39mreshape((env\u001B[38;5;241m.\u001B[39mobservation_space\u001B[38;5;241m.\u001B[39mn, env\u001B[38;5;241m.\u001B[39maction_space\u001B[38;5;241m.\u001B[39mn))[np\u001B[38;5;241m.\u001B[39marange(env\u001B[38;5;241m.\u001B[39mobservation_space\u001B[38;5;241m.\u001B[39mn), adv_action_table])\n",
      "\u001B[0;31mNameError\u001B[0m: name 'heatmap' is not defined"
     ]
    }
   ],
   "source": [
    "# Compute the occupancy of the adversarial policy\n",
    "adv_policy_distribution = get_policy_distribution(env, adv_policy)\n",
    "adv_action_table = np.argmax(adv_policy.dist(np.arange(num_states)), -1)\n",
    "print(adv_policy_distribution.reshape([num_states, num_actions])[np.arange(num_states), adv_action_table].reshape([4, 4]))\n",
    "importance_weights = adv_policy_distribution / mu\n",
    "heatmap(importance_weights.reshape((env.observation_space.n, env.action_space.n))[np.arange(env.observation_space.n), adv_action_table])"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-09-15T02:08:17.311832507Z",
     "start_time": "2023-09-15T02:08:17.273063048Z"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "### Perform POP-Q learning with policy regularization\n",
    "\n",
    "Next, we perform POP-Q learning with a tabular g-model."
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "outputs": [
    {
     "data": {
      "text/plain": "Sampling Trajectories:   0%|          | 0/1000000 [00:00<?, ?it/s]",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "cd61ca0f97354706bb181c8147196554"
      }
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "On-policy dataset distribution:\n",
      "[[  7854  34828 210399 242446]\n",
      " [  6631  27888 138044  62770]\n",
      " [  8329  22033 103521  70875]\n",
      " [  2593   5265  14938  41586]]\n"
     ]
    }
   ],
   "source": [
    "eval_policy = EpsilonGreedyPolicy(opt_policy, env.action_space, epsilon=0.5)\n",
    "on_policy_dataset = generate_dataset(env, eval_policy, n=DATASET_SIZE, progress_bar=True)\n",
    "mu_on_policy = get_policy_distribution(env, eval_policy)\n",
    "print('On-policy dataset distribution:')\n",
    "print(visit_map(env, on_policy_dataset))\n",
    "\n",
    "def merge_datasets(dataset1, dataset2, p=0.5):\n",
    "    n = min(len(dataset1.s), len(dataset2.s))\n",
    "    idx1 = np.random.choice(len(dataset1.s), size=int(n * (1 - p)), replace=False)\n",
    "    idx2 = np.random.choice(len(dataset2.s), size=int(n * p), replace=False)\n",
    "\n",
    "    return FrozenLakeDataset(\n",
    "        s=np.concatenate([dataset1.s[idx1], dataset2.s[idx2]]),\n",
    "        a=np.concatenate([dataset1.a[idx1], dataset2.a[idx2]]),\n",
    "        r=np.concatenate([dataset1.r[idx1], dataset2.r[idx2]]),\n",
    "        sp=np.concatenate([dataset1.sp[idx1], dataset2.sp[idx2]]),\n",
    "        done=np.concatenate([dataset1.done[idx1], dataset2.done[idx2]])\n",
    "    )\n",
    "\n",
    "k = 63\n",
    "Phi = np.random.uniform(-1, 1, (env.observation_space.n * env.action_space.n, k))\n",
    "Phi /= np.linalg.norm(Phi, keepdims=True, axis=-1)\n",
    "\n",
    "explore_epsilon = 0.2"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-09-15T02:08:47.748763473Z",
     "start_time": "2023-09-15T02:08:17.839722895Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "60\n"
     ]
    },
    {
     "data": {
      "text/plain": "Training offline:   0%|          | 0/200000 [00:00<?, ?it/s]",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "5d2da38b719e4e32a1852a0bbecdac7d"
      }
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Step 0: TD-error = 0.0672, Q-error = 0.0519\n",
      " Q Values at state 2 = [ 0.05292317 -0.03512138  0.01848823  0.00703838], \n",
      " Policy dist at state 2 = [[0.24958214 0.2456169  0.27004865 0.23475233]], \n",
      " Reweight at state 2 = [1.1135256 1.1020331 1.1802144 1.0609591], \n",
      "Policy update info: policy_loss = -1.388063, pop_loss = 0.002302, alpha = 0.999900, alpha_loss = 0.000000, policy_entropy = 1.384402, \n",
      "Secondary update info: dual_loss = 0.094596, g_loss = 0.008233, a_2_norm = 0.999000, b_2_norm = 0.999000, F_mu_min_eig = -0.002493, F_q_min_eig = -0.003501, dual_obj = 1.116098, \n",
      "→↓↓S\n",
      "←#→#\n",
      "←↓↓↑\n",
      "→→#G\n",
      "Step 2000: TD-error = 0.0674, Q-error = 0.0689\n",
      " Q Values at state 2 = [-0.04226029 -0.05663896 -0.06702234 -0.04814854], \n",
      " Policy dist at state 2 = [[0.24495092 0.24630079 0.24766637 0.26108193]], \n",
      " Reweight at state 2 = [1.0744754  1.0563371  1.0829003  0.93888897], \n",
      "Policy update info: policy_loss = -1.094047, pop_loss = 0.042543, alpha = 0.818642, alpha_loss = -0.176861, policy_entropy = 1.384264, \n",
      "Secondary update info: dual_loss = -0.015604, g_loss = 0.000407, a_2_norm = 1.300500, b_2_norm = 1.039230, F_mu_min_eig = -0.002111, F_q_min_eig = -0.001275, dual_obj = 1.008093, \n",
      "→↑←S\n",
      "↓#↑#\n",
      "→→↓↓\n",
      "↓→#G\n",
      "Step 4000: TD-error = 0.0704, Q-error = 0.1134\n",
      " Q Values at state 2 = [-0.03557047 -0.0715676  -0.05912631 -0.07464728], \n",
      " Policy dist at state 2 = [[0.2469993  0.24343476 0.24698628 0.26257968]], \n",
      " Reweight at state 2 = [1.0858897 1.1177233 1.1386368 0.9145193], \n",
      "Policy update info: policy_loss = -0.840724, pop_loss = 0.089903, alpha = 0.670345, alpha_loss = -0.353571, policy_entropy = 1.384233, \n",
      "Secondary update info: dual_loss = 0.003483, g_loss = 0.000100, a_2_norm = 1.738148, b_2_norm = 1.426950, F_mu_min_eig = -0.001894, F_q_min_eig = -0.001059, dual_obj = 1.003333, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "←→→↓\n",
      "↓↑#G\n",
      "Step 6000: TD-error = 0.0711, Q-error = 0.1726\n",
      " Q Values at state 2 = [-0.02813738 -0.07341188 -0.04970355 -0.06904051], \n",
      " Policy dist at state 2 = [[0.24418758 0.24138784 0.24421065 0.27021393]], \n",
      " Reweight at state 2 = [1.127368  1.1043792 1.1992092 0.9016168], \n",
      "Policy update info: policy_loss = -0.677219, pop_loss = 0.110959, alpha = 0.549007, alpha_loss = -0.526379, policy_entropy = 1.377966, \n",
      "Secondary update info: dual_loss = -0.014558, g_loss = 0.000066, a_2_norm = 2.184845, b_2_norm = 1.862186, F_mu_min_eig = -0.001681, F_q_min_eig = -0.000850, dual_obj = 1.002533, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "↓↑#G\n",
      "Step 8000: TD-error = 0.0723, Q-error = 0.2917\n",
      " Q Values at state 2 = [-0.01151555 -0.07979279 -0.03313836 -0.07856226], \n",
      " Policy dist at state 2 = [[0.24811037 0.23522171 0.24329807 0.27336988]], \n",
      " Reweight at state 2 = [1.1202718 1.117304  1.1872766 0.8911423], \n",
      "Policy update info: policy_loss = -0.524771, pop_loss = 0.165146, alpha = 0.449710, alpha_loss = -0.693421, policy_entropy = 1.367804, \n",
      "Secondary update info: dual_loss = -0.039097, g_loss = 0.000045, a_2_norm = 2.623831, b_2_norm = 2.345867, F_mu_min_eig = -0.001522, F_q_min_eig = -0.000688, dual_obj = 1.000392, \n",
      "↓←←S\n",
      "↓#↓#\n",
      "↓→→↓\n",
      "↓↑#G\n",
      "Step 10000: TD-error = 0.0735, Q-error = 0.4502\n",
      " Q Values at state 2 = [-0.01370227 -0.0611678  -0.02713603 -0.06101268], \n",
      " Policy dist at state 2 = [[0.24076414 0.22745696 0.24700682 0.28477207]], \n",
      " Reweight at state 2 = [1.1574657 1.1441613 1.2984346 0.8710343], \n",
      "Policy update info: policy_loss = -0.406396, pop_loss = 0.187646, alpha = 0.368458, alpha_loss = -0.866936, policy_entropy = 1.368386, \n",
      "Secondary update info: dual_loss = -0.018595, g_loss = 0.000052, a_2_norm = 3.015812, b_2_norm = 2.845588, F_mu_min_eig = -0.001263, F_q_min_eig = -0.000505, dual_obj = 0.997794, \n",
      "↓←←S\n",
      "↓#↓#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "Step 12000: TD-error = 0.0731, Q-error = 0.7394\n",
      " Q Values at state 2 = [-0.00601931 -0.08589998 -0.00095229 -0.06566235], \n",
      " Policy dist at state 2 = [[0.24150951 0.21457587 0.25014916 0.29376543]], \n",
      " Reweight at state 2 = [1.2233903 1.1053549 1.313463  0.8773817], \n",
      "Policy update info: policy_loss = -0.297502, pop_loss = 0.231181, alpha = 0.301980, alpha_loss = -1.027212, policy_entropy = 1.357945, \n",
      "Secondary update info: dual_loss = -0.024619, g_loss = 0.000022, a_2_norm = 3.247569, b_2_norm = 3.249876, F_mu_min_eig = -0.001080, F_q_min_eig = -0.000423, dual_obj = 1.004491, \n",
      "↓←←S\n",
      "↓#↓#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "Step 14000: TD-error = 0.0733, Q-error = 1.2099\n",
      " Q Values at state 2 = [ 0.04316439 -0.08592066  0.0230673  -0.06282678], \n",
      " Policy dist at state 2 = [[0.24965271 0.20636515 0.24707294 0.29690924]], \n",
      " Reweight at state 2 = [1.167137  1.1798804 1.3782631 0.8887168], \n",
      "Policy update info: policy_loss = -0.231607, pop_loss = 0.225321, alpha = 0.247627, alpha_loss = -1.165469, policy_entropy = 1.335023, \n",
      "Secondary update info: dual_loss = 0.004477, g_loss = 0.000027, a_2_norm = 3.435503, b_2_norm = 3.513789, F_mu_min_eig = -0.000959, F_q_min_eig = -0.000397, dual_obj = 1.007606, \n",
      "↓←←S\n",
      "↓#↓#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "Step 16000: TD-error = 0.0716, Q-error = 2.0155\n",
      " Q Values at state 2 = [ 0.05270653 -0.09580575  0.01838578 -0.04735515], \n",
      " Policy dist at state 2 = [[0.2557805  0.18642314 0.2465551  0.31124127]], \n",
      " Reweight at state 2 = [1.2205743 1.1813052 1.3191093 0.8966297], \n",
      "Policy update info: policy_loss = -0.190816, pop_loss = 0.245604, alpha = 0.203206, alpha_loss = -1.326393, policy_entropy = 1.332411, \n",
      "Secondary update info: dual_loss = 0.004629, g_loss = 0.000027, a_2_norm = 3.615943, b_2_norm = 3.766491, F_mu_min_eig = -0.000797, F_q_min_eig = -0.000328, dual_obj = 1.007152, \n",
      "↓←←S\n",
      "↓#↓#\n",
      "→→→↓\n",
      "→↑#G\n",
      "Step 18000: TD-error = 0.0711, Q-error = 3.0037\n",
      " Q Values at state 2 = [ 0.08686788 -0.09303267  0.03077451 -0.02457744], \n",
      " Policy dist at state 2 = [[0.26484543 0.17112282 0.2406053  0.32342649]], \n",
      " Reweight at state 2 = [1.238372  1.1711004 1.2649252 0.8854164], \n",
      "Policy update info: policy_loss = -0.130822, pop_loss = 0.243479, alpha = 0.166857, alpha_loss = -1.458166, policy_entropy = 1.314383, \n",
      "Secondary update info: dual_loss = 0.020463, g_loss = 0.000021, a_2_norm = 3.709278, b_2_norm = 3.925380, F_mu_min_eig = -0.000718, F_q_min_eig = -0.000284, dual_obj = 1.010420, \n",
      "↓←←S\n",
      "↓#↓#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "Step 20000: TD-error = 0.0715, Q-error = 4.7123\n",
      " Q Values at state 2 = [ 0.09349003 -0.10003158  0.05968314 -0.01916081], \n",
      " Policy dist at state 2 = [[0.26762533 0.14928503 0.24682476 0.33626485]], \n",
      " Reweight at state 2 = [1.2310163 1.1223621 1.2524254 0.9023896], \n",
      "Policy update info: policy_loss = -0.173180, pop_loss = 0.206887, alpha = 0.137179, alpha_loss = -1.527768, policy_entropy = 1.269124, \n",
      "Secondary update info: dual_loss = -0.013240, g_loss = 0.000017, a_2_norm = 3.746574, b_2_norm = 3.961258, F_mu_min_eig = -0.000661, F_q_min_eig = -0.000239, dual_obj = 1.010247, \n",
      "↓←←S\n",
      "↓#↓#\n",
      "→→→↓\n",
      "→↑#G\n",
      "Step 22000: TD-error = 0.0718, Q-error = 6.4557\n",
      " Q Values at state 2 = [ 0.16594592 -0.10263604  0.10409592  0.02502811], \n",
      " Policy dist at state 2 = [[0.283883   0.1314146  0.22903223 0.35567012]], \n",
      " Reweight at state 2 = [1.2181588  1.0917004  1.3261055  0.87999743], \n",
      "Policy update info: policy_loss = -0.143800, pop_loss = 0.241550, alpha = 0.112877, alpha_loss = -1.629131, policy_entropy = 1.246843, \n",
      "Secondary update info: dual_loss = -0.014536, g_loss = 0.000016, a_2_norm = 3.699224, b_2_norm = 4.013743, F_mu_min_eig = -0.000626, F_q_min_eig = -0.000357, dual_obj = 1.012366, \n",
      "↓←←S\n",
      "↓#↓#\n",
      "→→→↓\n",
      "→↑#G\n",
      "Step 24000: TD-error = 0.0732, Q-error = 8.7980\n",
      " Q Values at state 2 = [ 0.17650923 -0.10536849  0.12462192  0.05487415], \n",
      " Policy dist at state 2 = [[0.29292506 0.11048465 0.22854711 0.36804312]], \n",
      " Reweight at state 2 = [1.2560421 1.0735435 1.1930633 0.9502919], \n",
      "Policy update info: policy_loss = -0.193182, pop_loss = 0.215661, alpha = 0.092997, alpha_loss = -1.585309, policy_entropy = 1.167473, \n",
      "Secondary update info: dual_loss = -0.016748, g_loss = 0.000012, a_2_norm = 3.627347, b_2_norm = 4.027000, F_mu_min_eig = -0.000666, F_q_min_eig = -0.000250, dual_obj = 1.011971, \n",
      "↓←←S\n",
      "↓#↓#\n",
      "→→→↓\n",
      "→↑#G\n",
      "Step 26000: TD-error = 0.0777, Q-error = 11.5340\n",
      " Q Values at state 2 = [ 0.24506897 -0.10782644  0.15480153  0.10215309], \n",
      " Policy dist at state 2 = [[0.30496347 0.08719708 0.22191353 0.3859259 ]], \n",
      " Reweight at state 2 = [1.1990933  1.1094449  1.1945344  0.96005994], \n",
      "Policy update info: policy_loss = -0.263920, pop_loss = 0.189100, alpha = 0.076715, alpha_loss = -1.515876, policy_entropy = 1.090396, \n",
      "Secondary update info: dual_loss = 0.013947, g_loss = 0.000013, a_2_norm = 3.630518, b_2_norm = 4.124693, F_mu_min_eig = -0.000787, F_q_min_eig = -0.000299, dual_obj = 1.008874, \n",
      "↓←←S\n",
      "↓#↓#\n",
      "→→→↓\n",
      "→↑#G\n",
      "Step 28000: TD-error = 0.0888, Q-error = 13.4167\n",
      " Q Values at state 2 = [ 0.30458784 -0.08659786  0.2118012   0.15120685], \n",
      " Policy dist at state 2 = [[0.2944527  0.0757634  0.20227066 0.4275133 ]], \n",
      " Reweight at state 2 = [1.2535614 1.0784769 1.1674331 0.9694326], \n",
      "Policy update info: policy_loss = -0.214892, pop_loss = 0.242747, alpha = 0.063342, alpha_loss = -1.540776, policy_entropy = 1.058432, \n",
      "Secondary update info: dual_loss = -0.011220, g_loss = 0.000016, a_2_norm = 3.768024, b_2_norm = 4.311220, F_mu_min_eig = -0.000633, F_q_min_eig = -0.000283, dual_obj = 1.012612, \n",
      "↓←←S\n",
      "↓#↓#\n",
      "→→→↓\n",
      "→↑#G\n",
      "Step 30000: TD-error = 0.0958, Q-error = 14.5682\n",
      " Q Values at state 2 = [ 0.31944662 -0.05678731  0.21163239  0.21136737], \n",
      " Policy dist at state 2 = [[0.28419787 0.06538778 0.1819163  0.46849808]], \n",
      " Reweight at state 2 = [1.2358717 1.0663711 1.09913   1.0456576], \n",
      "Policy update info: policy_loss = -0.276000, pop_loss = 0.250391, alpha = 0.052309, alpha_loss = -1.592075, policy_entropy = 1.039595, \n",
      "Secondary update info: dual_loss = -0.008764, g_loss = 0.000015, a_2_norm = 3.856735, b_2_norm = 4.561563, F_mu_min_eig = -0.000620, F_q_min_eig = -0.000265, dual_obj = 1.009519, \n",
      "↓←←S\n",
      "↓#↓#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "Step 32000: TD-error = 0.0968, Q-error = 15.3612\n",
      " Q Values at state 2 = [ 0.3583842  -0.06383479  0.24983662  0.21784306], \n",
      " Policy dist at state 2 = [[0.27528057 0.05651838 0.16130526 0.5068958 ]], \n",
      " Reweight at state 2 = [1.2530767  1.0721618  1.0525509  0.99343854], \n",
      "Policy update info: policy_loss = -0.227964, pop_loss = 0.318857, alpha = 0.043212, alpha_loss = -1.550677, policy_entropy = 0.993603, \n",
      "Secondary update info: dual_loss = -0.047918, g_loss = 0.000011, a_2_norm = 3.987822, b_2_norm = 4.849348, F_mu_min_eig = -0.000663, F_q_min_eig = -0.000263, dual_obj = 1.013005, \n",
      "↓←←S\n",
      "↓#↓#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "Step 34000: TD-error = 0.1101, Q-error = 15.8773\n",
      " Q Values at state 2 = [ 0.46953046 -0.08049333  0.23111174  0.2798876 ], \n",
      " Policy dist at state 2 = [[0.26833707 0.04705541 0.1383635  0.546244  ]], \n",
      " Reweight at state 2 = [1.277939  0.9975698 0.9217186 1.0382216], \n",
      "Policy update info: policy_loss = -0.281645, pop_loss = 0.322500, alpha = 0.035728, alpha_loss = -1.497494, policy_entropy = 0.949464, \n",
      "Secondary update info: dual_loss = 0.010199, g_loss = 0.000010, a_2_norm = 4.156576, b_2_norm = 5.289792, F_mu_min_eig = -0.000751, F_q_min_eig = -0.000322, dual_obj = 1.012359, \n",
      "↓←←S\n",
      "↓#↓#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "Step 36000: TD-error = 0.1160, Q-error = 16.0573\n",
      " Q Values at state 2 = [ 0.44594717 -0.07867199  0.2997075   0.3146431 ], \n",
      " Policy dist at state 2 = [[0.25687572 0.03954525 0.12124586 0.58233315]], \n",
      " Reweight at state 2 = [1.2481482 1.0205773 0.9172254 1.0308964], \n",
      "Policy update info: policy_loss = -0.292475, pop_loss = 0.391360, alpha = 0.029564, alpha_loss = -1.387643, policy_entropy = 0.894093, \n",
      "Secondary update info: dual_loss = -0.063889, g_loss = 0.000031, a_2_norm = 4.394328, b_2_norm = 5.779626, F_mu_min_eig = -0.000851, F_q_min_eig = -0.000266, dual_obj = 1.012696, \n",
      "↓←←S\n",
      "↓#↓#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "Step 38000: TD-error = 0.1320, Q-error = 15.7852\n",
      " Q Values at state 2 = [ 0.5506476  -0.09226215  0.30460292  0.37559614], \n",
      " Policy dist at state 2 = [[0.25505847 0.03338952 0.10717499 0.6043771 ]], \n",
      " Reweight at state 2 = [1.3317937  1.0144665  0.77700377 1.0868939 ], \n",
      "Policy update info: policy_loss = -0.344151, pop_loss = 0.390232, alpha = 0.024415, alpha_loss = -1.401211, policy_entropy = 0.877433, \n",
      "Secondary update info: dual_loss = 0.038739, g_loss = 0.000017, a_2_norm = 4.650141, b_2_norm = 6.226761, F_mu_min_eig = -0.000951, F_q_min_eig = -0.000512, dual_obj = 1.011972, \n",
      "↓←←S\n",
      "↓#↓#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "Step 40000: TD-error = 0.1281, Q-error = 15.3529\n",
      " Q Values at state 2 = [ 0.5300673  -0.0310204   0.29102346  0.39168292], \n",
      " Policy dist at state 2 = [[0.24319774 0.03051227 0.0956242  0.6306658 ]], \n",
      " Reweight at state 2 = [1.3076551  0.98199815 0.7743329  1.0727555 ], \n",
      "Policy update info: policy_loss = -0.229398, pop_loss = 0.457020, alpha = 0.020135, alpha_loss = -1.395833, policy_entropy = 0.857431, \n",
      "Secondary update info: dual_loss = -0.012838, g_loss = 0.000017, a_2_norm = 4.883189, b_2_norm = 6.700201, F_mu_min_eig = -0.000999, F_q_min_eig = -0.000216, dual_obj = 1.024432, \n",
      "↓←←S\n",
      "↓#↓#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "Step 42000: TD-error = 0.1477, Q-error = 15.2610\n",
      " Q Values at state 2 = [ 0.6177785  -0.06187296  0.3151139   0.44817525], \n",
      " Policy dist at state 2 = [[0.24250312 0.02715197 0.08481669 0.6455283 ]], \n",
      " Reweight at state 2 = [1.290292   1.0027213  0.70532024 1.1255658 ], \n",
      "Policy update info: policy_loss = -0.394771, pop_loss = 0.600973, alpha = 0.016609, alpha_loss = -1.034201, policy_entropy = 0.752385, \n",
      "Secondary update info: dual_loss = 0.028082, g_loss = 0.000011, a_2_norm = 5.129105, b_2_norm = 7.170110, F_mu_min_eig = -0.001061, F_q_min_eig = -0.000194, dual_obj = 1.016301, \n",
      "↓←←S\n",
      "↓#↓#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "Step 44000: TD-error = 0.1522, Q-error = 14.7919\n",
      " Q Values at state 2 = [ 0.6236142  -0.05329931  0.3421739   0.4600379 ], \n",
      " Policy dist at state 2 = [[0.23633698 0.02549169 0.07955772 0.65861356]], \n",
      " Reweight at state 2 = [1.4070494  1.0135202  0.70799404 1.1580898 ], \n",
      "Policy update info: policy_loss = -0.215604, pop_loss = 0.735441, alpha = 0.013673, alpha_loss = -1.212486, policy_entropy = 0.782482, \n",
      "Secondary update info: dual_loss = -0.000660, g_loss = 0.000015, a_2_norm = 5.453345, b_2_norm = 7.719177, F_mu_min_eig = -0.001116, F_q_min_eig = -0.000350, dual_obj = 1.011493, \n",
      "↓←←S\n",
      "↓#↓#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "Step 46000: TD-error = 0.1612, Q-error = 14.2719\n",
      " Q Values at state 2 = [ 0.67768395 -0.07693553  0.3333525   0.4864654 ], \n",
      " Policy dist at state 2 = [[0.2404445  0.02403108 0.07134284 0.66418165]], \n",
      " Reweight at state 2 = [1.4298638  0.9430574  0.58644146 1.140635  ], \n",
      "Policy update info: policy_loss = -0.377389, pop_loss = 0.730701, alpha = 0.011241, alpha_loss = -1.489602, policy_entropy = 0.831903, \n",
      "Secondary update info: dual_loss = -0.024189, g_loss = 0.000019, a_2_norm = 5.783545, b_2_norm = 8.211271, F_mu_min_eig = -0.001143, F_q_min_eig = -0.000288, dual_obj = 1.020766, \n",
      "↓←←S\n",
      "↓#↓#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "Step 48000: TD-error = 0.1728, Q-error = 13.7869\n",
      " Q Values at state 2 = [ 0.741658   -0.09493017  0.37909538  0.53859985], \n",
      " Policy dist at state 2 = [[0.24299476 0.02193417 0.06553657 0.66953444]], \n",
      " Reweight at state 2 = [1.2604911  0.90675044 0.6951354  1.1163223 ], \n",
      "Policy update info: policy_loss = -0.338009, pop_loss = 0.835188, alpha = 0.009254, alpha_loss = -1.211265, policy_entropy = 0.758676, \n",
      "Secondary update info: dual_loss = -0.068103, g_loss = 0.000008, a_2_norm = 6.090206, b_2_norm = 8.697987, F_mu_min_eig = -0.001233, F_q_min_eig = -0.000297, dual_obj = 1.025528, \n",
      "↓←←S\n",
      "↓#↓#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "Step 50000: TD-error = 0.1763, Q-error = 13.2153\n",
      " Q Values at state 2 = [ 0.7619994 -0.0723815  0.340158   0.5568129], \n",
      " Policy dist at state 2 = [[0.24561517 0.02144261 0.06162507 0.6713172 ]], \n",
      " Reweight at state 2 = [1.3035475  0.98420435 0.56730336 1.1306393 ], \n",
      "Policy update info: policy_loss = -0.229209, pop_loss = 1.161189, alpha = 0.007595, alpha_loss = -1.137336, policy_entropy = 0.733052, \n",
      "Secondary update info: dual_loss = 0.004022, g_loss = 0.000011, a_2_norm = 6.543647, b_2_norm = 9.338019, F_mu_min_eig = -0.001247, F_q_min_eig = -0.000295, dual_obj = 1.023921, \n",
      "↓←←S\n",
      "↓#↓#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "Step 52000: TD-error = 0.1864, Q-error = 12.7674\n",
      " Q Values at state 2 = [ 0.79733765 -0.05089521  0.35316256  0.5816495 ], \n",
      " Policy dist at state 2 = [[0.22881922 0.0218262  0.05794232 0.69141227]], \n",
      " Reweight at state 2 = [1.4000727 0.8134816 0.6844901 1.1979002], \n",
      "Policy update info: policy_loss = -0.243147, pop_loss = 1.049856, alpha = 0.006239, alpha_loss = -1.300357, policy_entropy = 0.756137, \n",
      "Secondary update info: dual_loss = -0.152956, g_loss = 0.000008, a_2_norm = 6.916139, b_2_norm = 9.992876, F_mu_min_eig = -0.001291, F_q_min_eig = -0.000339, dual_obj = 1.044201, \n",
      "↓←←S\n",
      "↓#↓#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "Step 54000: TD-error = 0.1900, Q-error = 12.4448\n",
      " Q Values at state 2 = [ 0.7937789  -0.03768015  0.29799548  0.58473635], \n",
      " Policy dist at state 2 = [[0.21120706 0.02198656 0.05164276 0.71516365]], \n",
      " Reweight at state 2 = [1.2943939 0.91209   0.5766986 1.1182424], \n",
      "Policy update info: policy_loss = -0.150933, pop_loss = 1.032733, alpha = 0.005125, alpha_loss = -1.818067, policy_entropy = 0.844751, \n",
      "Secondary update info: dual_loss = 0.080392, g_loss = 0.000005, a_2_norm = 7.322421, b_2_norm = 10.572872, F_mu_min_eig = -0.001300, F_q_min_eig = -0.000701, dual_obj = 1.041657, \n",
      "↓←←S\n",
      "↓#↓#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "Step 56000: TD-error = 0.1849, Q-error = 11.8346\n",
      " Q Values at state 2 = [ 0.82047266 -0.04602957  0.32608747  0.5926752 ], \n",
      " Policy dist at state 2 = [[0.20643967 0.0246012  0.05436095 0.71459824]], \n",
      " Reweight at state 2 = [1.4266267 0.9117388 0.6751134 1.263648 ], \n",
      "Policy update info: policy_loss = -0.067604, pop_loss = 1.314583, alpha = 0.004188, alpha_loss = -1.931737, policy_entropy = 0.852797, \n",
      "Secondary update info: dual_loss = 0.109845, g_loss = 0.000011, a_2_norm = 7.720595, b_2_norm = 11.179291, F_mu_min_eig = -0.001270, F_q_min_eig = -0.000234, dual_obj = 1.086253, \n",
      "↓←←S\n",
      "↓#↓#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "Step 58000: TD-error = 0.1983, Q-error = 12.0999\n",
      " Q Values at state 2 = [ 0.86997133 -0.10233772  0.3485897   0.61549747], \n",
      " Policy dist at state 2 = [[0.22460817 0.02561457 0.05725133 0.6925259 ]], \n",
      " Reweight at state 2 = [1.3034847 0.8148291 0.6050161 1.1218373], \n",
      "Policy update info: policy_loss = -0.073434, pop_loss = 1.550645, alpha = 0.003423, alpha_loss = -1.415376, policy_entropy = 0.749316, \n",
      "Secondary update info: dual_loss = -0.070229, g_loss = 0.000005, a_2_norm = 8.036544, b_2_norm = 11.651432, F_mu_min_eig = -0.001347, F_q_min_eig = -0.000821, dual_obj = 1.075968, \n",
      "↓←←S\n",
      "↓#↓#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "Step 60000: TD-error = 0.1987, Q-error = 12.1885\n",
      " Q Values at state 2 = [ 0.83265966 -0.08133209  0.34113473  0.6433887 ], \n",
      " Policy dist at state 2 = [[0.22396669 0.02609765 0.05585787 0.69407773]], \n",
      " Reweight at state 2 = [1.4497777 1.0638493 0.49377   1.1433914], \n",
      "Policy update info: policy_loss = -0.052356, pop_loss = 1.468960, alpha = 0.002808, alpha_loss = -2.088706, policy_entropy = 0.855512, \n",
      "Secondary update info: dual_loss = -0.097374, g_loss = 0.000003, a_2_norm = 8.204698, b_2_norm = 12.006824, F_mu_min_eig = -0.001384, F_q_min_eig = -0.000451, dual_obj = 1.081726, \n",
      "↓←←S\n",
      "↓#↓#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "Step 62000: TD-error = 0.2032, Q-error = 11.7000\n",
      " Q Values at state 2 = [ 0.88356745 -0.0702554   0.3618372   0.65645456], \n",
      " Policy dist at state 2 = [[0.21715623 0.02893064 0.0572181  0.696695  ]], \n",
      " Reweight at state 2 = [1.6692141 0.8140859 0.556014  1.1743066], \n",
      "Policy update info: policy_loss = -0.080166, pop_loss = 1.514632, alpha = 0.002295, alpha_loss = -2.042166, policy_entropy = 0.836041, \n",
      "Secondary update info: dual_loss = -0.126709, g_loss = 0.000005, a_2_norm = 8.400396, b_2_norm = 12.365919, F_mu_min_eig = -0.001296, F_q_min_eig = -0.000206, dual_obj = 1.090089, \n",
      "↓←←S\n",
      "↓#↓#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "Step 64000: TD-error = 0.2141, Q-error = 11.3944\n",
      " Q Values at state 2 = [ 0.8646692  -0.09750092  0.44977915  0.66337377], \n",
      " Policy dist at state 2 = [[0.21607593 0.03035808 0.05486995 0.698696  ]], \n",
      " Reweight at state 2 = [1.4155244  0.7662257  0.57469165 1.2239615 ], \n",
      "Policy update info: policy_loss = -0.122364, pop_loss = 1.421267, alpha = 0.001871, alpha_loss = -2.002531, policy_entropy = 0.818805, \n",
      "Secondary update info: dual_loss = -0.069389, g_loss = 0.000004, a_2_norm = 8.641004, b_2_norm = 12.665238, F_mu_min_eig = -0.001241, F_q_min_eig = -0.001109, dual_obj = 1.089904, \n",
      "↓←←S\n",
      "↓#↓#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "Step 66000: TD-error = 0.2142, Q-error = 10.6845\n",
      " Q Values at state 2 = [ 0.86167634 -0.08822274  0.46202165  0.6421101 ], \n",
      " Policy dist at state 2 = [[0.22171731 0.03265879 0.05584293 0.689781  ]], \n",
      " Reweight at state 2 = [1.2831981  0.8455965  0.66444707 1.2347006 ], \n",
      "Policy update info: policy_loss = -0.027932, pop_loss = 1.621608, alpha = 0.001524, alpha_loss = -2.294648, policy_entropy = 0.853784, \n",
      "Secondary update info: dual_loss = 0.134615, g_loss = 0.000006, a_2_norm = 8.752163, b_2_norm = 12.855207, F_mu_min_eig = -0.001131, F_q_min_eig = -0.000499, dual_obj = 1.090128, \n",
      "↓←←S\n",
      "↓#↓#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "Step 68000: TD-error = 0.2265, Q-error = 9.5801\n",
      " Q Values at state 2 = [ 0.8531048  -0.08095241  0.517396    0.64677405], \n",
      " Policy dist at state 2 = [[0.22073774 0.03340329 0.05854293 0.687316  ]], \n",
      " Reweight at state 2 = [1.5454509 0.6710935 0.6010828 1.2032874], \n",
      "Policy update info: policy_loss = -0.070217, pop_loss = 1.419916, alpha = 0.001241, alpha_loss = -2.471275, policy_entropy = 0.869281, \n",
      "Secondary update info: dual_loss = 0.135972, g_loss = 0.000004, a_2_norm = 8.868203, b_2_norm = 12.901651, F_mu_min_eig = -0.001001, F_q_min_eig = -0.000642, dual_obj = 1.090641, \n",
      "↓←←S\n",
      "↓#↓#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "Step 70000: TD-error = 0.2297, Q-error = 8.8739\n",
      " Q Values at state 2 = [ 0.82771194 -0.09152472  0.5539625   0.67384744], \n",
      " Policy dist at state 2 = [[0.22148596 0.03339825 0.06232518 0.68279064]], \n",
      " Reweight at state 2 = [1.3328035  0.78872454 0.52446675 1.1983393 ], \n",
      "Policy update info: policy_loss = -0.156835, pop_loss = 1.373601, alpha = 0.001011, alpha_loss = -3.067446, policy_entropy = 0.944792, \n",
      "Secondary update info: dual_loss = 0.254934, g_loss = 0.000011, a_2_norm = 8.986444, b_2_norm = 13.032495, F_mu_min_eig = -0.000933, F_q_min_eig = -0.000710, dual_obj = 1.076564, \n",
      "↓←←S\n",
      "↓#↓#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "Step 72000: TD-error = 0.2246, Q-error = 8.7110\n",
      " Q Values at state 2 = [ 0.8487493  -0.08030927  0.54597664  0.69226336], \n",
      " Policy dist at state 2 = [[0.21300563 0.03099911 0.06002721 0.69596803]], \n",
      " Reweight at state 2 = [1.223648  1.0321392 0.5643264 1.2997736], \n",
      "Policy update info: policy_loss = -0.054568, pop_loss = 1.594860, alpha = 0.000827, alpha_loss = -2.656822, policy_entropy = 0.874303, \n",
      "Secondary update info: dual_loss = 0.053056, g_loss = 0.000003, a_2_norm = 9.115547, b_2_norm = 13.217407, F_mu_min_eig = -0.000992, F_q_min_eig = -0.000327, dual_obj = 1.089916, \n",
      "↓←←S\n",
      "↓#↓#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "Step 74000: TD-error = 0.2144, Q-error = 8.6676\n",
      " Q Values at state 2 = [ 0.8364893  -0.02423906  0.48813874  0.689189  ], \n",
      " Policy dist at state 2 = [[0.2049981  0.03079132 0.05486352 0.709347  ]], \n",
      " Reweight at state 2 = [1.4237939 0.9780152 0.5470658 1.2231727], \n",
      "Policy update info: policy_loss = -0.141936, pop_loss = 1.419919, alpha = 0.000677, alpha_loss = -2.961924, policy_entropy = 0.905901, \n",
      "Secondary update info: dual_loss = 0.093963, g_loss = 0.000002, a_2_norm = 9.147700, b_2_norm = 13.411889, F_mu_min_eig = -0.001049, F_q_min_eig = -0.000545, dual_obj = 1.112490, \n",
      "↓←←S\n",
      "↓#↓#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "Step 76000: TD-error = 0.2129, Q-error = 8.7282\n",
      " Q Values at state 2 = [ 0.867857   -0.04793143  0.5304866   0.65651685], \n",
      " Policy dist at state 2 = [[0.21719489 0.03149974 0.05363519 0.69767016]], \n",
      " Reweight at state 2 = [1.1918765 0.8498845 0.6373221 1.0655813], \n",
      "Policy update info: policy_loss = -0.052149, pop_loss = 1.453468, alpha = 0.000555, alpha_loss = -3.159130, policy_entropy = 0.921394, \n",
      "Secondary update info: dual_loss = 0.171942, g_loss = 0.000008, a_2_norm = 9.226649, b_2_norm = 13.631821, F_mu_min_eig = -0.001133, F_q_min_eig = -0.000616, dual_obj = 1.107949, \n",
      "↓←←S\n",
      "↓#↓#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "Step 78000: TD-error = 0.2053, Q-error = 8.8224\n",
      " Q Values at state 2 = [ 0.831921   -0.09170794  0.4542863   0.66083574], \n",
      " Policy dist at state 2 = [[0.21592453 0.03231713 0.05020732 0.7015511 ]], \n",
      " Reweight at state 2 = [1.4968156  1.1079733  0.48797745 1.2314353 ], \n",
      "Policy update info: policy_loss = -0.043311, pop_loss = 1.410882, alpha = 0.000454, alpha_loss = -2.837229, policy_entropy = 0.868652, \n",
      "Secondary update info: dual_loss = 0.173377, g_loss = 0.000006, a_2_norm = 9.370369, b_2_norm = 13.966276, F_mu_min_eig = -0.001146, F_q_min_eig = -0.000966, dual_obj = 1.099044, \n",
      "↓←←S\n",
      "↓#↓#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "Step 80000: TD-error = 0.2105, Q-error = 9.0694\n",
      " Q Values at state 2 = [ 0.8942325  -0.09859502  0.46869314  0.65466344], \n",
      " Policy dist at state 2 = [[0.21539615 0.03190995 0.04608615 0.70660776]], \n",
      " Reweight at state 2 = [1.2563653 0.8870447 0.7433867 1.1981547], \n",
      "Policy update info: policy_loss = -0.041845, pop_loss = 1.659818, alpha = 0.000373, alpha_loss = -3.205382, policy_entropy = 0.906113, \n",
      "Secondary update info: dual_loss = 0.013532, g_loss = 0.000005, a_2_norm = 9.537712, b_2_norm = 14.322846, F_mu_min_eig = -0.001213, F_q_min_eig = -0.000547, dual_obj = 1.094366, \n",
      "↓←←S\n",
      "↓#↓#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "Step 82000: TD-error = 0.2013, Q-error = 9.0967\n",
      " Q Values at state 2 = [ 0.84656906 -0.04998672  0.4698454   0.6470071 ], \n",
      " Policy dist at state 2 = [[0.23179203 0.03195555 0.04581497 0.69043744]], \n",
      " Reweight at state 2 = [1.4051083 0.9537449 0.7731819 1.1948068], \n",
      "Policy update info: policy_loss = -0.059078, pop_loss = 1.635543, alpha = 0.000306, alpha_loss = -2.872888, policy_entropy = 0.855045, \n",
      "Secondary update info: dual_loss = -0.106693, g_loss = 0.000007, a_2_norm = 9.685993, b_2_norm = 14.584094, F_mu_min_eig = -0.001254, F_q_min_eig = -0.000783, dual_obj = 1.113115, \n",
      "↓←←S\n",
      "↓#↓#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "Step 84000: TD-error = 0.2046, Q-error = 8.8924\n",
      " Q Values at state 2 = [ 0.8481185  -0.05331481  0.43401378  0.64635795], \n",
      " Policy dist at state 2 = [[0.22375189 0.03192384 0.0464493  0.69787496]], \n",
      " Reweight at state 2 = [1.4554404 0.6298012 0.7295616 1.2858144], \n",
      "Policy update info: policy_loss = 0.078311, pop_loss = 1.695038, alpha = 0.000251, alpha_loss = -3.086019, policy_entropy = 0.872210, \n",
      "Secondary update info: dual_loss = 0.118022, g_loss = 0.000007, a_2_norm = 9.831852, b_2_norm = 14.844179, F_mu_min_eig = -0.001253, F_q_min_eig = -0.000440, dual_obj = 1.151354, \n",
      "↓←←S\n",
      "↓#↓#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "Step 86000: TD-error = 0.2076, Q-error = 8.9822\n",
      " Q Values at state 2 = [ 0.8204638  -0.1135236   0.5052556   0.66945904], \n",
      " Policy dist at state 2 = [[0.22102743 0.03242769 0.04887155 0.69767326]], \n",
      " Reweight at state 2 = [1.7317606  0.76252794 0.46679115 1.2780502 ], \n",
      "Policy update info: policy_loss = -0.035734, pop_loss = 1.541684, alpha = 0.000206, alpha_loss = -3.212658, policy_entropy = 0.878421, \n",
      "Secondary update info: dual_loss = -0.155190, g_loss = 0.000005, a_2_norm = 9.991911, b_2_norm = 15.196174, F_mu_min_eig = -0.001273, F_q_min_eig = -0.000366, dual_obj = 1.146948, \n",
      "↓←←S\n",
      "↓#↓#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "Step 88000: TD-error = 0.2019, Q-error = 9.1152\n",
      " Q Values at state 2 = [ 0.87901247 -0.15167427  0.42367652  0.65308887], \n",
      " Policy dist at state 2 = [[0.22525787 0.03351393 0.04825816 0.69297004]], \n",
      " Reweight at state 2 = [1.5337023  0.87395275 0.5419064  1.292882  ], \n",
      "Policy update info: policy_loss = -0.086969, pop_loss = 1.595341, alpha = 0.000169, alpha_loss = -3.264850, policy_entropy = 0.875778, \n",
      "Secondary update info: dual_loss = -0.535831, g_loss = 0.000006, a_2_norm = 10.231095, b_2_norm = 15.577420, F_mu_min_eig = -0.001321, F_q_min_eig = -0.001421, dual_obj = 1.183263, \n",
      "↓←←S\n",
      "↓#↓#\n",
      "↓→→↓\n",
      "→↑#G\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
      "\u001B[0;31mKeyboardInterrupt\u001B[0m                         Traceback (most recent call last)",
      "Cell \u001B[0;32mIn[13], line 30\u001B[0m\n\u001B[1;32m     28\u001B[0m \u001B[38;5;66;03m# pop_q = POPQ(env.observation_space, env.action_space, Phi_torch, g_model, env.get_terminal_matrix().flatten(), rank=4, q_lr=1e-3, dual_lr=1e-2, pop_margin=-1e-3)\u001B[39;00m\n\u001B[1;32m     29\u001B[0m pop_q \u001B[38;5;241m=\u001B[39m POPQ(env\u001B[38;5;241m.\u001B[39mobservation_space, env\u001B[38;5;241m.\u001B[39maction_space, Phi_torch, g_model, env\u001B[38;5;241m.\u001B[39mget_terminal_matrix()\u001B[38;5;241m.\u001B[39mflatten(), rank\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m4\u001B[39m, q_lr\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m1e-3\u001B[39m, dual_lr\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m1e-3\u001B[39m, pop_gamma\u001B[38;5;241m=\u001B[39mnp\u001B[38;5;241m.\u001B[39msqrt(\u001B[38;5;241m0.99\u001B[39m))\n\u001B[0;32m---> 30\u001B[0m log \u001B[38;5;241m=\u001B[39m \u001B[43mtrain\u001B[49m\u001B[43m(\u001B[49m\u001B[43mpop_q\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mpolicy\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43moff_policy_dataset\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43miters\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;28;43mint\u001B[39;49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m2e5\u001B[39;49m\u001B[43m)\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mbatch_size\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;241;43m32\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mprogress_bar\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;28;43;01mTrue\u001B[39;49;00m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mlog_freq\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;28;43mint\u001B[39;49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m2e3\u001B[39;49m\u001B[43m)\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mpolicy_update_freq\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;241;43m1\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43msecondary_update_freq\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;241;43m1\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43msecondary_update_iters\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;241;43m1\u001B[39;49m\u001B[43m)\u001B[49m\n\u001B[1;32m     32\u001B[0m ax \u001B[38;5;241m=\u001B[39m plt\u001B[38;5;241m.\u001B[39mgca()\n\u001B[1;32m     33\u001B[0m ax\u001B[38;5;241m.\u001B[39mset_yscale(\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mlog\u001B[39m\u001B[38;5;124m'\u001B[39m)\n",
      "Cell \u001B[0;32mIn[11], line 91\u001B[0m, in \u001B[0;36mtrain\u001B[0;34m(q_model, policy, dataset, iters, batch_size, gamma, progress_bar, log_freq, dataset_update_freq, policy_update_freq, policy_update_iters, secondary_update_freq, secondary_update_iters, bc, device)\u001B[0m\n\u001B[1;32m     89\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m policy_update_freq \u001B[38;5;241m==\u001B[39m \u001B[38;5;241m1\u001B[39m \u001B[38;5;129;01mor\u001B[39;00m (step \u001B[38;5;241m+\u001B[39m \u001B[38;5;241m1\u001B[39m) \u001B[38;5;241m%\u001B[39m policy_update_freq \u001B[38;5;241m==\u001B[39m \u001B[38;5;241m0\u001B[39m:\n\u001B[1;32m     90\u001B[0m     \u001B[38;5;28;01mfor\u001B[39;00m step2 \u001B[38;5;129;01min\u001B[39;00m \u001B[38;5;28mrange\u001B[39m(policy_update_iters):\n\u001B[0;32m---> 91\u001B[0m         policy_update_info \u001B[38;5;241m=\u001B[39m \u001B[43mpolicy\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mupdate\u001B[49m\u001B[43m(\u001B[49m\u001B[43mq_model\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43ms\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43ma\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mr\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43msp\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mdone\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mgamma\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mgamma\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mbc\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mbc\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m     93\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m step \u001B[38;5;241m%\u001B[39m log_freq \u001B[38;5;241m==\u001B[39m \u001B[38;5;241m0\u001B[39m:\n\u001B[1;32m     94\u001B[0m     log[\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mstep\u001B[39m\u001B[38;5;124m'\u001B[39m]\u001B[38;5;241m.\u001B[39mappend(step)\n",
      "File \u001B[0;32m~/projects/POP-QL/small_scale/policies.py:177\u001B[0m, in \u001B[0;36mPOPLinearPolicy.update\u001B[0;34m(self, q_model, state, action, reward, state_prime, done, gamma, bc)\u001B[0m\n\u001B[1;32m    174\u001B[0m x \u001B[38;5;241m=\u001B[39m state \u001B[38;5;241m*\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mnum_actions \u001B[38;5;241m+\u001B[39m action\n\u001B[1;32m    175\u001B[0m reweight \u001B[38;5;241m=\u001B[39m q_model\u001B[38;5;241m.\u001B[39mreweight[x]\n\u001B[0;32m--> 177\u001B[0m alpha, alpha_loss, alpha_info \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mupdate_alpha\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mforward\u001B[49m\u001B[43m(\u001B[49m\u001B[43mstate\u001B[49m\u001B[43m)\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m    178\u001B[0m beta, beta_loss, beta_info \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mupdate_beta(reweight)\n\u001B[1;32m    180\u001B[0m policy_loss \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_policy_loss(q_model, state, action, reward, state_prime, done, gamma\u001B[38;5;241m=\u001B[39mgamma, alpha\u001B[38;5;241m=\u001B[39malpha)\n",
      "File \u001B[0;32m~/projects/POP-QL/small_scale/policies.py:116\u001B[0m, in \u001B[0;36mLinearPolicy.update_alpha\u001B[0;34m(self, policy_dist)\u001B[0m\n\u001B[1;32m    110\u001B[0m     \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39malpha_optim\u001B[38;5;241m.\u001B[39mstep()\n\u001B[1;32m    112\u001B[0m     alpha \u001B[38;5;241m=\u001B[39m torch\u001B[38;5;241m.\u001B[39mexp(\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mlog_alpha) \u001B[38;5;241m*\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39malpha_multiplier\n\u001B[1;32m    113\u001B[0m     info \u001B[38;5;241m=\u001B[39m {\n\u001B[1;32m    114\u001B[0m         \u001B[38;5;124m'\u001B[39m\u001B[38;5;124malpha\u001B[39m\u001B[38;5;124m'\u001B[39m: alpha\u001B[38;5;241m.\u001B[39mdetach()\u001B[38;5;241m.\u001B[39mcpu()\u001B[38;5;241m.\u001B[39mnumpy(),\n\u001B[1;32m    115\u001B[0m         \u001B[38;5;124m'\u001B[39m\u001B[38;5;124malpha_loss\u001B[39m\u001B[38;5;124m'\u001B[39m: alpha_loss\u001B[38;5;241m.\u001B[39mdetach()\u001B[38;5;241m.\u001B[39mcpu()\u001B[38;5;241m.\u001B[39mnumpy(),\n\u001B[0;32m--> 116\u001B[0m         \u001B[38;5;124m'\u001B[39m\u001B[38;5;124mpolicy_entropy\u001B[39m\u001B[38;5;124m'\u001B[39m: \u001B[43mpolicy_entropy\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mmean\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mdetach\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mcpu\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\u001B[38;5;241m.\u001B[39mnumpy(),\n\u001B[1;32m    117\u001B[0m     }\n\u001B[1;32m    118\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[1;32m    119\u001B[0m     alpha \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39malpha_multiplier\n",
      "\u001B[0;31mKeyboardInterrupt\u001B[0m: "
     ]
    }
   ],
   "source": [
    "print(k)\n",
    "\n",
    "torch.manual_seed(seed + 123)\n",
    "Phi_torch = torch.tensor(Phi, device=default_train_device, dtype=torch.float32)\n",
    "# policy = POPLinearPolicy(\n",
    "#     env.observation_space, env.action_space, Phi_torch, lr=1e-3,\n",
    "#     use_automatic_entropy_tuning=True, target_entropy=0.25, alpha_lr=1e-2,\n",
    "#     use_automatic_kl_tuning=True, target_kl=1.5e0, beta_lr=1e-2, beta_multiplier=1e0,\n",
    "# )\n",
    "# policy = POPLinearPolicy(\n",
    "#     env.observation_space, env.action_space, Phi_torch, lr=1e-3,\n",
    "#     use_automatic_entropy_tuning=True, target_entropy=0.25, alpha_lr=1e-3,\n",
    "#     # use_automatic_kl_tuning=False, beta_multiplier=1e-2,\n",
    "#     use_automatic_kl_tuning=False, beta_multiplier=1e-1,\n",
    "# )\n",
    "# g_model = LinearGModel(env.observation_space, env.action_space, Phi_torch, lr=2e-3)\n",
    "# pop_q = POPQ(env.observation_space, env.action_space, Phi_torch, g_model, env.get_terminal_matrix().flatten(), rank=6, q_lr=2e-2, dual_lr=2e-3)\n",
    "# pop_q = POPQ(env.observation_space, env.action_space, Phi_torch, g_model, env.get_terminal_matrix().flatten(), rank=6, q_lr=1e-3, dual_lr=2e-3)\n",
    "# log = train(pop_q, policy, off_policy_dataset, iters=int(2e5), batch_size=32, progress_bar=True, log_freq=int(2e3), policy_update_freq=int(1e1), secondary_update_freq=int(1e1), secondary_update_iters=int(2e1))\n",
    "\n",
    "policy = POPLinearPolicy(\n",
    "    env.observation_space, env.action_space, Phi_torch, lr=1e-4,\n",
    "    use_automatic_entropy_tuning=True, target_entropy=0.5, alpha_lr=1e-4,\n",
    "    # use_automatic_kl_tuning=False, beta_multiplier=1e-2,\n",
    "    use_automatic_kl_tuning=False, beta_multiplier=5e-1,\n",
    ")\n",
    "g_model = LinearGModel(env.observation_space, env.action_space, Phi_torch, lr=1e-4)\n",
    "# pop_q = POPQ(env.observation_space, env.action_space, Phi_torch, g_model, env.get_terminal_matrix().flatten(), rank=4, q_lr=1e-3, dual_lr=1e-2, pop_margin=-1e-3)\n",
    "pop_q = POPQ(env.observation_space, env.action_space, Phi_torch, g_model, env.get_terminal_matrix().flatten(), rank=4, q_lr=1e-3, dual_lr=1e-3, pop_gamma=np.sqrt(0.99))\n",
    "log = train(pop_q, policy, off_policy_dataset, iters=int(2e5), batch_size=32, progress_bar=True, log_freq=int(2e3), policy_update_freq=1, secondary_update_freq=1, secondary_update_iters=1)\n",
    "\n",
    "ax = plt.gca()\n",
    "ax.set_yscale('log')\n",
    "plt.plot(log['step'], log['TD-error'], label='TD-error')\n",
    "plt.xlabel('step')\n",
    "plt.ylabel('TD-error')\n",
    "plt.show()\n",
    "\n",
    "with np.printoptions(precision=2, suppress=True):\n",
    "    print(\"POP-Q with LINEAR g-model\")\n",
    "    print(\"max(Q(s,*))\")\n",
    "    print(get_q_table(pop_q).max(-1).reshape(env.map.shape))\n",
    "    print()\n",
    "\n",
    "    print(\"Q-error\")\n",
    "    print(compute_Q_error(env, off_policy_dataset, opt_q_table, get_q_table(pop_q)))\n",
    "    print()\n",
    "\n",
    "    print(\"Effective Policy\")\n",
    "    print(env.render_policy(policy)) "
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-09-28T21:31:28.791750452Z",
     "start_time": "2023-09-28T20:28:23.695369655Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "outputs": [
    {
     "data": {
      "text/plain": "0.6664610185699616"
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "performance_env = frozen_lake_env_from_string(\n",
    "    frzmap, loop=False, slippery=0.1\n",
    ")\n",
    "compute_performance(performance_env, policy)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-09-28T00:00:01.006597771Z",
     "start_time": "2023-09-28T00:00:00.772493658Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "outputs": [],
   "source": [
    "x = np.arange(num_states * num_actions)\n",
    "done = torch.tensor(env.get_terminal_matrix().flatten(), device=pop_q.Phi.device, dtype=torch.float32)\n",
    "        \n",
    "a_2_norm = torch.svd(pop_q.a, compute_uv=False)[1].max().detach()\n",
    "b_2_norm = torch.svd(pop_q.b, compute_uv=False)[1].max().detach()\n",
    "m_a = pop_q.a_mag * (pop_q.Phi @ (pop_q.a / a_2_norm))[x]\n",
    "m_b = pop_q.b_mag * (pop_q.Phi @ (pop_q.b / b_2_norm))[x]\n",
    "P_policy = torch.tensor(get_P_policy(policy, env.get_transition_matrix()), device=pop_q.Phi.device, dtype=torch.float32)\n",
    "\n",
    "g_tilde = (m_b * (P_policy @ m_a)).sum(-1)\n",
    "reweight = torch.exp((m_a ** 2 + m_b ** 2).sum(-1) + 2 * (1 - done.float()) * g_tilde)\n",
    "obj = (reweight * mu).sum()"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-09-26T16:57:37.558896337Z",
     "start_time": "2023-09-26T16:57:37.496180773Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "outputs": [
    {
     "data": {
      "text/plain": "tensor([1.0877, 0.9170, 0.8926, 0.8973, 0.8218, 0.9053, 0.9884, 0.9753, 0.9331,\n        1.0443, 1.0141, 1.0554, 0.8251, 1.0759, 1.0887, 0.9511, 0.8020, 0.9934,\n        0.9885, 1.0744, 0.9794, 1.3938, 0.8704, 0.6938, 0.7966, 0.8904, 1.8461,\n        0.8603, 0.9515, 1.7645, 0.8010, 1.6593, 0.6643, 0.8631, 1.1014, 0.7621,\n        0.8771, 0.7823, 0.9191, 0.7502, 0.8950, 1.1594, 1.0481, 0.7519, 1.0616,\n        1.4225, 1.0945, 1.3327, 0.9833, 1.1731, 1.1797, 0.8146, 1.0152, 0.9453,\n        0.9758, 0.7439, 1.0035, 0.8624, 1.8142, 0.8493, 1.3365, 0.9011, 0.8165,\n        0.7996], device='cuda:0', grad_fn=<ExpBackward0>)"
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "reweight"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-09-26T16:57:43.352922837Z",
     "start_time": "2023-09-26T16:57:43.313735532Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "reweight = pop_q.reweight.detach().cpu().numpy()\n",
    "reweight /= np.sum(reweight * mu)\n",
    "heatmap(reweight.reshape((env.observation_space.n, env.action_space.n))[np.arange(env.observation_space.n), policy.dist(np.arange(num_states)).argmax(-1)])"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "## Create histogram of the reweighting terms\n",
    "\n",
    "def histogram(reweight):\n",
    "    plt.hist(reweight, bins=100)\n",
    "    plt.xlabel('reweight')\n",
    "    plt.ylabel('count')\n",
    "    plt.show()"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "histogram(reweight)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "### Conservative Q-Learning"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "outputs": [
    {
     "data": {
      "text/plain": "Training offline:   0%|          | 0/200000 [00:00<?, ?it/s]",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "6e36dfa66f664b7b965dc31bb69af00d"
      }
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Step 0: TD-error = 0.0677, Q-error = 0.0579\n",
      " Q Values at state 11 = [-0.01566242  0.13557395  0.07399583  0.04026781], \n",
      " Policy dist at state 11 = [[0.23600623 0.28342348 0.24792464 0.23264566]], \n",
      "Policy update info: policy_loss = -1.390535, alpha = 0.999000, alpha_loss = 0.000000, policy_entropy = 1.384533, \n",
      "↓←←S\n",
      "→#↓#\n",
      "↓→←→\n",
      "↑→#G\n",
      "Step 2000: TD-error = 0.0936, Q-error = 3.5192\n",
      " Q Values at state 11 = [-0.03732778  0.5597771   0.9635792   0.26930273], \n",
      " Policy dist at state 11 = [[0.00951057 0.11678604 0.85458916 0.01911431]], \n",
      "Policy update info: policy_loss = -0.279507, alpha = 0.142228, alpha_loss = -1.430936, policy_entropy = 1.234033, \n",
      "→↑←S\n",
      "↓#↓#\n",
      "→→→↓\n",
      "→↑#G\n",
      "Step 4000: TD-error = 0.2220, Q-error = 0.4828\n",
      " Q Values at state 11 = [0.4253315  0.9167613  1.2400932  0.80030906], \n",
      " Policy dist at state 11 = [[0.00100589 0.05261168 0.94529516 0.00108729]], \n",
      "Policy update info: policy_loss = -0.548288, alpha = 0.051288, alpha_loss = 0.088490, policy_entropy = 0.470210, \n",
      "→↑←S\n",
      "↓#↓#\n",
      "→→→↓\n",
      "→↑#G\n",
      "Step 6000: TD-error = 0.4673, Q-error = 0.3398\n",
      " Q Values at state 11 = [1.0859752 1.3264542 1.8786938 1.487664 ], \n",
      " Policy dist at state 11 = [[5.6151900e-04 6.4419776e-02 9.3470466e-01 3.1408851e-04]], \n",
      "Policy update info: policy_loss = -1.025427, alpha = 0.141295, alpha_loss = 0.225551, policy_entropy = 0.384770, \n",
      "→↑←S\n",
      "↓#↓#\n",
      "→→→↓\n",
      "→↑#G\n",
      "Step 8000: TD-error = 0.7591, Q-error = 0.9064\n",
      " Q Values at state 11 = [1.4601432 1.5605959 2.2656343 2.0477958], \n",
      " Policy dist at state 11 = [[1.0410799e-03 6.6126019e-02 9.3261129e-01 2.2162794e-04]], \n",
      "Policy update info: policy_loss = -1.540367, alpha = 0.127964, alpha_loss = 0.123566, policy_entropy = 0.439911, \n",
      "→↑←S\n",
      "↑#↓#\n",
      "→→→↓\n",
      "→→#G\n",
      "Step 10000: TD-error = 1.1204, Q-error = 1.6572\n",
      " Q Values at state 11 = [1.8882686 1.7871032 2.7202716 2.5462937], \n",
      " Policy dist at state 11 = [[5.2407553e-04 3.8664512e-02 9.6068341e-01 1.2799983e-04]], \n",
      "Policy update info: policy_loss = -1.815752, alpha = 0.175946, alpha_loss = 0.065130, policy_entropy = 0.462516, \n",
      "→↑←S\n",
      "↑#↓#\n",
      "↓→→↓\n",
      "→→#G\n",
      "Step 12000: TD-error = 1.5093, Q-error = 2.7814\n",
      " Q Values at state 11 = [2.350888  2.0771198 3.1191454 2.9925566], \n",
      " Policy dist at state 11 = [[3.0788477e-04 2.1431068e-02 9.7818106e-01 7.9961617e-05]], \n",
      "Policy update info: policy_loss = -2.275506, alpha = 0.222625, alpha_loss = 0.092290, policy_entropy = 0.438571, \n",
      "→↑←S\n",
      "↑#↑#\n",
      "↓→→↓\n",
      "→→#G\n",
      "Step 14000: TD-error = 1.9701, Q-error = 4.0294\n",
      " Q Values at state 11 = [2.7952695 2.2786393 3.518561  3.4079356], \n",
      " Policy dist at state 11 = [[2.0981079e-04 1.5113403e-02 9.8461539e-01 6.1299092e-05]], \n",
      "Policy update info: policy_loss = -2.607279, alpha = 0.261708, alpha_loss = 0.020092, policy_entropy = 0.485013, \n",
      "→↑←S\n",
      "↑#↑#\n",
      "↓→→↓\n",
      "→→#G\n",
      "Step 16000: TD-error = 2.5770, Q-error = 5.4924\n",
      " Q Values at state 11 = [3.3059156 2.4940162 3.9003859 3.861562 ], \n",
      " Policy dist at state 11 = [[2.4228788e-04 1.3776648e-02 9.8589176e-01 8.9239380e-05]], \n",
      "Policy update info: policy_loss = -2.998998, alpha = 0.291478, alpha_loss = -0.042167, policy_entropy = 0.534206, \n",
      "→↑←S\n",
      "↑#↑#\n",
      "↓↓→↓\n",
      "→→#G\n",
      "Step 18000: TD-error = 3.2486, Q-error = 7.1185\n",
      " Q Values at state 11 = [3.7300658 2.7286344 4.2953987 4.2540984], \n",
      " Policy dist at state 11 = [[2.8582892e-04 1.1555706e-02 9.8803210e-01 1.2629929e-04]], \n",
      "Policy update info: policy_loss = -3.618083, alpha = 0.322575, alpha_loss = -0.001027, policy_entropy = 0.500908, \n",
      "→↑←S\n",
      "↑#↑#\n",
      "↓↓→↓\n",
      "→→#G\n",
      "Step 20000: TD-error = 3.9695, Q-error = 8.8789\n",
      " Q Values at state 11 = [4.1296535 3.0841238 4.697255  4.7172384], \n",
      " Policy dist at state 11 = [[7.5173296e-04 1.4840758e-02 9.8404974e-01 3.5771696e-04]], \n",
      "Policy update info: policy_loss = -3.718221, alpha = 0.353821, alpha_loss = 0.036891, policy_entropy = 0.464503, \n",
      "→↑←S\n",
      "↑#↑#\n",
      "↓↑→↓\n",
      "→→#G\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
      "\u001B[0;31mKeyboardInterrupt\u001B[0m                         Traceback (most recent call last)",
      "Cell \u001B[0;32mIn[14], line 6\u001B[0m\n\u001B[1;32m      4\u001B[0m linear_q \u001B[38;5;241m=\u001B[39m CQL(env\u001B[38;5;241m.\u001B[39mobservation_space, env\u001B[38;5;241m.\u001B[39maction_space, Phi_torch, learning_rate\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m1e-3\u001B[39m, alpha_prime\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m1e-2\u001B[39m)\n\u001B[1;32m      5\u001B[0m policy \u001B[38;5;241m=\u001B[39m LinearPolicy(env\u001B[38;5;241m.\u001B[39mobservation_space, env\u001B[38;5;241m.\u001B[39maction_space, Phi_torch, lr\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m1e-3\u001B[39m, use_automatic_entropy_tuning\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mTrue\u001B[39;00m, target_entropy\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m0.5\u001B[39m, alpha_lr\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m1e-3\u001B[39m)\n\u001B[0;32m----> 6\u001B[0m log \u001B[38;5;241m=\u001B[39m \u001B[43mtrain\u001B[49m\u001B[43m(\u001B[49m\u001B[43mlinear_q\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mpolicy\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43moff_policy_dataset\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43miters\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;28;43mint\u001B[39;49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m2e5\u001B[39;49m\u001B[43m)\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mbatch_size\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;241;43m32\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mprogress_bar\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;28;43;01mTrue\u001B[39;49;00m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mlog_freq\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;28;43mint\u001B[39;49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m2e3\u001B[39;49m\u001B[43m)\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mpolicy_update_freq\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;241;43m1\u001B[39;49m\u001B[43m)\u001B[49m\n\u001B[1;32m      8\u001B[0m ax \u001B[38;5;241m=\u001B[39m plt\u001B[38;5;241m.\u001B[39mgca()\n\u001B[1;32m      9\u001B[0m ax\u001B[38;5;241m.\u001B[39mset_yscale(\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mlog\u001B[39m\u001B[38;5;124m'\u001B[39m)\n",
      "Cell \u001B[0;32mIn[9], line 89\u001B[0m, in \u001B[0;36mtrain\u001B[0;34m(q_model, policy, dataset, iters, batch_size, gamma, progress_bar, log_freq, dataset_update_freq, policy_update_freq, policy_update_iters, secondary_update_freq, secondary_update_iters, bc, device)\u001B[0m\n\u001B[1;32m     87\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m policy_update_freq \u001B[38;5;241m==\u001B[39m \u001B[38;5;241m1\u001B[39m \u001B[38;5;129;01mor\u001B[39;00m (step \u001B[38;5;241m+\u001B[39m \u001B[38;5;241m1\u001B[39m) \u001B[38;5;241m%\u001B[39m policy_update_freq \u001B[38;5;241m==\u001B[39m \u001B[38;5;241m0\u001B[39m:\n\u001B[1;32m     88\u001B[0m     \u001B[38;5;28;01mfor\u001B[39;00m step2 \u001B[38;5;129;01min\u001B[39;00m \u001B[38;5;28mrange\u001B[39m(policy_update_iters):\n\u001B[0;32m---> 89\u001B[0m         policy_update_info \u001B[38;5;241m=\u001B[39m \u001B[43mpolicy\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mupdate\u001B[49m\u001B[43m(\u001B[49m\u001B[43mq_model\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43ms\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43ma\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mr\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43msp\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mdone\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mgamma\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mgamma\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mbc\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mbc\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m     91\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m step \u001B[38;5;241m%\u001B[39m log_freq \u001B[38;5;241m==\u001B[39m \u001B[38;5;241m0\u001B[39m:\n\u001B[1;32m     92\u001B[0m     log[\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mstep\u001B[39m\u001B[38;5;124m'\u001B[39m]\u001B[38;5;241m.\u001B[39mappend(step)\n",
      "File \u001B[0;32m~/projects/POP-QL/small_scale/policies.py:130\u001B[0m, in \u001B[0;36mLinearPolicy.update\u001B[0;34m(self, q_model, state, action, reward, state_prime, done, gamma, bc)\u001B[0m\n\u001B[1;32m    128\u001B[0m policy_loss \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_policy_loss(q_model, state, action, reward, state_prime, done, alpha\u001B[38;5;241m=\u001B[39malpha, gamma\u001B[38;5;241m=\u001B[39mgamma, bc\u001B[38;5;241m=\u001B[39mbc)\n\u001B[1;32m    129\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39moptim\u001B[38;5;241m.\u001B[39mzero_grad()\n\u001B[0;32m--> 130\u001B[0m \u001B[43mpolicy_loss\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mbackward\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m    131\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39moptim\u001B[38;5;241m.\u001B[39mstep()\n\u001B[1;32m    133\u001B[0m info \u001B[38;5;241m=\u001B[39m {\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mpolicy_loss\u001B[39m\u001B[38;5;124m'\u001B[39m: policy_loss\u001B[38;5;241m.\u001B[39mdetach()\u001B[38;5;241m.\u001B[39mcpu()\u001B[38;5;241m.\u001B[39mnumpy()}\n",
      "File \u001B[0;32m~/micromamba/envs/JaxCQL/lib/python3.9/site-packages/torch/_tensor.py:487\u001B[0m, in \u001B[0;36mTensor.backward\u001B[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001B[0m\n\u001B[1;32m    477\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m has_torch_function_unary(\u001B[38;5;28mself\u001B[39m):\n\u001B[1;32m    478\u001B[0m     \u001B[38;5;28;01mreturn\u001B[39;00m handle_torch_function(\n\u001B[1;32m    479\u001B[0m         Tensor\u001B[38;5;241m.\u001B[39mbackward,\n\u001B[1;32m    480\u001B[0m         (\u001B[38;5;28mself\u001B[39m,),\n\u001B[0;32m   (...)\u001B[0m\n\u001B[1;32m    485\u001B[0m         inputs\u001B[38;5;241m=\u001B[39minputs,\n\u001B[1;32m    486\u001B[0m     )\n\u001B[0;32m--> 487\u001B[0m \u001B[43mtorch\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mautograd\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mbackward\u001B[49m\u001B[43m(\u001B[49m\n\u001B[1;32m    488\u001B[0m \u001B[43m    \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mgradient\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mretain_graph\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mcreate_graph\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43minputs\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43minputs\u001B[49m\n\u001B[1;32m    489\u001B[0m \u001B[43m\u001B[49m\u001B[43m)\u001B[49m\n",
      "File \u001B[0;32m~/micromamba/envs/JaxCQL/lib/python3.9/site-packages/torch/autograd/__init__.py:200\u001B[0m, in \u001B[0;36mbackward\u001B[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001B[0m\n\u001B[1;32m    195\u001B[0m     retain_graph \u001B[38;5;241m=\u001B[39m create_graph\n\u001B[1;32m    197\u001B[0m \u001B[38;5;66;03m# The reason we repeat same the comment below is that\u001B[39;00m\n\u001B[1;32m    198\u001B[0m \u001B[38;5;66;03m# some Python versions print out the first line of a multi-line function\u001B[39;00m\n\u001B[1;32m    199\u001B[0m \u001B[38;5;66;03m# calls in the traceback and some print out the last line\u001B[39;00m\n\u001B[0;32m--> 200\u001B[0m \u001B[43mVariable\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_execution_engine\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mrun_backward\u001B[49m\u001B[43m(\u001B[49m\u001B[43m  \u001B[49m\u001B[38;5;66;43;03m# Calls into the C++ engine to run the backward pass\u001B[39;49;00m\n\u001B[1;32m    201\u001B[0m \u001B[43m    \u001B[49m\u001B[43mtensors\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mgrad_tensors_\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mretain_graph\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mcreate_graph\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43minputs\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m    202\u001B[0m \u001B[43m    \u001B[49m\u001B[43mallow_unreachable\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;28;43;01mTrue\u001B[39;49;00m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43maccumulate_grad\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;28;43;01mTrue\u001B[39;49;00m\u001B[43m)\u001B[49m\n",
      "\u001B[0;31mKeyboardInterrupt\u001B[0m: "
     ]
    }
   ],
   "source": [
    "torch.manual_seed(seed + 123)\n",
    "Phi_torch = torch.tensor(Phi, device=default_train_device, dtype=torch.float32)\n",
    "# linear_q = CQL(env.observation_space, env.action_space, Phi_torch, learning_rate=1e-3, alpha_prime=5e0)\n",
    "linear_q = CQL(env.observation_space, env.action_space, Phi_torch, learning_rate=1e-3, alpha_prime=1e-2)\n",
    "policy = LinearPolicy(env.observation_space, env.action_space, Phi_torch, lr=1e-3, use_automatic_entropy_tuning=True, target_entropy=0.5, alpha_lr=1e-3)\n",
    "log = train(linear_q, policy, off_policy_dataset, iters=int(2e5), batch_size=32, progress_bar=True, log_freq=int(2e3), policy_update_freq=1)\n",
    "\n",
    "ax = plt.gca()\n",
    "ax.set_yscale('log')\n",
    "plt.plot(log['step'], log['TD-error'], label='TD-error')\n",
    "plt.xlabel('step')\n",
    "plt.ylabel('TD-error')\n",
    "plt.show()\n",
    "\n",
    "with np.printoptions(precision=2, suppress=True):\n",
    "    print(\"LINEAR OFF-POLICY\")\n",
    "    print(\"max(Q(s,*))\")\n",
    "    print(get_q_table(linear_q).max(-1).reshape(env.map.shape))\n",
    "    print()\n",
    "\n",
    "    print(\"Q-error\")\n",
    "    print(compute_Q_error(env, off_policy_dataset, opt_q_table, get_q_table(linear_q)))\n",
    "    print()\n",
    "\n",
    "    print('F min eigenvalue')\n",
    "    F = compute_expected_F(mu, policy, env.get_transition_matrix(), Phi)\n",
    "    print(np.real(np.linalg.eigvals(F)).min())\n",
    "\n",
    "    print(\"Effective Policy\")\n",
    "    print(env.render_policy(policy))"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-09-17T03:40:13.592733747Z",
     "start_time": "2023-09-17T03:36:08.440709338Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "outputs": [
    {
     "data": {
      "text/plain": "0.15587145007105885"
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "performance_env = frozen_lake_env_from_string(\n",
    "    frzmap, loop=False, slippery=0.1\n",
    ")\n",
    "compute_performance(performance_env, data_policy)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-09-17T03:25:47.256278217Z",
     "start_time": "2023-09-17T03:25:44.077686006Z"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "### Behavior Cloning"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "outputs": [
    {
     "data": {
      "text/plain": "Training offline:   0%|          | 0/100000 [00:00<?, ?it/s]",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "854362a135f04704a280b67370c2fb3b"
      }
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Policy update info: policy_loss = 1.224630, \n",
      "↓←←S\n",
      "→#↓#\n",
      "↓→←→\n",
      "↑→#G\n",
      "\n",
      "Policy update info: policy_loss = 0.967212, \n",
      "↓←←S\n",
      "↓#↓#\n",
      "↓→↑↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 1.026380, \n",
      "↓←←S\n",
      "↓#↓#\n",
      "↓→↓↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 1.052213, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 1.079863, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 0.946832, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 1.162197, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 1.014042, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 1.283407, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 1.167713, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 1.074266, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 1.113678, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 0.960092, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 1.012112, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 0.919760, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 0.806026, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 0.959419, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 1.193580, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 1.091449, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 0.935536, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 1.137999, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 0.933426, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 0.827785, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 1.011742, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 0.862795, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 0.942968, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 0.918620, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 1.162748, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 0.867599, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 0.888983, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 0.965386, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 1.038401, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 0.882515, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 0.873764, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 1.015906, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 0.998588, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 0.969070, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 0.958342, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 1.175124, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 0.754630, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 0.963021, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 0.918654, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 1.106214, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 1.012825, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 1.190285, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 0.749368, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 0.917832, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 1.082168, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 1.108547, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 0.914099, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 0.958198, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 0.957364, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 0.974648, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 0.828455, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 0.966653, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 0.779751, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 0.685161, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 0.936098, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 0.876270, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 0.785679, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 0.911053, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 0.886822, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 0.736450, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 0.869899, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 0.882647, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 0.739842, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 1.099620, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 0.876390, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 0.920548, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 1.010013, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 1.062185, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 1.097123, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 1.010809, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 0.957336, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 1.002213, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 0.777685, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 0.999938, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 1.007768, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 0.919792, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 1.053124, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 0.854454, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 0.967491, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 0.831083, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 1.151516, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 1.155019, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 0.828098, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 1.055665, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 0.920074, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 0.695077, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 0.823513, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 0.914840, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 1.237857, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 1.021542, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 1.007975, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 0.917389, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 0.783209, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 0.917200, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 0.729949, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 0.876953, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n",
      "\n",
      "Policy update info: policy_loss = 0.970193, \n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n"
     ]
    }
   ],
   "source": [
    "torch.manual_seed(seed + 123)\n",
    "Phi_torch = torch.tensor(Phi, device=default_train_device, dtype=torch.float32)\n",
    "policy = LinearPolicy(env.observation_space, env.action_space, Phi_torch, lr=1e-3, alpha_multiplier=0.1)\n",
    "log = train(None, policy, off_policy_dataset, iters=int(1e5), batch_size=32, progress_bar=True, log_freq=int(1e3), policy_update_freq=1, bc=True)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-09-14T18:56:51.823862516Z",
     "start_time": "2023-09-14T18:49:21.857261155Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "outputs": [
    {
     "data": {
      "text/plain": "0.1380385807978683"
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "performance_env = frozen_lake_env_from_string(\n",
    "    frzmap, loop=False, slippery=0.1\n",
    ")\n",
    "compute_performance(performance_env, policy)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-09-14T03:20:39.523289006Z",
     "start_time": "2023-09-14T03:20:11.597742671Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [],
   "metadata": {
    "collapsed": false
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
