{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "146acba4",
   "metadata": {
    "id": "146acba4"
   },
   "source": [
    "# Part 1: Introduction\n",
    "The goal of this colab is to introduce the core abstractions used within this library.\n",
    "These include the `Task` and `Optimizer` objects.\n",
    "\n",
    "We will first introduce these abstractions and illustrate basic functionality. We will then show how to define a custom `Optimizer`, and how to optimize optimizers via gradient-based meta-training.\n",
    "\n",
    "This colab serves as a brief, limited introduction to the capabilities of the library. Further notebooks introduce further functionality as well as more complex learned optimizer models."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6dab76c7",
   "metadata": {
    "id": "6dab76c7"
   },
   "source": [
    "## Prerequisites\n",
    "\n",
    "This document assumes knowledge of JAX which is covered in depth at the [JAX Docs](https://jax.readthedocs.io/en/latest/index.html).\n",
    "In particular, we would recomend making your way through [JAX tutorial 101](https://jax.readthedocs.io/en/latest/jax-101/index.html)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "yizpQK7IvIGg",
   "metadata": {
    "id": "yizpQK7IvIGg"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import jax.numpy as jnp\n",
    "import jax\n",
    "from matplotlib import pylab as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d5c47834",
   "metadata": {
    "id": "d5c47834"
   },
   "outputs": [],
   "source": [
    "!pip install git+https://github.com/google/learned_optimization.git"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "374e391d",
   "metadata": {
    "id": "374e391d"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import jax.numpy as jnp\n",
    "import jax\n",
    "from matplotlib import pylab as plt\n",
    "\n",
    "from learned_optimization.outer_trainers import full_es\n",
    "from learned_optimization.outer_trainers import truncated_pes\n",
    "from learned_optimization.outer_trainers import gradient_learner\n",
    "from learned_optimization.outer_trainers import truncation_schedule\n",
    "\n",
    "from learned_optimization.tasks import quadratics\n",
    "from learned_optimization.tasks.fixed import image_mlp\n",
    "from learned_optimization.tasks import base as tasks_base\n",
    "from learned_optimization.tasks.datasets import base as datasets_base\n",
    "\n",
    "from learned_optimization.learned_optimizers import base as lopt_base\n",
    "from learned_optimization.learned_optimizers import mlp_lopt\n",
    "from learned_optimization.optimizers import base as opt_base\n",
    "\n",
    "from learned_optimization import optimizers\n",
    "from learned_optimization import eval_training\n",
    "\n",
    "import haiku as hk\n",
    "import tqdm"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e341b64a",
   "metadata": {
    "id": "e341b64a"
   },
   "source": [
    "## Tasks\n",
    "\n",
    "A `Task` is an object containing a specification of a machine learning or optimization problem. The `Task` requires:\n",
    "  * Parameters: for example, these may include the decision variables of an arbitrary optimization problem, or parameters of a predictive model such as a neural network. These are initialized through the `init` method.\n",
    "  * Optionally a model state: this includes model parameters which are not to be updated via gradients. One example is the running population statistics used within batch norm.\n",
    "  * Optionally, a `.dataset` attribute with iterators of datasets.\n",
    "  * A loss function: this maps from the parameters, and possibly a batch of data to a loss.\n",
    "\n",
    "This object can be thought of as a loss function, and these are the base objects we train learned optimizers to perform well on.\n",
    "\n",
    "Tasks contain the following:\n",
    "  * A `init` function which initializes the parameters of the task.\n",
    "  * A `loss` function, which evaluates the loss given parameters and data.\n",
    "  * Optionally a `.dataset` attribute with iterators of datasets.\n",
    "\n",
    "For tasks which make use of a model state (e.g. tasks with batchnorm), a `init_with_state` and `loss_with_state` will also be provided.\n",
    "\n",
    "We'll begin by looking at some built-in tasks in the library. In future colabs, we will discuss how custom tasks can be designed, and how families of tasks can be efficiently designed for parallelization.\n",
    "\n",
    "We will look at the `ImageMLP_FashionMnist8_Relu32` task. This task consists of a 1 hidden layer MLP trained on Fashion MNIST resized to 8x8.\n",
    "\n",
    "First, let's initialize the parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49b53916",
   "metadata": {
    "executionInfo": {
     "elapsed": 5659,
     "status": "ok",
     "timestamp": 1643173202623,
     "user": {
      "displayName": "",
      "photoUrl": "",
      "userId": ""
     },
     "user_tz": 480
    },
    "id": "49b53916",
    "outputId": "fb69ee9e-cf0c-4e7f-d601-b9dea16f473f"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'mlp/~/linear_0': {'b': (32,), 'w': (64, 32)},\n",
       " 'mlp/~/linear_1': {'b': (10,), 'w': (32, 10)}}"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "key = jax.random.PRNGKey(0)\n",
    "task = image_mlp.ImageMLP_FashionMnist8_Relu32()\n",
    "\n",
    "params = task.init(key)\n",
    "jax.tree_util.tree_map(lambda x: x.shape, params)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a7ce6c75",
   "metadata": {
    "id": "a7ce6c75"
   },
   "source": [
    "We can see we initialized parameters which correspond to the weights of the MLP.\n",
    "\n",
    "Next, let's look at the data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bf030fd0",
   "metadata": {
    "executionInfo": {
     "elapsed": 4,
     "status": "ok",
     "timestamp": 1643173202813,
     "user": {
      "displayName": "",
      "photoUrl": "",
      "userId": ""
     },
     "user_tz": 480
    },
    "id": "bf030fd0",
    "outputId": "c6b2ee31-b1ef-4fa9-8f07-04e041ac9537"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "FlatMap({\n",
       "  'image': ((128, 8, 8, 1), dtype('float32')),\n",
       "  'label': ((128,), dtype('int32')),\n",
       "})"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "batch = next(task.datasets.train)\n",
    "jax.tree_util.tree_map(lambda x: (x.shape, x.dtype), batch)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7ad6d9b8",
   "metadata": {
    "id": "7ad6d9b8"
   },
   "source": [
    "We get batches of 128 with images of size 8x8 and labels stored as integers.\n",
    "\n",
    "To compute losses, we can call the `loss` function. Some loss functions can be stochastic. For these, in addition to passing in params, and the batch of data, we also pass in a random number."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d5033451",
   "metadata": {
    "executionInfo": {
     "elapsed": 318,
     "status": "ok",
     "timestamp": 1643173203307,
     "user": {
      "displayName": "",
      "photoUrl": "",
      "userId": ""
     },
     "user_tz": 480
    },
    "id": "d5033451",
    "outputId": "f95171bb-5087-4ed1-8ddd-a9a42ff89e88"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DeviceArray(2.3404593, dtype=float32)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "key, key1 = jax.random.split(key)\n",
    "\n",
    "loss = task.loss(params, key1, batch)\n",
    "loss"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "72c0828a",
   "metadata": {
    "id": "72c0828a"
   },
   "source": [
    "Function transformations can also be used to compute gradients."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0c78a15a",
   "metadata": {
    "executionInfo": {
     "elapsed": 672,
     "status": "ok",
     "timestamp": 1643173204150,
     "user": {
      "displayName": "",
      "photoUrl": "",
      "userId": ""
     },
     "user_tz": 480
    },
    "id": "0c78a15a",
    "outputId": "7706bf0f-741d-401a-f70b-8c66c8f10859"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'mlp/~/linear_0': {'b': (32,), 'w': (64, 32)},\n",
       " 'mlp/~/linear_1': {'b': (10,), 'w': (32, 10)}}"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "loss, grad = jax.value_and_grad(task.loss)(params, key1, batch)\n",
    "jax.tree_util.tree_map(lambda x: x.shape, grad)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fbe89c3b",
   "metadata": {
    "id": "fbe89c3b"
   },
   "source": [
    "Now let's pull this together to train this task with SGD. Note that we will _jit_ the loss gradient computation for improved performance---if this is not familiar, we recommend reading about [Just in Time Compilation with JAX](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49cbc7ba",
   "metadata": {
    "executionInfo": {
     "elapsed": 960,
     "status": "ok",
     "timestamp": 1643173205253,
     "user": {
      "displayName": "",
      "photoUrl": "",
      "userId": ""
     },
     "user_tz": 480
    },
    "id": "49cbc7ba",
    "outputId": "cc472f4d-2695-4bc0-df78-6178eea5b0eb"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss at 0: 2.3431787490844727. Test loss: 2.3290696144104004\n",
      "train loss at 100: 1.3063488006591797. Test loss: 1.3161423206329346\n",
      "train loss at 200: 1.0361652374267578. Test loss: 0.9567877054214478\n",
      "train loss at 300: 0.8567416667938232. Test loss: 0.8595446348190308\n",
      "train loss at 400: 0.7901570796966553. Test loss: 0.7775527238845825\n",
      "train loss at 500: 0.7749090194702148. Test loss: 0.8601065874099731\n",
      "train loss at 600: 0.717944860458374. Test loss: 1.0189208984375\n",
      "train loss at 700: 0.664547324180603. Test loss: 0.639930784702301\n",
      "train loss at 800: 0.5981862545013428. Test loss: 0.6370106935501099\n",
      "train loss at 900: 0.6088892817497253. Test loss: 0.8031588792800903\n"
     ]
    }
   ],
   "source": [
    "grad_fn = jax.jit(jax.value_and_grad(task.loss))\n",
    "key = jax.random.PRNGKey(0)\n",
    "params = task.init(key)\n",
    "lr = 0.1\n",
    "\n",
    "for i in range(1000):\n",
    "  key, key1 = jax.random.split(key)\n",
    "  batch = next(task.datasets.train)\n",
    "  l, grads = grad_fn(params, key1, batch)\n",
    "  # apply SGD to each parameter\n",
    "  params = jax.tree_util.tree_map(lambda p, g: p - lr * g, params, grads)\n",
    "  if i % 100 == 0:\n",
    "    test_l = task.loss(params, key, next(task.datasets.test))\n",
    "    print(f\"train loss at {i}: {float(l)}. Test loss: {float(test_l)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "32604359",
   "metadata": {
    "id": "32604359"
   },
   "source": [
    "Note the evaluations in the above are quite noisy as they are only done on a single batch of data."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2115d037",
   "metadata": {
    "id": "2115d037"
   },
   "source": [
    "## Optimizers\n",
    "We have so far implemented a rough SGD optimizer to train our model parameters. In this section, we will develop useful abstractions to create more powerful optimizers.\n",
    "\n",
    "Sadly there is no gold standard interface for optimizers in Jax: there are Flax's optimizers, optax optimizers, optimizers from jaxopt, and optix. This library uses it's own interface to expose additional types of inputs to the optimizer. These additional inputs will become more obvious when we discuss learned optimizers later in this colab, as well as in future colabs.\n",
    "\n",
    "\n",
    "In this library, optimizers are stateless classes that implement:\n",
    "  * an `init` which creates an `OptimizerState` instance which wraps parameters and optionally a model stats as well as contains any additional optimizer state needed (e.g. momentum values)\n",
    "  * a `get_params` and `get_state`  which return the parameters and state of the `OptimizerState`.\n",
    "  * an `update` function which takes in a previous optimizer state, gradients, and optionally a loss values to produce a new `OptimizerState` (with new parameters).\n",
    "\n",
    "\n",
    "Let's look at a couple examples. First, SGD:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e9673b31",
   "metadata": {
    "executionInfo": {
     "elapsed": 5,
     "status": "ok",
     "timestamp": 1643173205421,
     "user": {
      "displayName": "",
      "photoUrl": "",
      "userId": ""
     },
     "user_tz": 480
    },
    "id": "e9673b31",
    "outputId": "848b36a3-3bb4-421d-b558-170e0730a18b"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "OptaxState(params={'a': DeviceArray([0., 0.], dtype=float32)}, state=None, optax_opt_state=(EmptyState(), EmptyState()), iteration=0)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "fake_params = {\"a\": jnp.zeros((2,))}\n",
    "\n",
    "opt = opt_base.SGD(1e-4)\n",
    "opt_state = opt.init(fake_params)\n",
    "opt_state"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9cf3fb23",
   "metadata": {
    "id": "9cf3fb23"
   },
   "source": [
    "We can see the `opt_state` has parameter values, and a couple other values such as current iteration."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5db1e09c",
   "metadata": {
    "executionInfo": {
     "elapsed": 75,
     "status": "ok",
     "timestamp": 1643173205656,
     "user": {
      "displayName": "",
      "photoUrl": "",
      "userId": ""
     },
     "user_tz": 480
    },
    "id": "5db1e09c",
    "outputId": "4eab9598-17a9-4ae0-a3ea-04639ab0dcd7"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "OptaxState(params={'a': DeviceArray([0., 0.], dtype=float32)}, state=None, optax_opt_state=(ScaleByAdamState(count=DeviceArray(0, dtype=int32), mu={'a': DeviceArray([0., 0.], dtype=float32)}, nu={'a': DeviceArray([0., 0.], dtype=float32)}), EmptyState()), iteration=0)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "opt = opt_base.Adam(1e-4)\n",
    "opt_state = opt.init(fake_params)\n",
    "opt_state"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "337afafe",
   "metadata": {
    "id": "337afafe"
   },
   "source": [
    "Adam, on the other hand, has more data inside as it contains first and second moment accumulators.\n",
    "\n",
    "Now let's take one step with an optimizer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "357962b9",
   "metadata": {
    "executionInfo": {
     "elapsed": 68,
     "status": "ok",
     "timestamp": 1643173205854,
     "user": {
      "displayName": "",
      "photoUrl": "",
      "userId": ""
     },
     "user_tz": 480
    },
    "id": "357962b9",
    "outputId": "de9a3d77-11d5-48b9-f293-bed832537cee"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'a': DeviceArray([-9.9999335e-05, -9.9999335e-05], dtype=float32)}"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "fake_grads = {\"a\": jnp.ones((2,))}\n",
    "fake_loss = 10.\n",
    "\n",
    "next_opt_state = opt.update(opt_state, fake_grads, fake_loss)\n",
    "opt.get_params(next_opt_state)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "65114ea8",
   "metadata": {
    "id": "65114ea8"
   },
   "source": [
    "We can see the parameters of our model have been updated slightly.\n",
    "\n",
    "Now let's pull this all together and train a Task with this optimizer API."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "52fa2b0d",
   "metadata": {
    "executionInfo": {
     "elapsed": 543,
     "status": "ok",
     "timestamp": 1643173206548,
     "user": {
      "displayName": "",
      "photoUrl": "",
      "userId": ""
     },
     "user_tz": 480
    },
    "id": "52fa2b0d",
    "outputId": "afb7c02f-0baf-4ec6-9847-f65d8ae90ba2"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.3285213\n",
      "2.2702265\n",
      "2.2235658\n",
      "2.1345925\n",
      "2.0540721\n",
      "2.039238\n",
      "1.922206\n",
      "1.847055\n",
      "1.8123329\n",
      "1.7331252\n"
     ]
    }
   ],
   "source": [
    "task = image_mlp.ImageMLP_FashionMnist8_Relu32()\n",
    "key = jax.random.PRNGKey(0)\n",
    "params = task.init(key)\n",
    "\n",
    "opt = opt_base.Adam(1e-2)\n",
    "opt_state = opt.init(params)\n",
    "\n",
    "for i in range(10):\n",
    "  batch = next(task.datasets.train)\n",
    "  key, key1 = jax.random.split(key)\n",
    "  params = opt.get_params(opt_state)\n",
    "  loss, grads = jax.value_and_grad(task.loss)(params, key1, batch)\n",
    "  opt_state = opt.update(opt_state, grads, loss)\n",
    "  print(loss)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "631099e9",
   "metadata": {
    "id": "631099e9"
   },
   "source": [
    "The above doesn't make use of any sort of `jax.jit` and thus it is slow. In practice, we often like to create one update function which maps from one `opt_state` to the next and jit this entire function. For example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a2ebdbd5",
   "metadata": {
    "executionInfo": {
     "elapsed": 692,
     "status": "ok",
     "timestamp": 1643173207374,
     "user": {
      "displayName": "",
      "photoUrl": "",
      "userId": ""
     },
     "user_tz": 480
    },
    "id": "a2ebdbd5",
    "outputId": "a141a7c1-baa9-47b4-eb35-e3e45a8c7493"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.3393626\n",
      "2.259151\n",
      "2.1479445\n",
      "2.0973787\n",
      "2.089923\n",
      "2.012719\n",
      "1.8739626\n",
      "1.8503203\n",
      "1.7333521\n",
      "1.6747149\n"
     ]
    }
   ],
   "source": [
    "task = image_mlp.ImageMLP_FashionMnist8_Relu32()\n",
    "key = jax.random.PRNGKey(0)\n",
    "params = task.init(key)\n",
    "\n",
    "opt = opt_base.Adam(1e-2)\n",
    "opt_state = opt.init(params)\n",
    "\n",
    "\n",
    "@jax.jit\n",
    "def update(opt_state, key, batch):\n",
    "  key, key1 = jax.random.split(key)\n",
    "  params, model_state = opt.get_params_state(opt_state)\n",
    "  loss, grads = jax.value_and_grad(task.loss)(params, key1, batch)\n",
    "  opt_state = opt.update(opt_state, grads, loss=loss)\n",
    "\n",
    "  return opt_state, key, loss\n",
    "\n",
    "\n",
    "for i in range(10):\n",
    "  batch = next(task.datasets.train)\n",
    "  opt_state, key, loss = update(opt_state, key, batch)\n",
    "  print(loss)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dc2f1728",
   "metadata": {
    "id": "dc2f1728"
   },
   "source": [
    "### Defining a custom `Optimizer`\n",
    "\n",
    "To define a custom optimizer, one simply needs to define a stateless instance of the `Optimizer` class and some pytree object with the optimizer state.\n",
    "\n",
    "As an example let's implement the momentum optimizer. As our state we will use a flax dataclass (though a simple dictionary or named tuple would also suffice)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "65a556ce",
   "metadata": {
    "executionInfo": {
     "elapsed": 167,
     "status": "ok",
     "timestamp": 1643173207718,
     "user": {
      "displayName": "",
      "photoUrl": "",
      "userId": ""
     },
     "user_tz": 480
    },
    "id": "65a556ce",
    "outputId": "966a8921-c6d8-448e-9f0a-fcd9c5f75cc6"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MomentumOptState(params={'a': DeviceArray(1.1, dtype=float32, weak_type=True), 'b': DeviceArray(1.9, dtype=float32, weak_type=True)}, model_state=None, iteration=DeviceArray(1, dtype=int32), momentums={'a': DeviceArray(1.1, dtype=float32, weak_type=True), 'b': DeviceArray(1.9, dtype=float32, weak_type=True)})"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import flax\n",
    "from typing import Any\n",
    "\n",
    "\n",
    "@flax.struct.dataclass\n",
    "class MomentumOptState:\n",
    "  params: Any\n",
    "  model_state: Any\n",
    "  iteration: jnp.ndarray\n",
    "  momentums: Any\n",
    "\n",
    "\n",
    "class MomentumOptimizer(opt_base.Optimizer):\n",
    "\n",
    "  def __init__(self, lr=1e-3, momentum=0.9):\n",
    "    super().__init__()\n",
    "    self._lr = lr\n",
    "    self._momentum = momentum\n",
    "\n",
    "  def get_state(self, opt_state):\n",
    "    return opt_state.model_state\n",
    "\n",
    "  def get_params(self, opt_state):\n",
    "    return opt_state.params\n",
    "\n",
    "  def init(self, params, model_state=None, **kwargs):\n",
    "    return MomentumOptState(\n",
    "        params=params,\n",
    "        model_state=model_state,\n",
    "        momentums=jax.tree_util.tree_map(jnp.zeros_like, params),\n",
    "        iteration=jnp.asarray(0, dtype=jnp.int32))\n",
    "\n",
    "  def update(self, opt_state, grads, loss, model_state=None, **kwargs):\n",
    "    struct = jax.tree_util.tree_structure(grads)\n",
    "    flat_momentum = jax.tree_util.tree_leaves(opt_state.momentums)\n",
    "    flat_grads = jax.tree_util.tree_leaves(grads)\n",
    "    flat_params = jax.tree_util.tree_leaves(opt_state.params)\n",
    "\n",
    "    output_params = []\n",
    "    output_momentums = []\n",
    "    for m, g, p in zip(flat_momentum, flat_grads, flat_params):\n",
    "      next_m = m * self._momentum + g * (1 - self._momentum)\n",
    "      next_p = p - next_m * self._lr\n",
    "      output_params.append(next_p)\n",
    "      output_momentums.append(next_m)\n",
    "    return MomentumOptState(\n",
    "        params=jax.tree_util.tree_unflatten(struct, output_params),\n",
    "        model_state=model_state,\n",
    "        iteration=opt_state.iteration + 1,\n",
    "        momentums=jax.tree_util.tree_unflatten(struct, output_params),\n",
    "    )\n",
    "\n",
    "\n",
    "opt = MomentumOptimizer(lr=1)\n",
    "opt_state = opt.init({\"a\": 1.0, \"b\": 2.0})\n",
    "opt.update(opt_state, {\"a\": -1.0, \"b\": 1.0}, 1.0)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "80bbfe9e",
   "metadata": {
    "id": "80bbfe9e"
   },
   "source": [
    "## Learned Optimizers\n",
    "\n",
    "Learned optimizers are simply optimizers parameterized by some additional set of variables, often called `theta` by convention.\n",
    "\n",
    "Like before, instances of `LearnedOptimizer` should contain no immutable state.\n",
    "\n",
    "They implement 2 functions:\n",
    "  * `init` which initializes the weights of the learned optimizer (e.g. randomly as done with neural networks, or with some fixed values).\n",
    "  * `opt_fn` which takes in the parameters of the learned optimizer, and produces an `Optimizer` instance.\n",
    "\n",
    "\n",
    "One of the simplest forms of learned optimizer is a hand-designed optimizer with meta-learnable hyperparameters. Let's look at `LearnableAdam`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7a8bd53",
   "metadata": {
    "executionInfo": {
     "elapsed": 60,
     "status": "ok",
     "timestamp": 1643173207903,
     "user": {
      "displayName": "",
      "photoUrl": "",
      "userId": ""
     },
     "user_tz": 480
    },
    "id": "d7a8bd53",
    "outputId": "62d867f0-2095-415e-cb03-a1bea9951794"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'log_epsilon': DeviceArray(-18.420681, dtype=float32, weak_type=True),\n",
       " 'log_lr': DeviceArray(-6.9077554, dtype=float32, weak_type=True),\n",
       " 'one_minus_beta1': DeviceArray(-2.3025851, dtype=float32, weak_type=True),\n",
       " 'one_minus_beta2': DeviceArray(-6.9077554, dtype=float32, weak_type=True)}"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lopt = lopt_base.LearnableAdam()\n",
    "theta = lopt.init(key)\n",
    "theta"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "22c5ac36",
   "metadata": {
    "id": "22c5ac36"
   },
   "source": [
    "We see this optimizer has 4 meta-learnable parameters corresponding to log learning rate, 2 values for beta (parameterized as the log of one minus the beta values), and log epsilon.\n",
    "\n",
    "We can access an instance of the optimizer with the opt_fn, and use that optimizer just like the ones in the previous section."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f326f969",
   "metadata": {
    "id": "f326f969"
   },
   "outputs": [],
   "source": [
    "opt = lopt.opt_fn(theta)\n",
    "opt_state = opt.init({\"p\": jnp.zeros([\n",
    "    2,\n",
    "])})"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8c05e64e",
   "metadata": {
    "id": "8c05e64e"
   },
   "source": [
    "With our optimizers split up in this way we can now write functions that are a function of the learned optimizer weights.\n",
    "\n",
    "As an example, let us define a function, `meta_loss` which is the result of applying a learned optimizer to a given problem for some number of steps."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d83689fc",
   "metadata": {
    "executionInfo": {
     "elapsed": 416,
     "status": "ok",
     "timestamp": 1643173208653,
     "user": {
      "displayName": "",
      "photoUrl": "",
      "userId": ""
     },
     "user_tz": 480
    },
    "id": "d83689fc",
    "outputId": "454b36b2-97b2-4615-fce9-b2c58ff254b5"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DeviceArray(2.2636137, dtype=float32)"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "task = image_mlp.ImageMLP_FashionMnist8_Relu32()\n",
    "key = jax.random.PRNGKey(0)\n",
    "\n",
    "lopt = lopt_base.LearnableAdam()\n",
    "\n",
    "\n",
    "def meta_loss(theta, key, batch):\n",
    "  opt = lopt.opt_fn(theta)\n",
    "  key1, key = jax.random.split(key)\n",
    "  param = task.init(key1)\n",
    "  opt_state = opt.init(param)\n",
    "  for i in range(4):\n",
    "    param = opt.get_params(opt_state)\n",
    "    key1, key = jax.random.split(key)\n",
    "    l, grad = jax.value_and_grad(task.loss)(param, key1, batch)\n",
    "    opt_state = opt.update(opt_state, grad, l)\n",
    "\n",
    "  param, state = opt.get_params_state(opt_state)\n",
    "  key1, key = jax.random.split(key)\n",
    "  final_loss = task.loss(param, key1, batch)\n",
    "  return final_loss\n",
    "\n",
    "\n",
    "batch = next(task.datasets.train)\n",
    "meta_loss(theta, key, batch)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bc9d8820",
   "metadata": {
    "id": "bc9d8820"
   },
   "source": [
    "But let's not stop there, we can leverage jax now to easily compute meta-gradients, or gradients with respect to the weights of the learned optimizer. This will take a bit to compile (~20 seconds on my machine) as this computation graph is a bit complex. Note: this can be greatly reduced by leveraging `jax.lax.scan`!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bcffb853",
   "metadata": {
    "executionInfo": {
     "elapsed": 14706,
     "status": "ok",
     "timestamp": 1643173223561,
     "user": {
      "displayName": "",
      "photoUrl": "",
      "userId": ""
     },
     "user_tz": 480
    },
    "id": "bcffb853",
    "outputId": "26b32b58-50f6-42ff-abc0-04ff33c103cc"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'log_epsilon': DeviceArray(7.518652e-08, dtype=float32, weak_type=True),\n",
       " 'log_lr': DeviceArray(-0.04048448, dtype=float32, weak_type=True),\n",
       " 'one_minus_beta1': DeviceArray(-3.1022726e-05, dtype=float32),\n",
       " 'one_minus_beta2': DeviceArray(-2.4318695e-07, dtype=float32)}"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "meta_value_and_grad = jax.jit(jax.value_and_grad(meta_loss))\n",
    "\n",
    "ml, meta_grad = meta_value_and_grad(theta, key, batch)\n",
    "meta_grad"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "683ab7ab",
   "metadata": {
    "id": "683ab7ab"
   },
   "source": [
    "We can see that this meta-gradient is saying we should increase the log learning rate to improve performance.\n",
    "\n",
    "We can now meta-train by using an additional optimizer -- this time to optimize `theta`, the weights of the learned optimizer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2df4120e",
   "metadata": {
    "executionInfo": {
     "elapsed": 17330,
     "status": "ok",
     "timestamp": 1643173241067,
     "user": {
      "displayName": "",
      "photoUrl": "",
      "userId": ""
     },
     "user_tz": 480
    },
    "id": "2df4120e",
    "outputId": "6c5663d7-647e-4fc7-cae2-d75dd5ba8280"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.2463012\n",
      "2.2374353\n",
      "1.75054\n",
      "1.3767055\n",
      "1.0847464\n",
      "1.1990764\n",
      "1.3325281\n",
      "1.2866051\n",
      "1.1671047\n",
      "1.1733787\n",
      "1.5384903\n",
      "1.1215551\n",
      "1.123384\n",
      "1.2072926\n",
      "1.1499238\n",
      "1.1710639\n",
      "1.3856463\n",
      "1.4461384\n",
      "1.1795483\n",
      "1.1503576\n"
     ]
    }
   ],
   "source": [
    "theta_opt = opt_base.Adam(1e-2)\n",
    "\n",
    "key = jax.random.PRNGKey(0)\n",
    "theta = lopt.init(key)\n",
    "theta_opt_state = theta_opt.init(theta)\n",
    "\n",
    "learning_rates = []\n",
    "learnable_adam_meta_losses = []\n",
    "for i in range(2000):\n",
    "  batch = next(task.datasets.train)\n",
    "  key, key1 = jax.random.split(key)\n",
    "  theta = theta_opt.get_params(theta_opt_state)\n",
    "  ml, meta_grad = meta_value_and_grad(theta, key, batch)\n",
    "  theta_opt_state = theta_opt.update(theta_opt_state, meta_grad, ml)\n",
    "  learning_rates.append(theta[\"log_lr\"])\n",
    "  learnable_adam_meta_losses.append(ml)\n",
    "  if i % 100 == 0:\n",
    "    print(ml)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7456b62",
   "metadata": {
    "colab": {
     "height": 296
    },
    "executionInfo": {
     "elapsed": 934,
     "status": "ok",
     "timestamp": 1643173242450,
     "user": {
      "displayName": "",
      "photoUrl": "",
      "userId": ""
     },
     "user_tz": 480
    },
    "id": "a7456b62",
    "outputId": "57cac3cc-cc6e-4dbc-f717-22ebaad14428"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Text(0.5, 0, 'meta-iteration')"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYoAAAEGCAYAAAB7DNKzAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90\nbGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsT\nAAALEwEAmpwYAAAoiElEQVR4nO3dd3hcV53/8fdX1SqWZFnusi33XqMUp4c4IcnihISWsLAJJYEl\nwAI/+tKXpWUDu3QMpBAgjZACJCQmOHFinDjuXXKTLcmWZEnWqFl1zu+PGQnZlsajMlWf1/Po0ehq\nZu5Xd0b3M/ece88x5xwiIiJ9SYh0ASIiEt0UFCIiEpCCQkREAlJQiIhIQAoKEREJKCnSBYRCXl6e\nKygoiHQZIiIxZfPmzdXOuTFnLo/LoCgoKGDTpk2RLkNEJKaY2ZHelqvpSUREAlJQiIhIQHEVFGa2\nysxWezyeSJciIhI34ioonHN/cs7dlZ2dHelSRETiRlwFhYiIDD0FhYiIBKSgEBGRgBQUIjKsNbZ2\n8NvXjrC9tC7SpUQtBYUEtPXoSb7+p900tXZEupQ+7Sr3cO8LRTz02hFa2jsBqKxvYePhWrzeyM+3\n0t7p5URDK68dqqHn/C9er0PzwfRPVX0LB6oa6OzjdW3r8NLcFvx7taW9k/c/8AZfemoXb/3pep7e\nVj6gulo7Orvfe/Eorq7MNrNVwKqZM2dGupSoVNvURnZaMokJBkBRRQPf/Mse5k3I4os3zDvr/rvK\nPfzbrzfS0NrBnmP1PPC+C0hLSQy4jjV7Knl253FuWjqRK+eMDcnfAb6d73O7Knh+VwXP7TpO137j\ngfWHufX8Kax+5RAnGlpZNCmbK2aP4V3nT2ZybnrI6unN3uP1rF53iGd3Hqe1wwvA56+fy4evmMHa\noio+8cg2cjNSeOxDKxgzMjWktRyoaqCmsY0LpuViZue8f1uHl+REC+q+AB2dXp7edowVM0YzMSdt\nsOWe5VjdKf7vb/t5YksZHV5HekoiK+eN44s3zGN89ggAXtl/go/+fittHV7uecdi3rJ44jlr/uSj\n29h4uJb/vnkhT287xmf+sIP5E7KYNW4kAE2tHTy5tZykBOPiGXlMGX36e+hITRP3PF/E2n1VdHgd\nV80Zy3+snMW8CVmD+nurG1tJMCM3I2VQzzNULB4/0RQWFjoN4fFPXq/jgX+U8F9/2cO88Vn88Lal\n5I9K5y0/epUDVY0A/OK95/HmBeMBOHiikbX7qvi/F/eTNSKZ9186jf/+yx4WTsrmoumjmZKbztuW\n558WGs45vvHnPdy/vgQzcA7uuLiAqaPTOdnUxm0XTmFCdt87EOccXkd3iAXS0t7J3b/bwov7qsjL\nTOGtSyfx8ZWz2HLkJF/4406Oe1oYnzWCGxZNYMOhGoorGwC487LpfPbNc0jwr6O5rYPUpMSz1tnU\n2kFG6tmfoTq9jp+sPcCGgzXcfnEB1y0cf1ZdT20tZ+PhWrYcPUlJTTOZqUmsWjKR8Vkj+HtRFUUV\n9Tx61wpuv38jGSlJVNa3sGBiFitm5LF0cjYXz8wjPTmR454W8kelnXNH3d7pZfORk4zPGkFGahIn\nGlqZO34kCQmGc45fvnKI+9eXcNzTAvhek6+umt/n83Z6HavXHeLeF4pITUpgUX42NyyaQEpiAjnp\nKRQWjCItOZG/7a1k4+FaXtlfTafXcaq9k9qmNibnpvGDdy6luLKR9QeqqaxvISHB+NbNC5k5dmSv\n6yytbSYxwfoMmMPVTdz441dp7fBy2/mTWZSfw6aSWh7bVEpyYgLfedsi0pKT+PjDW5k6Op3kxAT2\nHK/nM2+ew91Xnf2hsbmtgzV7KvnFy4fYc7yeL/3LPD542XSqG1u5+t6XmZybxqN3reCvuyr47l/3\nUdXQCkBKYgLfvHkh7yycDMDmI7V84MFNNLV2cP3CCWSlJfGn7cepb2nnw1fM4DPX/vO91tHp5VB1\nE7PH9b4NnHO8VHyCDQdrqKxv4bldFYxMTeKpuy855wecozXNNLZ2kJRo5GWmDipczGyzc67wrOUK\nivi04WANn31iOwlmHK9roa3Ty4KJWRz3tNDc1sG8CVlsPVrHfXcUcs/zxRypaeKLN8zj7/uq+Pu+\nKgCWTcnhh7cuY3JuOk9vK+cHa4o55mmhrcPLBdNyeegDF5Ca5AuLxzaV8tk/7OD2FVP5zHVz+daz\ne3l441G63l5TctP5zi2LGDkimdcP19Da4eV9lxRQ19zO2qIqfrBmPwA3L5vIxTPzWJqfQ0568mk7\ntI2Ha3l8UynrD1RzzNPCN25awHsunNr9zwi+nfWGgzUsnZzDKP8/zLG6U/xgTTGPby7j/IJRvOO8\nyTzyxlG2ltaRmZrEjDGZHK5u4lR7J23+T/6FU0fxqWtmc/HMvO7n/uaf9/CrVw8DkJKUwFMfuYT5\nE32fHPccq+ejD2/h0IkmcjNSWD5lFBdNz+Ud500mOz0ZgJLqJlZ+/2U6vI6MlESe+dilbCqp5ftr\niqlpbKPD60hMMJITjZZ2L5Ny0vjyW+azfGoOeRmpp/2dzjm+93wRv3/9KJ5T7ae99nmZqdz7ziW8\ncbiWH689wCUzR3PZrDEcrGrk8c1lXDlnDNfOH+973sxUUpISuP/VEp7YUkZDSzsnm9t509yxTB2d\nzp+2H6e6sbX7uVMSE0hLScRzqp3M1CQunJZLdloyXucYlz2C3792lAZ/M+WknDQm56ax51g9GalJ\n/PLfCkn3P3bkiCRyM1L5xbqD/OqVw3R6HdPyMnj3BVO48/Lp3es7UNXAXb/ZTE1TG09+5GKmj8ns\n/l1JdRMff2QrO8p8F9jOGTeSxz60gvTURD756Db+uquCJz9yCYvy/3ld1e5jHm6/byPVjW1MyB7B\nV1fN57qFE7p//+LeSj74m00Y4HWwZHIOX/6XeeRmpPCVp3fz6oFqFk7KIicthY0ltUzKSeP+O86n\nIC8DgLrmNr717F4e21TGBQW5/Ojdy0hNSuCu32xmY0kt/3nDvNP+vq7X8nNP7OCxTWWkJCWQl5HC\nzHEjef1QDR1ex6yxmXzhhnlcMds3Vt/+ygYOVzfhdfDEljLW7Knsfq7EBOPpuy9h4aSBXUumoBhG\nmts6uOb762hp7+Si6aOZmDOCKaMzeFfhZOqa2/j0H3bw6v4T3HnZdL5wwzwq61u4/b6N7KtoICMl\nkY9cNZNViycyOffsT7TOOZ7YUs6nH9/O3PEjuX7hBEZlJPPd5/axYGI2j9x1UfcOraqhBeegvO4U\nt61+rbv5pUtKYgJtnb5l8yZkMT4rlXX+T6gA2WnJTMpJwwzyR6Xx/O5KRqYmccG0XN594RSunjcu\n6G3inOOxTaXc83wx1Y2tTMgewTsKJ1Na20x53Slmj8skMzWZ5ESjqbWTv+46zjFPC8um5PDxq2dR\nVtvMl5/ezR0XF/CxN83k6u+/zKyxmfzotuWU153iAw++wYikRP775oVcNWfsaTv1nn73+hH+uKWc\n/3ftbC6e8c8Qau/0svVoHa8eqKahpZ0puek8srGUIv/RUF5mKlfNGYOZb3sWVTRQ3djGZbPyuH7h\nBBIToKXdS0KCcf/6wxw60QTAredP5ls3L+o+wvjVK4e5d00RLe3es2q7bFYe47JGsHLeWK6dP56E\nBKO1o5OTTe14naOyvoUnt5bT1NrJ28/L5/yCUSQlnt7NWVXfwpajdYxKT+5u5tp9zNeEWdPU1us2\necd5+UzJTeeFPZXsLPfw9RsXcO2Ccfz2tSM8sL6EEcmJ/Ow953HBtNyzHtvW4WX9gWpONLZy/cLx\njBzhC2VPcztv/t91pKUk8rP3LGfu+Cy2l9bx3l+/TmZqEve+cykXTMvt9Qh2XfEJntxazpVzxrBq\n8cTu17KlvZMf/K2YX7x8iDEjU7luwXg+fvWss5oNnXM8vrmMrz2zm+TEBJITE6hrbiMhwWjr8PLe\ni6by2evmdNf60GtH+PJTu/jQ5dP59JvnkOzfpq8fquE3G46w5ehJPKfa+fqNC1h/oJqnth3rXtfI\n1CTed+k0pvuDqqSmiY+9aVZQR+a9UVAMI995bh8/f/kgj31oRa//XOA7FO75T17f0s6zO45z+ewx\nQbUxP7vzOD996QC7yusBmD4mg9998MI+m5cqPC3sr2qgscV3NFNZ38Jjm8qYOTaTS2aOZtGkbMwM\nT3M7eyvq2VZax4GqRo57TrHlSB0dXi/vv3Qan1w5mxHJgftJAmlu8/W3LJiYHbC/paW9k0ffKOW+\n9Yc5UtMMwMp5Y/nFewtJTLDuI6guk3LSePjOi85qwx4Mz6l2nt9dwam2Tl7cV8Xmklo6nWN6Xibz\nJmRx0fRcblmef9ZOobG1g9+9doRRGSm8fXn+WaHV0NJObVMbj75RSqfXkZ6SxOLJ2VwVwj6lk01t\nPL2tvLuJ6binhdLaZm5Zns+c8b7mmE6v40MPbeJve31HtGZwzbxxfP2mBQGbLfvy+qEa7npoM/Ut\n7aycN471B6oZnZnC7z940aD6q3aWeZial06Wf0ffl+LKBn6y9gDVja186prZLJ08im/+xdc8O2ts\nJt+6ZRFPbi3n968f5ao5Y/jV7ef3uoOvqm/hrT9ZzzFPCymJCdx1+XSuXTAO56BgdEb3EetQUFAM\nEzvK6rjlp//gluWT+N7bl4R8fU2tHRyrO8XU0RmkJIXmJLqW9k5a271D+g8RrNaOTn789wO0tHfy\nqWvmnBYuxZUNvFx0guRE4y1LJpKXGdoO6eHA09zOvWuKSEtJ5D0XTh30CQh1zW38YE0xD28sZXF+\nNj9+9/Luzu9I+ceBau78zSaa2nxnSd2+YiqfvW5ur/1iXbo+QM0dP5Kc9NB1cCsohoGq+hZu/PF6\nEhOMP3/s0u42epHhrtPf/xMtSmub2VHmYVpeRnc/VzToKyji6vTY4cg5x/6qRl4uOsGDG0qob2nn\n8Q+vUEiI9BBNIQEwOTc97KdrD4aCIoY55/jQQ5t5wX/Ww7wJWfzotmUsmKjRc0Vk6MRVUAy3C+5e\nKj7BC3squXnZJP7ftbPJHxU7n1BEJHbE1RAew2k+CuccP3pxP5Ny0vje2xcrJEQkZOIqKIaTF/dW\nseVoHR++Ynr3edciIqGgPUwM8nod33p2L7PGZnLrBVMiXY6IxDkFRQxas7eSQ9VNfPzqWTqaEJGQ\n014mBv1y3SHyR6Vx/RmD0omIhIKCIsZsPnKSTUdO8oFLp501zo6ISChoTxNjfrnuENlpyd1DHYuI\nhJqCIoaUVDfx/J4K3nPRlIDjwoiIDCUFRQx54B8lJCckcPuKgkiXIiLDiIIiRjjneHbnca6eN5ax\nWZEd/VJEhpe4CgozW2Vmqz0eT6RLGXK7yuupamhlZT8m6xERGQpxFRTxPITH3/ZWYgZXzhkT6VJE\nZJiJq6CIZy/uq2T5lFGM1uQ4IhJmCooYUF53il3l9Vw7X81OIhJ+CooYsGZ3BQDXLtCV2CISfgqK\nGPDCnkpmjc1kWl5GpEsRkWFIQRHl6prbeP1wLdcuULOTiESGgiLKrS2qotPruGa+mp1EJDIUFFHu\n1f01jEpPZvGk+DvlV0Rig4Iiijnn2HCwmhUzRpOQYJEuR0SGKQVFFDvmaeGYp4ULp42OdCkiMowp\nKKLYjtI6AJZMzoloHSIyvCkootj2Mg/Jica8CSMjXYqIDGNxFRTxNijg9tI65k3IIjUpMdKliMgw\nFldBEU+DAnq9jl3lHhbnx/7fIiKxLa6CIp4cqm6iobWDJfk5kS5FRIY5BUWU2q6ObBGJEgqKKLWj\nrI6MlERmjMmMdCkiMswpKKLUtjIPCydlk6gL7UQkwhQUUaitw8veY/VqdhKRqKCgiEJFFQ20dXrV\nkS0iUUFBEYW2ldUB6NRYEYkKCoootKO0jtyMFPJHpUW6FBERBUU02l5Wx5L8bMzUkS0ikaegiDJN\nrR0cqGpksfonRCRKKCiizK5yD14HS3XGk4hECQVFlNmujmwRiTIKiiizvczDpJw0RmemRroUERFA\nQRF1dpTVqdlJRKKKgiKK1DS2Ulp7Ss1OIhJVFBRRZEe5b8IlDd0hItFEQRFFtpfWYQYLJ+mIQkSi\nh4Iiiuws8zBzTCaZqUmRLkVEpFtcBUWsz5m993g9CyZmRboMEZHTxFVQxPKc2fUt7RzztDB7/MhI\nlyIicpq4CopYVlzRAMBcBYWIRBkFRZTY5w+K2eMUFCISXRQUUaK4soGRqUlMytHQ4iISXRQUUWJf\nRQOzx4/U0OIiEnUUFFHAOUdRRYOanUQkKikookBVQyueU+3qyBaRqKSgiAJdHdlzFBQiEoUUFFGg\n69TYOWp6EpEopKCIAvsqGhg7MpVRGSmRLkVE5CwKiihQVFmvZicRiVoKigjr9Dr2Vzaq2UlEopaC\nIsKO1DTR2uHVEYWIRC0FRYQVdY/xpFFjRSQ6nTMozGy2mb1oZrv8Py82sy+FvrThoaiyATOYOTYz\n0qWIiPQqmCOKXwJfANoBnHM7gFtDWdRwUlTRQMHoDNJSEiNdiohIr4IJinTn3MYzlnWEopjhqKii\nQR3ZIhLVggmKajObATgAM3s7cDykVQ0TLe2dlNQ0abIiEYlqwUzOfDewGphrZuXAYeBfQ1rVMHGg\nqhGv02RFIhLdggkK55xbaWYZQIJzrsHMpoW6sOGgSGM8iUgMCKbp6QkA51yTc67Bv+wPoStp+Ciq\nbCAlKYGpuemRLkVEpE99HlGY2VxgAZBtZrf0+FUWMCLUhQ0H+yoamDU2k6REXc4iItErUNPTHOAt\nQA6wqsfyBuDOENY0bBRXNHDxjNGRLkNEJKA+g8I59zTwtJmtcM5tCGNNw4KnuZ2K+hb1T4hI1Aum\nM3urmd2Nrxmqu8nJOff+kFU1DOyrqAfUkS0i0S+YxvGHgPHAm4GXgXx8zU8yCMWVOuNJRGJDMEEx\n0zn3ZaDJOfcg8C/AotCWFf+KKxsZOSKJ8Vk6L0BEolswQdHu/15nZguBbKAgZBUNE8WVvjOezCzS\npYiIBBRMUKw2s1HAl4BngD3Ad0Na1TBwoKqR2RrjSURiQMDObDNLAOqdcyeBdcD0sFR1eg1vxdfc\nNRb4iXPuhXDXMNRqGlupaWrT0OIiEhMCHlE457zARwf65GZ2n5lVdc1l0WP5dWZWZGYHzOzz56jh\nKefcncAdwLsGWks02V/VCKAjChGJCcGcHrvGzD4NPAo0dS10ztUG8dgHgB8Dv+laYGaJwE+Aa4Ay\n4A0zewZIBL59xuPf75yr8t/+kv9xMW+//4ynWeN0RCEi0S+YoOi6XuLuHsscQTRDOefWmVnBGYsv\nAA445w4BmNkjwE3OuW/juxL8NObr7f0O8Jxzbktf6zKzu4C7AKZMmXKu0iKquLKRkak640lEYsM5\ng8I5N9QjxU4CSnv8XAZcGOD+HwNW4htzaqZz7ue93ck5txrfcOgUFha6Iao1JPZXNTBznM54EpHY\nEMwRxVDrbe/Y547dOfdD4IehKyf89lc2snLeuEiXISISlEgMW1oGTO7xcz5wLAJ1RETXGU/qnxCR\nWBGJoHgDmGVm08wsBbgV3/UZw0LXGU+zdMaTiMSIczY9mdnyXhZ7gCPOuY5zPPZh4Eogz8zKgK86\n535tZh8Fnsd3ptN9zrnd/a689/WtAlbNnDlzKJ4uJLqDQtdQiEiMCKaP4qfAcmAHvv6Fhf7bo83s\nw4EugHPO3dbH8meBZ/tfbmDOuT8BfyosLIza+TL2VzaQmZrEhGyd8SQisSGYpqcSYJlzrtA5dx6w\nDNiF70yk74WwtrhUXNnATI3xJCIxJJigmNuzacg5twdfcBwKXVnxyTlHUUUDc9Q/ISIxJJimpyIz\n+xnwiP/ndwHFZpbKP0eWlSBUN7Zxsrldc1CISEwJ5ojiDuAA8Angk8Ah/7J24KoQ1RWXNFmRiMSi\nYK7MPgXc6/86U+OQVzQI0X7W074KX1BoMEARiSXnPKIws0vMbI2ZFZvZoa6vcBTXX865Pznn7srO\nzo50Kb0qrmhgdEYKY0amRroUEZGgBdNH8Wt8TU6bgc7QlhPfiiobdDQhIjEnmD4Kj3PuOedclXOu\npusr5JXFGa/XUVzZoP4JEYk5wRxRrDWze4A/Aq1dCwMN+S1nK687RXNbp4JCRGJOMEHRNQR4YY9l\nDnjT0JcTv4rUkS0iMSqYs55i5hTYaD7rqaiyKyg0xpOIxJY+g8LM3uOc+62Zfaq33zvnvh+6sgYm\nmsd6KqpoYFJOGiNHJEe6FBGRfgl0RJHh/662kiGgjmwRiVV9BoVz7hf+718PXznxqb3Ty8ETjVwx\nZ0ykSxER6bdg5qMYA9wJFPS8v3Pu/aErK77sr2ykvdMxf0JWpEsREem3YM56ehp4BfgbuuBuQHYd\n8wCwcFJ0XjEuIhJIMEGR7pz7XMgriWO7yz1kpCQybXTGue8sIhJlgrky+89mdkPIK4lju4/VM39i\nFgkJmqxIRGJPMEHxH/jC4pSZ1ZtZg5nVh7qwgTCzVWa22uPxRLqUbp1ex57j9SyYqGYnEYlNAYPC\nzBKA65xzCc65NOdclnNupHMuKntlo3H02MPVTTS3dbJgYlRuMhGRcwoYFM45L/A/YaolLu1WR7aI\nxLhgmp5eMLO3mZka2Adg97F6UpISmDlWQ3eISGwK5qynT+G7SrvDzFoAA1y0Nj9Fm93HPMwdP5Lk\nxGAyWUQk+gQzKKDGnRgg5xy7yuu5YdGESJciIjJgwRxRYGajgFnAiK5lzrl1oSoqXpSdPIXnVLs6\nskUkpgUzhMcH8Z0imw9sAy4CNqD5KM5pe1kdAEvycyJah4jIYAR7HcX5wBH/3BTLgBMhrSpObC+t\nIyUpgbkT1HonIrErmKBocc61AJhZqnNuHzAntGUNTLRdcLettI6FE7PUkS0iMS2YPViZmeUATwFr\nzOxp4FgoixqoaLrgrr3Ty85yD0snj4p0KSIigxLMWU83+29+zczWAtnAX0NaVRwormygpd3LksmR\nDy0RkcEI9qynS4FZzrn7/fNTTAIOh7SyGLettA6AZTqiEJEYd86mJzP7KvA54Av+RcnAb0NZVDzY\nXlpHbkYKk3PTIl2KiMigBNNHcTNwI9AE4Jw7hubRPqdtpXUsyc9GI5+ISKwLJijanHMOcABmptl3\nzqGhpZ39VY0smZwT6VJERAYtmKB4zMx+AeSY2Z34pkT9ZWjLim07yz04B0sVFCISB4I56+l/zOwa\noB7f9RNfcc6tCXllMayrI1tXZItIPAjqrCd/MCgcgrS9tI6C0emMykiJdCkiIoPWZ1CYWQP+fokz\nf0WUDjNuZquAVTNnzoxYDc45tpXWcdH00RGrQURkKPXZR9E15WkvX5oKNYCyk6eorG/lvKm6fkJE\n4oMGIRpim4+cBFBQiEjcUFAMsTdKaslMTWLu+Kg86BIR6TcFxRDbfOQky6bkkJigC+1EJD4oKIaQ\n51Q7RZUNanYSkbiioBhCW4+exDk4vyA30qWIiAwZBcUQ2nzkJIkJpiuyRSSuKCiG0KaSk8ybMJKM\n1KCuYxQRiQkKiiHS3ullW2kdhVPV7CQi8UVBMUT2Hq/nVHunOrJFJO4oKIbIGyW+C+0KCxQUIhJf\nFBRDZPORWiblpDEhWzPaiUh8UVAMAeccm0pOqtlJROJSXAWFma0ys9Uejyes6y07eYqqhlY1O4lI\nXIqroIjU6LGbjtQCGghQROJTXAVFpGwqOamBAEUkbikohoAGAhSReKagGKSugQB1oZ2IxCsFxSBt\n8Q8EqI5sEYlXCopB2lRSq4EARSSuKSgGacPBGhZNytZAgCIStxQUg9DU2sGOMg8rZoyOdCkiIiGj\noBiEN0pq6fA6VkxXUIhI/FJQDMKGQzUkJ5o6skUkrikoBuG1gzUsyc8hPUX9EyISvxQUA1Tf0s7O\ncvVPiEj8U1AM0BuHa/E61D8hInFPQTFAGw7WkJKYwHINBCgicU5BMUAbDtWwbEoOI5ITI12KiEhI\nKSgGoK65jT3H69U/ISLDgoJiAF4/XItT/4SIDBMKigHYcLCG1KQElk7JiXQpIiIhp6AYgNcO1VBY\nMIrUJPVPiEj8i6ugCMec2TWNreyraFCzk4gMG3EVFOGYM/uV/dUAXDprTMjWISISTeIqKMLhpaIq\ncjNSWDwpdGEkIhJNFBT94PU61u2v5vJZeSRofmwRGSYUFP2ws9xDbVMbV84ZG+lSRETCRkHRDy8V\nncAMLp+t/gkRGT4UFP3wUnEVi/NzyM1IiXQpIiJho6AI0smmNraV1nGljiZEZJhRUARp3f4TOAdX\nzFFQiMjwoqAI0svFJxiVnsyS/JxIlyIiElYKiiB4vY51xSe4bNYYEnVarIgMMwqKIOw+Vk91YxtX\nqtlJRIYhBUUQXiqqAnRarIgMTwqKILxUfILF+dnkZaZGuhQRkbBTUJxDXXMbW4+e1GmxIjJsKSjO\n4eXiE3gdXKFhO0RkmFJQnMMLeyrJy0xl2eScSJciIhIRCooAWjs6ebnoBNfMH6vRYkVk2FJQBLDh\nYA2NrR1cM39cpEsREYkYBUUAa/ZUkp6SyMUz8iJdiohIxCgo+uD1Ov62t5IrZo9hRHJipMsREYkY\nBUUfdpR7qKxvVbOTiAx7Coo+rNlTQWKC8aa5Oi1WRIY3BUUfXthdyQUFueSka5IiERneFBS9OFzd\nxP6qRq5doGYnEREFRS/W7KkAUP+EiAgKil6t2VPJ/AlZ5I9Kj3QpIiIRp6A4Q3VjK5uPnNTRhIiI\nn4LiDH/fW4XXof4JERE/BcUZXthTyaScNOZPyIp0KSIiUSHqg8LM5pnZz83sD2b276FcV3NbB6/s\nP8E188dhpkEARUQgxEFhZveZWZWZ7Tpj+XVmVmRmB8zs84Gewzm31zn3YeCdQGEo631lfzWtHV6u\nVf+EiEi3UB9RPABc13OBmSUCPwGuB+YDt5nZfDNbZGZ/PuNrrP8xNwKvAi+GstgXdleSnZbM+dNy\nQ7kaEZGYkhTKJ3fOrTOzgjMWXwAccM4dAjCzR4CbnHPfBt7Sx/M8AzxjZn8Bft/bfczsLuAugClT\npgyo3uljMnhP9hSSE6O+RU5EJGxCGhR9mASU9vi5DLiwrzub2ZXALUAq8Gxf93POrQZWAxQWFrqB\nFHb3VTMH8jARkbgWiaDorZe4zx27c+4l4KVQFSMiIoFFoo2lDJjc4+d84FgE6hARkSBEIijeAGaZ\n2TQzSwFuBZ6JQB0iIhKEUJ8e+zCwAZhjZmVm9gHnXAfwUeB5YC/wmHNu9xCtb5WZrfZ4PEPxdCIi\nAphzA+r3jWqFhYVu06ZNkS5DRCSmmNlm59xZ16vpPFAREQlIQSEiIgEpKEREJKC47KMwsxPAkQE+\nPA+oHsJyhorq6h/V1T+qq3+itS4YXG1TnXNjzlwYl0ExGGa2qbfOnEhTXf2juvpHdfVPtNYFoalN\nTU8iIhKQgkJERAJSUJxtdaQL6IPq6h/V1T+qq3+itS4IQW3qoxARkYB0RCEiIgEpKEREJCAFhV9/\n5vEOwbonm9laM9trZrvN7D/8y79mZuVmts3/dUOPx3zBX2uRmb05hLWVmNlO//o3+ZflmtkaM9vv\n/z4qnHWZ2Zwe22SbmdWb2Scitb16mxt+INvIzM7zb+sDZvZDM+tt7pbB1nWPme0zsx1m9qSZ5fiX\nF5jZqR7b7udhrqvfr12Y6nq0R00lZrbNvzyc26uv/UP43mPOuWH/BSQCB4HpQAqwHZgfxvVPAJb7\nb48EivHNJ/414NO93H++v8ZUYJq/9sQQ1VYC5J2x7HvA5/23Pw98N9x1nfHaVQBTI7W9gMuB5cCu\nwWwjYCOwAt/kXs8B14egrmuBJP/t7/aoq6Dn/c54nnDU1e/XLhx1nfH7e4GvRGB79bV/CNt7TEcU\nPt3zeDvn2oBHgJvCtXLn3HHn3Bb/7QZ8w69PCvCQm4BHnHOtzrnDwAF8f0O43AQ86L/9IPDWCNZ1\nNXDQORfoSvyQ1uWcWwfU9rLOoLeRmU0AspxzG5zvP/o3PR4zZHU5515wvqH+AV7DN3FYn8JVVwAR\n3V5d/J+83wk8HOg5QlRXX/uHsL3HFBQ+vc3jHWhHHTJmVgAsA173L/qov5ngvh6HluGs1wEvmNlm\nM7vLv2ycc+44+N7EwNgI1NXlVk7/54309urS3200yX87nDW+H9+nyi7TzGyrmb1sZpf5l4Wzrv68\nduHeXpcBlc65/T2WhX17nbF/CNt7TEHh0695vENWhFkm8ATwCedcPfAzYAawFDiO79AXwlvvJc65\n5cD1wN1mdnmA+4Z1O5pvhsQbgcf9i6Jhe51LX7WEe9v9J9AB/M6/6DgwxTm3DPgU8HszywpjXf19\n7cL9mt7G6R9Iwr69etk/9HnXPmoYcG0KCp+Iz+NtZsn43gS/c879EcA5V+mc63TOeYFf8s/mkrDV\n65w75v9eBTzpr6HSfxjbdahdFe66/K4HtjjnKv01Rnx79dDfbVTG6c1AIavRzG4H3gL8q78JAn8z\nRY3/9mZ87dqzw1XXAF67cG6vJOAW4NEe9YZ1e/W2fyCM7zEFhU9E5/H2t3/+GtjrnPt+j+UTetzt\nZqDrbIxngFvNLNXMpgGz8HVSDXVdGWY2sus2vo7QXf713+6/2+3A0+Gsq4fTPuVFenudoV/byN90\n0GBmF/nfD//W4zFDxsyuAz4H3Oica+6xfIyZJfpvT/fXdSiMdfXrtQtXXX4rgX3Oue5mm3Bur772\nD4TzPTaY3vh4+gJuwHc2wUHgP8O87kvxHQLuALb5v24AHgJ2+pc/A0zo8Zj/9NdaxCDPqghQ13R8\nZ09sB3Z3bRdgNPAisN//PTecdfnXkw7UANk9lkVke+ELq+NAO75PbR8YyDYCCvHtIA8CP8Y/csIQ\n13UAX/t11/vs5/77vs3/Gm8HtgCrwlxXv1+7cNTlX/4A8OEz7hvO7dXX/iFs7zEN4SEiIgGp6UlE\nRAJSUIiISEAKChERCUhBISIiASkoREQkIAWFSABmttR6jGTaj8d9w8xW+m9/wszSh7Cmt5rZ/N7W\nJRIKOj1WJAAzuwModM59dBDPUeJ/jup+PCbROdfZx+8eAP7snPvDQGsS6Q8dUUjcM9/cAfvM7Fdm\ntsvMfmdmK81svX8s/wv8V6HfZ2Zv+Ad6u8l/lf43gHeZb86Bd/nv+w//ff5hZnP6WOcDZvZ2M/s4\nMBFYa2Zr/b+71sw2mNkWM3vcP4ZP19wfXzGzV4F3mNmd/nq2m9kTZpZuZhfjG9/qHn9NM7rW5X+O\nq/217fT/Pak9nvvr/nXuNLO5Id/wEjcUFDJczAT+D1gMzAXeje+K108DX8R3JevfnXPnA1cB9wDJ\nwFeAR51zS51zjwL7gMudbzC4rwDfCrRS59wP8Y2nc5Vz7iozywO+BKx0vsEWN+EbVK5Li3PuUufc\nI8AfnXPnO+eW4Bta+gPOuX/gu3L5M/6aDnY90MxG4LuK+F3OuUVAEvDvPZ672r/On/n/bpGgJEW6\nAJEwOeyc2wlgZruBF51zzsx24puEJh+40cy6dqAjgCm9PE828KCZzcI3rEJyP+u4CN/EMut9w+2Q\nAmzo8ftHe9xeaGbfBHKATOD5czz3HHx/Z7H/5weBu4H/9f/cNZjcZnyD3IkERUEhw0Vrj9veHj97\n8f0fdAJvc84V9XyQmV14xvP8F7DWOXez+eYGeMl/v/vxzRNwzDkXqPPbgDXOudv6+H1Tj9sPAG91\nzm3395VcGeB5u547kK6/uRP970s/qOlJxOd54GP+UTUxs2X+5Q34pp/skg2U+2/f0bXQOfc+f1NQ\nbyHR8zleAy4xs5n+9aSb2ew+ahoJHDffENP/2sfz9bQPKOh6buC9wMt9PLdI0BQUIj7/ha8ZaYeZ\n7fL/DLAWmN/VmY1vnuJvm9l6fPN1B2M18JyZrXXOncAXMA+b2Q58wdFXx/KX8c1ktgZfCHR5BPiM\nv9N6RtdC51wL8D7gcX+Tmhf4eZA1ivRJp8eKiEhAOqIQEZGAFBQiIhKQgkJERAJSUIiISEAKChER\nCUhBISIiASkoREQkoP8Pm3N3cv06x/AAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 600x400 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import numpy as np\n",
    "from matplotlib import pylab as plt\n",
    "\n",
    "plt.semilogy(np.exp(learning_rates))\n",
    "plt.ylabel(\"learning rate\")\n",
    "plt.xlabel(\"meta-iteration\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2342786b",
   "metadata": {
    "id": "2342786b"
   },
   "source": [
    "And there you have it: we have used gradient-based meta-training to train the hyperparameters of our Adam optimizer! This is the core idea in learned optimizers.\n",
    "\n",
    "Fitting a handful of scalars is a relatively simple application of the tools we have developed. In this library there are a number of more complex learned optimizers. We will explore these models, as well as more complex library functionality, in the next colab notebook."
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "last_runtime": {
    "build_target": "//learning/deepmind/public/tools/ml_python:ml_notebook",
    "kind": "private"
   },
   "name": "Part1_Introduction.ipynb",
   "provenance": []
  },
  "jupytext": {
   "formats": "ipynb,md:myst,py",
   "main_language": "python"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
