{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "Brax Training with PyTorch on GPU",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "trVNqxHmGISS"
      },
      "source": [
        "# Training in Brax with PyTorch on GPUs\n",
        "\n",
        "Brax is ready to integrate into other research toolkits by way of the [OpenAI Gym](https://gym.openai.com/) interface.  Brax environments convert to Gym environments using either [GymWrapper](https://github.com/google/brax/blob/main/brax/envs/wrappers.py) for single environments, or [VectorGymWrapper](https://github.com/google/brax/blob/main/brax/envs/wrappers.py) for batched (parallelized) environments."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "GJhPpM5ZPrpq"
      },
      "source": [
        "#@title Import Brax and some helper modules\n",
        "from IPython.display import clear_output\n",
        "\n",
        "import collections\n",
        "from datetime import datetime\n",
        "import functools\n",
        "import math\n",
        "import time\n",
        "from typing import Any, Callable, Dict, Optional, Sequence\n",
        "\n",
        "try:\n",
        "  import brax\n",
        "except ImportError:\n",
        "  !pip install git+https://github.com/google/brax.git@main\n",
        "  clear_output()\n",
        "  import brax\n",
        "\n",
        "from brax import envs\n",
        "from brax.envs import to_torch\n",
        "from brax.io import metrics\n",
        "from brax.training import ppo\n",
        "import gym\n",
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "import torch\n",
        "from torch import nn\n",
        "from torch import optim\n",
        "import torch.nn.functional as F\n",
        "\n",
        "# have torch allocate on device first, to prevent JAX from swallowing up all the\n",
        "# GPU memory. By default JAX will pre-allocate 90% of the available GPU memory:\n",
        "# https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html\n",
        "v = torch.ones(1, device='cuda')"
      ],
      "execution_count": 1,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vQFCkfu8Qwre"
      },
      "source": [
        "Here is a PPO Agent written in PyTorch:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "fWJE4b5BHeH7"
      },
      "source": [
        "class Agent(nn.Module):\n",
        "  \"\"\"Standard PPO Agent with GAE and observation normalization.\"\"\"\n",
        "\n",
        "  def __init__(self,\n",
        "               policy_layers: Sequence[int],\n",
        "               value_layers: Sequence[int],\n",
        "               entropy_cost: float,\n",
        "               discounting: float,\n",
        "               reward_scaling: float,\n",
        "               device: str):\n",
        "    super(Agent, self).__init__()\n",
        "\n",
        "    policy = []\n",
        "    for w1, w2 in zip(policy_layers, policy_layers[1:]):\n",
        "      policy.append(nn.Linear(w1, w2))\n",
        "      policy.append(nn.SiLU())\n",
        "    policy.pop()  # drop the final activation\n",
        "    self.policy = nn.Sequential(*policy)\n",
        "\n",
        "    value = []\n",
        "    for w1, w2 in zip(value_layers, value_layers[1:]):\n",
        "      value.append(nn.Linear(w1, w2))\n",
        "      value.append(nn.SiLU())\n",
        "    value.pop()  # drop the final activation\n",
        "    self.value = nn.Sequential(*value)\n",
        "\n",
        "    self.num_steps = torch.zeros((), device=device)\n",
        "    self.running_mean = torch.zeros(policy_layers[0], device=device)\n",
        "    self.running_variance = torch.zeros(policy_layers[0], device=device)\n",
        "\n",
        "    self.entropy_cost = entropy_cost\n",
        "    self.discounting = discounting\n",
        "    self.reward_scaling = reward_scaling\n",
        "    self.lambda_ = 0.95\n",
        "    self.epsilon = 0.3\n",
        "    self.device = device\n",
        "\n",
        "  @torch.jit.export\n",
        "  def dist_create(self, logits):\n",
        "    \"\"\"Normal followed by tanh.\n",
        "\n",
        "    torch.distribution doesn't work with torch.jit, so we roll our own.\"\"\"\n",
        "    loc, scale = torch.split(logits, logits.shape[-1] // 2, dim=-1)\n",
        "    scale = F.softplus(scale) + .001\n",
        "    return loc, scale\n",
        "\n",
        "  @torch.jit.export\n",
        "  def dist_sample_no_postprocess(self, loc, scale):\n",
        "    return torch.normal(loc, scale)\n",
        "\n",
        "  @classmethod\n",
        "  def dist_postprocess(cls, x):\n",
        "    return torch.tanh(x)\n",
        "\n",
        "  @torch.jit.export\n",
        "  def dist_entropy(self, loc, scale):\n",
        "    log_normalized = 0.5 * math.log(2 * math.pi) + torch.log(scale)\n",
        "    entropy = 0.5 + log_normalized\n",
        "    entropy = entropy * torch.ones_like(loc)\n",
        "    dist = torch.normal(loc, scale)\n",
        "    log_det_jacobian = 2 * (math.log(2) - dist - F.softplus(-2 * dist))\n",
        "    entropy = entropy + log_det_jacobian\n",
        "    return entropy.sum(dim=-1)\n",
        "\n",
        "  @torch.jit.export\n",
        "  def dist_log_prob(self, loc, scale, dist):\n",
        "    log_unnormalized = -0.5 * ((dist - loc) / scale).square()\n",
        "    log_normalized = 0.5 * math.log(2 * math.pi) + torch.log(scale)\n",
        "    log_det_jacobian = 2 * (math.log(2) - dist - F.softplus(-2 * dist))\n",
        "    log_prob = log_unnormalized - log_normalized - log_det_jacobian\n",
        "    return log_prob.sum(dim=-1)\n",
        "\n",
        "  @torch.jit.export\n",
        "  def update_normalization(self, observation):\n",
        "    self.num_steps += observation.shape[0] * observation.shape[1]\n",
        "    input_to_old_mean = observation - self.running_mean\n",
        "    mean_diff = torch.sum(input_to_old_mean / self.num_steps, dim=(0, 1))\n",
        "    self.running_mean = self.running_mean + mean_diff\n",
        "    input_to_new_mean = observation - self.running_mean\n",
        "    var_diff = torch.sum(input_to_new_mean * input_to_old_mean, dim=(0, 1))\n",
        "    self.running_variance = self.running_variance + var_diff\n",
        "\n",
        "  @torch.jit.export\n",
        "  def normalize(self, observation):\n",
        "    variance = self.running_variance / (self.num_steps + 1.0)\n",
        "    variance = torch.clip(variance, 1e-6, 1e6)\n",
        "    return ((observation - self.running_mean) / variance.sqrt()).clip(-5, 5)\n",
        "\n",
        "  @torch.jit.export\n",
        "  def get_logits_action(self, observation):\n",
        "    observation = self.normalize(observation)\n",
        "    logits = self.policy(observation)\n",
        "    loc, scale = self.dist_create(logits)\n",
        "    action = self.dist_sample_no_postprocess(loc, scale)\n",
        "    return logits, action\n",
        "\n",
        "  @torch.jit.export\n",
        "  def compute_gae(self, truncation, termination, reward, values,\n",
        "                  bootstrap_value):\n",
        "    truncation_mask = 1 - truncation\n",
        "    # Append bootstrapped value to get [v1, ..., v_t+1]\n",
        "    values_t_plus_1 = torch.cat(\n",
        "        [values[1:], torch.unsqueeze(bootstrap_value, 0)], dim=0)\n",
        "    deltas = reward + self.discounting * (\n",
        "        1 - termination) * values_t_plus_1 - values\n",
        "    deltas *= truncation_mask\n",
        "\n",
        "    acc = torch.zeros_like(bootstrap_value)\n",
        "    vs_minus_v_xs = torch.zeros_like(truncation_mask)\n",
        "\n",
        "    for ti in range(truncation_mask.shape[0]):\n",
        "      ti = truncation_mask.shape[0] - ti - 1\n",
        "      acc = deltas[ti] + self.discounting * (\n",
        "          1 - termination[ti]) * truncation_mask[ti] * self.lambda_ * acc\n",
        "      vs_minus_v_xs[ti] = acc\n",
        "\n",
        "    # Add V(x_s) to get v_s.\n",
        "    vs = vs_minus_v_xs + values\n",
        "    vs_t_plus_1 = torch.cat([vs[1:], torch.unsqueeze(bootstrap_value, 0)], 0)\n",
        "    advantages = (reward + self.discounting *\n",
        "                  (1 - termination) * vs_t_plus_1 - values) * truncation_mask\n",
        "    return vs, advantages\n",
        "\n",
        "  @torch.jit.export\n",
        "  def loss(self, td: Dict[str, torch.Tensor]):\n",
        "    observation = self.normalize(td['observation'])\n",
        "    policy_logits = self.policy(observation[:-1])\n",
        "    baseline = self.value(observation)\n",
        "    baseline = torch.squeeze(baseline, dim=-1)\n",
        "\n",
        "    # Use last baseline value (from the value function) to bootstrap.\n",
        "    bootstrap_value = baseline[-1]\n",
        "    baseline = baseline[:-1]\n",
        "    reward = td['reward'] * self.reward_scaling\n",
        "    termination = td['done'] * (1 - td['truncation'])\n",
        "\n",
        "    loc, scale = self.dist_create(td['logits'])\n",
        "    behaviour_action_log_probs = self.dist_log_prob(loc, scale, td['action'])\n",
        "    loc, scale = self.dist_create(policy_logits)\n",
        "    target_action_log_probs = self.dist_log_prob(loc, scale, td['action'])\n",
        "\n",
        "    with torch.no_grad():\n",
        "      vs, advantages = self.compute_gae(\n",
        "          truncation=td['truncation'],\n",
        "          termination=termination,\n",
        "          reward=reward,\n",
        "          values=baseline,\n",
        "          bootstrap_value=bootstrap_value)\n",
        "\n",
        "    rho_s = torch.exp(target_action_log_probs - behaviour_action_log_probs)\n",
        "    surrogate_loss1 = rho_s * advantages\n",
        "    surrogate_loss2 = rho_s.clip(1 - self.epsilon,\n",
        "                                 1 + self.epsilon) * advantages\n",
        "    policy_loss = -torch.mean(torch.minimum(surrogate_loss1, surrogate_loss2))\n",
        "\n",
        "    # Value function loss\n",
        "    v_error = vs - baseline\n",
        "    v_loss = torch.mean(v_error * v_error) * 0.5 * 0.5\n",
        "\n",
        "    # Entropy reward\n",
        "    entropy = torch.mean(self.dist_entropy(loc, scale))\n",
        "    entropy_loss = self.entropy_cost * -entropy\n",
        "\n",
        "    return policy_loss + v_loss + entropy_loss"
      ],
      "execution_count": 2,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CWbuk7IAR0SU"
      },
      "source": [
        "Finally, some code for unrolling and batching environment data:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "D3y5o7-oSBm-"
      },
      "source": [
        "StepData = collections.namedtuple(\n",
        "    'StepData',\n",
        "    ('observation', 'logits', 'action', 'reward', 'done', 'truncation'))\n",
        "\n",
        "\n",
        "def sd_map(f: Callable[..., torch.Tensor], *sds) -> StepData:\n",
        "  \"\"\"Map a function over each field in StepData.\"\"\"\n",
        "  items = {}\n",
        "  keys = sds[0]._asdict().keys()\n",
        "  for k in keys:\n",
        "    items[k] = f(*[sd._asdict()[k] for sd in sds])\n",
        "  return StepData(**items)\n",
        "\n",
        "\n",
        "def eval_unroll(agent, env, length):\n",
        "  \"\"\"Return number of episodes and average reward for a single unroll.\"\"\"\n",
        "  observation = env.reset()\n",
        "  episodes = torch.zeros((), device=agent.device)\n",
        "  episode_reward = torch.zeros((), device=agent.device)\n",
        "  for _ in range(length):\n",
        "    _, action = agent.get_logits_action(observation)\n",
        "    observation, reward, done, _ = env.step(Agent.dist_postprocess(action))\n",
        "    episodes += torch.sum(done)\n",
        "    episode_reward += torch.sum(reward)\n",
        "  return episodes, episode_reward / episodes\n",
        "\n",
        "\n",
        "def train_unroll(agent, env, observation, num_unrolls, unroll_length):\n",
        "  \"\"\"Return step data over multple unrolls.\"\"\"\n",
        "  sd = StepData([], [], [], [], [], [])\n",
        "  for _ in range(num_unrolls):\n",
        "    one_unroll = StepData([observation], [], [], [], [], [])\n",
        "    for _ in range(unroll_length):\n",
        "      logits, action = agent.get_logits_action(observation)\n",
        "      observation, reward, done, info = env.step(Agent.dist_postprocess(action))\n",
        "      one_unroll.observation.append(observation)\n",
        "      one_unroll.logits.append(logits)\n",
        "      one_unroll.action.append(action)\n",
        "      one_unroll.reward.append(reward)\n",
        "      one_unroll.done.append(done)\n",
        "      one_unroll.truncation.append(info['truncation'])\n",
        "    one_unroll = sd_map(torch.stack, one_unroll)\n",
        "    sd = sd_map(lambda x, y: x + [y], sd, one_unroll)\n",
        "  td = sd_map(torch.stack, sd)\n",
        "  return observation, td\n",
        "\n",
        "\n",
        "def train(\n",
        "    env_name: str = 'ant',\n",
        "    num_envs: int = 2048,\n",
        "    episode_length: int = 1000,\n",
        "    device: str = 'cuda',\n",
        "    num_timesteps: int = 30_000_000,\n",
        "    eval_frequency: int = 10,\n",
        "    unroll_length: int = 5,\n",
        "    batch_size: int = 1024,\n",
        "    num_minibatches: int = 32,\n",
        "    num_update_epochs: int = 4,\n",
        "    reward_scaling: float = .1,\n",
        "    entropy_cost: float = 1e-2,\n",
        "    discounting: float = .97,\n",
        "    learning_rate: float = 3e-4,\n",
        "    progress_fn: Optional[Callable[[int, Dict[str, Any]], None]] = None,\n",
        "):\n",
        "  \"\"\"Trains a policy via PPO.\"\"\"\n",
        "  gym_name = f'brax-{env_name}-v0'\n",
        "  if gym_name not in gym.envs.registry.env_specs:\n",
        "    entry_point = functools.partial(envs.create_gym_env, env_name=env_name)\n",
        "    gym.register(gym_name, entry_point=entry_point)\n",
        "  env = gym.make(gym_name, batch_size=num_envs, episode_length=episode_length)\n",
        "  # automatically convert between jax ndarrays and torch tensors:\n",
        "  env = to_torch.JaxToTorchWrapper(env, device=device)\n",
        "\n",
        "  # env warmup\n",
        "  env.reset()\n",
        "  action = torch.zeros(env.action_space.shape).to(device)\n",
        "  env.step(action)\n",
        "\n",
        "  # create the agent\n",
        "  policy_layers = [\n",
        "      env.observation_space.shape[-1], 64, 64, env.action_space.shape[-1] * 2\n",
        "  ]\n",
        "  value_layers = [env.observation_space.shape[-1], 64, 64, 1]\n",
        "  agent = Agent(policy_layers, value_layers, entropy_cost, discounting,\n",
        "                reward_scaling, device)\n",
        "  agent = torch.jit.script(agent.to(device))\n",
        "  optimizer = optim.Adam(agent.parameters(), lr=learning_rate)\n",
        "\n",
        "  sps = 0\n",
        "  total_steps = 0\n",
        "  total_loss = 0\n",
        "  for eval_i in range(eval_frequency + 1):\n",
        "    if progress_fn:\n",
        "      t = time.time()\n",
        "      with torch.no_grad():\n",
        "        episode_count, episode_reward = eval_unroll(agent, env, episode_length)\n",
        "      duration = time.time() - t\n",
        "      # TODO: only count stats from completed episodes\n",
        "      episode_avg_length = env.num_envs * episode_length / episode_count\n",
        "      eval_sps = env.num_envs * episode_length / duration\n",
        "      progress = {\n",
        "          'eval/episode_reward': episode_reward,\n",
        "          'eval/completed_episodes': episode_count,\n",
        "          'eval/avg_episode_length': episode_avg_length,\n",
        "          'speed/sps': sps,\n",
        "          'speed/eval_sps': eval_sps,\n",
        "          'losses/total_loss': total_loss,\n",
        "      }\n",
        "      progress_fn(total_steps, progress)\n",
        "\n",
        "    if eval_i == eval_frequency:\n",
        "      break\n",
        "\n",
        "    observation = env.reset()\n",
        "    num_steps = batch_size * num_minibatches * unroll_length\n",
        "    num_epochs = num_timesteps // (num_steps * eval_frequency)\n",
        "    num_unrolls = batch_size * num_minibatches // env.num_envs\n",
        "    total_loss = 0\n",
        "    t = time.time()\n",
        "    for _ in range(num_epochs):\n",
        "      observation, td = train_unroll(agent, env, observation, num_unrolls,\n",
        "                                     unroll_length)\n",
        "\n",
        "      # make unroll first\n",
        "      def unroll_first(data):\n",
        "        data = data.swapaxes(0, 1)\n",
        "        return data.reshape([data.shape[0], -1] + list(data.shape[3:]))\n",
        "      td = sd_map(unroll_first, td)\n",
        "\n",
        "      # update normalization statistics\n",
        "      agent.update_normalization(td.observation)\n",
        "\n",
        "      for _ in range(num_update_epochs):\n",
        "        # shuffle and batch the data\n",
        "        with torch.no_grad():\n",
        "          permutation = torch.randperm(td.observation.shape[1], device=device)\n",
        "          def shuffle_batch(data):\n",
        "            data = data[:, permutation]\n",
        "            data = data.reshape([data.shape[0], num_minibatches, -1] +\n",
        "                                list(data.shape[2:]))\n",
        "            return data.swapaxes(0, 1)\n",
        "          epoch_td = sd_map(shuffle_batch, td)\n",
        "\n",
        "        for minibatch_i in range(num_minibatches):\n",
        "          td_minibatch = sd_map(lambda d: d[minibatch_i], epoch_td)\n",
        "          loss = agent.loss(td_minibatch._asdict())\n",
        "          optimizer.zero_grad()\n",
        "          loss.backward()\n",
        "          optimizer.step()\n",
        "          total_loss += loss.detach()\n",
        "\n",
        "    duration = time.time() - t\n",
        "    total_steps += num_epochs * num_steps\n",
        "    total_loss = total_loss / (num_epochs * num_update_epochs * num_minibatches)\n",
        "    sps = num_epochs * num_steps / duration"
      ],
      "execution_count": 3,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "R2A9MMlHUajH"
      },
      "source": [
        "Let's go!"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "B-lrKHvkUeYM",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 368
        },
        "outputId": "409e6da6-d877-4955-e0d0-ceaaae24e2ae"
      },
      "source": [
        "xdata = []\n",
        "ydata = []\n",
        "eval_sps = []\n",
        "train_sps = []\n",
        "times = [datetime.now()]\n",
        "\n",
        "def progress(num_steps, metrics):\n",
        "  times.append(datetime.now())\n",
        "  xdata.append(num_steps)\n",
        "  ydata.append(metrics['eval/episode_reward'])\n",
        "  eval_sps.append(metrics['speed/eval_sps'])\n",
        "  train_sps.append(metrics['speed/sps'])\n",
        "  clear_output(wait=True)\n",
        "  plt.xlim([0, 30_000_000])\n",
        "  plt.ylim([0, 6000])\n",
        "  plt.xlabel('# environment steps')\n",
        "  plt.ylabel('reward per episode')\n",
        "  plt.plot(xdata, ydata)\n",
        "  plt.show()\n",
        "\n",
        "train(progress_fn=progress)\n",
        "\n",
        "print(f'time to jit: {times[1] - times[0]}')\n",
        "print(f'time to train: {times[-1] - times[1]}')\n",
        "print(f'eval steps/sec: {np.mean(eval_sps[1:])}')\n",
        "print(f'train steps/sec: {np.mean(train_sps[1:])}')\n",
        "!nvidia-smi -L"
      ],
      "execution_count": 6,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZMAAAEKCAYAAADXdbjqAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3dd3xV9f3H8deHDWGELXtvAYUAYh2gCGqtqMU90Fr3qO1P62gr1lXtcLYOWlGsA3CCoihDxAkEQfZeIewRdvbn98c5qREZNyQ3Nzd5Px+P+8g5557xOdxwPznfae6OiIhIYZSLdQAiIhL/lExERKTQlExERKTQlExERKTQlExERKTQlExERKTQoppMzCzRzN42s8VmtsjM+ppZHTObaGbLwp+1w33NzJ4xs+VmNtfMeuQ7z9Bw/2VmNjSaMYuISMFF+8nkaWCCu3cEugOLgHuAye7eDpgcrgOcBbQLX9cDzwOYWR1gGNAH6A0My0tAIiJSMkQtmZhZLeAU4CUAd8909zRgMDAy3G0kcF64PBh41QPfAolm1ggYBEx09+3uvgOYCJwZrbhFRKTgKkTx3K2ALcDLZtYdmAX8Bmjo7hvCfTYCDcPlJkBKvuPXhdsOtf1HzOx6gicaEhISenbs2LHo7kREJI5k5eSyY18W2/dmkpWTS3kzEqtVpE5CJapULH/I42bNmrXV3esfzTWjmUwqAD2A29x9upk9zQ9FWgC4u5tZkYzn4u7DgeEASUlJnpycXBSnFRGJC5nZuUxZvIlRM1OYtnQL5RwGt63Lxb2aM7Bzw8MmkTxmtuZorx/NZLIOWOfu08P1twmSySYza+TuG8JirM3h+6lAs3zHNw23pQL9Dtg+NYpxi4jEjeWb9zAmOYV3v1vH1j2ZNKxZmVv6t+XCns1oXrdascURtWTi7hvNLMXMOrj7EuB0YGH4Ggo8Fv4cGx4yDrjVzEYRVLbvDBPOJ8Cj+SrdBwL3RituEZGSbl9mNuPnbmBMcgozV++gQjnj9E4NuLhXM05pV58K5Yu/10c0n0wAbgNeN7NKwErgGoJK/zFmdi2wBrgo3Pcj4GxgObAv3Bd3325mDwEzw/0edPftUY5bRKREcXfmpe5k1MwUxs1Zz56MbFrXS+CeszpyQY8mNKhRJabxWWkcgl51JiJSWqTty+T92amMmpnC4o27qVKxHGd3bcQlvZrTq2VtzKzIrmVms9w96WiOjfaTiYiIFFBurvPtym2MmpnChAUbyczOpWuTWjx83rGce1xjalapGOsQf0LJRESkhNi4M523Z6UwJnkda7fvo2aVClzaqxkX9WpGl8a1Yh3eYSmZiIjEUFZOLlMWb2b0zBSmLtlMrkPf1nX5v4HtGdTlmIia9JYESiYiIjGwcsseRien8M6sVLbuyaBBjcrc1K8NFyU1o0XdhFiHV2BKJiIixSQ7J5cP5q7nzRkpzFi1nfLljNM6NuDipGb06xCbJr1FRclERCTK3J2pS7fwyPhFLN+8h5Z1q/H7MzswpEdTGtSMbZPeoqJkIiISRYs37uKR8Yv4YtlWWtVL4MUrezKwc8MibdJbEiiZiIhEwZbdGTwxcSmjZ66lRpWK3H9OZ644oQWVKsRvUdbhKJmIiBSh9KwcRny1iuc+W0F6Vg5Xn9iK209vS2K1SrEOLaqUTEREioC788HcDTz+8WJS0/ZzRueG3HtWR1rXrx7r0IqFkomISCF9t3YHD324kNlr0+jcqCZ/u7AbJ7apF+uwipWSiYjIUUrZvo+/frKED75fT4MalfnrkG78skdTypcrXZXrkVAyEREpoN3pWTw3dQUvfbmKcga3n96OG05pTULlsvuVWnbvXESkgLJzchmTvI4nJi5h655MLji+CXcO6kDjxKqxDi3mlExERCIwLex0uGTTbnq3rMOIqzvRrWlirMMqMZRMREQOY9mm3Tzy0SKmLtlC8zrVeP7yHpx57DGlrtNhYSmZiIgcxLY9GTw1aRlvzFhLtUrl+cPZnbjqxBZUrhAfo/gWNyUTEZF8MrJzeOWr1fxzynL2ZeVwRZ/m/GZAe+oklO5Oh4WlZCIiQtDp8OP5G/nLx4tI2b6f0zo24L6zO9K2QY1YhxYXlExEpMybk5LGwx8uJHnNDjoeU4P/Xtubk9vVj3VYcUXJRETKrPVp+/nrhMW8P2c99apX4i8XdOWipGZlstNhYSmZiEiZszcjmxc+X8HwaStx4Jb+bbipX1uql+FOh4WlfzkRKTNycp23Z6Xw90+XsmV3BoOPa8xdgzrQtHa1WIcW95RMRKRMmLFqO8PGLWDRhl30aJ7Ii1f2pEfz2rEOq9RQMhGRUm3z7nQe+2gx785OpXGtKjx76fGc062ROh0WMSUTESmVsnNyGfnNGp6auJSM7Fxu7d+WW/q3pWoldTqMBiUTESl1pq/cxv1jF7Bk025ObV+fB87tQqt6CbEOq1SLajIxs9XAbiAHyHb3JDOrA4wGWgKrgYvcfYcFz5xPA2cD+4Cr3f278DxDgT+Gp33Y3UdGM24RiU+bd6Xz6EeLeH/OepokVuXFK3sysHNDFWkVg+J4Munv7lvzrd8DTHb3x8zsnnD9buAsoF346gM8D/QJk88wIAlwYJaZjXP3HcUQu4jEgaycXEZ+vZqnJi0jMzuX205ry839VKRVnGJRzDUY6BcujwSmEiSTwcCr7u7At2aWaGaNwn0nuvt2ADObCJwJvFm8YYtISfTNim0MGzefpZv20L9DfYb9ogstVaRV7KKdTBz41MwceNHdhwMN3X1D+P5GoGG43ARIyXfsunDbobaLSBm2aVc6j4xfxLjv19O0dlX+fVUSAzo1UJFWjEQ7mZzk7qlm1gCYaGaL87/p7h4mmkIzs+uB6wGaN29eFKcUkRIoKyeXl79axdOTlpGV69x+ejtu7teGKhVVpBVLUU0m7p4a/txsZu8BvYFNZtbI3TeExVibw91TgWb5Dm8abkvlh2KxvO1TD3Kt4cBwgKSkpCJJUCJSsny9YivDxi5g2eY9nNaxAcN+0ZkWdVWkVRKUi9aJzSzBzGrkLQMDgfnAOGBouNtQYGy4PA64ygInADvD4rBPgIFmVtvMaofn+SRacYtIybNxZzq3vvEdl/17OunZOfznqiRGXN1LiaQEieaTSUPgvbD8sgLwhrtPMLOZwBgzuxZYA1wU7v8RQbPg5QRNg68BcPftZvYQMDPc78G8yngRKd0ys3MZ8dUqnpm8jOxc544B7bjxVBVplUQWNJ4qXZKSkjw5OTnWYYhIIXy1fCv3j53Pii17GdCpAfef04XmdTUgYzSZ2Sx3TzqaY9UDXkRKlA079/Pwh4sYP28DzetUY8TVSZzWseGRD5SYUjIRkRIhMzuXl75cxbNTlpGT6/x2QHtuOLW1irTihJKJiMTcF8u2MGzcAlZu2csZnRty/zmdaVZHRVrxRMlERGJmfdp+Hh6/kI/mbaRF3Wq8fHUv+ndsEOuw5CgomYhIscvIzuE/X6zin1OW4zj/d0Z7rjtFRVrxTMlERIrV50u38MC4BazaupeBnRvyJxVplQpKJiJSLFLT9vPQBwuZsGAjLetW45VretGvg4q0SgslExGJqn2Z2bz4+UpenLYCgDsHBkValSuoSKs0UTIRkajIzXXGfp/K4x8vYeOudH7erRH3ntWRprVVpFUaKZmISJH7bu0OHvxgIXNS0ujapBbPXnY8vVrWiXVYEkVKJiJSZNan7efxCYsZO2c9DWpU5u8XdueC45tQrpzmGCntlExEpND2ZWbzwucrGT5tBbkOt/Zvy0392pBQWV8xZYU+aRE5agfWi5zTrRH3qF6kTFIyEZGjMmvNDh78cCHfp6TRrWkt/nnZ8SSpXqTMUjIRkQJZn7afxz5ezLjvVS8iP1AyEZGI5K8XcYfbTmvLjaeqXkQC+i0QkcPKzXXen5PK4xMWs2lXBr/o3pi7z+ygehH5ESUTETmkWWu28+AHC/l+3U66Na3Fvy7roXoROSglExH5idS0/Twe1os0rFmZf1zYnfNVLyKHoWQiIv+zLzObF6au4MVpKwG4/bS23KB6EYmAfkNE5KD1Ivec1ZEmiVVjHZrEiSMmEzNrCDwKNHb3s8ysM9DX3V+KenQiEnX560W6N63Fc5f3oGcL1YtIwUTyZPIK8DLwh3B9KTAaUDIRiWOpYX+RD8J6kScu6s55x6leRI5OJMmknruPMbN7Adw928xyohyXiETJ3oxsXvw8X73I6e248dTWVKukUm85epH89uw1s7qAA5jZCcDOqEYlIkUuN9d5b3Yqf/0kqBc5t3tj7la9iBSRSJLJ74BxQBsz+wqoDwyJalQiUqSSV2/noQ/DepFmiaoXkSJ3xGTi7t+Z2alAB8CAJe6eFfXIRKTQvk9J48lJS5m6ZAvH1KzCkxd3Z3B31YtI0TtkMjGzCw7xVnszw93fjVJMIlJI81N38tSkpUxatJna1Spyz1kduapvC9WLSNQc7jfrF+HPBsCJwJRwvT/wNRBRMjGz8kAykOru55hZK2AUUBeYBVzp7plmVhl4FegJbAMudvfV4TnuBa4FcoDb3f2TiO9QpAxZtGEXT01ayicLNlGrakXuGtSBoSe2pLo6HUqUHfI3zN2vATCzT4HO7r4hXG9E0Fw4Ur8BFgE1w/XHgSfdfZSZvUCQJJ4Pf+5w97Zmdkm438Vhv5ZLgC5AY2CSmbV3d7UoEwkt27SbpyYtY/y8DdSoXIE7BrTjVye1omaVirEOTcqISP5caZaXSEKbgOaRnNzMmgI/Bx4BfmdmBpwGXBbuMhJ4gCCZDA6XAd4G/hnuPxgY5e4ZwCozWw70Br6JJAaR0mzFlj08M3kZ475fT7WK5bnttLb8+qTW1KqmJCLFK5JkMtnMPgHeDNcvBiZFeP6ngN8DNcL1ukCau2eH6+uAJuFyEyAF/teXZWe4fxPg23znzH/M/5jZ9cD1AM2bR5TrROLW6q17eWbKMt6fnUrlCuW54ZQ2XH9Ka+okVIp1aFJGRdKa61YzOx84Jdw03N3fO9JxZnYOsNndZ5lZv8KFeWTuPhwYDpCUlOTRvp5ILKRs38ezU5bxznepVChnXHtSK244tQ31qleOdWhSxkVaK/c1kE3QcXFGhMf8DDjXzM4GqhDUmTwNJJpZhfDppCmQGu6fCjQD1plZBaAWQUV83vY8+Y8RKRNS0/bzr8+WM2ZmCuXKGVf1bcFNp7ahQc0qsQ5NBIhsoMeLgL8BUwn6mTxrZne5+9uHO87d7wXuDc/RD7jT3S83s7cIOj2OAoYCY8NDxoXr34TvT3F3N7NxwBtm9gRBBXw7Ik9oInFt4850npu6nFEzUnCcy/o05+Z+bTmmlpKIlCyRPJn8Aejl7psBzKw+QZ3JYZPJYdwNjDKzh4HZ/DBg5EvAf8MK9u0ELbhw9wVmNgZYSPB0dItacklpt3l3Os9PXcHr09eSm+tc1KsZt/Rvq6FPpMSKJJmUy0skoW1AuYJcxN2nEjzZ4O4rCVpjHbhPOnDhIY5/hKBFmEiptnVPBi9+voL/fruGrBznlz2acNtp7WhWR/OtS8kWSTKZcJDWXB9FLySRsmfH3kxenLaSkV+vJiM7h/OOb8Ltp7WjZb2EWIcmEpFIWnPdFQ6tclK4KaLWXCJyZDv3ZfGfL1cy4stV7MvK4dzujbn99Ha0qV891qGJFEgkFfAJwFh3f9fMOgAdzKyiBnsUOXq70rMY8eUqXvpiFbszsvl510bcMaAd7RrWOPLBIiVQJMVc04CTzaw2MIFgnK2LgcujGZhIabQnI5tXvlrF8Gkr2ZWezaAuDbljQHs6Nap55INFSrBIkom5+z4zuxZ43t3/amZzoh2YSGmyLzObV79Zw4ufr2DHviwGdGrAHQPac2yTWrEOTaRIRJRMzKwvwZPIteG28tELSaT0SM/K4bVv1/D81BVs25tJvw71uWNAe45rlhjr0ESKVCTJ5A6CzofvhX0+WgOfRTcskfiWmZ3L6JlreXbKcjbvzuBnbevyuzPaa3ZDKbUiac31OfB5vvWVwO3RDEokXmXn5PLu7FSenrSM1LT99GpZm2cuPZ4TWteNdWgiUXW4mRafcvc7zOwDgjG5fsTdz41qZCJxJDfX+WDuep6atIxVW/fSrWktHr2gK6e0q0cwk4JI6Xa4J5P/hj//XhyBiMQjd+fThZt44tOlLNm0m47H1GD4lT05o3NDJREpUw430+Ks8OfnZlYJ6EjwhLLE3TOLKT6REsnd+XzpFp6YuJS563bSul4Cz1x6POd0bUS5ckoiUvZE0mnx58ALwAqCUYNbmdkN7v5xtIMTKYm+XbmNf3y6hJmrd9C0dlX+NqQb5x/fhArlCzRknUipEklrrn8A/d19OYCZtQHGA0omUqZ8t3YHT3y6lC+Xb6Vhzco8dN6xXJzUjEoVlEREIkkmu/MSSWglsDtK8YiUOAvW7+SJT5cyefFm6iZU4o8/78QVJ7SgSkV1txLJE0kySTazj4AxBHUmFwIzw8Efcfd3oxifSMws37ybJycuY/y8DdSsUoG7BnXg6hNbklA50glKRcqOSP5XVAE2AaeG61uAqsAvCJKLkomUKmu27eXpSct4f04qVSuW5/bT2nLtya2pVbVirEMTKbEi6bR4TXEEIhJr69P28+yU5byVnEL5csavT27Njae2oU5CpViHJlLiRdKaqz3wPNDQ3Y81s27Aue7+cNSjEykGm3en89xnK3hj+loc5/I+zbmlf1sa1NQ86yKRiqSY69/AXcCLAO4+18zeAJRMJK7ln90wMyeXIT2actvpbWlaW1PkihRUJMmkmrvPOKA3b3aU4hGJul3pWbz0xSpe+nIVezOzGdy9Mb8Z0J5WmiJX5KhFkky2hn1LHMDMhgAbohqVSBTsy8xm5NdreHHaCtL2ZXFml2P47Rnt6XCMZjcUKaxIksktwHCgo5mlAqvQLIsSR9Kzcnhj+lqem7qcrXsy6d+hPr87owNdm2piKpGiEklrrpXAgHAu+HLurg6LEhdyc523Z63jyUlL2bAznb6t6/LilZpTRCQaIu595e57oxmISFGat24nfxo7nzkpaRzfPJF/XNidE9vWi3VYIqWWuvJKqZK2L5O/f7qE16evpW5CZZ64qDvnH99Ew8GLRNlhk4mZlQNOcPeviykekaOSV6T12ITFpO3LZGjflvz2jPbqtS5STA6bTNw918z+BRxfTPGIFNj81KBIa/baNJJa1ObBwX3o3LhmrMMSKVMiGTt7spn90gpYTmBmVcxshpl9b2YLzOzP4fZWZjbdzJab2ehw4i3MrHK4vjx8v2W+c90bbl9iZoMKEoeUXjv3Z3H/2Pmc+88vWbttH3+/sDtjbuirRCISA5HUmdwA/A7IMbP9BBNkubsf6X9sBnCau+8xs4rAl2b2cXiuJ919lJm9AFxLMFzLtcAOd29rZpcAjwMXm1ln4BKgC9AYmGRm7d09p+C3K6VBbq7zznfreOzjxezYl8mVJ7TgdwM7qEhLJIYiaRp8VD263N2BPeFqxfDlwGnAZeH2kcADBMlkcLgM8Dbwz/BpaDAwyt0zgFVmthzoDXxzNHFJfFu4fhf3j51P8pod9GieyMhf9ebYJuovIhJrkQz0aASdFFu5+0Nm1gxo5O4zIji2PDALaAv8i2Dq3zR3zxuOZR3QJFxuAqQAuHu2me0E6obbv8132vzH5L/W9cD1AM2bNz9SaBJndu7P4smJS3n1m9UkVqvEX4d0Y0iPpppvXaSEiKSY6zkgl+CJ4iGCp41/Ab2OdGBYFHWcmSUC7wEdjz7UI15rOEFPfZKSkjxa15Hi5e68NzuVRz9azLa9GVzRpwV3DuxArWoq0hIpSSJJJn3cvYeZzQZw9x15leaRcvc0M/sM6AskmlmF8OmkKZAa7pYKNAPWmVkFoBawLd/2PPmPkVJs0YagSGvm6h0c1yyRl6/upSFQREqoSFpzZYXFVXkDPdYneFI5LDOrHz6RYGZVgTOARcBnwJBwt6HA2HB5XLhO+P6UsN5lHHBJ2NqrFdAOOGIRm8SvXelZPPjBQs559kuWb97D47/syrs3nahEIlKCRfJk8gxBEVVDM3uE4Iv+jxEc1wgYGSaicsAYd//QzBYCo8zsYWA28FK4/0vAf8MK9u0ELbhw9wVmNgZYSDD0/S1qyVU6uTtj56znkY8WsXVPBpf1bs5dgzqQWE0zHYqUdBb88X+Encw6AqeHq1PcfVFUoyqkpKQkT05OjnUYUgBLNu7mT2PnM2PVdro3rcVD5x1Lt6aJsQ5LpEwxs1nunnQ0x0Y6Nlc1IK+oq+rRXEjkYHanZ/H0pGW8/PVqalSpwF8u6MrFSc3USkskzkTSNPh+4ELgHYIOiy+b2VuaA14Kw90Z9/16Hhm/iC17MrikV3N+P6gDtRNUpCUSjyJ5Mrkc6O7u6QBm9hgwB80BL0dp2aagSOvbldvp1rQWw69K4rhmKtISiWeRJJP1QBUgPVyvjJrmylHYk5HNM5OXMeLLVSRUrsAj5x/LJb2aU15FWiJxL5JkshNYYGYTCepMzgBmmNkzAO5+exTjk1LA3flw7gYeHr+QTbsyuKRXM35/ZkfqqEhLpNSIJJm8F77yTI1OKFIaLd+8m/vHLuDrFds4tklNnr+iJz2a1451WCJSxCIZ6HFkcQQipUt6Vg7PTF7G8GkrqVapPA+ddyyX9VaRlkhppWl7pcjNXL2du9+ey8qtexnSsyn3ntWRutUrxzosEYkiJRMpMnsysvnbhMW8+u0amiRW5b/X9ubkdvVjHZaIFAMlEykS05Zu4d5357F+536G9m3JXYM6kFBZv14iZcUh/7eb2QeEgzsejLufG5WIJK6k7cvk4fGLeHvWOtrUT+DtG/vSs0WdWIclIsXscH86/j38eQFwDPBauH4psCmaQUl8mDB/A398fwE79mVya/+23HpaW6pULB/rsEQkBg6ZTNz9cwAz+8cBA399YGYaRbEM27w7nWFjF/Dx/I10aVyTkb/qRZfGGh5epCyLpFA7wcxau/tKgHBOkYTohiUlkbvzznepPPThQvZn5fD7Mztw3cmtqVg+kmlxRKQ0iySZ3AFMNbOVBAM9tiCca13KjtS0/dz37jw+X7qFpBa1eXxIN9rUrx7rsESkhDhsMjGzcgTT57bjh/nbF7t7RrQDk5IhN9d5bfoaHv94MQ78+dwuXHlCCw0RLyI/cthk4u65ZvZ7dx8DfF9MMUkJsXLLHu5+Zy4zV+/g5Hb1ePT8rjSrUy3WYYlICRRJMdckM7sTGA3szdvo7tujFpXEVHZOLv/+YhVPTlpKlQrl+NuQbgzp2RQzPY2IyMFFkkwuDn/ekm+bA62LPhyJtYXrd/H7d75nfuouzuxyDA+e14UGNarEOiwRKeEiGeixVXEEIrGVkZ3DP6cs5/mpK0isVonnL+/BWV0bxTosEYkTEY13YWbHAp0JJskCwN1fjVZQUrxmrdnB3e/MZfnmPfyyR1P+dE4nEqtprhERiVwkc8APA/oRJJOPgLOALwElkzi3LzObv32yhFe+Xk3jWlV55Zpe9OvQINZhiUgciuTJZAjQHZjt7teYWUN+GFpF4tRXy7dyz7tzSdm+n6v6tuD3Z3akugZmFJGjFMm3x/6wiXC2mdUENgPNohyXRMnO/Vk8On4Ro5NTaFUvgTE39KV3Kw3MKCKFE0kySTazRODfwCxgD/BNVKOSqPh0wUb++P58tu3N5KZ+bfjN6e00MKOIFIlIWnPdHC6+YGYTgJruPje6YUlR2rongwfGLeDDuRvo1KgmLw3tRdemGphRRIpOJBXw/wWmAV+4++LohyRFxd0ZO2c9f/5gAXszcvi/M9pzY782GphRRIpcJN8qI4BGwLNmttLM3jGz3xzpIDNrZmafmdlCM1uQd4yZ1TGziWa2LPxZO9xuZvaMmS03s7lm1iPfuYaG+y8zs6FHea9lysad6Vw7Mpk7Rs+hZb0Ext9+Ered3k6JRESiwtwPOZniDzuZlQd6Af2BGwkq5Tse4ZhGQCN3/87MahDUt5wHXA1sd/fHzOweoLa7321mZwO3AWcDfYCn3b2PmdUBkoEkgp73s4Ce7r7jUNdOSkry5OSyO+XKko27uWrEdHbtz+auQR0YemJLymtgRhE5AjObdcD8VRGLpJhrMsH8Jd8AXwC93H3zkY5z9w3AhnB5t5ktApoAgwn6rQCMBKYCd4fbX/Ugu31rZolhQuoHTMwbC8zMJgJnAm9GfJdlyIxV2/n1yJlUqVied28+kU6NasY6JBEpAyIp85gLZALHAt2AY82sakEuYmYtgeOB6UDDMNEAbAQahstNgJR8h60Ltx1q+4HXuN7Mks0secuWLQUJr9T4dMFGrnxpOvWqV+adm5RIRKT4HDGZuPtv3f0UgrngtwEvA2mRXsDMqgPvAHe4+64Dzu0ERVeF5u7D3T3J3ZPq169fFKeMK2/OWMuNr82iY6OavH3TiRoqXkSKVSTFXLcCJwM9gdUEFfJfRHJyM6tIkEhed/d3w82bzKyRu28Ii7HyisxS+XFnyKbhtlR+KBbL2z41kuuXBe7OM5OX8+SkpfTrUJ/nLu9BtUrqyS4ixSuSYq4qwBNAR3cf4O5/dvcpRzrIgskvXgIWufsT+d4aB+S1yBoKjM23/aqwVdcJwM6wOOwTYKCZ1Q5bfg0Mt5V5ObnOn8bO58lJS7mgRxP+fVWSEomIxEQknRb/bmYnAVcCL5tZfaC6u686wqE/C4+ZZ2Zzwm33AY8BY8zsWmANcFH43kcELbmWA/uAa8Lrbzezh4CZ4X4PamIuSM/K4bej5/Dx/I3ccGpr7jmzoyavEpGYOWLT4HDU4CSgg7u3N7PGwFvu/rPiCPBolPamwbvSs7huZDLTV23njz/vxK9P1jxlIlJ4UW0aDJxP0BLrOwB3Xx/2G5EY2LwrnaEvz2TZpt08dfFxnHf8Txq2iYgUu0iSSaa7u5k5gJklRDkmOYSVW/Zw1YgZbN+byYire3FK+7LXak1ESqZIkskYM3sRSDSz64BfEYwgLMXo+5Q0rnklqDZ687oT6N4sMcYRiYj84LDJJGyRNRroCOwCOgD3u/vEYohNQtLDP5QAABCTSURBVJ8v3cJNr82iTkIlXv1Vb1rXrx7rkEREfuSwySQs3vrI3bsCSiAx8P7sVO5863vaNazByGt60aBmlViHJCLyE5H0M/nOzHpFPRL5if98sZI7Rs8hqWVtRt9wghKJiJRYkdSZ9AEuN7M1wF7ACB5aukU1sjIsN9d5fMJiXpy2krO7HsOTFx9H5QqaEVFESq5IksmgqEch/5OVk8vdb8/l3dmpXNW3BcN+0UXDx4tIiRdJD/g1xRGIwN6MbG5+/Ts+X7qF/zujPbee1la92kUkLmggpxJi+95MrnllJvPWpfHYBV25pHfzWIckIhIxJZMSIGX7PoaOmEFq2n5euKInA7scE+uQREQKRMkkxhZt2MXQETNIz8rhtV/3oVfLOrEOSUSkwJRMYujbldu47tVkEipV4K0bT6TDMRryTETik5JJjEyYv4HbR82hWe2qvHptH5okFmgmZBGREkXJJAZe+3YN94+dT/dmiYwY2ovaCZViHZKISKEomRQjd+fpyct4atIyTuvYgH9d1oOqldQZUUTin5JJMcmbYveN6WsZ0rMpf7mgKxXLRzKajYhIyadkUgzSs3K4Y9QcJizYyM392nDXoA7qjCgipYqSSZTt3J/Fda8mM2PVdu4/pzO/OqlVrEMSESlySiZRtGlXOkNHzGDFlj08c+nxnNu9caxDEhGJCiWTKFmxZQ9XvTSDtH2ZvHx1b05qVy/WIYmIRI2SSRQsWL+TK/4znfLljFHX96Vr01qxDklEJKqUTIpYZnYuvx09h8oVyjPq+hNoWS8h1iGJiESdkkkR+8+XK1m6aQ8vDU1SIhGRMkMdHYrQ2m37eGbyMs7scgynd2oY63BERIqNkkkRcQ86JVYoV44Hzu0S63BERIqVkkkRGT9vQzBD4sD2HFOrSqzDEREpVlFLJmY2wsw2m9n8fNvqmNlEM1sW/qwdbjcze8bMlpvZXDPrke+YoeH+y8xsaLTiLYyd+7P48wcL6dqkFlf1bRnrcEREil00n0xeAc48YNs9wGR3bwdMDtcBzgLaha/rgechSD7AMKAP0BsYlpeASpK/f7KEbXsyePT8rpQvp2FSRKTsiVoycfdpwPYDNg8GRobLI4Hz8m1/1QPfAolm1ggYBEx09+3uvgOYyE8TVEzNXruD16avYeiJLdWfRETKrOKuM2no7hvC5Y1AXpOnJkBKvv3WhdsOtb1EyM7J5b735tOwRhX+b2CHWIcjIhIzMauAd3cHvKjOZ2bXm1mymSVv2bKlqE57WC9/tZpFG3bxwLldqF5ZXXZEpOwq7mSyKSy+Ivy5OdyeCjTLt1/TcNuhtv+Euw939yR3T6pfv36RB36gdTv28cTEpQzo1IBBXdSnRETKtuJOJuOAvBZZQ4Gx+bZfFbbqOgHYGRaHfQIMNLPaYcX7wHBbTLk7D4xbAMAD53bR3CQiUuZFrWzGzN4E+gH1zGwdQausx4AxZnYtsAa4KNz9I+BsYDmwD7gGwN23m9lDwMxwvwfd/cBK/WL3yYJNTFq0mT+c3YmmtavFOhwRkZizoOqidElKSvLk5OSonHtPRjYD/vE5tRMq8cGtP6OCpt4VkVLCzGa5e9LRHKtvwgL6x6dL2LQ7nUfPP1aJREQkpG/DApi3bicjv17NFX1acHzzEtd3UkQkZpRMIpST69z33jzqVq/MXWeqT4mISH5KJhF69ZvVzEvdybBfdKZmlYqxDkdEpERRMonAhp37+cenSzm1fX1+3rVRrMMRESlxlEwi8OdxC8nKyeWhwceqT4mIyEEomRzBpIWbmLBgI78Z0I7mddWnRETkYJRMDmNfZjbDxi2gfcPqXHdy61iHIyJSYml0wsN4atIyUtP28/aNfamoPiUiIoekb8hDWLh+Fy99uYpLezcjqWWdWIcjIlKiKZkcRF6fksSqFbn7zI6xDkdEpMRTMjmIN2asZU5KGn86pzOJ1SrFOhwRkRJPyeQAm3el89cJizmpbT0GH9c41uGIiMQFJZMDPPjhQjKyc3noPPUpERGJlJJJPlOXbObDuRu4tX9bWtVLiHU4IiJxQ8kktD8zhz+NnU/r+gnccKr6lIiIFIT6mYSenbKMlO37efO6E6hcoXyswxERiSt6MgGWbtrN8GkrGdKzKX3b1I11OCIicafMJ5PcXOe+d+dRo0oF7ju7U6zDERGJS2U+mYxJTiF5zQ7uPbsTdRLUp0RE5GiU6WSydU8Gf/l4Mb1b1eHCnk1jHY6ISNwq08nkkfGL2JeZzaPnd1WfEhGRQiizyeSr5Vt5b3YqN53ahrYNqsc6HBGRuFYmk0l6Vg5/fH8+LetW4+b+bWMdjohI3CuT/Uyem7qCVVv38tq1fahSUX1KREQKq8w9mSzfvIcXpq7gvOMac1K7erEOR0SkVChTycTd+cN786hSsRx/+HnnWIcjIlJqlKlk8s53qUxftZ17zupE/RqVYx2OiEipETfJxMzONLMlZrbczO4p6PHb92byyPiF9GxRm0t6NYtGiCIiZVZcJBMzKw/8CzgL6AxcamYFKqf6y0eL2J0e9CkpV059SkREilJcJBOgN7Dc3Ve6eyYwChgc6cHTV27jrVnr+PXJrelwTI2oBSkiUlbFS9PgJkBKvvV1QJ/8O5jZ9cD14WqGmc0/8CT3Pg73Ri3EYlUP2BrrIKJI9xffSvP9leZ7A+hwtAfGSzI5IncfDgwHMLNkd0+KcUhRo/uLb7q/+FWa7w2C+zvaY+OlmCsVyF9r3jTcJiIiJUC8JJOZQDsza2VmlYBLgHExjklEREJxUczl7tlmdivwCVAeGOHuCw5zyPDiiSxmdH/xTfcXv0rzvUEh7s/cvSgDERGRMiheirlERKQEUzIREZFCi+tkcqQhVsysspmNDt+fbmYtiz/KoxfB/V1tZlvMbE74+nUs4jwaZjbCzDYfrD9Q+L6Z2TPhvc81sx7FHWNhRHB//cxsZ77P7v7ijvFomVkzM/vMzBaa2QIz+81B9onbzy/C+4vnz6+Kmc0ws+/D+/vzQfYp+Henu8fli6AifgXQGqgEfA90PmCfm4EXwuVLgNGxjruI7+9q4J+xjvUo7+8UoAcw/xDvnw18DBhwAjA91jEX8f31Az6MdZxHeW+NgB7hcg1g6UF+N+P284vw/uL58zOgerhcEZgOnHDAPgX+7oznJ5NIhlgZDIwMl98GTrf4mey9UEPIlHTuPg3YfphdBgOveuBbINHMGhVPdIUXwf3FLXff4O7fhcu7gUUEo1TkF7efX4T3F7fCz2RPuFoxfB3YEqvA353xnEwONsTKgR/4//Zx92xgJ1C3WKIrvEjuD+CXYTHC22ZWmoZDjvT+41nfsKjhYzPrEutgjkZY/HE8wV+3+ZWKz+8w9wdx/PmZWXkzmwNsBia6+yE/v0i/O+M5mQh8ALR0927ARH74S0JKvu+AFu7eHXgWeD/G8RSYmVUH3gHucPddsY6nqB3h/uL683P3HHc/jmA0kd5mdmxhzxnPySSSIVb+t4+ZVQBqAduKJbrCO+L9ufs2d88IV/8D9Cym2IpDqR5Cx9135RU1uPtHQEUzi5t5pM2sIsEX7evu/u5Bdonrz+9I9xfvn18ed08DPgPOPOCtAn93xnMyiWSIlXHA0HB5CDDFwxqlOHDE+zugDPpcgrLd0mIccFXYKugEYKe7b4h1UEXFzI7JK4M2s94E/xfj4g+dMO6XgEXu/sQhdovbzy+S+4vzz6++mSWGy1WBM4DFB+xW4O/OuBhO5WD8EEOsmNmDQLK7jyP4hfivmS0nqAy9JHYRF0yE93e7mZ0LZBPc39UxC7iAzOxNghYx9cxsHTCMoCIQd38B+IigRdByYB9wTWwiPToR3N8Q4CYzywb2A5fE0R86PwOuBOaF5e4A9wHNoVR8fpHcXzx/fo2AkRZMOlgOGOPuHxb2u1PDqYiISKHFczGXiIiUEEomIiJSaEomIiJSaEomIiJSaEomIiJx7kgDix6w75P5BqhcamZpRRGDkonEDTP7i5n1N7PzzOzeKF+rsZm9Hc1rFAUzu6+A+19tZo2jFY/EzCv8tOPhQbn7b939uLAH/LPAwTqdFpiSicSTPsC3wKnAtGheyN3Xu/uQA7eHvYFLkgIlE4K+SEompczBBhY1szZmNsHMZpnZF2bW8SCHXgq8WRQxKJlIiWdmfzOzuUAv4Bvg18DzB5tDIuzd+46ZzQxfPwu3PxAWBUw1s5Vmdnu4/TEzuyXf8Q+Y2Z1m1jKvyCD8a36cmU0BJptZHTN7Pxxg81sz63aEa7Q0s8Vm9kpYrPC6mQ0ws6/MbFnYgxozSwiPn2Fms81scL7rvxt+MSwzs7/mxQ5UDYsrXj/g36F8eL35ZjbPzH5rZkOAJOD18JiqZtbTzD4Pv3A+sXBUhfAeng73m58vxlPzFZHMNrMaRfU5S5EbDtzm7j2BO4Hn8r9pZi2AVsCUIrlarMfW10uvSF4EieRZgl7kXx1mvzeAk8Ll5gRDYgA8AHwNVAbqEQx9UZFgRNjP8x2/kGBMopaEc5EQ/DW/DqgTrj8LDAuXTwPmHOEaLQlGKehK8AfcLGAEwbwSg4H3w+MfBa4IlxMJ5tFICK+/kmB8pCrAGqBZuN+eQ/w79CQYDTZvPTH8ORVICpcrhvHWD9cvJhhpIW+/f4fLp+T7t/gA+Fm4XB2oEOvfDb3+9xnn/52tTtAzf06+16ID9r8beLaorl/SHtlFDqUHwQRhHTn8GGQDgM72w9QLNS0Y/RVgvAcDY2aY2WagobvPNrMGYT1CfWCHu6fYT2eWm+juecUIJwG/BHD3KWZW18xqHuoa4fZV7j4PwMwWAJPd3c1sHsGXAMBA4FwzuzNcr0I4hEe4/87w+IVAC348xPuBVgKtzexZYDzw6UH26QAcC0wM/73KA/nHz3ozvMdpZlbTgvGcvgKeCJ+E3nX3dYeJQWKnHJDmQb3IoVwC3HKY9wtEyURKNDM7jqBysSmwFagWbLY5QF9333/AIeUIZo1LP+A8ABn5NuXww+//WwRjLR0DjD5EKHsjDPlQ18i/PTffem6+fQz4pbsvOSD2Poc570G5+w4z6w4MAm4ELgJ+dcBuBixw976HOs1PT+uPmdl4gnG3vjKzQe5+4CCBEmPuvsvMVpnZhe7+lgX/Abq5+/cAYf1JbYJi4yKhOhMp0dx9TvjX1VKgM0H57iAPWqMcmEgg+Av8tryVMBkdyWiCv9KGECSWI/kCuDw8fz9gqxfNfB6fALeF//Exs+MjOCbLguHSf8SC4dDLufs7wB8JnuwAdhNMRQuwBKhvZn3DYyrajyd5ujjcfhLBqL87zayNu89z98cJRrY+WKWuFDMLBhb9BuhgZuvM7FqC39Frzex7YAE/nqn1EmCUh+VdRUFPJlLimVle8VOumXV094WH2f124F9hhX0FglZfNx7u/B6MxlwDSPXIhkl/ABgRXmMfPwzVXVgPAU8Bc82sHLAKOOcIxwwP9//O3S/Pt70J8HJ4HoC8ptSvAC+Y2X6gL0ECfcbMahH8ez1F8MUDkG5mswnqVvKeau4ws/4ET1QLCOZ5lxhz90sP8dZBmwu7+wNFHYNGDRaRnzCzqcCd7p4c61gkPqiYS0RECk1PJiIiUmh6MhERkUJTMhERkUJTMhERkUJTMhERkUJTMhERkUL7fzn8gOYcuJELAAAAAElFTkSuQmCC\n",
            "text/plain": [
              "<Figure size 432x288 with 1 Axes>"
            ]
          },
          "metadata": {
            "needs_background": "light"
          }
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "time to jit: 0:00:19.228796\n",
            "time to train: 0:03:26.774632\n",
            "eval steps/sec: 556234.3976894745\n",
            "train steps/sec: 175898.37275935235\n",
            "GPU 0: Tesla V100-SXM2-16GB (UUID: GPU-973380a2-ad8c-9269-e77c-4623d7678679)\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Y2p-20bCi4iI"
      },
      "source": [
        "In this arrangement, we can rollout environment steps much faster than we can train: the speed at which PyTorch can backpropagate the loss and step the optimizer is the bottleneck.  This PyTorch code can probably be sped up by adding [automatic mixed precision](https://pytorch.org/docs/stable/notes/amp_examples.html), and following other recommendations in the [PyTorch performance tuning guide](https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html).\n",
        "\n",
        "We know we have a fair bit of headroom to improve the PyTorch implementation, as the built-in Brax trainer (which uses [flax.optim](https://flax.readthedocs.io/en/latest/flax.optim.html)) runs at nearly double the steps per second:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Xmuz3I21p35H",
        "outputId": "686f64a2-6f58-49fe-c70b-5a304abb5503"
      },
      "source": [
        "train_sps = []\n",
        "\n",
        "def progress(_, metrics):\n",
        "  train_sps.append(metrics['speed/sps'])\n",
        "\n",
        "ppo.train(\n",
        "    environment_fn=envs.create_fn(env_name='ant'), num_timesteps = 30_000_000,\n",
        "    log_frequency = 10, reward_scaling = .1, episode_length = 1000,\n",
        "    normalize_observations = True, action_repeat = 1, unroll_length = 5,\n",
        "    num_minibatches = 32, num_update_epochs = 4, discounting = 0.97,\n",
        "    learning_rate = 3e-4, entropy_cost = 1e-2, num_envs = 2048,\n",
        "    batch_size = 1024, progress_fn = progress)\n",
        "\n",
        "print(f'train steps/sec: {np.mean(train_sps[1:])}')"
      ],
      "execution_count": 7,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "train steps/sec: 320521.9375\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eqXKdDwVL6L4"
      },
      "source": [
        "tunaalabagana! 👋"
      ]
    }
  ]
}
