{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "Supplementary_LCB_Bonus.ipynb",
      "provenance": [
        {
          "file_id": "1BSHb_I-xnfsLqCJp0mM8TvBrw6SfEN86",
          "timestamp": 1643367807343
        },
        {
          "file_id": "1NUiaKBD2cbyXIHR_-sYq71B1LlEx6-G_",
          "timestamp": 1643242920367
        }
      ],
      "last_runtime": {
        "build_target": "//learning/deepmind/public/tools/ml_python:ml_notebook",
        "kind": "private"
      }
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "source": [
        "'''\n",
        "Simulate toy MDP, offline dataset, and policy to evaluate, for discovering\n",
        "examples of optimistic LCB bonus.\n",
        "'''\n",
        "import numpy as np\n",
        "from scipy import linalg\n",
        "\n",
        "\n",
        "dist = np.random.standard_normal\n",
        "d_s = d_a = 30\n",
        "scale = 1.\n",
        "num_episodes = 5\n",
        "steps_per_episode = 5\n",
        "gamma = 0.5\n",
        "EPS = 1e-8\n",
        "\n",
        "\n",
        "def check_mdp_result(seed):\n",
        "  # Set the seed\n",
        "  np.random.seed(seed)\n",
        "\n",
        "  # Create the data matrices\n",
        "  X = []\n",
        "  Xprime = []\n",
        "  for _ in range(num_episodes):\n",
        "    s = dist((d_s)) / scale\n",
        "    for _ in range(steps_per_episode):\n",
        "      a = dist((d_a)) / scale\n",
        "      sprime = dist((d_s)) / scale\n",
        "      aprime = a\n",
        "      X.append(np.concatenate([s, a]))\n",
        "      Xprime.append(np.concatenate([sprime, aprime]))\n",
        "      s = sprime\n",
        "  X = np.array(X)\n",
        "  Xprime = np.array(Xprime)\n",
        "  R = dist((X.shape[0], 1))\n",
        "\n",
        "  A = X @ X.T + EPS * np.eye(X.shape[0]) # to avoid numerical issues with matrix inversion\n",
        "  C = Xprime @ X.T @ np.linalg.inv(A)\n",
        "\n",
        "  # Find largest singular value of gamma * C\n",
        "  singular_value = np.linalg.svd(C)[1][0] * gamma\n",
        "\n",
        "  # Compute stddev for LCB, i.e., sqrt(E[(Q0(X') - C * Q0(X))^2]).\n",
        "  # We assume Q-functions are linear function approximators\n",
        "  # We assume the initial weight distribution is a spherical Gaussian with dimension d_s + d_a.\n",
        "  # Back of the envelope calculations using this assumption and our derived equations leads to,\n",
        "  Xdiff = Xprime - C @ X\n",
        "  Xdiff_std = np.sqrt(np.sum(Xdiff ** 2, -1))\n",
        "\n",
        "  # if singular_value < 1.:\n",
        "  # Compute mean & penalty for LCB.\n",
        "  t = 1000\n",
        "\n",
        "  cum_C = np.eye(X.shape[0])\n",
        "  for _ in range(t):\n",
        "    cum_C = np.eye(X.shape[0]) + gamma * C @ cum_C\n",
        "\n",
        "  lcb_mean = cum_C @ C @ R\n",
        "  lcb_penalty = cum_C @ Xdiff_std\n",
        "\n",
        "  # If min is < 0, that means there exists a (s', \\pi(s')) where the LCB penalty is actually a bonus.\n",
        "  min_lcb = np.min(lcb_penalty)\n",
        "  max_lcb = np.max(lcb_penalty)\n",
        "\n",
        "  # print(f'{seed}: {min_lcb}, {singular_value}')\n",
        "  return (min_lcb < 0.), (singular_value < 1.)\n",
        "\n",
        "num_good_singular_values = 0\n",
        "bad_lcbs = 0\n",
        "num_good_examples = 0\n",
        "for i in range(1000):\n",
        "  bad_lcb, good_sv = check_mdp_result(i)\n",
        "  bad_lcbs += int(bad_lcb)\n",
        "  num_good_singular_values += int(good_sv)\n",
        "  num_good_examples += int(bad_lcb and good_sv)\n",
        "\n",
        "print(f'Number of examples found: {num_good_examples}')\n"
      ],
      "metadata": {
        "id": "VoynCxIdsyAW",
        "executionInfo": {
          "status": "ok",
          "timestamp": 1643368914511,
          "user_tz": 480,
          "elapsed": 14814,
          "user": {
            "displayName": "Kamyar Ghasemipour",
            "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GhGM_ZPzw9GLH_fS8zuVmKXjoOReStH2VWVETyL=s64",
            "userId": "14496833910239793873"
          }
        },
        "outputId": "97581fc2-1716-482a-961d-962e7d924acb"
      },
      "execution_count": 3,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Number of examples found: 221\n"
          ]
        }
      ]
    }
  ]
}