{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "toy_pw.ipynb",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "widgets": {
      "application/vnd.jupyter.widget-state+json": {
        "6ab874a911a14859a62975d12faa4a73": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "state": {
            "_view_name": "HBoxView",
            "_dom_classes": [],
            "_model_name": "HBoxModel",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "box_style": "",
            "layout": "IPY_MODEL_3e17d01a470b446187ff13ef3bdeb976",
            "_model_module": "@jupyter-widgets/controls",
            "children": [
              "IPY_MODEL_a3538dcf32aa44cca1e888351db37021",
              "IPY_MODEL_6d6a956a14e94a859435cb5298cb5950"
            ]
          }
        },
        "3e17d01a470b446187ff13ef3bdeb976": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "a3538dcf32aa44cca1e888351db37021": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "state": {
            "_view_name": "ProgressView",
            "style": "IPY_MODEL_37850fd921fa4dab9785cab2b3e82117",
            "_dom_classes": [],
            "description": "  0%",
            "_model_name": "FloatProgressModel",
            "bar_style": "",
            "max": 500,
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": 0,
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "orientation": "horizontal",
            "min": 0,
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_d759a16893b8403b9f5ed14cad0cd278"
          }
        },
        "6d6a956a14e94a859435cb5298cb5950": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "state": {
            "_view_name": "HTMLView",
            "style": "IPY_MODEL_7b9e789dcf7e41ee87fcba3e427ed803",
            "_dom_classes": [],
            "description": "",
            "_model_name": "HTMLModel",
            "placeholder": "​",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": " 0/500 [00:00&lt;?, ?it/s]",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_bf1a5b53bea14396a1411575b8fe785c"
          }
        },
        "37850fd921fa4dab9785cab2b3e82117": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "ProgressStyleModel",
            "description_width": "initial",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "bar_color": null,
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "d759a16893b8403b9f5ed14cad0cd278": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "7b9e789dcf7e41ee87fcba3e427ed803": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "DescriptionStyleModel",
            "description_width": "",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "bf1a5b53bea14396a1411575b8fe785c": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        }
      }
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 49,
          "referenced_widgets": [
            "6ab874a911a14859a62975d12faa4a73",
            "3e17d01a470b446187ff13ef3bdeb976",
            "a3538dcf32aa44cca1e888351db37021",
            "6d6a956a14e94a859435cb5298cb5950",
            "37850fd921fa4dab9785cab2b3e82117",
            "d759a16893b8403b9f5ed14cad0cd278",
            "7b9e789dcf7e41ee87fcba3e427ed803",
            "bf1a5b53bea14396a1411575b8fe785c"
          ]
        },
        "id": "dPWuocIbmnxp",
        "outputId": "16ee8f59-d3dd-4a0f-e591-416fe0121659"
      },
      "source": [
        "import torch\n",
        "import numpy as np\n",
        "import matplotlib.pyplot as plt\n",
        "from torch import nn\n",
        "import pandas as pd\n",
        "import seaborn as sns\n",
        "from torch.nn.utils import spectral_norm\n",
        "from collections import defaultdict\n",
        "from tqdm.notebook import tqdm\n",
        "\n",
        "class GMMSampler:\n",
        "  def __init__(self, n_components, dim):\n",
        "    self.categorical = np.random.dirichlet(np.ones((n_components,)))\n",
        "    self.means = np.random.normal(scale=1, size=(n_components, dim))\n",
        "    tmp = np.random.normal(scale=0.1, size=(n_components, dim, dim))\n",
        "    self.var = np.einsum('...ij,...ik->...jk', tmp, tmp)\n",
        "    self.indices = np.arange(n_components)\n",
        "\n",
        "  def sample(self, n_samples):\n",
        "    events = np.random.choice(self.indices, p=self.categorical, size=n_samples)\n",
        "    unique, counts = np.unique(events, return_counts=True)\n",
        "    samples = []\n",
        "    for i, n_events in zip(unique, counts):\n",
        "      if n_events == 0:\n",
        "        continue\n",
        "      samples.append(np.random.multivariate_normal(self.means[i], self.var[i], size=n_events))\n",
        "    return np.concatenate(samples)\n",
        "\n",
        "class LinearModel(nn.Module):\n",
        "  def __init__(self, input_dim, hidden_dim=50):\n",
        "    super().__init__()\n",
        "    self.input_dim = input_dim\n",
        "    self.fc1 = nn.Linear(input_dim, hidden_dim)\n",
        "    self.fc2 = nn.Linear(hidden_dim, 1)\n",
        "\n",
        "  def forward(self, x):\n",
        "    x = self.fc1(x)\n",
        "    return 10 * self.fc2(x)\n",
        "\n",
        "\n",
        "class SinusoidalModel(nn.Module):\n",
        "  def __init__(self, input_dim, hidden_dim=50):\n",
        "    super().__init__()\n",
        "    self.input_dim = input_dim\n",
        "    self.fc1 = nn.Linear(input_dim, hidden_dim)\n",
        "    self.fc2 = nn.Linear(hidden_dim, 1)\n",
        "\n",
        "  def forward(self, x):\n",
        "    x = torch.sin(self.fc1(x))\n",
        "    return 10 * self.fc2(x)\n",
        "\n",
        "\n",
        "class ReluModel(nn.Module):\n",
        "  def __init__(self, input_dim, hidden_dim=50):\n",
        "    super().__init__()\n",
        "    self.input_dim = input_dim\n",
        "    self.fc1 = nn.Linear(input_dim, hidden_dim)\n",
        "    self.fc2 = nn.Linear(hidden_dim, 1)\n",
        "\n",
        "  def forward(self, x):\n",
        "    x = torch.nn.functional.relu(self.fc1(x))\n",
        "    return 10 * self.fc2(x)\n",
        "\n",
        "\n",
        "class SigmoidModel(nn.Module):\n",
        "  def __init__(self, input_dim, hidden_dim=50):\n",
        "    super().__init__()\n",
        "    self.input_dim = input_dim\n",
        "    self.fc1 = nn.Linear(input_dim, hidden_dim)\n",
        "    self.fc2 = nn.Linear(hidden_dim, 1)\n",
        "\n",
        "  def forward(self, x):\n",
        "    x = torch.sigmoid(self.fc1(x))\n",
        "    return 10 * self.fc2(x)\n",
        "\n",
        "\n",
        "class MRP:\n",
        "  def __init__(self, sampler, model, n_v0_samples=5, s0_var=0.01):\n",
        "    self.sampler = sampler\n",
        "    self.model = model\n",
        "    self.v0 = model(torch.from_numpy(sampler.sample(n_v0_samples)).float()).mean().item()\n",
        "    self.s0_var = s0_var\n",
        "\n",
        "  def sample(self, s=None, n_samples=1):\n",
        "    if s is None:\n",
        "      return np.random.normal(0, 0.01, size=(n_samples,))\n",
        "    return s + self.sampler.sample(n_samples)\n",
        "\n",
        "  def evaluate_s0(self, s0):\n",
        "    return self.v0 + s0\n",
        "\n",
        "  def evaluate_s1(self, s1):\n",
        "    return self.model(torch.from_numpy(s1).float()).detach().numpy()\n",
        "\n",
        "\n",
        "class PW:\n",
        "  def __init__(self, exp, factor=1):\n",
        "    self.exp = exp\n",
        "    self.n_visits = defaultdict(int)\n",
        "    self.children = defaultdict(list)\n",
        "    self.factor = factor\n",
        "\n",
        "  def should_expand(self, state):\n",
        "    return len(self.children[state]) <= self.factor * self.n_visits[state] ** self.exp\n",
        "\n",
        "  def sample_from(self, state, mrp):\n",
        "    if state is None:\n",
        "      self.n_visits[state] += 1\n",
        "    if self.should_expand(state):\n",
        "      if state is None:\n",
        "        next_state = mrp.sample().flatten()\n",
        "      else:\n",
        "        next_state = mrp.sample(np.array(state[1])).flatten()\n",
        "      next_state = (state, tuple(next_state))\n",
        "      self.children[state].append(next_state)\n",
        "      self.n_visits[next_state] += 1\n",
        "      return next_state, True\n",
        "    else:\n",
        "      next_state = self.children[state][np.random.choice(np.arange(len(self.children[state])))]\n",
        "      self.n_visits[next_state] += 1\n",
        "      return next_state, False\n",
        "\n",
        "\n",
        "class AR:\n",
        "  def __init__(self, exp, factor):\n",
        "    self.exp = exp\n",
        "    self.factor = factor\n",
        "    self.n_visits = defaultdict(int)\n",
        "    self.children = defaultdict(list)\n",
        "\n",
        "  def sample_from(self, state, sampler):\n",
        "    if state is None:\n",
        "      self.n_visits[state] += 1\n",
        "      new_state = mrp.sample().flatten()\n",
        "    else:\n",
        "      new_state = mrp.sample(np.array(state[1])).flatten()\n",
        "    distances = {child: np.sqrt(((np.array(child[1]) - new_state) ** 2).sum()) for child in self.children[state]}\n",
        "    if len(distances) == 0:\n",
        "      next_state = (state, tuple(new_state))\n",
        "      self.children[state].append(next_state)\n",
        "      self.n_visits[next_state] += 1\n",
        "      return next_state, True\n",
        "    nn = min(distances, key=distances.get)\n",
        "    dist = distances[nn]\n",
        "    if dist < self.factor * (self.n_visits[nn] ** (- self.exp)):\n",
        "      next_state = nn\n",
        "      self.n_visits[next_state] += 1\n",
        "      return next_state, False\n",
        "    next_state = (state, tuple(new_state))\n",
        "    self.children[state].append(next_state)\n",
        "    self.n_visits[next_state] += 1\n",
        "    return next_state, True\n",
        "    \n",
        "\n",
        "def evaluate(method, mrp, n_rollouts):\n",
        "  vals = []\n",
        "  for t in range(n_rollouts):\n",
        "    s_init = None\n",
        "    s0, new_state = method.sample_from(s_init, mrp)\n",
        "    if new_state:\n",
        "      vals.append(mrp.evaluate_s0(np.array(s0[1])))\n",
        "      continue\n",
        "    s1, new_state = method.sample_from(s0, mrp)\n",
        "    vals.append(mrp.evaluate_s1(np.array(s1[1])))\n",
        "  return np.mean(vals)\n",
        "\n",
        "def evaluate_empirical(sampler, model, supp_card):\n",
        "  points = sampler.sample(supp_card)\n",
        "  with torch.no_grad():\n",
        "    vals = model(torch.from_numpy(points).float()).squeeze().numpy()\n",
        "  return vals.mean()\n",
        "  \n",
        "n_components = 10\n",
        "dim = 30\n",
        "ground_truth_n_samples = 100_000\n",
        "n_trials = 500\n",
        "num_rollouts = [10, 20, 50, 100, 200, 500]\n",
        "pw_factors = [0.5, 1, 2, 5, 10, 20, 50, 100]\n",
        "pw_exponents = [0.1, 0.3, 0.5, 0.7, 0.9]\n",
        "ar_factor = 1\n",
        "ar_exponent = 0.1\n",
        "\n",
        "vfs = {'Relu': ReluModel}\n",
        "\n",
        "df = {'method': [], 'value_error': [], 'num_rollouts': [], 'vf': [], 'level_one_children': [], 'level_two_children': [], 'k': [], 'exp': []}\n",
        "for trial in tqdm(range(n_trials)):\n",
        "  sampler = GMMSampler(n_components, dim)\n",
        "  for vf_name, vf_class in vfs.items():\n",
        "    model = vf_class(dim)\n",
        "    mrp = MRP(sampler, model)\n",
        "    ground_truth = evaluate_empirical(sampler, model, ground_truth_n_samples)\n",
        "    for nr in num_rollouts:\n",
        "      for e in pw_exponents:\n",
        "        for k in pw_factors:\n",
        "          df['method'].append(f'PW{e}')\n",
        "          df['exp'].append(e)\n",
        "          df['k'].append(k)\n",
        "          df['num_rollouts'].append(nr)\n",
        "          alg = PW(e, k)\n",
        "          df['value_error'].append(abs(evaluate(alg, mrp, nr) - ground_truth))\n",
        "          df['level_one_children'].append(len(alg.children[None]))\n",
        "          df['level_two_children'].append(sum([len(val) for key, val in alg.children.items() if key is not None]))\n",
        "          df['vf'].append(vf_name)\n",
        "      df['method'].append('AR')\n",
        "      df['num_rollouts'].append(nr)\n",
        "      alg = AR(ar_exponent, ar_factor)\n",
        "      df['exp'].append(ar_exponent)\n",
        "      df['k'].append(ar_factor)\n",
        "      df['value_error'].append(abs(evaluate(alg, mrp, nr) - ground_truth))\n",
        "      df['level_one_children'].append(len(alg.children[None]))\n",
        "      df['level_two_children'].append(sum([len(val) for key, val in alg.children.items() if key is not None]))\n",
        "      df['vf'].append(vf_name)\n",
        "  df_ = pd.DataFrame(df)\n",
        "  plt.close()\n",
        "  df_.to_pickle('df.pkl')\n",
        "  df_['value_error'] = df_['value_error'].astype(float)\n",
        "  ax = sns.relplot(data=df_, x='num_rollouts', y='value_error', hue='method', kind='line', col='k', facet_kws={'sharey': True, 'sharex': True})\n",
        "  ax.set(yscale=\"log\")\n",
        "  plt.savefig('mrp.png')\n",
        "  plt.savefig('mrp.pdf')\n",
        "  plt.close()\n",
        "  ax = sns.relplot(data=df_, x='num_rollouts', y='level_one_children', hue='method', kind='line', col='k', facet_kws={'sharey': True, 'sharex': True})\n",
        "  plt.savefig('level_one.png')\n",
        "  plt.savefig('level_one.pdf')\n",
        "  plt.close()\n",
        "  sns.relplot(data=df_, x='num_rollouts', y='level_two_children', hue='method', kind='line', col='k', facet_kws={'sharey': True, 'sharex': True})\n",
        "  plt.savefig('level_two.png')\n",
        "  plt.savefig('level_two.pdf')"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "6ab874a911a14859a62975d12faa4a73",
              "version_minor": 0,
              "version_major": 2
            },
            "text/plain": [
              "HBox(children=(FloatProgress(value=0.0, max=500.0), HTML(value='')))"
            ]
          },
          "metadata": {
            "tags": []
          }
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "AnoIFxfTOC3K"
      },
      "source": [
        ""
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}