{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VxxB73fm2z7i"
      },
      "source": [
        "# Setup"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "sVGPqbuAm4qJ",
        "outputId": "77e9c4e0-5422-4f44-f33e-d565f45dab09"
      },
      "outputs": [],
      "source": [
        "# !command -v ffmpeg >/dev/null || (apt update && apt install -y ffmpeg)\n",
        "# !pip install -q mediapy\n",
        "\n",
        "# # !pip install --upgrade tensorflow\n",
        "# !pip install tensorflow==2.16\n",
        "# # !pip uninstall tensorflow-probability jax jaxlib -y\n",
        "# # !pip install --upgrade tensorflow-probability[jax]\n",
        "# # !pip install --upgrade jax jaxlib\n",
        "\n",
        "# !pip install dm_env\n",
        "\n",
        "# # !pip install --upgrade tensorflow\n",
        "\n",
        "# !pip install git+https://github.com/deepmind/dm-haiku"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 347
        },
        "id": "HPfbzUlRtTjE",
        "outputId": "23c19ccd-3e44-45e4-b16d-dca62f4dcf45"
      },
      "outputs": [
        {
          "ename": "AttributeError",
          "evalue": "module 'tensorflow_probability' has no attribute 'substrates'",
          "output_type": "error",
          "traceback": [
            "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
            "\u001b[0;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
            "Cell \u001b[0;32mIn[2], line 29\u001b[0m\n\u001b[1;32m     25\u001b[0m \u001b[38;5;66;03m# import tensorflow as tf\u001b[39;00m\n\u001b[1;32m     26\u001b[0m \u001b[38;5;66;03m# import tensorflow.data as tf_data\u001b[39;00m\n\u001b[1;32m     27\u001b[0m \u001b[38;5;66;03m# from tensorflow.io import gfile\u001b[39;00m\n\u001b[1;32m     28\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mtensorflow_probability\u001b[39;00m\n\u001b[0;32m---> 29\u001b[0m tfp \u001b[38;5;241m=\u001b[39m \u001b[43mtensorflow_probability\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msubstrates\u001b[49m\u001b[38;5;241m.\u001b[39mjax\n\u001b[1;32m     30\u001b[0m \u001b[38;5;66;03m# tfp = tensorflow_probability\u001b[39;00m\n\u001b[1;32m     31\u001b[0m tfd \u001b[38;5;241m=\u001b[39m tfp\u001b[38;5;241m.\u001b[39mdistributions\n",
            "\u001b[0;31mAttributeError\u001b[0m: module 'tensorflow_probability' has no attribute 'substrates'"
          ]
        }
      ],
      "source": [
        "import os\n",
        "os.environ['JAX_PLATFORMS'] = ''\n",
        "\n",
        "import io\n",
        "import pickle\n",
        "from typing import Optional, Any, NamedTuple, Callable, Mapping, Sequence, Dict, Tuple\n",
        "import dataclasses\n",
        "import enum\n",
        "import time\n",
        "from collections import defaultdict\n",
        "from pprint import pprint\n",
        "\n",
        "import numpy as np\n",
        "import matplotlib.pyplot as plt\n",
        "from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas\n",
        "from PIL import Image\n",
        "import mediapy as media\n",
        "from IPython.display import clear_output\n",
        "\n",
        "import jax\n",
        "import jax.numpy as jnp\n",
        "import optax\n",
        "import haiku as hk\n",
        "\n",
        "# import tensorflow as tf\n",
        "# import tensorflow.data as tf_data\n",
        "# from tensorflow.io import gfile\n",
        "import tensorflow_probability\n",
        "tfp = tensorflow_probability.substrates.jax\n",
        "# tfp = tensorflow_probability\n",
        "tfd = tfp.distributions\n",
        "\n",
        "\n",
        "import dm_env\n",
        "from dm_env import specs"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "q8RKBq1t_FxZ"
      },
      "outputs": [],
      "source": [
        "tensorflow_probability.substrates"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "z48CKYN8Zb1k"
      },
      "source": [
        "# Create the pointmass environment"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zEhJawAQZec3"
      },
      "outputs": [],
      "source": [
        "BOUNDS_X = np.array([-1., 1.], dtype=np.float32)\n",
        "BOUNDS_Y = np.array([-1., 1.], dtype=np.float32)\n",
        "\n",
        "DPI = 100\n",
        "RENDER_HEIGHT_INCHES = 5\n",
        "\n",
        "class Point2D(dm_env.Environment):\n",
        "  def __init__(self):\n",
        "    self._cur_pos = np.zeros(2, dtype=np.float32)\n",
        "    self._goal_pos = np.zeros(2, dtype=np.float32)\n",
        "    self._cur_vel = np.zeros(2, dtype=np.float32)\n",
        "    self._cur_episode_traj = []\n",
        "    self._physics_substeps = 10\n",
        "    self._success_radius = 0.15\n",
        "\n",
        "  def sample_goal(self):\n",
        "    border_x = (BOUNDS_X[1] - BOUNDS_X[0]) * 0.05\n",
        "    border_y = (BOUNDS_Y[1] - BOUNDS_Y[0]) * 0.05\n",
        "    goal_x = np.random.uniform(\n",
        "        BOUNDS_X[0] + border_x, BOUNDS_X[1] - border_x)\n",
        "    goal_y = np.random.uniform(\n",
        "        BOUNDS_Y[0] + border_y, BOUNDS_Y[1] - border_y)\n",
        "    return np.array([goal_x, goal_y], dtype=np.float32)\n",
        "\n",
        "  def set_goal(self, goal_pos):\n",
        "    self._goal_pos = goal_pos\n",
        "\n",
        "  def reset(self):\n",
        "    self._goal_pos = self.sample_goal()\n",
        "    cur_x = np.random.uniform(BOUNDS_X[0], BOUNDS_X[1])\n",
        "    cur_y = np.random.uniform(BOUNDS_Y[0], BOUNDS_Y[1])\n",
        "    self._cur_pos = np.array([cur_x, cur_y], dtype=np.float32)\n",
        "    cur_pos_copy = self._cur_pos.copy()\n",
        "\n",
        "    self._cur_vel = np.zeros(2, dtype=np.float32)\n",
        "    cur_vel_copy = self._cur_vel.copy()\n",
        "\n",
        "    obs = {\n",
        "        'cur_pos': cur_pos_copy,\n",
        "        'cur_vel': cur_vel_copy,\n",
        "        'goal_pos': self._goal_pos.copy()}\n",
        "    ts = dm_env.TimeStep(\n",
        "        step_type=dm_env.StepType.FIRST,\n",
        "        reward=None,\n",
        "        discount=None,\n",
        "        observation=obs,)\n",
        "\n",
        "    self._cur_episode_traj = [cur_pos_copy]\n",
        "    return ts\n",
        "\n",
        "  def step(self, action):\n",
        "    for i in range(self._physics_substeps):\n",
        "      self._cur_vel += action\n",
        "      self._cur_pos += self._cur_vel\n",
        "\n",
        "    cur_pos_copy = self._cur_pos.copy()\n",
        "    cur_vel_copy = self._cur_vel.copy()\n",
        "    obs = {\n",
        "        'cur_pos': cur_pos_copy,\n",
        "        'cur_vel': cur_vel_copy,\n",
        "        'goal_pos': self._goal_pos.copy()}\n",
        "\n",
        "    if self.success():\n",
        "      step_type = dm_env.StepType.LAST\n",
        "    else:\n",
        "      step_type = dm_env.StepType.MID\n",
        "    ts = dm_env.TimeStep(\n",
        "        step_type=step_type,\n",
        "        reward=-1. * np.linalg.norm(self._cur_pos - self._goal_pos),\n",
        "        discount=1.,\n",
        "        observation=obs,)\n",
        "\n",
        "    self._cur_episode_traj.append(cur_pos_copy)\n",
        "    return ts\n",
        "\n",
        "  def success(self, waypoint: Optional[np.ndarray] = None):\n",
        "    if waypoint is not None:\n",
        "      goal_pos = waypoint\n",
        "    else:\n",
        "      goal_pos = self._goal_pos\n",
        "    return np.linalg.norm(self._cur_pos - goal_pos) < self._success_radius\n",
        "\n",
        "  def observation_spec(self):\n",
        "    return {\n",
        "        'cur_pos': specs.Array((2,), dtype=np.float32),\n",
        "        'cur_vel': specs.Array((2,), dtype=np.float32),\n",
        "        'goal_pos': specs.Array((2,), dtype=np.float32),}\n",
        "\n",
        "  def action_spec(self):\n",
        "    return specs.Array((2,), dtype=np.float32)\n",
        "\n",
        "  def render(self, title=''):\n",
        "    fig, ax = plt.subplots(\n",
        "        figsize=(RENDER_HEIGHT_INCHES, RENDER_HEIGHT_INCHES), dpi=DPI)\n",
        "    ax.set_xlim(BOUNDS_X[0], BOUNDS_X[1])\n",
        "    ax.set_ylim(BOUNDS_Y[0], BOUNDS_Y[1])\n",
        "    ax.set_aspect('equal')\n",
        "\n",
        "    points = np.array(self._cur_episode_traj)\n",
        "    ax.plot(points[:, 0], points[:, 1], marker='.', color='blue')\n",
        "    ax.scatter(\n",
        "        self._goal_pos[0], self._goal_pos[1], marker='*', s=200, color='orange')\n",
        "    ax.scatter(\n",
        "        self._cur_pos[0], self._cur_pos[1], marker='o', s=100, color='red')\n",
        "\n",
        "    if title != '':\n",
        "      ax.set_title(title)\n",
        "\n",
        "    plt.tight_layout()\n",
        "\n",
        "    # Render the plot using FigureCanvasAgg\n",
        "    canvas = FigureCanvas(fig)\n",
        "    canvas.draw()\n",
        "\n",
        "    # Convert the rendered image to a numpy array\n",
        "    width, height = fig.get_size_inches() * fig.get_dpi()\n",
        "    image = np.frombuffer(canvas.tostring_rgb(), dtype='uint8')\n",
        "    image = image.reshape(int(height), int(width), 3)\n",
        "\n",
        "    plt.close(fig)\n",
        "    return image\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DtApnonezrGQ"
      },
      "source": [
        "# Create the PD controller for data generation"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "a82KfzE7zvDG"
      },
      "outputs": [],
      "source": [
        "def pd_controller(cur_pos, cur_vel, goal_pos):\n",
        "  Kp = 0.0002\n",
        "  Kd = 0.0125\n",
        "  act = Kp * (goal_pos - cur_pos) + Kd * (-1. * cur_vel)\n",
        "  return act"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VelqDWf40BgB"
      },
      "source": [
        "# Visualize Trajctories from the PD controller"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "B83AbBxwjbBM"
      },
      "outputs": [],
      "source": [
        "env = Point2D()\n",
        "\n",
        "imgs = []\n",
        "ts = env.reset()\n",
        "for eps in range(7):\n",
        "  while not env.success():\n",
        "    cur_pos = ts.observation['cur_pos']\n",
        "    cur_vel = ts.observation['cur_vel']\n",
        "    goal_pos = ts.observation['goal_pos']\n",
        "    act = pd_controller(cur_pos, cur_vel, goal_pos)\n",
        "    ts = env.step(act)\n",
        "    imgs.append(env.render())\n",
        "\n",
        "  env.set_goal(env.sample_goal())\n",
        "\n",
        "media.show_video(imgs, fps=10)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fhfQQeVvVZ1d"
      },
      "source": [
        "# Generate a Dataset\n",
        "In the dataset, in each episode, the pointmass will go to various random waypoints, and eventually go to the desired goal position for the episode."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Gw4QYwZQVZid"
      },
      "outputs": [],
      "source": [
        "NestedArray = Any\n",
        "\n",
        "num_episodes = 10000 #@param {type: \"number\"}\n",
        "num_waypoints_per_episode = 5 #@param {type: \"number\"}\n",
        "episode_len_discard_thresh = 10 #@param {type: \"number\"}\n",
        "\n",
        "# This dataset has the format (obs, act, time_to_success, next_obs)\n",
        "# It's inefficent to store next obs but makes life easier\n",
        "# Will save two versions: 1) trajectories kept separate,\n",
        "# 2) just tuples\n",
        "\n",
        "class DataTuple(NamedTuple):\n",
        "  observation: NestedArray\n",
        "  action: NestedArray\n",
        "  time_to_success: NestedArray\n",
        "  reward: NestedArray\n",
        "  discount: NestedArray\n",
        "  next_observation: NestedArray\n",
        "\n",
        "\n",
        "episodes = []\n",
        "while len(episodes) < num_episodes:\n",
        "  ep_num = len(episodes)\n",
        "  traj = []\n",
        "\n",
        "  ts = env.reset()\n",
        "  cur_obs = ts.observation\n",
        "  succ = env.success()\n",
        "\n",
        "  waypoint_idx = 0\n",
        "  if num_waypoints_per_episode == 0:\n",
        "    cur_waypoint = cur_obs['goal_pos']\n",
        "  else:\n",
        "    cur_waypoint = env.sample_goal()\n",
        "  waypoint_succ = env.success(waypoint=cur_waypoint)\n",
        "\n",
        "  while not succ:\n",
        "    if waypoint_succ:\n",
        "      waypoint_idx += 1\n",
        "      waypoint_idx = min(waypoint_idx, num_waypoints_per_episode)\n",
        "      if waypoint_idx == num_waypoints_per_episode:\n",
        "        cur_waypoint = cur_obs['goal_pos']\n",
        "      else:\n",
        "        cur_waypoint = env.sample_goal()\n",
        "\n",
        "    act = pd_controller(\n",
        "        cur_obs['cur_pos'], cur_obs['cur_vel'], cur_waypoint)\n",
        "    ts = env.step(act)\n",
        "    next_obs = ts.observation\n",
        "\n",
        "    data_tuple = DataTuple(\n",
        "        observation=cur_obs,\n",
        "        action=act,\n",
        "        time_to_success=0.,\n",
        "        reward=ts.reward if ts.reward is not None else 0.,\n",
        "        discount=1.,\n",
        "        next_observation=next_obs,)\n",
        "    traj.append(data_tuple)\n",
        "\n",
        "    cur_obs = next_obs\n",
        "    succ = env.success()\n",
        "    waypoint_succ = env.success(waypoint=cur_waypoint)\n",
        "\n",
        "  # add the last timestep\n",
        "  act = pd_controller(\n",
        "      cur_obs['cur_pos'], cur_obs['cur_vel'], cur_waypoint)\n",
        "  data_tuple = DataTuple(\n",
        "      observation=cur_obs,\n",
        "      action=act,\n",
        "      time_to_success=0.,\n",
        "      reward=ts.reward if ts.reward is not None else 0.,\n",
        "      discount=0.,\n",
        "      next_observation=cur_obs,)\n",
        "  traj.append(data_tuple)\n",
        "\n",
        "  # discard if episode is too short\n",
        "  traj_len = len(traj)\n",
        "  if traj_len < episode_len_discard_thresh:\n",
        "    continue\n",
        "\n",
        "  # stack the traj arrays\n",
        "  new_traj = jax.tree.map(\n",
        "      lambda *xs: np.stack(xs, dtype=np.float32), *traj)\n",
        "\n",
        "  # label traj arrays with time to success\n",
        "  new_traj = new_traj._replace(\n",
        "      time_to_success=np.arange(\n",
        "          traj_len - 1, -1, -1, dtype=np.float32))\n",
        "  episodes.append(new_traj)\n",
        "\n",
        "# concat all the trajs\n",
        "all_tuples = jax.tree.map(\n",
        "    lambda *xs: np.concatenate(xs, dtype=np.float32), *episodes)\n",
        "\n",
        "# print stats\n",
        "print('\\nEpisodes 0-2:')\n",
        "pprint(jax.tree.map(lambda x: (x.shape, x.dtype), episodes[:2]))\n",
        "\n",
        "print('\\nData Tuples:')\n",
        "pprint(jax.tree.map(lambda x: (x.shape, x.dtype), all_tuples))\n",
        "\n",
        "episode_lens = np.array(\n",
        "    list(map(lambda x: x.observation['cur_pos'].shape[0], episodes)))\n",
        "print('\\n')\n",
        "print(f'Num Episodes: {episode_lens.shape[0]}')\n",
        "print(f'Episode Lens: {np.mean(episode_lens)} +/- {np.std(episode_lens)}')\n",
        "print(f'Max Episode Len: {np.max(episode_lens)}')\n",
        "print(f'Min Episode Len: {np.min(episode_lens)}')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xfQfj0e24_te"
      },
      "outputs": [],
      "source": [
        "plt.figure()\n",
        "plt.hist(episode_lens, bins=30)\n",
        "plt.title('Histogram of Episode Lengths in the Dataset')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "41eBwCqyjDjg"
      },
      "outputs": [],
      "source": [
        "def render_frame(\n",
        "    points: np.ndarray,\n",
        "    goal_pos: np.ndarray,\n",
        "    title: str=''):\n",
        "  fig, ax = plt.subplots(figsize=(RENDER_HEIGHT_INCHES, RENDER_HEIGHT_INCHES), dpi=DPI)\n",
        "  ax.set_xlim(BOUNDS_X[0], BOUNDS_X[1])\n",
        "  ax.set_ylim(BOUNDS_Y[0], BOUNDS_Y[1])\n",
        "  ax.set_aspect('equal')\n",
        "\n",
        "  ax.plot(points[:, 0], points[:, 1], marker='.', color='blue')\n",
        "  ax.scatter(\n",
        "      goal_pos[0], goal_pos[1], marker='*', s=200, color='orange')\n",
        "  cur_pos = points[-1]\n",
        "  ax.scatter(\n",
        "      cur_pos[0], cur_pos[1], marker='o', s=100, color='red')\n",
        "\n",
        "  if title != '':\n",
        "    ax.set_title(title)\n",
        "\n",
        "  plt.tight_layout()\n",
        "\n",
        "  # Render the plot using FigureCanvasAgg\n",
        "  canvas = FigureCanvas(fig)\n",
        "  canvas.draw()\n",
        "\n",
        "  # Convert the rendered image to a numpy array\n",
        "  width, height = fig.get_size_inches() * fig.get_dpi()\n",
        "  image = np.frombuffer(canvas.tostring_rgb(), dtype='uint8')\n",
        "  image = image.reshape(int(height), int(width), 3)\n",
        "\n",
        "  plt.close(fig)\n",
        "  return image\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XpFrajDQ1DzD"
      },
      "outputs": [],
      "source": [
        "# plot some episodes to make sure things look good\n",
        "episode_num = 4\n",
        "debug_episode = episodes[episode_num]\n",
        "points = debug_episode.observation['cur_pos']\n",
        "goal_pos = debug_episode.observation['goal_pos'][0]\n",
        "imgs = []\n",
        "for i in range(points.shape[0]):\n",
        "  imgs.append(render_frame(\n",
        "      points[:i+1], goal_pos, f'Dataset Episode {episode_num}'))\n",
        "\n",
        "media.show_video(imgs, fps=10)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wLMuvyczjNM6"
      },
      "outputs": [],
      "source": [
        "save_datasets = True #@param {type: \"boolean\"}\n",
        "save_path = 'pointmass_dataset.pkl' #@param {type: \"string\"}\n",
        "\n",
        "if save_datasets:\n",
        "  # save version with trajs\n",
        "  head, tail = os.path.splitext(save_path)\n",
        "  with gfile.GFile('{}_trajs{}'.format(head, tail), 'wb') as fp:\n",
        "    pickle.dump(episodes, fp)\n",
        "\n",
        "  # save version with just data tuples\n",
        "  with gfile.GFile('{}_tuples{}'.format(head, tail), 'wb') as fp:\n",
        "    pickle.dump(all_tuples, fp)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4UGQw5JK67l5"
      },
      "source": [
        "# Create the networks"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8vg2tP_c6_oA"
      },
      "outputs": [],
      "source": [
        "Params = Any\n",
        "PRNGKey = Any\n",
        "NetworkOutput = Any\n",
        "Entropy = Any\n",
        "ActDistParams = Params\n",
        "FeedForwardPolicyWithExtra = Any\n",
        "LogProbFn = Any\n",
        "SampleFn = Any\n",
        "Observation = Any\n",
        "Action = Any\n",
        "DistanceToSuccessDistParams = Params\n",
        "EntropyFn = Callable[\n",
        "    [Params, PRNGKey], Entropy]\n",
        "\n",
        "\n",
        "@dataclasses.dataclass\n",
        "class FeedForwardNetwork:\n",
        "  \"\"\"Holds a pair of pure functions defining a feed-forward network.\n",
        "\n",
        "  Attributes:\n",
        "    init: A Jax pure function: ``params = init(rng, *a, **k)``\n",
        "    apply: A Jax pure function: ``out = apply(params, rng, *a, **k)``\n",
        "  \"\"\"\n",
        "  # Initializes and returns the networks parameters.\n",
        "  init: Callable[..., Params]\n",
        "  # Computes and returns the outputs of a forward pass.\n",
        "  apply: Callable[..., NetworkOutput]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XVLCjIzr7BZZ"
      },
      "outputs": [],
      "source": [
        "MIN_ACT_SCALE = 1e-2 #@param {type: \"number\"}\n",
        "\n",
        "\n",
        "class MVNDiagParams(NamedTuple):\n",
        "  \"\"\"Parameters for a diagonal multi-variate normal distribution.\"\"\"\n",
        "  loc: jnp.ndarray\n",
        "  scale_diag: jnp.ndarray\n",
        "\n",
        "\n",
        "class CategoricalParams(NamedTuple):\n",
        "  \"\"\"Parameters for a categorical distribution.\"\"\"\n",
        "  logits: jnp.ndarray\n",
        "\n",
        "\n",
        "class TIMERNetworkOutput(NamedTuple):\n",
        "  act_dist_params: ActDistParams\n",
        "  dist_to_succ_dist_params: DistanceToSuccessDistParams\n",
        "\n",
        "\n",
        "@dataclasses.dataclass\n",
        "class TIMERNetworks:\n",
        "  \"\"\"Network and pure functions for the TIMER agent.\n",
        "\n",
        "  network: outputs TIMERNetworkOutputs\n",
        "  act_log_prob: log probability of an action\n",
        "  act_entropy: optional method for entropy of an action distribution\n",
        "  sample_act: samples an action given [ActDistParams, PRNGKey]\n",
        "  sample_act_mode: optional separate action sampling procedure\n",
        "  dist_log_prob: log probability of a distance\n",
        "  dist_entropy: optional method for entropy of a distance distribution\n",
        "  sample_dist: samples a distance given [DistanceToSuccessDistParams, PRNGKey]\n",
        "  sample_dist_mode: optional separate distance sampling procedure\n",
        "  \"\"\"\n",
        "  network: FeedForwardNetwork\n",
        "  act_log_prob: LogProbFn\n",
        "  sample_act: SampleFn\n",
        "  dist_log_prob: LogProbFn\n",
        "  sample_dist: SampleFn\n",
        "  act_entropy: Optional[EntropyFn] = None\n",
        "  sample_act_mode: Optional[SampleFn] = None\n",
        "  dist_entropy: Optional[EntropyFn] = None\n",
        "  sample_dist_mode: Optional[SampleFn] = None\n",
        "\n",
        "\n",
        "def make_policy_fn(\n",
        "    timer_networks: TIMERNetworks,\n",
        "    evaluation: bool) -> FeedForwardPolicyWithExtra:\n",
        "  \"\"\"Returns a policy function for the TIMER agent.\"\"\"\n",
        "\n",
        "  def _policy_fn(\n",
        "      params: Params,\n",
        "      key: PRNGKey,\n",
        "      observations: Observation,\n",
        "  ):\n",
        "    timer_network_output: TIMERNetworkOutput = timer_networks.network.apply(\n",
        "        params, observations)\n",
        "    if evaluation:\n",
        "      actions = timer_networks.sample_act_eval(\n",
        "          timer_network_output.act_dist_params, key)\n",
        "    else:\n",
        "      actions = timer_networks.sample_act(\n",
        "          timer_network_output.act_dist_params, key)\n",
        "    return actions, {}\n",
        "\n",
        "  return _policy_fn\n",
        "\n",
        "\n",
        "def build_continuous_act_discrete_dist_v0(\n",
        "    layer_sizes: Sequence[int],\n",
        "    act_dim: int,\n",
        "    num_dist_bins: int,\n",
        "    dummy_input,\n",
        ") -> TIMERNetworks:\n",
        "  \"\"\"\"Builds TIMERNetworks for continuous action and discrete distance.\"\"\"\n",
        "\n",
        "  def _network(\n",
        "      x: Observation) -> TIMERNetworkOutput:\n",
        "    #### Build the action part\n",
        "    h_act = hk.nets.MLP(\n",
        "        output_sizes=layer_sizes,\n",
        "        activation=jax.nn.relu,\n",
        "        activate_final=True,)(x)\n",
        "    act_loc = hk.Linear(\n",
        "        act_dim,\n",
        "        w_init=hk.initializers.VarianceScaling(1e-4),\n",
        "        b_init=hk.initializers.Constant(0.))(h_act)\n",
        "    act_scale = hk.Linear(\n",
        "        act_dim,\n",
        "        w_init=hk.initializers.VarianceScaling(1e-4),\n",
        "        b_init=hk.initializers.Constant(0.))(h_act)\n",
        "    act_scale = jax.nn.softplus(act_scale) + MIN_ACT_SCALE\n",
        "    act_dist = MVNDiagParams(loc=act_loc, scale_diag=act_scale)\n",
        "\n",
        "    #### Build the distance part\n",
        "    h_dist = hk.nets.MLP(\n",
        "        output_sizes=layer_sizes,\n",
        "        activation=jax.nn.relu,\n",
        "        activate_final=True,)(x)\n",
        "    dist_logits = hk.Linear(num_dist_bins, with_bias=False)(h_dist)\n",
        "    distance_dist = CategoricalParams(logits=dist_logits)\n",
        "\n",
        "    return TIMERNetworkOutput(\n",
        "        act_dist_params=act_dist,\n",
        "        dist_to_succ_dist_params=distance_dist,)\n",
        "\n",
        "  transformed_network = hk.without_apply_rng(hk.transform(_network))\n",
        "  def init_closure(rng: PRNGKey):\n",
        "    return transformed_network.init(rng, dummy_input)\n",
        "  network = FeedForwardNetwork(\n",
        "      init=init_closure,\n",
        "      apply=transformed_network.apply,)\n",
        "\n",
        "  def act_log_prob(params: MVNDiagParams, action):\n",
        "    return tfd.MultivariateNormalDiag(\n",
        "        loc=params.loc, scale_diag=params.scale_diag).log_prob(action)\n",
        "\n",
        "  def act_entropy(\n",
        "      params: MVNDiagParams, key: PRNGKey\n",
        "  ) -> Entropy:\n",
        "    del key\n",
        "    return tfd.MultivariateNormalDiag(\n",
        "        loc=params.loc, scale_diag=params.scale_diag).entropy()\n",
        "\n",
        "  def sample_act(params: MVNDiagParams, key: PRNGKey):\n",
        "    return tfd.MultivariateNormalDiag(\n",
        "        loc=params.loc, scale_diag=params.scale_diag).sample(seed=key)\n",
        "\n",
        "  def sample_act_mode(params: MVNDiagParams, key: PRNGKey):\n",
        "    del key\n",
        "    return tfd.MultivariateNormalDiag(\n",
        "        loc=params.loc, scale_diag=params.scale_diag).mode()\n",
        "\n",
        "  def dist_log_prob(params: CategoricalParams, action):\n",
        "    return tfd.Categorical(logits=params.logits).log_prob(action)\n",
        "\n",
        "  def dist_entropy(\n",
        "      params: CategoricalParams, key: PRNGKey\n",
        "  ) -> Entropy:\n",
        "    del key\n",
        "    return tfd.Categorical(logits=params.logits).entropy()\n",
        "\n",
        "  def sample_dist(params: CategoricalParams, key: PRNGKey):\n",
        "    return tfd.Categorical(logits=params.logits).sample(seed=key)\n",
        "\n",
        "  def sample_dist_mode(params: CategoricalParams, key: PRNGKey):\n",
        "    del key\n",
        "    return tfd.Categorical(logits=params.logits).mode()\n",
        "\n",
        "  return TIMERNetworks(\n",
        "      network=network,\n",
        "      act_log_prob=act_log_prob,\n",
        "      sample_act=sample_act,\n",
        "      dist_log_prob=dist_log_prob,\n",
        "      sample_dist=sample_dist,\n",
        "      act_entropy=act_entropy,\n",
        "      sample_act_mode=sample_act_mode,\n",
        "      dist_entropy=dist_entropy,\n",
        "      sample_dist_mode=sample_dist_mode,\n",
        "  )\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WqIWvPLD-v3e"
      },
      "outputs": [],
      "source": [
        "# sanity checks\n",
        "timer_networks = build_continuous_act_discrete_dist_v0((64, 64), 2, 10, np.ones((4, 6), dtype=np.float32))\n",
        "params = timer_networks.network.init(jax.random.PRNGKey(42))\n",
        "pprint(jax.tree.map(lambda x: x.shape, params))\n",
        "x = timer_networks.network.apply(params, np.ones((4, 6), dtype=np.float32))\n",
        "pprint(jax.tree.map(lambda x: x.shape, x))\n",
        "print(timer_networks.act_log_prob(x.act_dist_params, np.ones((4, 2), dtype=np.float32)).shape)\n",
        "print(timer_networks.sample_act(x.act_dist_params, jax.random.PRNGKey(42)).shape)\n",
        "print(timer_networks.dist_log_prob(x.dist_to_succ_dist_params, np.array([3, 1, 5, 0], dtype=np.int32)).shape)\n",
        "print(timer_networks.sample_dist(x.dist_to_succ_dist_params, jax.random.PRNGKey(42)).shape)\n",
        "print(timer_networks.act_entropy(x.act_dist_params, jax.random.PRNGKey(42)).shape)\n",
        "print(timer_networks.sample_act_mode(x.act_dist_params, jax.random.PRNGKey(42)).shape)\n",
        "print(timer_networks.dist_entropy(x.dist_to_succ_dist_params, jax.random.PRNGKey(42)).shape)\n",
        "print(timer_networks.sample_dist_mode(x.dist_to_succ_dist_params, jax.random.PRNGKey(42)).shape)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uJiFyu_WEms4"
      },
      "source": [
        "# Timestep prediction converters between discrete predictions and continuous values"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OumQQiwoEo-l"
      },
      "outputs": [],
      "source": [
        "class DistanceConverters(NamedTuple):\n",
        "  distance_to_network_format: Callable[\n",
        "      [jnp.ndarray], NetworkOutput]\n",
        "  network_format_to_distance: Callable[\n",
        "      [NetworkOutput], jnp.ndarray]\n",
        "\n",
        "\n",
        "def build_discrete_distance_converter(\n",
        "    min_distance: float,\n",
        "    max_distance: float,\n",
        "    num_bins: int = 100) -> DistanceConverters:\n",
        "\n",
        "  bin_size = (max_distance - min_distance) / num_bins\n",
        "\n",
        "  def _distance_to_network_format(d: float):\n",
        "    d = jnp.clip(d, min_distance, max_distance - bin_size / 2.)\n",
        "    bin_index = jnp.floor_divide(d - min_distance, bin_size)\n",
        "    return bin_index\n",
        "\n",
        "  distance_to_network_format = jax.vmap(_distance_to_network_format)\n",
        "\n",
        "  dist_vals = jnp.linspace(\n",
        "      min_distance,\n",
        "      max_distance,\n",
        "      num_bins + 1,\n",
        "      endpoint=True, dtype=jnp.float32)\n",
        "  dist_vals = dist_vals[:-1]\n",
        "\n",
        "  def _network_format_to_distance(logits: NetworkOutput):\n",
        "    dist = jnp.sum(dist_vals * jax.nn.softmax(logits))\n",
        "    return dist\n",
        "\n",
        "  network_format_to_distance = jax.vmap(_network_format_to_distance)\n",
        "\n",
        "  return DistanceConverters(\n",
        "      distance_to_network_format,\n",
        "      network_format_to_distance,)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jPiJFgVzHJ3Y"
      },
      "outputs": [],
      "source": [
        "# sanity check\n",
        "dc = build_discrete_distance_converter(0, 55, 50)\n",
        "print(dc.distance_to_network_format(np.arange(60, dtype=np.float32)))\n",
        "print(dc.network_format_to_distance(np.ones((4, 50))))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HO_rXeS_A3r9"
      },
      "source": [
        "# Implementations for Stage 1: Supervised Fine-Tuning"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "awa9FUOFOrO6"
      },
      "outputs": [],
      "source": [
        "def get_from_first_device(x, as_numpy=False):\n",
        "  if as_numpy:\n",
        "    return jax.device_get(jax.tree.map(lambda x: x[0], x))\n",
        "  else:\n",
        "    return jax.tree.map(lambda x: x[0], x)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "18Nxgh59BBdG"
      },
      "outputs": [],
      "source": [
        "TIMERParams = Params\n",
        "\n",
        "\n",
        "class TrainingState(NamedTuple):\n",
        "  \"\"\"Training state for the TIMER learner.\"\"\"\n",
        "  params: TIMERParams\n",
        "  opt_state: optax.OptState\n",
        "  random_key: PRNGKey\n",
        "\n",
        "\n",
        "class PretrainLearner():\n",
        "  def __init__(\n",
        "      self,\n",
        "      timer_networks: TIMERNetworks,\n",
        "      distance_converters: DistanceConverters,\n",
        "      optimizer: optax.GradientTransformation,\n",
        "      random_key: PRNGKey,\n",
        "      global_minibatch_size: int,\n",
        "      num_minibatches: int,):\n",
        "\n",
        "    self.local_learner_devices = jax.local_devices()\n",
        "    self.num_local_learner_devices = jax.local_device_count()\n",
        "    self.learner_devices = jax.devices()\n",
        "    per_device_minibatch_size = (\n",
        "        global_minibatch_size // len(self.learner_devices))\n",
        "\n",
        "    self._num_full_update_steps = 0\n",
        "    self.global_minibatch_size = global_minibatch_size\n",
        "    self.num_minibatches = num_minibatches\n",
        "\n",
        "    def pretrain_loss(params, minibatch: DataTuple):\n",
        "      obs = minibatch.observation\n",
        "      obs = jnp.concatenate(\n",
        "          [obs['cur_pos'], obs['cur_vel'], obs['goal_pos']], axis=-1)\n",
        "      acts = minibatch.action\n",
        "      dist_idx = dc.distance_to_network_format(minibatch.time_to_success)\n",
        "\n",
        "      preds = timer_networks.network.apply(params, obs)\n",
        "      act_dist_params = preds.act_dist_params\n",
        "      dist_to_succ_dist_params = preds.dist_to_succ_dist_params\n",
        "\n",
        "      # bc loss\n",
        "      act_log_prob = jnp.mean(\n",
        "          timer_networks.act_log_prob(act_dist_params, acts))\n",
        "      bc_loss = -1.0 * act_log_prob\n",
        "\n",
        "      # Distance to success loss\n",
        "      dist_log_prob = jnp.mean(timer_networks.dist_log_prob(\n",
        "          dist_to_succ_dist_params, minibatch.time_to_success))\n",
        "      dist_loss = -1.0 * dist_log_prob\n",
        "\n",
        "      total_loss = bc_loss + dist_loss\n",
        "\n",
        "      return total_loss, {\n",
        "          'pretrain_loss': total_loss,\n",
        "          'act_log_prob': act_log_prob,\n",
        "          'dist_log_prob': dist_log_prob,}\n",
        "\n",
        "    pretrain_grad = jax.grad(pretrain_loss, has_aux=True)\n",
        "\n",
        "    def per_device_pretrain_step(\n",
        "        state: TrainingState,\n",
        "        minibatch: DataTuple,\n",
        "    ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]:\n",
        "      pretrain_loss_grad, metrics = pretrain_grad(state.params, minibatch)\n",
        "      pretrain_loss_grad = jax.lax.pmean(\n",
        "          pretrain_loss_grad, axis_name='devices'\n",
        "      )\n",
        "      updates, new_opt_state = optimizer.update(\n",
        "          pretrain_loss_grad, state.opt_state\n",
        "      )\n",
        "      new_params = optax.apply_updates(state.params, updates)\n",
        "      state = state._replace(params=new_params, opt_state=new_opt_state)\n",
        "      return state, metrics\n",
        "\n",
        "    def scanned_per_device_pretrain_step(\n",
        "        state: TrainingState, batch: DataTuple):\n",
        "      def reshape_for_scan(x):\n",
        "        new_shape = [\n",
        "            num_minibatches,\n",
        "            per_device_minibatch_size,\n",
        "        ] + list(x.shape[1:])\n",
        "        return jnp.reshape(x, new_shape)\n",
        "\n",
        "      minibatches = jax.tree.map(reshape_for_scan, batch)\n",
        "      state, metrics = jax.lax.scan(\n",
        "          per_device_pretrain_step, state, minibatches, length=num_minibatches)\n",
        "      metrics = jax.tree.map(jnp.mean, metrics)\n",
        "\n",
        "      return state, metrics\n",
        "\n",
        "    self._pmapped_scanned_pretrain_step = jax.pmap(\n",
        "        scanned_per_device_pretrain_step,\n",
        "        axis_name='devices',\n",
        "        devices=self.learner_devices)\n",
        "\n",
        "    def per_device_compute_loss(\n",
        "        state: TrainingState,\n",
        "        minibatch: DataTuple,\n",
        "    ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]:\n",
        "      _, metrics = pretrain_loss(state.params, minibatch)\n",
        "      return state, metrics\n",
        "\n",
        "    def scanned_per_device_compute_loss(\n",
        "        state: TrainingState, batch: DataTuple):\n",
        "      def reshape_for_scan(x):\n",
        "        new_shape = [\n",
        "            num_minibatches,\n",
        "            per_device_minibatch_size,\n",
        "        ] + list(x.shape[1:])\n",
        "        return jnp.reshape(x, new_shape)\n",
        "\n",
        "      minibatches = jax.tree.map(reshape_for_scan, batch)\n",
        "      state, metrics = jax.lax.scan(\n",
        "          per_device_compute_loss, state, minibatches, length=num_minibatches)\n",
        "      metrics = jax.tree.map(jnp.mean, metrics)\n",
        "\n",
        "      return state, metrics\n",
        "\n",
        "    self._pmapped_scanned_compute_loss = jax.pmap(\n",
        "        scanned_per_device_compute_loss,\n",
        "        axis_name='devices',\n",
        "        devices=self.learner_devices)\n",
        "\n",
        "    def make_initial_state(random_key: PRNGKey) -> TrainingState:\n",
        "      all_keys = jax.random.split(\n",
        "          random_key, num=self.num_local_learner_devices + 1)\n",
        "      key_init, key_state = all_keys[0], all_keys[1:]\n",
        "      key_state = [key_state[i] for i in range(self.num_local_learner_devices)]\n",
        "      key_state = jax.device_put_sharded(key_state, self.local_learner_devices)\n",
        "\n",
        "      initial_params = timer_networks.network.init(key_init)\n",
        "      initial_opt_state = optimizer.init(initial_params)\n",
        "\n",
        "      initial_params = jax.device_put_replicated(initial_params,\n",
        "                                                 self.local_learner_devices)\n",
        "      initial_opt_state = jax.device_put_replicated(initial_opt_state,\n",
        "                                                    self.local_learner_devices)\n",
        "\n",
        "      return TrainingState(\n",
        "          params=initial_params,\n",
        "          opt_state=initial_opt_state,\n",
        "          random_key=key_state,)\n",
        "\n",
        "    # Initialise training state (parameters and optimizer state).\n",
        "    self._state = make_initial_state(random_key)\n",
        "\n",
        "  def step(self, batch: DataTuple):\n",
        "    self._state, results = self._pmapped_scanned_pretrain_step(\n",
        "        self._state, batch)\n",
        "\n",
        "    self._num_full_update_steps += self.num_minibatches\n",
        "\n",
        "    results = jax.tree.map(jnp.mean, results)\n",
        "    return results\n",
        "\n",
        "  def compute_loss(self, batch: DataTuple):\n",
        "    _, results = self._pmapped_scanned_pretrain_step(\n",
        "        self._state, batch)\n",
        "    results = jax.tree.map(jnp.mean, results)\n",
        "    return results\n",
        "\n",
        "  def get_state(self):\n",
        "    return get_from_first_device(self._state, as_numpy=True)\n",
        "\n",
        "  def restore(self, state: TrainingState):\n",
        "    random_key = state.random_key\n",
        "    random_key = jax.random.split(\n",
        "        random_key, num=self.num_local_learner_devices)\n",
        "    random_key = jax.device_put_sharded(\n",
        "        [random_key[i] for i in range(self.num_local_learner_devices)],\n",
        "        self.local_learner_devices)\n",
        "\n",
        "    state = jax.device_put_replicated(state, self.local_learner_devices)\n",
        "    state = state._replace(random_key=random_key)\n",
        "    self._state = state\n",
        "\n",
        "\n",
        "# sanity checks\n",
        "optimizer = optax.chain(\n",
        "    optax.clip_by_global_norm(0.5),\n",
        "    optax.scale_by_adam(eps=1e-7),\n",
        "    optax.scale(-3e-4))\n",
        "learner = PretrainLearner(\n",
        "    timer_networks,\n",
        "    dc,\n",
        "    optimizer,\n",
        "    jax.random.PRNGKey(0),\n",
        "    64,\n",
        "    4,)\n",
        "pprint(jax.tree.map(lambda x: x.shape, learner.get_state()))\n",
        "batch = jax.tree.map(lambda x: x[:256], all_tuples)\n",
        "def reshape_for_pmap(x):\n",
        "  new_shape = [\n",
        "      jax.device_count(),\n",
        "      (64 // jax.device_count()) * 4,\n",
        "  ] + list(x.shape[1:])\n",
        "  return jnp.reshape(x, new_shape)\n",
        "batch = jax.tree.map(lambda x: reshape_for_pmap(x), batch)\n",
        "print(learner.step(batch))\n",
        "print(learner.compute_loss(batch))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TQWhLW8O1uFA"
      },
      "outputs": [],
      "source": [
        "print(jnp.mean(all_tuples.observation['cur_pos'], axis=0))\n",
        "print(jnp.std(all_tuples.observation['cur_pos'], axis=0))\n",
        "print(jnp.mean(all_tuples.observation['cur_vel'], axis=0))\n",
        "print(jnp.std(all_tuples.observation['cur_vel'], axis=0))\n",
        "print(jnp.mean(all_tuples.observation['goal_pos'], axis=0))\n",
        "print(jnp.std(all_tuples.observation['goal_pos'], axis=0))\n",
        "print(jnp.mean(all_tuples.action, axis=0))\n",
        "print(jnp.std(all_tuples.action, axis=0))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Gbbl3wiv2STe"
      },
      "outputs": [],
      "source": [
        "cur_pos_mean = jnp.mean(all_tuples.observation['cur_pos'], axis=0, keepdims=True)\n",
        "cur_pos_std = jnp.std(all_tuples.observation['cur_pos'], axis=0, keepdims=True)\n",
        "cur_vel_mean = jnp.mean(all_tuples.observation['cur_vel'], axis=0, keepdims=True)\n",
        "cur_vel_std = jnp.std(all_tuples.observation['cur_vel'], axis=0, keepdims=True)\n",
        "act_mean = jnp.mean(all_tuples.action, axis=0, keepdims=True)\n",
        "act_std = jnp.std(all_tuples.action, axis=0, keepdims=True)\n",
        "\n",
        "def normalize_obs(obs):\n",
        "  normalized_obs = {\n",
        "      'cur_pos': (obs['cur_pos'] - cur_pos_mean) / cur_pos_std,\n",
        "      'cur_vel': (obs['cur_vel'] - cur_vel_mean) / cur_vel_std,\n",
        "      'goal_pos': (obs['goal_pos'] - cur_pos_mean) / cur_pos_std,}\n",
        "  return normalized_obs\n",
        "\n",
        "def normalize_action(action):\n",
        "  return (action - act_mean) / act_std\n",
        "\n",
        "def unnormalize_action(action):\n",
        "  return action * act_std + act_mean\n",
        "\n",
        "normalized_all_tuples = all_tuples._replace(\n",
        "    observation=normalize_obs(all_tuples.observation),\n",
        "    action=normalize_action(all_tuples.action))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lkpV0oTZ53GE"
      },
      "outputs": [],
      "source": [
        "print(jnp.mean(normalized_all_tuples.observation['cur_pos'], axis=0))\n",
        "print(jnp.std(normalized_all_tuples.observation['cur_pos'], axis=0))\n",
        "print(jnp.mean(normalized_all_tuples.observation['cur_vel'], axis=0))\n",
        "print(jnp.std(normalized_all_tuples.observation['cur_vel'], axis=0))\n",
        "print(jnp.mean(normalized_all_tuples.observation['goal_pos'], axis=0))\n",
        "print(jnp.std(normalized_all_tuples.observation['goal_pos'], axis=0))\n",
        "print(jnp.mean(normalized_all_tuples.action, axis=0))\n",
        "print(jnp.std(normalized_all_tuples.action, axis=0))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "v4VyUpQJZPfv"
      },
      "source": [
        "# Train Stage 1: Supervised Fine-Tuning"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "MI0ABDh6ZR7B"
      },
      "outputs": [],
      "source": [
        "# Make the data loader\n",
        "global_minibatch_size = 256 #@param {type: \"number\"}\n",
        "# num_minibatches is the is the number of SGD update steps we do per call to the learner\n",
        "# the updates steps are scanned using jax.lax.scan for efficiency.\n",
        "num_minibatches = 128 #@param {type: \"number\"}\n",
        "\n",
        "batch_size = global_minibatch_size * num_minibatches\n",
        "\n",
        "num_learners = jax.device_count()\n",
        "def reshape_for_pmap(x):\n",
        "  new_shape = [\n",
        "      num_learners,\n",
        "      (global_minibatch_size * num_minibatches) // num_learners,\n",
        "  ] + list(x.shape[1:])\n",
        "  return tf.reshape(x, new_shape)\n",
        "\n",
        "def make_dataset_from_tuples(data_tuples):\n",
        "  dataset = tf_data.Dataset.from_tensor_slices(data_tuples).cache()\n",
        "  dataset = dataset.shuffle(\n",
        "      all_tuples.observation['cur_pos'].shape[0], reshuffle_each_iteration=True)\n",
        "  dataset = dataset.repeat().batch(batch_size, drop_remainder=True)\n",
        "  dataset = dataset.map(lambda x: jax.tree.map(reshape_for_pmap, x))\n",
        "  dataset = dataset.prefetch(tf_data.experimental.AUTOTUNE)\n",
        "  dataset = dataset.as_numpy_iterator()\n",
        "  return dataset\n",
        "\n",
        "normalized_all_tuples_size = normalized_all_tuples.observation['cur_pos'].shape[0]\n",
        "train_set_ratio = 0.9\n",
        "train_set_size = int(normalized_all_tuples_size * train_set_ratio)\n",
        "train_dataset = make_dataset_from_tuples(\n",
        "    jax.tree.map(lambda x: np.array(x[:train_set_size]), normalized_all_tuples))\n",
        "val_dataset = make_dataset_from_tuples(\n",
        "    jax.tree.map(lambda x: np.array(x[train_set_size:]), normalized_all_tuples))\n",
        "\n",
        "print(jax.tree.map(lambda x: x.shape, next(train_dataset)))\n",
        "print(jax.tree.map(lambda x: x.shape, next(val_dataset)))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "o97Iiy1kh1Gy"
      },
      "outputs": [],
      "source": [
        "# Make the distance converters\n",
        "min_distance = 0 #@param {type: \"number\"}\n",
        "max_distance = 140 #@param {type: \"number\"}\n",
        "num_distance_bins = 50 #@param {type: \"number\"}\n",
        "\n",
        "distance_converter = build_discrete_distance_converter(\n",
        "    min_distance, max_distance, num_distance_bins)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "LWpODPiZgsDh"
      },
      "outputs": [],
      "source": [
        "# Make the networks\n",
        "layer_sizes = (256, 256, 256) #@param\n",
        "\n",
        "timer_networks = build_continuous_act_discrete_dist_v0(\n",
        "    layer_sizes,\n",
        "    env.action_spec().shape[0],\n",
        "    num_distance_bins,\n",
        "    np.ones((4, 6), dtype=np.float32))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XTSqx_DViOuo"
      },
      "outputs": [],
      "source": [
        "# Make the optimizer\n",
        "learning_rate = 3e-4 #@param {type: \"number\"}\n",
        "global_norm_clip = 1.0 #@param {type: \"number\"}\n",
        "\n",
        "optimizer = optax.chain(\n",
        "    # optax.clip_by_global_norm(global_norm_clip),\n",
        "    optax.scale_by_adam(eps=1e-7),\n",
        "    optax.scale(-1. * learning_rate))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "CLMbthPlZZEP"
      },
      "outputs": [],
      "source": [
        "# Make the learner\n",
        "learner = PretrainLearner(\n",
        "    timer_networks,\n",
        "    distance_converter,\n",
        "    optimizer,\n",
        "    jax.random.PRNGKey(42),\n",
        "    global_minibatch_size,\n",
        "    num_minibatches,)\n",
        "\n",
        "losses = []\n",
        "act_log_probs = []\n",
        "dist_log_probs = []\n",
        "\n",
        "val_losses = []\n",
        "val_act_log_probs = []\n",
        "val_dist_log_probs = []"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "CM1uau00Zbod"
      },
      "outputs": [],
      "source": [
        "# Stage 1 Supervised Fine-Tuning Train Loop\n",
        "# You can keep rerunning this cell if you would like to continue the training\n",
        "# for more iterations.\n",
        "\n",
        "# Number of SGD steps to perform\n",
        "num_steps = 32768 #@param\n",
        "assert num_steps % num_minibatches == 0\n",
        "display_rate = 1\n",
        "\n",
        "for i in range(num_steps // num_minibatches):\n",
        "  batch = next(train_dataset)\n",
        "  results = learner.step(batch)\n",
        "  cur_loss = results['pretrain_loss'].item()\n",
        "  losses.append(cur_loss)\n",
        "  act_log_probs.append(results['act_log_prob'].item())\n",
        "  dist_log_probs.append(results['dist_log_prob'].item())\n",
        "\n",
        "  if i % display_rate == 0:\n",
        "    val_batch = next(val_dataset)\n",
        "    val_results = learner.compute_loss(val_batch)\n",
        "    val_cur_loss = val_results['pretrain_loss'].item()\n",
        "    val_losses.append(val_cur_loss)\n",
        "    val_act_log_probs.append(val_results['act_log_prob'].item())\n",
        "    val_dist_log_probs.append(val_results['dist_log_prob'].item())\n",
        "\n",
        "    clear_output(wait=True)\n",
        "    plt.figure(figsize=[4*3,4*1])\n",
        "\n",
        "    plt.subplot(1, 3, 1)\n",
        "    plt.title('Stage 1 SFT Loss')\n",
        "    plt.plot(\n",
        "        np.arange(len(losses)) * num_minibatches,\n",
        "        losses,\n",
        "        color='blue',\n",
        "        label='train',)\n",
        "    plt.plot(\n",
        "        np.arange(len(val_losses)) * num_minibatches * display_rate,\n",
        "        val_losses,\n",
        "        color='red',\n",
        "        label='val',)\n",
        "    plt.legend()\n",
        "\n",
        "    plt.subplot(1, 3, 2)\n",
        "    plt.title('Action Log Prob')\n",
        "    plt.plot(\n",
        "        np.arange(len(losses)) * num_minibatches,\n",
        "        act_log_probs,\n",
        "        color='blue',\n",
        "        label='train',)\n",
        "    plt.plot(\n",
        "        np.arange(len(val_losses)) * num_minibatches * display_rate,\n",
        "        val_act_log_probs,\n",
        "        color='red',\n",
        "        label='val',)\n",
        "    plt.legend()\n",
        "\n",
        "    plt.subplot(1, 3, 3)\n",
        "    plt.title('Timestep Log Prob')\n",
        "    plt.plot(\n",
        "        np.arange(len(losses)) * num_minibatches,\n",
        "        dist_log_probs,\n",
        "        color='blue',\n",
        "        label='train',)\n",
        "    plt.plot(\n",
        "        np.arange(len(val_losses)) * num_minibatches * display_rate,\n",
        "        val_dist_log_probs,\n",
        "        color='red',\n",
        "        label='val',)\n",
        "    plt.legend()\n",
        "\n",
        "    plt.show()\n",
        "\n",
        "pretrain_cpu_state = learner.get_state()\n",
        "print(f'\\nTotal Steps: {len(losses) * num_minibatches}')\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bMnKtjzY5Zs6"
      },
      "outputs": [],
      "source": [
        "batch = next(train_dataset)\n",
        "print(jnp.mean(batch.observation['cur_pos'], axis=(0, 1)))\n",
        "print(jnp.std(batch.observation['cur_pos'], axis=(0, 1)))\n",
        "print(jnp.mean(batch.observation['cur_vel'], axis=(0, 1)))\n",
        "print(jnp.std(batch.observation['cur_vel'], axis=(0, 1)))\n",
        "print(jnp.mean(batch.observation['goal_pos'], axis=(0, 1)))\n",
        "print(jnp.std(batch.observation['goal_pos'], axis=(0, 1)))\n",
        "print(jnp.mean(batch.action, axis=(0, 1)))\n",
        "print(jnp.std(batch.action, axis=(0, 1)))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "O2n2JfrXxC8F"
      },
      "source": [
        "# Visualize Policies after the Stage 1 Supervised-Finetuning process"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9ZHhzGFrxGsH"
      },
      "outputs": [],
      "source": [
        "# Get the learner state and compile a CPU policy\n",
        "cpu_state = learner.get_state()\n",
        "cpu_params = cpu_state.params\n",
        "\n",
        "def _policy(obs, rng):\n",
        "  obs = jax.tree.map(lambda x: x[None], obs)\n",
        "  obs = normalize_obs(obs)\n",
        "  obs = jnp.concatenate(\n",
        "      [obs['cur_pos'], obs['cur_vel'], obs['goal_pos']], axis=-1)\n",
        "  pred = timer_networks.network.apply(cpu_params, obs)\n",
        "  act = timer_networks.sample_act(pred.act_dist_params, rng)\n",
        "  # print(act)\n",
        "  act = unnormalize_action(act)\n",
        "  # print(act)\n",
        "  dist_to_succ = distance_converter.network_format_to_distance(\n",
        "      pred.dist_to_succ_dist_params.logits)\n",
        "  extras = {\n",
        "      'pred_dist_to_succ': dist_to_succ,\n",
        "      'pred_dist_to_succ_dist_params': pred.dist_to_succ_dist_params,}\n",
        "  return act[0], extras\n",
        "\n",
        "policy = jax.jit(_policy, backend='cpu')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vBEs8Le7yLsg"
      },
      "outputs": [],
      "source": [
        "def get_distance_plot(distances, logits):\n",
        "  fig, ax = plt.subplots(\n",
        "      figsize=(2 * RENDER_HEIGHT_INCHES, RENDER_HEIGHT_INCHES), dpi=DPI)\n",
        "  plt.clf()\n",
        "\n",
        "  plt.subplot(1, 2, 1)\n",
        "  plt.plot(distances, color='blue')\n",
        "  plt.ylim(min_distance, max_distance)\n",
        "  plt.xlabel('Episode Step')\n",
        "  plt.ylabel('E[Timesteps Until Success]')\n",
        "  plt.title('E[Timesteps Until Success] Under Model Distribution')\n",
        "\n",
        "\n",
        "  plt.subplot(1, 2, 2)\n",
        "  probs = jax.nn.softmax(logits, axis=-1)\n",
        "  plt.bar(\n",
        "      np.linspace(\n",
        "          min_distance,\n",
        "          max_distance,\n",
        "          probs.shape[0] + 1,\n",
        "          endpoint=True)[:-1],\n",
        "      probs)\n",
        "  plt.ylim(0., 1.)\n",
        "  plt.xlabel('Timesteps Until Success')\n",
        "  plt.ylabel('Probability')\n",
        "  plt.title('p(Timesteps Until Success) at Current Frame')\n",
        "\n",
        "  plt.tight_layout()\n",
        "\n",
        "  # Render the plot using FigureCanvasAgg\n",
        "  canvas = FigureCanvas(fig)\n",
        "  canvas.draw()\n",
        "\n",
        "  # Convert the rendered image to a numpy array\n",
        "  width, height = fig.get_size_inches() * fig.get_dpi()\n",
        "  image = np.frombuffer(canvas.tostring_rgb(), dtype='uint8')\n",
        "  image = image.reshape(int(height), int(width), 3)\n",
        "\n",
        "  plt.close(fig)\n",
        "\n",
        "  return image\n",
        "\n",
        "def generate_distance_plots(all_extras):\n",
        "  imgs = []\n",
        "  for i in range(all_extras['pred_dist_to_succ'].shape[0]):\n",
        "    dist_preds = all_extras['pred_dist_to_succ'][:i+1]\n",
        "    logits = all_extras['pred_dist_to_succ_dist_params'].logits[i]\n",
        "    img = get_distance_plot(dist_preds, logits)\n",
        "    imgs.append(img)\n",
        "  return imgs\n",
        "\n",
        "def generate_policy_traj(policy, title):\n",
        "  imgs = []\n",
        "  all_extras = []\n",
        "  ts = env.reset()\n",
        "  t = 0\n",
        "  succ = env.success()\n",
        "  key = jax.random.PRNGKey(42)\n",
        "  imgs.append(env.render(title=title))\n",
        "\n",
        "  while (not succ) and t < max_distance:\n",
        "    obs = ts.observation\n",
        "    key, sub_key = jax.random.split(key)\n",
        "    act, extras = policy(obs, sub_key)\n",
        "    all_extras.append(jax.tree.map(lambda x: x[0], extras))\n",
        "    ts = env.step(act)\n",
        "    imgs.append(env.render(title=title))\n",
        "    succ = env.success()\n",
        "    t += 1\n",
        "\n",
        "  # repeat the last all_extras to match imgs len\n",
        "  all_extras.append(all_extras[-1])\n",
        "\n",
        "  all_extras = jax.tree.map(lambda *xs: np.stack(xs), *all_extras)\n",
        "\n",
        "  return imgs, all_extras\n",
        "\n",
        "def get_trajectory_visualization(policy, title):\n",
        "  imgs, all_extras = generate_policy_traj(policy, title)\n",
        "  plot_imgs = generate_distance_plots(all_extras)\n",
        "  video_imgs = []\n",
        "  for (x, y) in zip(imgs, plot_imgs):\n",
        "    video_imgs.append(np.concatenate([x, y], axis=1))\n",
        "  return video_imgs\n",
        "\n",
        "media.show_video(get_trajectory_visualization(policy, 'Stage 1 Supervised Fine-Tuning Policy'), fps=10)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "77pHfuNKCXhU"
      },
      "source": [
        "# RL Utilities"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vqT9JxC4CZfu"
      },
      "outputs": [],
      "source": [
        "def _timer_rollout_policy(params, normalized_obs, rng):\n",
        "  normalized_obs = jnp.concatenate(\n",
        "      [normalized_obs['cur_pos'], normalized_obs['cur_vel'], normalized_obs['goal_pos']], axis=-1)\n",
        "  preds = timer_networks.network.apply(params, normalized_obs)\n",
        "  normalized_act = timer_networks.sample_act(preds.act_dist_params, rng)\n",
        "  dist_pred = distance_converter.network_format_to_distance(\n",
        "      preds.dist_to_succ_dist_params.logits)\n",
        "  return normalized_act, {'pred_dist_to_succ': dist_pred}\n",
        "\n",
        "timer_rollout_policy = jax.jit(_timer_rollout_policy, backend='cpu')\n",
        "# timer_rollout_policy = _timer_rollout_policy\n",
        "\n",
        "def evaluate_timer_rollout_policy(env, params, num_episodes):\n",
        "  stats = []\n",
        "  for eps_num in range(num_episodes):\n",
        "    timesteps = []\n",
        "    ts = env.reset()\n",
        "    ts = ts._replace(reward=0.)\n",
        "    timesteps.append(ts)\n",
        "\n",
        "    key = jax.random.PRNGKey(42)\n",
        "\n",
        "    while (not env.success()) and len(timesteps) < max_distance:\n",
        "      key, sub_key = jax.random.split(key)\n",
        "      cur_obs = ts.observation\n",
        "      cur_obs = jax.tree.map(lambda x: x[None], cur_obs)\n",
        "      cur_obs = normalize_obs(cur_obs)  # 1 x dims\n",
        "      norm_act, _ = timer_rollout_policy(params, cur_obs, sub_key)\n",
        "      unnorm_act = unnormalize_action(norm_act)\n",
        "      ts = env.step(unnorm_act[0])\n",
        "      timesteps.append(ts)\n",
        "\n",
        "    episode_stats = {}\n",
        "    episode_stats['success'] = env.success()\n",
        "    episode_stats['return'] = sum(x.reward for x in timesteps)\n",
        "    episode_stats['len'] = len(timesteps)\n",
        "    stats.append(episode_stats)\n",
        "\n",
        "  stats = jax.tree.map(lambda *xs: np.stack(xs), *stats)\n",
        "  print(f'Success Rate: {np.mean(stats[\"success\"])}')\n",
        "  print(f'Returns: {np.mean(stats[\"return\"]):.2f} +/- {np.std(stats[\"return\"]):.2f}')\n",
        "  print(f'Episode Lengths: {np.mean(stats[\"len\"]):.2f} +/- {np.std(stats[\"len\"]):.2f}')\n",
        "  print(f'Max Episode Length: {np.max(stats[\"len\"])}')\n",
        "  print(f'Min Episode Length: {np.min(stats[\"len\"])}')\n",
        "\n",
        "evaluate_timer_rollout_policy(env, cpu_params, num_episodes=25)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "75NL-EzV_jV6"
      },
      "outputs": [],
      "source": [
        "class REINFORCETuple(NamedTuple):\n",
        "  observation: NestedArray\n",
        "  action: NestedArray\n",
        "  weight: NestedArray\n",
        "\n",
        "def generate_timer_reinforce_dataset(env, params, num_steps, gamma):\n",
        "  total_steps = 0\n",
        "  data_tuples = []\n",
        "  key = jax.random.PRNGKey(42)\n",
        "\n",
        "  data_stats = []\n",
        "\n",
        "  while total_steps < num_steps:\n",
        "    traj_obs = []\n",
        "    traj_acts = []\n",
        "    traj_dist_preds = []\n",
        "    episode_stats = {}\n",
        "\n",
        "    episode_steps = 0\n",
        "    episode_return = 0.\n",
        "    ts = env.reset()\n",
        "    cur_obs = ts.observation\n",
        "    cur_obs = jax.tree.map(lambda x: x[None], cur_obs)\n",
        "    cur_obs = normalize_obs(cur_obs)  # 1 x dims\n",
        "    traj_obs.append(cur_obs)\n",
        "\n",
        "    while (not env.success()) and episode_steps < max_distance:\n",
        "      sub_key, key = jax.random.split(key)\n",
        "      norm_act, extras = timer_rollout_policy(params, cur_obs, sub_key)\n",
        "      unnorm_act = unnormalize_action(norm_act)\n",
        "      traj_acts.append(norm_act)\n",
        "      traj_dist_preds.append(extras['pred_dist_to_succ'])\n",
        "\n",
        "      ts = env.step(unnorm_act[0])\n",
        "      episode_return += ts.reward\n",
        "      cur_obs = ts.observation\n",
        "      cur_obs = jax.tree.map(lambda x: x[None], cur_obs)\n",
        "      cur_obs = normalize_obs(cur_obs)  # 1 x dims\n",
        "      traj_obs.append(cur_obs)\n",
        "\n",
        "      episode_steps += 1\n",
        "\n",
        "    if episode_steps < 1:\n",
        "      continue\n",
        "\n",
        "    episode_stats['success'] = env.success()\n",
        "    episode_stats['return'] = episode_return\n",
        "    episode_stats['len'] = episode_steps\n",
        "\n",
        "    total_steps += len(traj_acts)\n",
        "\n",
        "    sub_key, key = jax.random.split(key)\n",
        "    norm_act, extras = timer_rollout_policy(params, cur_obs, sub_key)\n",
        "    traj_dist_preds.append(extras['pred_dist_to_succ'])\n",
        "\n",
        "    traj_obs = jax.tree.map(lambda *xs: np.concatenate(xs), *traj_obs)\n",
        "    traj_acts = jax.tree.map(\n",
        "        lambda *xs: np.concatenate(xs), *traj_acts)\n",
        "    traj_dist_preds = jax.tree.map(\n",
        "        lambda *xs: np.concatenate(xs), *traj_dist_preds)\n",
        "\n",
        "    rews = -1. * (traj_dist_preds[1:] - traj_dist_preds[:-1])\n",
        "    weights = []\n",
        "    temp = 0.\n",
        "    for i in range(rews.shape[0] - 1, -1, -1):\n",
        "      weights.append(rews[i] + gamma * temp)\n",
        "      temp = weights[-1]\n",
        "    weights = np.array(weights[::-1], dtype=np.float32)\n",
        "\n",
        "\n",
        "    traj_tuples = REINFORCETuple(\n",
        "        observation=jax.tree.map(lambda x: x[:-1], traj_obs),\n",
        "        action=traj_acts,\n",
        "        weight=weights,\n",
        "    )\n",
        "    data_tuples.append(traj_tuples)\n",
        "\n",
        "    data_stats.append(episode_stats)\n",
        "\n",
        "  data_tuples = jax.tree.map(\n",
        "      lambda *xs: np.concatenate(xs), *data_tuples)\n",
        "  data_stats = jax.tree.map(lambda *xs: np.stack(xs), *data_stats)\n",
        "  return data_tuples, data_stats\n",
        "\n",
        "\n",
        "# sanity check\n",
        "tick = time.time()\n",
        "reinforce_data, data_stats = generate_timer_reinforce_dataset(\n",
        "    env, cpu_params, num_steps=2048, gamma=0.9)\n",
        "print(f'Took {time.time() - tick:.2f} seconds')\n",
        "pprint(jax.tree.map(lambda x: x.shape, reinforce_data))\n",
        "print(jnp.mean(reinforce_data.observation['cur_pos'], axis=0))\n",
        "print(jnp.std(reinforce_data.observation['cur_pos'], axis=0))\n",
        "print(jnp.mean(reinforce_data.observation['cur_vel'], axis=0))\n",
        "print(jnp.std(reinforce_data.observation['cur_vel'], axis=0))\n",
        "print(jnp.mean(reinforce_data.observation['goal_pos'], axis=0))\n",
        "print(jnp.std(reinforce_data.observation['goal_pos'], axis=0))\n",
        "print(jnp.mean(reinforce_data.action, axis=0))\n",
        "print(jnp.std(reinforce_data.action, axis=0))\n",
        "print(jnp.mean(reinforce_data.weight, axis=0))\n",
        "print(jnp.std(reinforce_data.weight, axis=0))\n",
        "\n",
        "stats = data_stats\n",
        "print(f'Success Rate: {np.mean(stats[\"success\"])}')\n",
        "print(f'Returns: {np.mean(stats[\"return\"]):.2f} +/- {np.std(stats[\"return\"]):.2f}')\n",
        "print(f'Episode Lengths: {np.mean(stats[\"len\"]):.2f} +/- {np.std(stats[\"len\"]):.2f}')\n",
        "print(f'Max Episode Length: {np.max(stats[\"len\"])}')\n",
        "print(f'Min Episode Length: {np.min(stats[\"len\"])}')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ey0z3SZA_eu5"
      },
      "source": [
        "# Train Stage 2 Self-Improvement"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7b19so0xkR9h"
      },
      "outputs": [],
      "source": [
        "reinforce_global_minibatch_size = 64  # @param {type: \"number\"}\n",
        "reinforce_global_batch_size = 2048  # @param {type: \"number\"}\n",
        "reinforce_num_minibatches = (\n",
        "    reinforce_global_batch_size // reinforce_global_minibatch_size)\n",
        "per_device_reinforce_minibatch_size = reinforce_global_minibatch_size // jax.device_count()\n",
        "per_device_reinforce_batch_size = reinforce_global_batch_size // jax.device_count()\n",
        "num_devices = jax.device_count()\n",
        "\n",
        "def reinforce_loss(params, minibatch: REINFORCETuple):\n",
        "  obs = minibatch.observation\n",
        "  obs = jnp.concatenate(\n",
        "      [obs['cur_pos'], obs['cur_vel'], obs['goal_pos']], axis=-1)\n",
        "  preds = timer_networks.network.apply(params, obs)\n",
        "  act_log_probs = timer_networks.act_log_prob(\n",
        "      preds.act_dist_params, minibatch.action)\n",
        "  weights = minibatch.weight / float(max_distance)\n",
        "  loss = -1. * jnp.mean((weights * act_log_probs))\n",
        "  return loss, {'reinforce_loss': loss}\n",
        "\n",
        "reinforce_loss_grad = jax.grad(reinforce_loss, has_aux=True)\n",
        "\n",
        "def per_device_reinforce_step(\n",
        "    state: TrainingState,\n",
        "    minibatch: REINFORCETuple,) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]:\n",
        "    minibatch_loss_grad, metrics = reinforce_loss_grad(state.params, minibatch)\n",
        "    minibatch_loss_grad = jax.lax.pmean(\n",
        "        minibatch_loss_grad, axis_name='devices'\n",
        "    )\n",
        "    updates, new_opt_state = optimizer.update(\n",
        "        minibatch_loss_grad, state.opt_state\n",
        "    )\n",
        "    new_params = optax.apply_updates(state.params, updates)\n",
        "    state = state._replace(params=new_params, opt_state=new_opt_state)\n",
        "    return state, metrics\n",
        "\n",
        "def scanned_per_device_reinforce_step(state: TrainingState, batch: DataTuple):\n",
        "  def reshape_for_scan(x):\n",
        "    new_shape = [\n",
        "        reinforce_num_minibatches,\n",
        "        per_device_reinforce_minibatch_size,\n",
        "    ] + list(x.shape[1:])\n",
        "    return jnp.reshape(x, new_shape)\n",
        "  minibatches = jax.tree.map(reshape_for_scan, batch)\n",
        "  state, metrics = jax.lax.scan(\n",
        "      per_device_reinforce_step, state, minibatches, length=reinforce_num_minibatches)\n",
        "  metrics = jax.tree.map(jnp.mean, metrics)\n",
        "  return state, metrics\n",
        "\n",
        "pmapped_scanned_per_device_reinforce_step = jax.pmap(\n",
        "    scanned_per_device_reinforce_step,\n",
        "    axis_name='devices',\n",
        "    devices=jax.devices())\n",
        "\n",
        "def _full_reinforce_step(state: TrainingState, batch: REINFORCETuple):\n",
        "  def reshape_for_pmap(x):\n",
        "    new_shape = [\n",
        "        num_devices,\n",
        "        reinforce_global_batch_size // num_devices,\n",
        "    ] + list(x.shape[1:])\n",
        "    return jnp.reshape(x, new_shape)\n",
        "  batch = jax.tree.map(lambda x: reshape_for_pmap(x), batch)\n",
        "  state, metrics = pmapped_scanned_per_device_reinforce_step(state, batch)\n",
        "  metrics = jax.tree.map(jnp.mean, metrics)\n",
        "  return state, metrics\n",
        "\n",
        "full_reinforce_step = _full_reinforce_step\n",
        "\n",
        "def restore_reinforce_state(state: TrainingState):\n",
        "    random_key = state.random_key\n",
        "    device_count = jax.device_count()\n",
        "    random_key = jax.random.split(\n",
        "        random_key, num=device_count)\n",
        "    random_key = jax.device_put_sharded(\n",
        "        [random_key[i] for i in range(device_count)],\n",
        "        jax.devices())\n",
        "\n",
        "    state = jax.device_put_replicated(state, jax.devices())\n",
        "    state = state._replace(random_key=random_key)\n",
        "    return state\n",
        "\n",
        "# sanity checks\n",
        "reinforce_state = restore_reinforce_state(pretrain_cpu_state)\n",
        "reinforce_state, metrics = full_reinforce_step(\n",
        "    reinforce_state,\n",
        "    jax.tree.map(lambda x: x[:reinforce_global_batch_size], reinforce_data))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hiNeBDIAzfHg"
      },
      "outputs": [],
      "source": [
        "reinforce_state = restore_reinforce_state(pretrain_cpu_state)\n",
        "reinforce_metrics = {}\n",
        "reinforce_metrics_list_form = defaultdict(list)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1tAnkuuFyTgc"
      },
      "outputs": [],
      "source": [
        "# Stage 2 Self-Improvement Train Loop\n",
        "# You can keep rerunning this cell if you would like to continue the training\n",
        "# for more iterations.\n",
        "\n",
        "# Number of REINFORCE steps to perform\n",
        "num_reinforce_sgd_steps = 2048  # @param {type: \"number\"}\n",
        "assert num_reinforce_sgd_steps % reinforce_num_minibatches == 0\n",
        "\n",
        "env_steps_per_batch = reinforce_global_batch_size\n",
        "\n",
        "\n",
        "for i in range(num_reinforce_sgd_steps // reinforce_num_minibatches):\n",
        "  reinforce_data, data_stats = generate_timer_reinforce_dataset(\n",
        "      env,\n",
        "      get_from_first_device(reinforce_state, as_numpy=True).params,\n",
        "      num_steps=env_steps_per_batch,\n",
        "      gamma=0.9)  # default that we use in our work, works\n",
        "      # gamma=1.)  # does not work, as expected\n",
        "      # gamma=0.)  # works, surprisingly, but we know this is not a good idea\n",
        "  reinforce_state, metrics = full_reinforce_step(\n",
        "      reinforce_state,\n",
        "      jax.tree.map(lambda x: x[:reinforce_global_batch_size], reinforce_data))\n",
        "  reinforce_metrics_list_form['reinforce_loss'].append(metrics['reinforce_loss'].item())\n",
        "  reinforce_metrics_list_form['success_rate'].append(np.mean(data_stats['success']))\n",
        "  reinforce_metrics_list_form['return_mean'].append(np.mean(data_stats['return']))\n",
        "  reinforce_metrics_list_form['return_std'].append(np.std(data_stats['return']))\n",
        "  reinforce_metrics_list_form['len_mean'].append(np.mean(data_stats['len']))\n",
        "  reinforce_metrics_list_form['len_std'].append(np.std(data_stats['len']))\n",
        "  reinforce_metrics_list_form['max_len'].append(np.max(data_stats['len']))\n",
        "  reinforce_metrics_list_form['min_len'].append(np.min(data_stats['len']))\n",
        "  reinforce_metrics = {}\n",
        "  for k, v in reinforce_metrics_list_form.items():\n",
        "    reinforce_metrics[k] = np.array(v)\n",
        "  # pprint(reinforce_metrics)\n",
        "\n",
        "  clear_output(wait=True)\n",
        "\n",
        "  fig, axs = plt.subplots(1, 4, figsize=(20,5))  # creating 4 subplots\n",
        "\n",
        "  X = np.arange(len(reinforce_metrics['reinforce_loss'])) * reinforce_num_minibatches\n",
        "\n",
        "  # Plotting reinforce_loss\n",
        "  axs[0].plot(X, reinforce_metrics['reinforce_loss'])\n",
        "  axs[0].set_title('REINFORCE Loss')\n",
        "  axs[0].set_xlabel('REINFORCE Steps')\n",
        "\n",
        "  # Plotting success_rate\n",
        "  axs[1].plot(X, reinforce_metrics['success_rate'])\n",
        "  axs[1].set_title('Success Rate')\n",
        "  axs[1].set_xlabel('REINFORCE Steps')\n",
        "\n",
        "  # Plotting return_mean and return_std\n",
        "  axs[2].plot(X, reinforce_metrics['return_mean'], label='Mean')\n",
        "  axs[2].fill_between(X,\n",
        "                      reinforce_metrics['return_mean'] - reinforce_metrics['return_std'],\n",
        "                      reinforce_metrics['return_mean'] + reinforce_metrics['return_std'],\n",
        "                      color='b', alpha=.1, label='Std deviation')\n",
        "  axs[2].set_title('Return Mean and Std')\n",
        "  axs[2].set_xlabel('REINFORCE Steps')\n",
        "\n",
        "  # Plotting len_mean and len_std\n",
        "  axs[3].plot(X, reinforce_metrics['len_mean'], label='Mean')\n",
        "  axs[3].fill_between(X,\n",
        "                      reinforce_metrics['len_mean'] - reinforce_metrics['len_std'],\n",
        "                      reinforce_metrics['len_mean'] + reinforce_metrics['len_std'],\n",
        "                      color='b', alpha=.1, label='Std')\n",
        "  axs[3].set_title('Episode Length Mean and Std')\n",
        "  axs[3].set_xlabel('REINFORCE Steps')\n",
        "\n",
        "  # Adding legend to the plots that need it\n",
        "  axs[2].legend()\n",
        "  axs[3].legend()\n",
        "\n",
        "  plt.show()\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SdU9uEo8dlwh"
      },
      "source": [
        "# Visualize policies after the Stage 2 Self-Improvement process"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "up3ZUE3z7iws"
      },
      "outputs": [],
      "source": [
        "# Had to do this weird code duplication to reuse some other code\n",
        "reinforce_cpu_params = get_from_first_device(\n",
        "    reinforce_state, as_numpy=True).params\n",
        "\n",
        "def _timer_eval_policy(obs, rng):\n",
        "  obs = jax.tree.map(lambda x: x[None], obs)\n",
        "  obs = normalize_obs(obs)\n",
        "  obs = jnp.concatenate(\n",
        "      [obs['cur_pos'], obs['cur_vel'], obs['goal_pos']], axis=-1)\n",
        "  pred = timer_networks.network.apply(reinforce_cpu_params, obs)\n",
        "  act = timer_networks.sample_act(pred.act_dist_params, rng)\n",
        "  # print(act)\n",
        "  act = unnormalize_action(act)\n",
        "  # print(act)\n",
        "  dist_to_succ = distance_converter.network_format_to_distance(\n",
        "      pred.dist_to_succ_dist_params.logits)\n",
        "  extras = {\n",
        "      'pred_dist_to_succ': dist_to_succ,\n",
        "      'pred_dist_to_succ_dist_params': pred.dist_to_succ_dist_params,}\n",
        "  return act[0], extras\n",
        "\n",
        "timer_eval_policy = jax.jit(_timer_eval_policy, backend='cpu')\n",
        "\n",
        "media.show_video(get_trajectory_visualization(timer_eval_policy, 'Stage 2 Self-Improvement Policy'), fps=10)\n"
      ]
    }
  ],
  "metadata": {
    "accelerator": "TPU",
    "colab": {
      "collapsed_sections": [
        "z48CKYN8Zb1k",
        "DtApnonezrGQ",
        "VelqDWf40BgB"
      ],
      "gpuType": "V28",
      "machine_shape": "hm",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3 (ipykernel)",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.10.16"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
