{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true,
    "ExecuteTime": {
     "end_time": "2023-09-28T00:28:27.224537713Z",
     "start_time": "2023-09-28T00:28:26.048850148Z"
    }
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import cvxpy as cp\n",
    "from dataclasses import dataclass\n",
    "import gym\n",
    "\n",
    "from tqdm.notebook import tnrange, tqdm_notebook as tqdm\n",
    "\n",
    "from frozen_lake import FrozenLakeEnv, frozen_lake_env_from_string, frozen_lake_policy_from_string\n",
    "from policies import EpsilonGreedyPolicy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Data Policy and Distribution:\n",
      "↓←←S\n",
      "↓#↑#\n",
      "↓→→↓\n",
      "→↑#G\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/mel/projects/POP-QL/small_scale/frozen_lake.py:192: DeprecationWarning: `np.bool` is a deprecated alias for the builtin `bool`. To silence this warning, use `bool` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.bool_` here.\n",
      "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
      "  map = np.array(map, np.bool)\n"
     ]
    }
   ],
   "source": [
    "gamma = 0.99\n",
    "frzmap = [\n",
    "    \"FFFS\",\n",
    "    \"FHFH\",\n",
    "    \"FFFF\",\n",
    "    \"FFHG\"\n",
    "]\n",
    "\n",
    "env = frozen_lake_env_from_string(\n",
    "    frzmap, loop=True, slippery=0.5\n",
    ")\n",
    "num_states = env.observation_space.n\n",
    "num_actions = env.action_space.n\n",
    "R = env.get_reward_matrix().reshape((num_states * num_actions, 1))\n",
    "\n",
    "# Generate Data distribution\n",
    "data_policy = frozen_lake_policy_from_string(\n",
    "    [\n",
    "        \"↓←←←\",\n",
    "        \"↓↑↑↑\",\n",
    "        \"↓→→↓\",\n",
    "        \"→↑↑↑\",\n",
    "    ],\n",
    "    epsilon=.2,\n",
    ")\n",
    "\n",
    "print(\"Data Policy and Distribution:\")\n",
    "print(env.render_policy(data_policy))"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-09-28T00:28:27.236745289Z",
     "start_time": "2023-09-28T00:28:27.226808548Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Eval Policy and Distribution:\n",
      "↓→↓S\n",
      "↓#↓#\n",
      "→→→↓\n",
      "→↑#G\n"
     ]
    }
   ],
   "source": [
    "opt_policy = env.shortest_path_policy()\n",
    "eval_policy = EpsilonGreedyPolicy(opt_policy, env.action_space, epsilon=0.2)\n",
    "print(\"Eval Policy and Distribution:\")\n",
    "print(env.render_policy(eval_policy))"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-09-28T00:28:31.309647265Z",
     "start_time": "2023-09-28T00:28:31.303642357Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "outputs": [],
   "source": [
    "def get_P_policy(policy, P):\n",
    "    # Compute the (S x A) x (S x A) transition matrix for the given policy using the S x A x S transition matrix.\n",
    "    num_states, num_actions, _ = P.shape\n",
    "    n = num_states * num_actions\n",
    "\n",
    "    P_policy = np.zeros([n, n])\n",
    "    for s in range(num_states):\n",
    "        for a in range(num_actions):\n",
    "            idx = s*num_actions + a\n",
    "            for sp in range(num_states):\n",
    "                idx_p = sp*num_actions\n",
    "                P_policy[idx, idx_p:idx_p+num_actions] = P[s, a, sp] * policy.dist(sp)\n",
    "    return P_policy\n",
    "\n",
    "def compute_F_s_slow(policy, P, Phi):\n",
    "    # Compute F(s) = E_{s' ~ P(s'|s)}[F(s, s')] for the given policy and transition matrix.\n",
    "    num_states, num_actions, _ = P.shape\n",
    "    n, k = Phi.shape\n",
    "\n",
    "    P_policy = get_P_policy(policy, P)\n",
    "\n",
    "    F_s = np.zeros((n, 2*k, 2*k))\n",
    "    for x in range(n):\n",
    "        for xp in range(n):\n",
    "            F_s[x, :k, :k] += P_policy[x, xp] * (Phi[x][:, None] @ Phi[x][None, :])\n",
    "            F_s[x, :k, k:] += P_policy[x, xp] * (Phi[x][:, None] @ Phi[xp][None, :])\n",
    "            F_s[x, k:, :k] += P_policy[x, xp] * (Phi[xp][:, None] @ Phi[x][None, :])\n",
    "            F_s[x, k:, k:] += P_policy[x, xp] * (Phi[x][:, None] @ Phi[x][None, :])\n",
    "    return F_s\n",
    "\n",
    "def compute_F_s(policy, P, Phi):\n",
    "    # Compute F(s) = E_{s' ~ P(s'|s)}[F(s, s')] for the given policy and transition matrix.\n",
    "\n",
    "    P_policy = get_P_policy(policy, P)\n",
    "\n",
    "    A = Phi[:, :, None] @ Phi[:, None, :]\n",
    "    B = Phi[:, :, None] @ (P_policy @ Phi)[:, None, :]\n",
    "\n",
    "    return np.concatenate([np.concatenate([A, B], axis=2), np.concatenate([B.transpose(0, 2, 1), A], axis=2)], axis=1)\n",
    "\n",
    "def compute_A_s_slow(policy, P, Phi):\n",
    "    # Compute F(s) = E_{s' ~ P(s'|s)}[F(s, s')] for the given policy and transition matrix.\n",
    "    num_states, num_actions, _ = P.shape\n",
    "    n, k = Phi.shape\n",
    "\n",
    "    P_policy = get_P_policy(policy, P)\n",
    "\n",
    "    A_s = np.zeros((n, k, k))\n",
    "    for x in range(n):\n",
    "        for xp in range(n):\n",
    "            A_s[x] += P_policy[x, xp] * (Phi[x][:, None] @ (Phi[x][None, :] - Phi[xp][None, :]))\n",
    "    return A_s\n",
    "\n",
    "def compute_expected_F(mu, policy, P, Phi):\n",
    "    # Compute $$E_{s ~ \\mu}[F(s)]$$ for the given state-action distribution, $$\\mu$$, and policy.\n",
    "\n",
    "    return (mu[:, None, None] * compute_F_s(policy, P, Phi)).sum(axis=0)\n",
    "\n",
    "def get_stationary_distribution(P):\n",
    "    # Compute the stationary distribution of the given transition matrix.\n",
    "    n = P.shape[0]\n",
    "    A = np.concatenate([P.T - np.eye(n), np.ones((1, n))], axis=0)\n",
    "    b = np.concatenate([np.zeros((n,)), [1]], axis=0)\n",
    "    return np.linalg.lstsq(A, b, rcond=None)[0]\n",
    "\n",
    "\n",
    "# verify that the two implementations are equivalent\n",
    "Phi = np.random.uniform(-1, 1, (env.observation_space.n * env.action_space.n, 50))\n",
    "assert np.allclose(compute_F_s_slow(eval_policy, env.P, Phi), compute_F_s(eval_policy, env.P, Phi))"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-09-28T00:28:31.713332348Z",
     "start_time": "2023-09-28T00:28:31.569745446Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Off-policy distribution:\n",
      "[[0.10371034 0.1102833  0.14215318 0.16371134]\n",
      " [0.07960952 0.03882821 0.02590268 0.03140544]\n",
      " [0.07084936 0.04305927 0.03053134 0.01975557]\n",
      " [0.06651885 0.0506403  0.01217575 0.01086556]]\n",
      "\n",
      "On-policy distribution:\n",
      "[[0.01024259 0.03937557 0.21288567 0.2538488 ]\n",
      " [0.00842316 0.03033267 0.13119609 0.06688607]\n",
      " [0.01017508 0.02322296 0.09405982 0.06086223]\n",
      " [0.00355091 0.00639582 0.01506834 0.03347423]]\n"
     ]
    }
   ],
   "source": [
    "## Generate datasets for both the evaluation policy, \"on-policy\", and the data policy, \"off-policy\".\n",
    "\n",
    "def visit_map(env, dist):\n",
    "    return dist.reshape((num_states, num_actions)).sum(1).reshape(env.map.shape)\n",
    "\n",
    "P_data = get_P_policy(data_policy, env.get_transition_matrix())\n",
    "mu_off_policy = get_stationary_distribution(P_data)\n",
    "print('Off-policy distribution:')\n",
    "print(visit_map(env, mu_off_policy))\n",
    "print()\n",
    "\n",
    "P_eval = get_P_policy(eval_policy, env.get_transition_matrix())\n",
    "mu_on_policy = get_stationary_distribution(P_eval)\n",
    "print('On-policy distribution:')\n",
    "print(visit_map(env, mu_on_policy))\n",
    "print()\n"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-09-15T20:53:57.666576147Z",
     "start_time": "2023-09-15T20:53:57.598882541Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "outputs": [],
   "source": [
    "# Compute the Q-error for a given Q-table.\n",
    "def compute_Q_error(dist, reference_q_table, q_table):\n",
    "    q_errors = (reference_q_table.flatten() - q_table.flatten()) ** 2\n",
    "    return np.sum(q_errors * dist)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-09-15T20:53:59.129477049Z",
     "start_time": "2023-09-15T20:53:59.096217590Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True Q-values:\n",
      "max(Q(s,*))\n",
      "[[3.16 3.2  3.28 3.21]\n",
      " [3.22 3.17 3.4  3.17]\n",
      " [3.3  3.41 3.61 3.89]\n",
      " [3.25 3.3  3.17 4.17]]\n"
     ]
    }
   ],
   "source": [
    "# Compute the Q-values for the optimal policy.\n",
    "n = env.observation_space.n * env.action_space.n\n",
    "opt_q = (np.linalg.inv(np.eye(n) - gamma * P_eval) @ R)\n",
    "\n",
    "with np.printoptions(precision=2, suppress=True):\n",
    "    print(\"True Q-values:\")\n",
    "    print(\"max(Q(s,*))\")\n",
    "    print(opt_q.reshape(env.observation_space.n, env.action_space.n).max(-1).reshape(env.map.shape))"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-09-15T20:54:01.156772878Z",
     "start_time": "2023-09-15T20:54:01.096616855Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "outputs": [],
   "source": [
    "## Construct feature matrices\n",
    "np.random.seed(4)\n",
    "\n",
    "k = 63\n",
    "# k = 53\n",
    "Phi = np.random.uniform(-1, 1, (env.observation_space.n * env.action_space.n, k))\n",
    "Phi /= np.linalg.norm(Phi, axis=-1, keepdims=True)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-09-15T20:54:01.396174167Z",
     "start_time": "2023-09-15T20:54:01.377706919Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "outputs": [],
   "source": [
    "# Compute LS-TD solution\n",
    "def compute_ls_td(dist, P, R, Phi, gamma):\n",
    "    return np.linalg.inv(Phi.T @ np.diag(dist) @ (Phi - gamma * P @ Phi)) @ Phi.T @ np.diag(dist) @ R"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-09-15T20:54:04.065127635Z",
     "start_time": "2023-09-15T20:54:04.045517522Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "outputs": [],
   "source": [
    "# Iterative TD solution\n",
    "def compute_iterative_td(dist, P, R, Phi, gamma, epsilon=1e-5, alpha=2e0, divergence_threshold=1e9,  max_iters=int(1e6), log_freq=int(1e3), progress_bar=None):\n",
    "    n, k = Phi.shape\n",
    "    w = np.zeros((k, 1))\n",
    "\n",
    "    if progress_bar is None:\n",
    "        progress_bar = tqdm(total=max_iters)\n",
    "    else:\n",
    "        progress_bar.reset(total=max_iters)\n",
    "    for i in range(max_iters):\n",
    "        v = Phi @ w\n",
    "        v_target = R + gamma * P @ Phi @ w\n",
    "        change_w = Phi.T @ np.diag(dist) @ (v_target - v)\n",
    "        w_next = w + alpha * change_w\n",
    "        td_error = (dist * ((v_target - v) ** 2)).sum()\n",
    "        \n",
    "        if td_error > divergence_threshold:\n",
    "            print(\"Warning: TD solution diverged.\")\n",
    "            break\n",
    "        if np.linalg.norm(w_next - w) < epsilon:\n",
    "            break\n",
    "        w = w_next\n",
    "\n",
    "        progress_bar.update(1)\n",
    "        if i % log_freq == 0:\n",
    "            progress_bar.set_description(f\"Computing Q... TD-Error: {td_error:.4f}\")\n",
    "    if i == max_iters - 1:\n",
    "        print(\"Warning: TD solution did not converge.\")\n",
    "    return w"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-09-15T20:54:05.425221608Z",
     "start_time": "2023-09-15T20:54:05.400066632Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "outputs": [],
   "source": [
    "## Solve the dual problem exactly\n",
    "def compute_pop_q_exact(dist, policy, P, Phi, pop_margin=0):\n",
    "    k = Phi.shape[1]\n",
    "    Z = cp.Variable((2*k, 2*k), PSD=True)\n",
    "    F_s = compute_F_s(policy, P, Phi)\n",
    "\n",
    "    obj = 0\n",
    "    for i in range(n):\n",
    "        obj += dist[i] * cp.exp(cp.trace(Z.T @ (F_s[i] - pop_margin * np.eye(2*k))))\n",
    "    prob = cp.Problem(cp.Minimize(obj), [Z >> 0])\n",
    "    prob.solve(verbose=True, solver=cp.MOSEK)\n",
    "\n",
    "    reweight = np.zeros(n)\n",
    "    for i in range(n):\n",
    "        reweight[i] = np.exp(np.trace(Z.value.T @ F_s[i]))\n",
    "    return reweight * dist / np.sum(reweight * dist)\n"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-09-15T20:54:06.428660835Z",
     "start_time": "2023-09-15T20:54:06.416483489Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "outputs": [],
   "source": [
    "## Solve POP-TD with linear g\n",
    "def compute_pop_q_linear(dist, P, Phi, dual_lr=2e0, g_lr=1e0, rank=4, max_iters=int(1e5), log_freq=int(1e3), progress_bar=None):\n",
    "    dist = dist.flatten()\n",
    "    Phi_prime = P @ Phi\n",
    "\n",
    "    a = 0.2 * np.random.normal(0, 1, (k, rank))\n",
    "    b = 0.2 * np.random.normal(0, 1, (k, rank))\n",
    "    w_g = 1e-4 * np.random.uniform(-1, 1, (k, 1))\n",
    "\n",
    "    if progress_bar is None:\n",
    "        progress_bar = tqdm(total=max_iters)\n",
    "    else:\n",
    "        progress_bar.reset(total=max_iters)\n",
    "    for i in range(max_iters):\n",
    "        m_a = Phi @ a\n",
    "        m_b = Phi @ b\n",
    "        m_a_prime = Phi_prime @ a\n",
    "        feature_term = (m_a ** 2 + m_b ** 2).sum(-1)\n",
    "        angle_term = (m_b * m_a_prime).sum(-1)\n",
    "        g = (Phi @ w_g).flatten()\n",
    "        reweight = np.exp(feature_term + 2 * g)\n",
    "\n",
    "        change_a = -((dist * reweight)[:, None, None] * (Phi[:, :, None] * m_a[:, None, :] + Phi_prime[:, :, None] * m_b[:, None, :])).mean(0)\n",
    "        change_b = -((dist * reweight)[:, None, None] * (Phi[:, :, None] * m_b[:, None, :] + Phi[:, :, None] * m_a_prime[:, None, :])).mean(0)\n",
    "\n",
    "        a += dual_lr * change_a\n",
    "        b += dual_lr * change_b\n",
    "        w_g += g_lr * Phi.T @ np.diag(dist) @ (angle_term - g)[:, None]\n",
    "\n",
    "        progress_bar.update(1)\n",
    "        if i % log_freq == 0:\n",
    "            obj = (dist * reweight).sum()\n",
    "            g_error = (angle_term - g) ** 2\n",
    "            progress_bar.set_description(f\"Computing POP-TD... Objective: {obj:.4f}, G Error: {g_error.mean():.4f}\")\n",
    "\n",
    "    return reweight * dist / np.sum(reweight * dist)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-09-15T20:54:06.739986088Z",
     "start_time": "2023-09-15T20:54:06.706399424Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "outputs": [
    {
     "data": {
      "text/plain": "  0%|          | 0/100 [00:00<?, ?it/s]",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "606f5415f51146f2bb73458853d42fca"
      }
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": "  0%|          | 0/1 [00:00<?, ?it/s]",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "f8a37f5c7fb1451db78b257d4a435be8"
      }
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Warning: TD solution diverged.\n",
      "Warning: TD solution diverged.\n",
      "Warning: TD solution diverged.\n",
      "Warning: TD solution diverged.\n",
      "Warning: TD solution diverged.\n",
      "Warning: TD solution diverged.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n"
     ]
    }
   ],
   "source": [
    "# Compute approximation errors\n",
    "np.random.seed(0)\n",
    "p_values = np.linspace(0, 1, int(1e2))\n",
    "q_error = []\n",
    "pbar = None\n",
    "for p in tqdm(p_values):\n",
    "    mu = p * mu_on_policy + (1-p) * mu_off_policy\n",
    "    pbar = tqdm(total=1) if pbar is None else pbar\n",
    "    # w = compute_ls_td(mu, P_eval, R, Phi, gamma)\n",
    "    w = compute_iterative_td(mu, P_eval, R, Phi, gamma, progress_bar=pbar)\n",
    "    q_error.append(compute_Q_error(mu, opt_q, Phi @ w))\n",
    "\n",
    "# Save results\n",
    "np.save(\"frozen_lake_eval_q_error.npy\", q_error)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-08-16T17:17:53.132223274Z",
     "start_time": "2023-08-16T15:50:33.461039299Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "outputs": [
    {
     "data": {
      "text/plain": "  0%|          | 0/1000 [00:00<?, ?it/s]",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "22c40959a8cb448c9f34fae84e8443bb"
      }
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Compute min eigs of E_{s ~ \\mu}[F(s)]\n",
    "np.random.seed(0)\n",
    "p_values = np.linspace(0, 1, int(1e3))\n",
    "min_eigs = []\n",
    "for p in tqdm(p_values):\n",
    "    mu = p * mu_on_policy + (1-p) * mu_off_policy\n",
    "    expected_F = compute_expected_F(mu, eval_policy, env.P, Phi)\n",
    "    eig_vals = np.linalg.eigvals(expected_F).real\n",
    "    min_eig = eig_vals.min() / eig_vals.max()\n",
    "    min_eigs.append(min_eig)\n",
    "\n",
    "# Save results\n",
    "np.save(\"frozen_lake_eval_min_eigs.npy\", min_eigs)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-08-16T17:18:52.755894501Z",
     "start_time": "2023-08-16T17:17:53.222679433Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "outputs": [
    {
     "data": {
      "text/plain": "  0%|          | 0/100 [00:00<?, ?it/s]",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "53809797bc7b47f59b89323a7e1d764b"
      }
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": "  0%|          | 0/1 [00:00<?, ?it/s]",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "adc5144f45a1458cb198f5be048e8a0a"
      }
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": "  0%|          | 0/1 [00:00<?, ?it/s]",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "a0c16aa669e74dd88ccb3f10a4020900"
      }
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n",
      "Warning: TD solution did not converge.\n"
     ]
    }
   ],
   "source": [
    "# Compute approximation errors with POP-TD\n",
    "np.random.seed(0)\n",
    "p_values = np.linspace(0, 1, int(1e2))\n",
    "pop_q_error = []\n",
    "pop_pbar = None\n",
    "q_pbar = None\n",
    "for p in tqdm(p_values):\n",
    "    mu = p * mu_on_policy + (1-p) * mu_off_policy\n",
    "    pop_pbar = tqdm(total=1) if pop_pbar is None else pop_pbar\n",
    "    q = compute_pop_q_linear(mu, P_eval, Phi, progress_bar=pop_pbar)\n",
    "    q_pbar = tqdm(total=1) if q_pbar is None else q_pbar\n",
    "    w = compute_iterative_td(q, P_eval, R, Phi, gamma, progress_bar=q_pbar)\n",
    "    pop_q_error.append(compute_Q_error(mu, opt_q, Phi @ w))\n",
    "    \n",
    "# Save results\n",
    "np.save(\"frozen_lake_eval_pop_q_error.npy\", pop_q_error)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-08-16T19:44:03.144223469Z",
     "start_time": "2023-08-16T17:18:52.763466135Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/mel/micromamba/envs/JaxCQL/lib/python3.10/site-packages/fontTools/misc/py23.py:11: DeprecationWarning: The py23 module has been deprecated and will be removed in a future release. Please update your code.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "text/plain": "<Figure size 640x480 with 0 Axes>"
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Set matplotlib settings\n",
    "\n",
    "import matplotlib\n",
    "from matplotlib.backends.backend_pgf import FigureCanvasPgf\n",
    "matplotlib.backend_bases.register_backend('pdf', FigureCanvasPgf)\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "matplotlib.rcParams.update(matplotlib.rcParamsDefault)\n",
    "\n",
    "latex_preamble = (r'\\usepackage{xcolor}'\n",
    "                  r'\\usepackage[scaled]{helvet}'\n",
    "                  r'\\usepackage{amssymb}'\n",
    "                  r'\\usepackage{amsmath}'\n",
    "                  r'\\usepackage{bm}'\n",
    "                  r'\\definecolor{offpolicycolor}{RGB}{112, 48, 160}'\n",
    "                  r'\\definecolor{onpolicycolor}{RGB}{68, 114, 196}')\n",
    "pgf_with_latex = {\n",
    "    \"text.usetex\": True,            # use LaTeX to write all text\n",
    "    \"pgf.rcfonts\": True,           # Ignore Matplotlibrc\n",
    "    # \"font.family\": \"phv\",\n",
    "    # \"font.serif\": 'Computer Modern Roman',\n",
    "    \"text.latex.preamble\": latex_preamble,\n",
    "    \"pgf.preamble\": latex_preamble,\n",
    "}\n",
    "plt.clf()\n",
    "matplotlib.rcParams.update(pgf_with_latex)\n",
    "\n",
    "# Choose colors\n",
    "tableau20 = [(31, 119, 180), (174, 199, 232), (255, 127, 14), (255, 187, 120),  \n",
    "             (44, 160, 44), (152, 223, 138), (214, 39, 40), (255, 152, 150),  \n",
    "             (148, 103, 189), (197, 176, 213), (140, 86, 75), (196, 156, 148),  \n",
    "             (227, 119, 194), (247, 182, 210), (127, 127, 127), (199, 199, 199),  \n",
    "             (188, 189, 34), (219, 219, 141), (23, 190, 207), (158, 218, 229)]\n",
    "for i in range(len(tableau20)):  \n",
    "    r, g, b = tableau20[i]  \n",
    "    tableau20[i] = (r / 255., g / 255., b / 255.) \n",
    "\n",
    "vanilla_ql_color = tableau20[14]\n",
    "pop_ql_color = tableau20[2]\n",
    "nec_color = tableau20[4]\n",
    "min_eig_color = tableau20[6]"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-09-28T00:28:50.384278326Z",
     "start_time": "2023-09-28T00:28:50.161309780Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# Load results\n",
    "q_error = np.load(\"frozen_lake_eval_q_error.npy\")\n",
    "min_eigs = np.load(\"frozen_lake_eval_min_eigs.npy\")\n",
    "pop_q_error = np.load(\"frozen_lake_eval_pop_q_error.npy\")\n",
    "\n",
    "# Plot the results\n",
    "\n",
    "# Set size of figure\n",
    "plt.figure(figsize=(6, 4))\n",
    "\n",
    "# Set font size\n",
    "plt.rcParams.update({'font.size': 16})\n",
    "\n",
    "# Set to log scale\n",
    "# plt.yscale('log')\n",
    "\n",
    "# Set axis limits\n",
    "plt.xlim(0, 1)\n",
    "pad_percent = 0.05\n",
    "y_max, y_min = 15, 0\n",
    "pad = pad_percent * (y_max - y_min)\n",
    "y_max += pad\n",
    "y_min -= pad\n",
    "# y_max, y_min = 1e3, 1e-2\n",
    "plt.ylim(y_min, y_max)\n",
    "\n",
    "# Set number of ticks\n",
    "plt.xticks(np.linspace(0, 1, 11))\n",
    "# plt.yticks(np.logspace(-4, 4, 5))\n",
    "\n",
    "# Set axis labels# Label x ticks\n",
    "plt.xticks([0, 0.5, 1.0], [r'\\quad \\quad \\textcolor{offpolicycolor}{100\\% Off-Policy}', r'\\textcolor{offpolicycolor}{50\\%}/\\textcolor{onpolicycolor}{50\\%}', r'\\textcolor{onpolicycolor}{100\\% On-Policy} \\quad \\quad'])\n",
    "plt.ylabel('Approximation Error')\n",
    "\n",
    "# Lighten the color of the axes\n",
    "ax = plt.gca()\n",
    "ax.spines['top'].set_color('lightgray')\n",
    "ax.spines['bottom'].set_color('lightgray')\n",
    "ax.spines['left'].set_color('lightgray')\n",
    "ax.spines['right'].set_color('lightgray')\n",
    "# ax.tick_params(axis='x', colors='gray')\n",
    "# ax.tick_params(axis='y', colors='gray')\n",
    "\n",
    "# Remove x ticks\n",
    "ax.xaxis.set_ticks_position('none')\n",
    "\n",
    "# Shade the non-expansive region\n",
    "p_values = np.linspace(0, 1, len(min_eigs))\n",
    "non_expansive = np.array(min_eigs) >= -0.005\n",
    "min_p_non_expansive = p_values[non_expansive].min()\n",
    "nec_area = plt.fill_between([min_p_non_expansive, 1], [y_min, y_min], [y_max, y_max], color=nec_color, alpha=0.2)\n",
    "nec_threshold_line = plt.plot([min_p_non_expansive, min_p_non_expansive], [y_min, y_max], color=nec_color, alpha=0.4, linestyle='-', linewidth=2)\n",
    "\n",
    "# Plot the results\n",
    "pop_ql_line, = plt.plot(np.linspace(0, 1, len(pop_q_error)), pop_q_error, linewidth=2, color=pop_ql_color)\n",
    "vanilla_ql_line, = plt.plot(np.linspace(0, 1, len(q_error)), q_error, linewidth=2, color=vanilla_ql_color, linestyle='--')\n",
    "\n",
    "plt.legend([vanilla_ql_line, pop_ql_line, nec_area], ['Vanilla Q-Learning', 'POP-QL (Our Method)', r'$\\kappa(\\mathbb{E}_{\\mu}[F(s, a)]))^{-1} \\approx 0$'], loc='upper right')\n",
    "plt.legend([vanilla_ql_line, pop_ql_line, nec_area], ['Vanilla Q-Learning', 'POP-QL (Our Method)', r'$\\lambda_{\\text{min}}(\\mathbb{E}_{\\mu}[F^\\pi(s, a)])) \\approx 0$'], loc='upper right')\n",
    "plt.savefig('frozen_lake_q_error.pdf', bbox_inches='tight')\n",
    "plt.close('all')"
   ],
   "metadata": {
    "collapsed": false,
    "is_executing": true,
    "ExecuteTime": {
     "start_time": "2023-09-28T00:30:23.196087799Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "outputs": [],
   "source": [
    "# Load results\n",
    "q_error = np.load(\"frozen_lake_eval_q_error.npy\")\n",
    "min_eigs = np.load(\"frozen_lake_eval_min_eigs.npy\")\n",
    "pop_q_error = np.load(\"frozen_lake_eval_pop_q_error.npy\")\n",
    "\n",
    "# Plot the results\n",
    "\n",
    "# Set size of figure\n",
    "plt.figure(figsize=(6, 4))\n",
    "\n",
    "# Set font size\n",
    "plt.rcParams.update({'font.size': 16})\n",
    "\n",
    "# Set to log scale\n",
    "plt.yscale('log')\n",
    "\n",
    "# Set axis limits\n",
    "plt.xlim(0, 1)\n",
    "# pad_percent = 0.05\n",
    "# y_max, y_min = 15, 0\n",
    "# pad = pad_percent * (y_max - y_min)\n",
    "# y_max += pad\n",
    "# y_min -= pad\n",
    "y_max, y_min = 1e3, 1e-2\n",
    "plt.ylim(y_min, y_max)\n",
    "\n",
    "# Set number of ticks\n",
    "plt.xticks(np.linspace(0, 1, 11))\n",
    "# plt.yticks(np.logspace(-4, 4, 5))\n",
    "\n",
    "# Set axis labels# Label x ticks\n",
    "plt.xticks([0, 0.5, 1.0], [r'\\quad \\quad \\textcolor{offpolicycolor}{100\\% Off-Policy}', r'\\textcolor{offpolicycolor}{50\\%}/\\textcolor{onpolicycolor}{50\\%}', r'\\textcolor{onpolicycolor}{100\\% On-Policy} \\quad \\quad'])\n",
    "plt.ylabel('Approximation Error')\n",
    "\n",
    "# Lighten the color of the axes\n",
    "ax = plt.gca()\n",
    "ax.spines['top'].set_color('lightgray')\n",
    "ax.spines['bottom'].set_color('lightgray')\n",
    "ax.spines['left'].set_color('lightgray')\n",
    "ax.spines['right'].set_color('lightgray')\n",
    "# ax.tick_params(axis='x', colors='gray')\n",
    "# ax.tick_params(axis='y', colors='gray')\n",
    "\n",
    "# Remove x ticks\n",
    "ax.xaxis.set_ticks_position('none')\n",
    "\n",
    "# Shade the non-expansive region\n",
    "p_values = np.linspace(0, 1, len(min_eigs))\n",
    "non_expansive = np.array(min_eigs) >= -0.005\n",
    "min_p_non_expansive = p_values[non_expansive].min()\n",
    "nec_area = plt.fill_between([min_p_non_expansive, 1], [y_min, y_min], [y_max, y_max], color=nec_color, alpha=0.2)\n",
    "nec_threshold_line = plt.plot([min_p_non_expansive, min_p_non_expansive], [y_min, y_max], color=nec_color, alpha=0.4, linestyle='-', linewidth=2)\n",
    "\n",
    "# Plot the results\n",
    "pop_ql_line, = plt.plot(np.linspace(0, 1, len(pop_q_error)), pop_q_error, linewidth=2, color=pop_ql_color)\n",
    "vanilla_ql_line, = plt.plot(np.linspace(0, 1, len(q_error)), q_error, linewidth=2, color=vanilla_ql_color, linestyle='--')\n",
    "\n",
    "plt.legend([vanilla_ql_line, pop_ql_line, nec_area], ['Vanilla Q-Learning', 'POP-QL (Our Method)', r'$\\kappa(\\mathbb{E}_{\\mu}[F(s, a)])) < 0.005$'], loc='upper right')\n",
    "plt.savefig('frozen_lake_q_error_logscale.pdf', bbox_inches='tight')\n",
    "plt.close('all')"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-09-15T20:55:09.034200230Z",
     "start_time": "2023-09-15T20:55:07.111781837Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "outputs": [],
   "source": [
    "# Load results\n",
    "q_error = np.load(\"frozen_lake_eval_q_error.npy\")\n",
    "min_eigs = np.load(\"frozen_lake_eval_min_eigs.npy\")\n",
    "pop_q_error = np.load(\"frozen_lake_eval_pop_q_error.npy\")\n",
    "\n",
    "# Plot the results\n",
    "\n",
    "# Set size of figure\n",
    "fig, ax = plt.subplots(figsize=(6, 4))\n",
    "ax2 = ax.twinx()\n",
    "\n",
    "# Set font size\n",
    "plt.rcParams.update({'font.size': 16})\n",
    "\n",
    "# Set to log scale\n",
    "ax.set_yscale('log')\n",
    "# ax2.set_yscale('log')\n",
    "\n",
    "# Set axis limits\n",
    "ax.set_xlim(0, 1)\n",
    "# pad_percent = 0.05\n",
    "# y_max, y_min = 15, 0\n",
    "# pad = pad_percent * (y_max - y_min)\n",
    "# y_max += pad\n",
    "# y_min -= pad\n",
    "y_max, y_min = 2e3, 2e-2\n",
    "ax.set_ylim(y_min, y_max)\n",
    "\n",
    "y_max2, y_min2 = -0.01, 0.16\n",
    "ax2.set_ylim(y_max2, y_min2)\n",
    "# ax2.set_yticks(np.logspace(3, 1, 3))\n",
    "y_ticks2 = np.linspace(0, 0.15, 4)\n",
    "ax2.set_yticks(y_ticks2, [r'$0$'] + [f'${y:.2f}$' for y in y_ticks2[1:]], color=min_eig_color)\n",
    "\n",
    "# Set number of ticks\n",
    "# ax.yticks(np.logspace(-4, 4, 5))\n",
    "\n",
    "# Set axis labels# Label x ticks\n",
    "ax.set_xticks([0, 0.5, 1.0], [r'\\quad \\quad \\textcolor{offpolicycolor}{100\\% Off-Policy}', r'\\textcolor{offpolicycolor}{50\\%}/\\textcolor{onpolicycolor}{50\\%}', r'\\textcolor{onpolicycolor}{100\\% On-Policy} \\quad \\quad'])\n",
    "ax.set_ylabel('Approximation Error', color='black')\n",
    "ax2.set_ylabel(r'$\\lambda_{\\min}$', color=min_eig_color)\n",
    "\n",
    "# Lighten the color of the axes\n",
    "ax.spines['top'].set_color('lightgray')\n",
    "ax.spines['bottom'].set_color('lightgray')\n",
    "ax.spines['left'].set_color('lightgray')\n",
    "ax.spines['right'].set_color('lightgray')\n",
    "# ax.tick_params(axis='x', colors='gray')\n",
    "# ax.tick_params(axis='y', colors='gray')\n",
    "\n",
    "# Plot the minimum eigenvalues\n",
    "min_eig_line, = ax2.plot(np.linspace(0, 1, len(min_eigs)), -min_eigs, color='red', linestyle='-', linewidth=2)\n",
    "\n",
    "# Plot the results\n",
    "pop_ql_line, = ax.plot(np.linspace(0, 1, len(pop_q_error)), pop_q_error, linewidth=2, color=pop_ql_color)\n",
    "vanilla_ql_line, = ax.plot(np.linspace(0, 1, len(q_error)), q_error, linewidth=2, color=vanilla_ql_color, linestyle='--')\n",
    "\n",
    "ax2.legend([vanilla_ql_line, pop_ql_line, min_eig_line], ['Vanilla Q-Learning', 'POP-QL (Our Method)', r'$\\lambda_{\\min}(\\mathbb{E}_{(s, a) \\sim \\mu}[F(s, a)])$'], loc='upper right')\n",
    "fig.savefig('frozen_lake_q_error_min_eig.pdf', bbox_inches=matplotlib.transforms.Bbox([[-0.3, 0.0], [6.5, 4.0]]))\n",
    "plt.close('all')"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-08-16T19:51:28.281441721Z",
     "start_time": "2023-08-16T19:51:27.240281031Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "outputs": [
    {
     "data": {
      "text/plain": "  0%|          | 0/100000 [00:00<?, ?it/s]",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "acbc78bbb98146249ef8f6f2c7057481"
      }
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "q_off_policy = compute_pop_q_linear(mu_off_policy, P_eval, Phi)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-08-16T19:44:39.037476159Z",
     "start_time": "2023-08-16T19:44:08.058208381Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "outputs": [],
   "source": [
    "# Plot densities\n",
    "\n",
    "# Set size of figure\n",
    "fig, axes = plt.subplots(1, 3, figsize=(4*3 + 1, 4), sharey=True)\n",
    "# fig.suptitle('Iterative TD Traces', fontsize=16)\n",
    "\n",
    "# Set font size\n",
    "plt.rcParams.update({'font.size': 16})\n",
    "\n",
    "# Load the image\n",
    "image = plt.imread(\"frozen_lake.png\")\n",
    "\n",
    "vmax = mu_on_policy.max()\n",
    "\n",
    "def plot_part(ax, data, title, plot_cbar=False):\n",
    "    ax.imshow(image, extent=[-0.02, env.wd, -0.02, env.ht], alpha=1, interpolation='nearest')\n",
    "    ax.imshow(data.reshape((env.wd, env.ht, -1)).sum(-1), vmin=0, vmax=vmax, cmap=\"Reds\", alpha=0.9, extent=[0.02, env.wd-0.02, 0.00, env.ht-0.00], interpolation='nearest')\n",
    "    for i in range(env.wd):\n",
    "        for j in range(env.ht):\n",
    "            ax.text(j+0.5, env.wd-i-0.5, f\"{data[i, j]*100:.0f}%\",\n",
    "                    ha=\"center\", va=\"center\", color=\"black\" if data[i, j] <= vmax * 0.45 else \"white\", fontsize=12\n",
    "            )\n",
    "    ax.title.set_text(title)\n",
    "    \n",
    "    # Gridlines and axis color\n",
    "    grid_color = 'black'\n",
    "    ax.grid(color=grid_color, linestyle='-', linewidth=0.5)\n",
    "    ax.set_axisbelow(True)\n",
    "    ax.spines['bottom'].set_color(grid_color)\n",
    "    ax.spines['top'].set_color(grid_color)\n",
    "    ax.spines['right'].set_color(grid_color)\n",
    "    ax.spines['left'].set_color(grid_color)\n",
    "    ax.spines['bottom'].set_linewidth(0.5)\n",
    "    ax.spines['top'].set_linewidth(0.5)\n",
    "    ax.spines['right'].set_linewidth(0.5)\n",
    "    ax.spines['left'].set_linewidth(0.5)\n",
    "    \n",
    "    # Remove ticks\n",
    "    ax.set_xticklabels([])\n",
    "    ax.set_yticklabels([])\n",
    "    ax.xaxis.set_ticks_position('none')\n",
    "    ax.yaxis.set_ticks_position('none')\n",
    "\n",
    "plot_part(axes[0], mu_off_policy.reshape((env.wd, env.ht, -1)).sum(-1), \"Off-Policy\")\n",
    "plot_part(axes[1], q_off_policy.reshape((env.wd, env.ht, -1)).sum(-1), \"POP-QL\")\n",
    "plot_part(axes[2], mu_on_policy.reshape((env.wd, env.ht, -1)).sum(-1), \"On-Policy\", plot_cbar=True)\n",
    "\n",
    "# # Create colorbar\n",
    "# cbar = fig.colorbar(axes[2].images[1], ax=axes, shrink=0.76)\n",
    "\n",
    "plt.savefig('frozen_lake_densities.pdf')\n",
    "plt.close('all')"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-08-16T19:44:40.300098427Z",
     "start_time": "2023-08-16T19:44:39.038048258Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [],
   "metadata": {
    "collapsed": false
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
