{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "3Qxk2eeYfqUH"
   },
   "source": [
    "# Part 4: GradientEstimators"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "executionInfo": {
     "elapsed": 110,
     "status": "ok",
     "timestamp": 1647560713279,
     "user": {
      "displayName": "Luke Metz",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gif9m36RuSe53tMVslYQLofCkRX0_Y47HVoDh3u=s64",
      "userId": "07706439306199750899"
     },
     "user_tz": 240
    },
    "id": "c-TW0zBs3ggj"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import jax.numpy as jnp\n",
    "import jax\n",
    "import functools\n",
    "from matplotlib import pylab as plt\n",
    "from typing import Optional, Tuple, Mapping\n",
    "\n",
    "from learned_optimization.outer_trainers import full_es\n",
    "from learned_optimization.outer_trainers import truncated_pes\n",
    "from learned_optimization.outer_trainers import truncated_grad\n",
    "from learned_optimization.outer_trainers import gradient_learner\n",
    "from learned_optimization.outer_trainers import truncation_schedule\n",
    "from learned_optimization.outer_trainers import common\n",
    "from learned_optimization.outer_trainers import lopt_truncated_step\n",
    "from learned_optimization.outer_trainers import truncated_step as truncated_step_mod\n",
    "from learned_optimization.outer_trainers.gradient_learner import WorkerWeights, GradientEstimatorState, GradientEstimatorOut\n",
    "from learned_optimization.outer_trainers import common\n",
    "\n",
    "from learned_optimization.tasks import quadratics\n",
    "from learned_optimization.tasks.fixed import image_mlp\n",
    "from learned_optimization.tasks import base as tasks_base\n",
    "\n",
    "from learned_optimization.learned_optimizers import base as lopt_base\n",
    "from learned_optimization.learned_optimizers import mlp_lopt\n",
    "from learned_optimization.optimizers import base as opt_base\n",
    "\n",
    "from learned_optimization import optimizers\n",
    "from learned_optimization import training\n",
    "from learned_optimization import eval_training\n",
    "\n",
    "import haiku as hk\n",
    "import tqdm"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "0iM51dVseYZk"
   },
   "source": [
    "Gradient estimators provide an interface to estimate gradients of some loss with respect to the parameters of some meta-learned system.\n",
    "`GradientEstimator` are not specific to learned optimizers, and can be applied to any unrolled system defined by a `TruncatedStep` (see previous colab).\n",
    "\n",
    "`learned_optimization` supports a handful of estimators each with different strengths and weaknesses. Understanding which estimators are right for which situations is an open research question. After providing some introductions to the GradientEstimator class, we provide a quick tour of the different estimators implemented here.\n",
    "\n",
    "The `GradientEstimator` base class signature is below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "executionInfo": {
     "elapsed": 2,
     "status": "ok",
     "timestamp": 1647560713907,
     "user": {
      "displayName": "Luke Metz",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gif9m36RuSe53tMVslYQLofCkRX0_Y47HVoDh3u=s64",
      "userId": "07706439306199750899"
     },
     "user_tz": 240
    },
    "id": "9UI2k2uAhVUP"
   },
   "outputs": [],
   "source": [
    "PRNGKey = jnp.ndarray\n",
    "\n",
    "\n",
    "class GradientEstimator:\n",
    "  truncated_step: truncated_step_mod.TruncatedStep\n",
    "\n",
    "  def init_worker_state(self, worker_weights: WorkerWeights,\n",
    "                        key: PRNGKey) -> GradientEstimatorState:\n",
    "    raise NotImplementedError()\n",
    "\n",
    "  def compute_gradient_estimate(\n",
    "      self, worker_weights: WorkerWeights, key: PRNGKey,\n",
    "      state: GradientEstimatorState, with_summary: Optional[bool]\n",
    "  ) -> Tuple[GradientEstimatorOut, Mapping[str, jnp.ndarray]]:\n",
    "    raise NotImplementedError()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "e3LP8MWZhqO6"
   },
   "source": [
    "A gradient estimator must have an instance of a TaskFamily -- or the task that is being used to estimate gradients with, an `init_worker_state` function -- which initializes the current state of the gradient estimator, and a `compute_gradient_estimate` function which takes state and computes a bunch of outputs (`GradientEstimatorOut`) which contain the computed gradients with respect to the learned optimizer, meta-loss values, and various other information about the unroll. Additionally a mapping which contains various metrics is returned.\n",
    "\n",
    "Both of these methods take in a `WorkerWeights` instance. This particular piece of data represents the learnable weights needed to compute a gradients including the weights of the learned optimizer, as well as potentially non-learnable running statistics such as those computed with batch norm. In every case this contains the weights of the meta-learned algorithm (e.g. an optimizer) and is called theta. This can also contain other info though. If the learned optimizer has batchnorm, for example, it could also contain running averages.\n",
    "\n",
    "In the following examples, we will show gradient estimation on learned optimizers using the `VectorizedLOptTruncatedStep`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "executionInfo": {
     "elapsed": 53,
     "status": "ok",
     "timestamp": 1647560728420,
     "user": {
      "displayName": "Luke Metz",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gif9m36RuSe53tMVslYQLofCkRX0_Y47HVoDh3u=s64",
      "userId": "07706439306199750899"
     },
     "user_tz": 240
    },
    "id": "1foCms9R2a10"
   },
   "outputs": [],
   "source": [
    "task_family = quadratics.FixedDimQuadraticFamily(10)\n",
    "lopt = lopt_base.LearnableAdam()\n",
    "# With FullES, there are no truncations, so we set trunc_sched to never ending.\n",
    "trunc_sched = truncation_schedule.NeverEndingTruncationSchedule()\n",
    "truncated_step = lopt_truncated_step.VectorizedLOptTruncatedStep(\n",
    "    task_family,\n",
    "    lopt,\n",
    "    trunc_sched,\n",
    "    num_tasks=3,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "IsfRHPaK-80z"
   },
   "source": [
    "## FullES\n",
    "\n",
    "The FullES estimator is one of the simplest, and most reliable estimators but can be slow in practice as it does not make use of truncations. Instead, it uses antithetic sampling to estimate a gradient via ES of an entire optimization (hence the full in the name).\n",
    "\n",
    "First we define a meta-objective, $f(\\theta)$, which could be the loss at the end of training, or average loss. Next, we compute a gradient estimate via ES gradient estimation:\n",
    "\n",
    "$\\nabla_\\theta f \\approx \\dfrac{\\epsilon}{2\\sigma^2} (f(\\theta + \\epsilon) - f(\\theta - \\epsilon))$\n",
    "\n",
    "We can instantiate one of these as follows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "executionInfo": {
     "elapsed": 54,
     "status": "ok",
     "timestamp": 1647560729615,
     "user": {
      "displayName": "Luke Metz",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gif9m36RuSe53tMVslYQLofCkRX0_Y47HVoDh3u=s64",
      "userId": "07706439306199750899"
     },
     "user_tz": 240
    },
    "id": "W5vQVk7o_VDq"
   },
   "outputs": [],
   "source": [
    "es_trunc_sched = truncation_schedule.ConstantTruncationSchedule(10)\n",
    "gradient_estimator = full_es.FullES(\n",
    "    truncated_step, truncation_schedule=es_trunc_sched)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "executionInfo": {
     "elapsed": 251,
     "status": "ok",
     "timestamp": 1647560730818,
     "user": {
      "displayName": "Luke Metz",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gif9m36RuSe53tMVslYQLofCkRX0_Y47HVoDh3u=s64",
      "userId": "07706439306199750899"
     },
     "user_tz": 240
    },
    "id": "nLOAKLXX_nX4"
   },
   "outputs": [],
   "source": [
    "key = jax.random.PRNGKey(0)\n",
    "theta = truncated_step.outer_init(key)\n",
    "worker_weights = gradient_learner.WorkerWeights(\n",
    "    theta=theta,\n",
    "    theta_model_state=None,\n",
    "    outer_state=gradient_learner.OuterState(0))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "6Mmm0894_poZ"
   },
   "source": [
    "Because we are working with full length unrolls, this gradient estimator has no state -- there is nothing to keep track of truncation to truncation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "executionInfo": {
     "elapsed": 57,
     "status": "ok",
     "timestamp": 1647560731861,
     "user": {
      "displayName": "Luke Metz",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gif9m36RuSe53tMVslYQLofCkRX0_Y47HVoDh3u=s64",
      "userId": "07706439306199750899"
     },
     "user_tz": 240
    },
    "id": "zyaPyPLY_nX5",
    "outputId": "e8995470-6525-49c1-edb0-8962e427d009"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "UnrollState()"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gradient_estimator_state = gradient_estimator.init_worker_state(\n",
    "    worker_weights, key=key)\n",
    "gradient_estimator_state"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "VwBwRmmw_zin"
   },
   "source": [
    "Gradients can be computed with the `compute_gradient_estimate` method."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "executionInfo": {
     "elapsed": 8023,
     "status": "ok",
     "timestamp": 1647560740470,
     "user": {
      "displayName": "Luke Metz",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gif9m36RuSe53tMVslYQLofCkRX0_Y47HVoDh3u=s64",
      "userId": "07706439306199750899"
     },
     "user_tz": 240
    },
    "id": "rbSr9tFc_vth"
   },
   "outputs": [],
   "source": [
    "out, metrics = gradient_estimator.compute_gradient_estimate(\n",
    "    worker_weights, key=key, state=gradient_estimator_state, with_summary=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "executionInfo": {
     "elapsed": 55,
     "status": "ok",
     "timestamp": 1647560740635,
     "user": {
      "displayName": "Luke Metz",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gif9m36RuSe53tMVslYQLofCkRX0_Y47HVoDh3u=s64",
      "userId": "07706439306199750899"
     },
     "user_tz": 240
    },
    "id": "XuoeYAt9_1hL",
    "outputId": "aa75236d-0ed0-4ffd-cf25-77c1790c9ba3"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'log_epsilon': DeviceArray(-0.0173279, dtype=float32),\n",
       " 'log_lr': DeviceArray(-0.00474211, dtype=float32),\n",
       " 'one_minus_beta1': DeviceArray(-0.02331395, dtype=float32),\n",
       " 'one_minus_beta2': DeviceArray(0.00497994, dtype=float32)}"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out.grad"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "tjiUWowcwJ1f"
   },
   "source": [
    "## TruncatedPES\n",
    "\n",
    "Truncated Persistent Evolutionary Strategies (PES) is a unbiased truncation method based on ES. It was proposed in [Unbiased Gradient Estimation in Unrolled Computation Graphs with Persistent Evolution Strategies](https://arxiv.org/abs/2112.13835) and has been a promising tool for training learned optimizers."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "executionInfo": {
     "elapsed": 53,
     "status": "ok",
     "timestamp": 1647560742648,
     "user": {
      "displayName": "Luke Metz",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gif9m36RuSe53tMVslYQLofCkRX0_Y47HVoDh3u=s64",
      "userId": "07706439306199750899"
     },
     "user_tz": 240
    },
    "id": "ailS8_Jbr8CT"
   },
   "outputs": [],
   "source": [
    "trunc_sched = truncation_schedule.ConstantTruncationSchedule(10)\n",
    "truncated_step = lopt_truncated_step.VectorizedLOptTruncatedStep(\n",
    "    task_family,\n",
    "    lopt,\n",
    "    trunc_sched,\n",
    "    num_tasks=3,\n",
    "    random_initial_iteration_offset=10)\n",
    "\n",
    "gradient_estimator = truncated_pes.TruncatedPES(\n",
    "    truncated_step=truncated_step, trunc_length=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "executionInfo": {
     "elapsed": 53,
     "status": "ok",
     "timestamp": 1647560743357,
     "user": {
      "displayName": "Luke Metz",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gif9m36RuSe53tMVslYQLofCkRX0_Y47HVoDh3u=s64",
      "userId": "07706439306199750899"
     },
     "user_tz": 240
    },
    "id": "Nx1BTPIG4gEJ"
   },
   "outputs": [],
   "source": [
    "key = jax.random.PRNGKey(1)\n",
    "theta = truncated_step.outer_init(key)\n",
    "worker_weights = gradient_learner.WorkerWeights(\n",
    "    theta=theta,\n",
    "    theta_model_state=None,\n",
    "    outer_state=gradient_learner.OuterState(0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "executionInfo": {
     "elapsed": 1429,
     "status": "ok",
     "timestamp": 1647560745100,
     "user": {
      "displayName": "Luke Metz",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gif9m36RuSe53tMVslYQLofCkRX0_Y47HVoDh3u=s64",
      "userId": "07706439306199750899"
     },
     "user_tz": 240
    },
    "id": "NlCnF8LT4HBx"
   },
   "outputs": [],
   "source": [
    "gradient_estimator_state = gradient_estimator.init_worker_state(\n",
    "    worker_weights, key=key)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "EvCBA9Z541sn"
   },
   "source": [
    "Now let's look at what this state contains."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "executionInfo": {
     "elapsed": 55,
     "status": "ok",
     "timestamp": 1647560745260,
     "user": {
      "displayName": "Luke Metz",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gif9m36RuSe53tMVslYQLofCkRX0_Y47HVoDh3u=s64",
      "userId": "07706439306199750899"
     },
     "user_tz": 240
    },
    "id": "u1QQxUYf31fy",
    "outputId": "5afa0dd5-a4af-4c40-feaa-f4b89438c8d5"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "PESWorkerState(pos_state=TruncatedUnrollState(inner_opt_state=OptaxState(params=(3, 10), state=None, optax_opt_state=(ScaleByAdamState(count=(3,), mu=(3, 10), nu=(3, 10)), EmptyState()), iteration=(3,)), inner_step=(3,), truncation_state=ConstantTruncationState(length=(3,)), task_param=(3, 10), is_done=(3,)), neg_state=TruncatedUnrollState(inner_opt_state=OptaxState(params=(3, 10), state=None, optax_opt_state=(ScaleByAdamState(count=(3,), mu=(3, 10), nu=(3, 10)), EmptyState()), iteration=(3,)), inner_step=(3,), truncation_state=ConstantTruncationState(length=(3,)), task_param=(3, 10), is_done=(3,)), accumulator={'log_epsilon': (3,), 'log_lr': (3,), 'one_minus_beta1': (3,), 'one_minus_beta2': (3,)})"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "jax.tree_util.tree_map(lambda x: x.shape, gradient_estimator_state)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "6meGBWzt45KV"
   },
   "source": [
    "First, this contains 2 instances of SingleState -- one for the positive perturbation, and one for the negative perturbation. Each one of these contains all the necessary state required to keep track of the training run. This means the opt_state, details from the truncation, the task parameters (sample from the task family), the inner_step, and a bool to determine if done or not.\n",
    "\n",
    "We can compute one gradient estimate as follows."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "executionInfo": {
     "elapsed": 5000,
     "status": "ok",
     "timestamp": 1647560751440,
     "user": {
      "displayName": "Luke Metz",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gif9m36RuSe53tMVslYQLofCkRX0_Y47HVoDh3u=s64",
      "userId": "07706439306199750899"
     },
     "user_tz": 240
    },
    "id": "MSpQTFc45lz2"
   },
   "outputs": [],
   "source": [
    "out, metrics = gradient_estimator.compute_gradient_estimate(\n",
    "    worker_weights, key=key, state=gradient_estimator_state, with_summary=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "vFDZSW5h6Iri"
   },
   "source": [
    "This `out` object contains various outputs from the gradient estimator including gradients with respect to the learned optimizer, as well as the next state of the training models."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "executionInfo": {
     "elapsed": 55,
     "status": "ok",
     "timestamp": 1647560751632,
     "user": {
      "displayName": "Luke Metz",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gif9m36RuSe53tMVslYQLofCkRX0_Y47HVoDh3u=s64",
      "userId": "07706439306199750899"
     },
     "user_tz": 240
    },
    "id": "74AnlkqB4xCV",
    "outputId": "3390c735-d894-4591-c000-2a0e765850e1"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'log_epsilon': DeviceArray(0.00452795, dtype=float32),\n",
       " 'log_lr': DeviceArray(-0.0123316, dtype=float32),\n",
       " 'one_minus_beta1': DeviceArray(0.00704127, dtype=float32),\n",
       " 'one_minus_beta2': DeviceArray(0.00493074, dtype=float32)}"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out.grad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "executionInfo": {
     "elapsed": 55,
     "status": "ok",
     "timestamp": 1647560751802,
     "user": {
      "displayName": "Luke Metz",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gif9m36RuSe53tMVslYQLofCkRX0_Y47HVoDh3u=s64",
      "userId": "07706439306199750899"
     },
     "user_tz": 240
    },
    "id": "82oSxk2i5-3L",
    "outputId": "099bd011-2590-4da0-8e6a-11d72c09d347"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "PESWorkerState(pos_state=TruncatedUnrollState(inner_opt_state=OptaxState(params=(3, 10), state=None, optax_opt_state=(ScaleByAdamState(count=(3,), mu=(3, 10), nu=(3, 10)), EmptyState()), iteration=(3,)), inner_step=(3,), truncation_state=ConstantTruncationState(length=(3,)), task_param=(3, 10), is_done=(3,)), neg_state=TruncatedUnrollState(inner_opt_state=OptaxState(params=(3, 10), state=None, optax_opt_state=(ScaleByAdamState(count=(3,), mu=(3, 10), nu=(3, 10)), EmptyState()), iteration=(3,)), inner_step=(3,), truncation_state=ConstantTruncationState(length=(3,)), task_param=(3, 10), is_done=(3,)), accumulator={'log_epsilon': (3,), 'log_lr': (3,), 'one_minus_beta1': (3,), 'one_minus_beta2': (3,)})"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "jax.tree_util.tree_map(lambda x: x.shape, out.unroll_state)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "MLqCPmkx6cja"
   },
   "source": [
    "One could simply use these gradients to meta-train, and then use the unroll_states as the next state passed into the compute gradient estimate. For example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "executionInfo": {
     "elapsed": 2,
     "status": "ok",
     "timestamp": 1647560751941,
     "user": {
      "displayName": "Luke Metz",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gif9m36RuSe53tMVslYQLofCkRX0_Y47HVoDh3u=s64",
      "userId": "07706439306199750899"
     },
     "user_tz": 240
    },
    "id": "VTPgbwtj6X0I",
    "outputId": "9986c5b6-3a35-46a6-d6fb-24e40261d074"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Progress on inner problem before [1 8 3]\n",
      "Progress on inner problem after [0 7 2]\n"
     ]
    }
   ],
   "source": [
    "print(\"Progress on inner problem before\", out.unroll_state.pos_state.inner_step)\n",
    "out, metrics = gradient_estimator.compute_gradient_estimate(\n",
    "    worker_weights, key=key, state=out.unroll_state, with_summary=False)\n",
    "print(\"Progress on inner problem after\", out.unroll_state.pos_state.inner_step)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "xODaAMI531O3"
   },
   "source": [
    "## TruncatedGrad\n",
    "TruncatedGrad performs truncated backprop through time. This is great for short unrolls, but can run into memory issues, and/or exploding gradients for longer unrolls."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "executionInfo": {
     "elapsed": 53,
     "status": "ok",
     "timestamp": 1647560756579,
     "user": {
      "displayName": "Luke Metz",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gif9m36RuSe53tMVslYQLofCkRX0_Y47HVoDh3u=s64",
      "userId": "07706439306199750899"
     },
     "user_tz": 240
    },
    "id": "p-8rk74x4Dn9"
   },
   "outputs": [],
   "source": [
    "truncated_step = lopt_truncated_step.VectorizedLOptTruncatedStep(\n",
    "    task_family,\n",
    "    lopt,\n",
    "    trunc_sched,\n",
    "    num_tasks=3,\n",
    "    random_initial_iteration_offset=10)\n",
    "\n",
    "gradient_estimator = truncated_grad.TruncatedGrad(\n",
    "    truncated_step=truncated_step, unroll_length=5, steps_per_jit=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "executionInfo": {
     "elapsed": 2,
     "status": "ok",
     "timestamp": 1647560757368,
     "user": {
      "displayName": "Luke Metz",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gif9m36RuSe53tMVslYQLofCkRX0_Y47HVoDh3u=s64",
      "userId": "07706439306199750899"
     },
     "user_tz": 240
    },
    "id": "vfrlWs2T4Dn9"
   },
   "outputs": [],
   "source": [
    "key = jax.random.PRNGKey(1)\n",
    "theta = truncated_step.outer_init(key)\n",
    "worker_weights = gradient_learner.WorkerWeights(\n",
    "    theta=theta,\n",
    "    theta_model_state=None,\n",
    "    outer_state=gradient_learner.OuterState(0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "executionInfo": {
     "elapsed": 2,
     "status": "ok",
     "timestamp": 1647560757501,
     "user": {
      "displayName": "Luke Metz",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gif9m36RuSe53tMVslYQLofCkRX0_Y47HVoDh3u=s64",
      "userId": "07706439306199750899"
     },
     "user_tz": 240
    },
    "id": "Jij6dNdr4Dn9"
   },
   "outputs": [],
   "source": [
    "gradient_estimator_state = gradient_estimator.init_worker_state(\n",
    "    worker_weights, key=key)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {
    "executionInfo": {
     "elapsed": 3,
     "status": "ok",
     "timestamp": 1647560757822,
     "user": {
      "displayName": "Luke Metz",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gif9m36RuSe53tMVslYQLofCkRX0_Y47HVoDh3u=s64",
      "userId": "07706439306199750899"
     },
     "user_tz": 240
    },
    "id": "AnawJAj84Dn-",
    "outputId": "b8821827-9bec-4671-d9ba-f8b063e24e52"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "TruncatedUnrollState(inner_opt_state=OptaxState(params=(3, 10), state=None, optax_opt_state=(ScaleByAdamState(count=(3,), mu=(3, 10), nu=(3, 10)), EmptyState()), iteration=(3,)), inner_step=(3,), truncation_state=ConstantTruncationState(length=(3,)), task_param=(3, 10), is_done=(3,))"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "jax.tree_util.tree_map(lambda x: x.shape, gradient_estimator_state)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {
    "executionInfo": {
     "elapsed": 4768,
     "status": "ok",
     "timestamp": 1647560762830,
     "user": {
      "displayName": "Luke Metz",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gif9m36RuSe53tMVslYQLofCkRX0_Y47HVoDh3u=s64",
      "userId": "07706439306199750899"
     },
     "user_tz": 240
    },
    "id": "b9NXPpmc4Dn-"
   },
   "outputs": [],
   "source": [
    "out, metrics = gradient_estimator.compute_gradient_estimate(\n",
    "    worker_weights, key=key, state=gradient_estimator_state, with_summary=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {
    "executionInfo": {
     "elapsed": 58,
     "status": "ok",
     "timestamp": 1647560763002,
     "user": {
      "displayName": "Luke Metz",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gif9m36RuSe53tMVslYQLofCkRX0_Y47HVoDh3u=s64",
      "userId": "07706439306199750899"
     },
     "user_tz": 240
    },
    "id": "LogPYNnP4Dn-",
    "outputId": "637695b0-b522-489c-e158-ff6b63846226"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'log_epsilon': DeviceArray(1.5270639e-10, dtype=float32),\n",
       " 'log_lr': DeviceArray(-0.03582412, dtype=float32),\n",
       " 'one_minus_beta1': DeviceArray(1.0147129e-06, dtype=float32),\n",
       " 'one_minus_beta2': DeviceArray(-3.5097173e-08, dtype=float32)}"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out.grad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "KP7qnWfZRhIF"
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "colab": {
   "last_runtime": {
    "build_target": "//learning/deepmind/public/tools/ml_python:ml_notebook",
    "kind": "private"
   },
   "name": "Part 4: GradientEstimators",
   "provenance": [
    {
     "file_id": "1cCx9eq3rOgIldNlVITkhK4TzR_caakvw",
     "timestamp": 1647560833887
    }
   ],
   "toc_visible": true
  },
  "jupytext": {
   "formats": "ipynb,md:myst,py",
   "main_language": "python"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "name": "python3"
  },
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
