{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "c9e26963",
   "metadata": {
    "id": "uSUkKaMchXQ9"
   },
   "source": [
    "# Part 6: Custom learned optimizer architectures\n",
    "In [Part 1](https://learned-optimization.readthedocs.io/en/latest/notebooks/Part1_Introduction.html) we introduced the `LearnedOptimizer` abstraction. In this notebook we will discuss how to construct one. We will show 3 examples: Meta-learning hyper parameters, a per-parameter optimizer, and a hyper parameter controller."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "96da5eed",
   "metadata": {
    "executionInfo": {
     "elapsed": 58,
     "status": "ok",
     "timestamp": 1644472716995,
     "user": {
      "displayName": "Luke Metz",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gif9m36RuSe53tMVslYQLofCkRX0_Y47HVoDh3u=s64",
      "userId": "07706439306199750899"
     },
     "user_tz": 480
    },
    "id": "LxTj6OcNLswq"
   },
   "outputs": [],
   "source": [
    "import flax\n",
    "from typing import Any\n",
    "import jax.numpy as jnp\n",
    "import jax\n",
    "\n",
    "from learned_optimization.learned_optimizers import base as lopt_base\n",
    "from learned_optimization.optimizers import base as opt_base"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7f62ec01",
   "metadata": {
    "id": "0fgVMtdVLuC0"
   },
   "source": [
    "## Meta-Learnable hyper parameters\n",
    "Let's first start by defining a learned optimizer with meta-learned hyper parameters. For this, we will choose SGD as the base optimizer, and meta-learn a learning rate and weight decay.\n",
    "\n",
    "\n",
    "First, we define the state of the learned optimizer. This state is used to keep track of the learned optimizer weights. It contains the inner parameters (`params`), the inner `model_state` which is None unless there are non-gradient updated parameters in the inner problem (such as batchnorm statistics), and `iteration` which contains the inner-training step."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "09f58c90",
   "metadata": {
    "executionInfo": {
     "elapsed": 3,
     "status": "ok",
     "timestamp": 1644472718443,
     "user": {
      "displayName": "Luke Metz",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gif9m36RuSe53tMVslYQLofCkRX0_Y47HVoDh3u=s64",
      "userId": "07706439306199750899"
     },
     "user_tz": 480
    },
    "id": "tmkJTQNSLvjj"
   },
   "outputs": [],
   "source": [
    "@flax.struct.dataclass\n",
    "class LOptState:\n",
    "  params: Any\n",
    "  model_state: Any\n",
    "  iteration: jnp.ndarray"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "de053eae",
   "metadata": {
    "id": "nSO2PgeqMF3X"
   },
   "source": [
    "Next for the main optimizer.\n",
    "See the comments inline the code description."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1b89aef",
   "metadata": {
    "executionInfo": {
     "elapsed": 2,
     "status": "ok",
     "timestamp": 1644472737722,
     "user": {
      "displayName": "Luke Metz",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gif9m36RuSe53tMVslYQLofCkRX0_Y47HVoDh3u=s64",
      "userId": "07706439306199750899"
     },
     "user_tz": 480
    },
    "id": "feQ6ZWlmNUWI"
   },
   "outputs": [],
   "source": [
    "MetaParams = Any  # typing definition to label some types below\n",
    "\n",
    "class MetaSGDWD(lopt_base.LearnedOptimizer):\n",
    "  def __init__(self, initial_lr=1e-3, initial_wd=1e-2):\n",
    "    self._initial_lr = initial_lr\n",
    "    self._initial_wd = initial_wd\n",
    "\n",
    "  def init(self, key) -> MetaParams:\n",
    "    \"\"\"Initialize the weights of the learned optimizer.\n",
    "\n",
    "    In this case the initial learning rate, and initial weight decay.\n",
    "    \"\"\"\n",
    "    # These are the initial values with which we would start meta-optimizing from\n",
    "    return {\n",
    "        \"log_lr\": jnp.log(self._initial_lr),\n",
    "        \"log_wd\": jnp.log(self._initial_wd)\n",
    "    }\n",
    "\n",
    "  def opt_fn(self, theta: MetaParams) -> opt_base.Optimizer:\n",
    "    # define an anonymous class which implements the optimizer.\n",
    "    # this captures over the meta-parameters, theta.\n",
    "\n",
    "    class _Opt(opt_base.Optimizer):\n",
    "      def init(self, params, model_state=None, **kwargs) -> LOptState:\n",
    "        # For our inital inner-opt state we pack the params, model state,\n",
    "        # and iteration into the LOptState dataclass.\n",
    "        return LOptState(\n",
    "            params=params,\n",
    "            model_state=model_state,\n",
    "            iteration=jnp.asarray(0, dtype=jnp.int32))\n",
    "\n",
    "      def update(self,\n",
    "                 opt_state: LOptState,\n",
    "                 grads,\n",
    "                 model_state=None,\n",
    "                 **kwargs) -> LOptState:\n",
    "        \"\"\"Perform the actual update.\"\"\"\n",
    "        # We grab the meta-parameters and transform them back to their original\n",
    "        # space\n",
    "        lr = jnp.exp(theta[\"log_lr\"])\n",
    "        wd = jnp.exp(theta[\"log_wd\"])\n",
    "\n",
    "        # Next we define the weight update.\n",
    "        def _update_one(p, g):\n",
    "          return p - g * lr - p * wd\n",
    "\n",
    "        next_params = jax.tree_util.tree_map(_update_one, opt_state.params, grads)\n",
    "        # Pack the new parameters back up\n",
    "        return LOptState(\n",
    "            params=next_params,\n",
    "            model_state=model_state,\n",
    "            iteration=opt_state.iteration + 1)\n",
    "    return _Opt()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "504ba5a0",
   "metadata": {
    "id": "KIJ2gyMBNpi4"
   },
   "source": [
    "To test this out, we can feed in a fake set of params and gradients and look at the new parameter values."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa03f84e",
   "metadata": {
    "executionInfo": {
     "elapsed": 129,
     "status": "ok",
     "timestamp": 1644473371088,
     "user": {
      "displayName": "Luke Metz",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gif9m36RuSe53tMVslYQLofCkRX0_Y47HVoDh3u=s64",
      "userId": "07706439306199750899"
     },
     "user_tz": 480
    },
    "id": "BB9kc27fNpAB",
    "outputId": "dcb51eea-4a4d-46e9-a9a2-8563799b797b"
   },
   "outputs": [],
   "source": [
    "lopt = MetaSGDWD()\n",
    "key = jax.random.PRNGKey(0)\n",
    "theta = lopt.init(key)\n",
    "opt = lopt.opt_fn(theta)\n",
    "fake_params = {\"a\": 1.0, \"b\": 2.0}\n",
    "opt_state = opt.init(fake_params)\n",
    "fake_grads = {\"a\": -1.0, \"b\": 1.0}\n",
    "new_opt_state = opt.update(opt_state, fake_grads)\n",
    "\n",
    "opt.get_params(new_opt_state)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b8604a7b",
   "metadata": {
    "id": "xzGRP13ZN4rJ"
   },
   "source": [
    "## Per Parameter learned optimizer\n",
    "Per parameter learned optimizers involves computing some learned function on\n",
    "each parameter of the inner-model. Because these calculations are done on\n",
    "every parameter, the computational cost of applying the optimizer grows linearly\n",
    "with the number of parameters in the inner problem.\n",
    "\n",
    "To demonstrate this kind of optimizer, we implement a small MLP which operates on gradients,\n",
    "momentum values, and parameters and produces a scalar update.\n",
    "This MLP is applied to each parameter independently. As such, it takes in three\n",
    "scalar inputs (the gradient, momentum, and parameter value), and produces two\n",
    "outputs which are combined to form a single scalar.\n",
    "The same MLP is then applied to every weight."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "696adbd0",
   "metadata": {
    "executionInfo": {
     "elapsed": 60,
     "status": "ok",
     "timestamp": 1644473371311,
     "user": {
      "displayName": "Luke Metz",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gif9m36RuSe53tMVslYQLofCkRX0_Y47HVoDh3u=s64",
      "userId": "07706439306199750899"
     },
     "user_tz": 480
    },
    "id": "WZY4cds2PUA1"
   },
   "outputs": [],
   "source": [
    "@flax.struct.dataclass\n",
    "class PerParamState:\n",
    "  params: Any\n",
    "  model_state: Any\n",
    "  iteration: jnp.ndarray\n",
    "  momentums: Any"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1c610206",
   "metadata": {
    "executionInfo": {
     "elapsed": 55,
     "status": "ok",
     "timestamp": 1644473578782,
     "user": {
      "displayName": "Luke Metz",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gif9m36RuSe53tMVslYQLofCkRX0_Y47HVoDh3u=s64",
      "userId": "07706439306199750899"
     },
     "user_tz": 480
    },
    "id": "RUCGENb7N6D9"
   },
   "outputs": [],
   "source": [
    "import haiku as hk\n",
    "\n",
    "class PerParamMLP(lopt_base.LearnedOptimizer):\n",
    "  def __init__(self, decay=0.9, hidden_size=64):\n",
    "    self.decay = decay\n",
    "    self.hidden_size = hidden_size\n",
    "\n",
    "    def forward(grads, momentum, params):\n",
    "      features = jnp.asarray([params, momentum, grads])\n",
    "      # transpose to have features dim last. The MLP will operate on this,\n",
    "      # and treat the leading dimensions as a batch dimension.\n",
    "      features = jnp.transpose(features,\n",
    "                               list(range(1, 1 + len(grads.shape))) + [0])\n",
    "\n",
    "      outs = hk.nets.MLP([self.hidden_size, 2])(features)\n",
    "\n",
    "      scale = outs[..., 0]\n",
    "      mag = outs[..., 1]\n",
    "      # Compute a step as follows.\n",
    "      return scale * 0.01 * jnp.exp(mag * 0.01)\n",
    "\n",
    "    self.net = hk.without_apply_rng(hk.transform(forward))\n",
    "\n",
    "\n",
    "\n",
    "  def init(self, key) -> MetaParams:\n",
    "    \"\"\"Initialize the weights of the learned optimizer.\"\"\"\n",
    "    # to initialize our neural network, we must pass in a batch that looks like\n",
    "    # data we might train on.\n",
    "    # Because we are operating per parameter, the shape of this batch doesn't\n",
    "    # matter.\n",
    "    fake_grads = fake_params = fake_mom = jnp.zeros([10, 10])\n",
    "    return {\"nn\": self.net.init(key, fake_grads, fake_mom, fake_params)}\n",
    "\n",
    "  def opt_fn(self, theta: MetaParams) -> opt_base.Optimizer:\n",
    "    # define an anonymous class which implements the optimizer.\n",
    "    # this captures over the meta-parameters, theta.\n",
    "\n",
    "    parent = self\n",
    "\n",
    "    class _Opt(opt_base.Optimizer):\n",
    "      def init(self, params, model_state=None, **kwargs) -> LOptState:\n",
    "        # In addition to params, model state, and iteration, we also need the\n",
    "        # initial momentum values.\n",
    "\n",
    "        momentums = jax.tree_util.tree_map(jnp.zeros_like, params)\n",
    "\n",
    "        return PerParamState(\n",
    "            params=params,\n",
    "            model_state=model_state,\n",
    "            iteration=jnp.asarray(0, dtype=jnp.int32),\n",
    "            momentums=momentums)\n",
    "\n",
    "      def update(self,\n",
    "                 opt_state: LOptState,\n",
    "                 grads,\n",
    "                 model_state=None,\n",
    "                 **kwargs) -> LOptState:\n",
    "        \"\"\"Perform the actual update.\"\"\"\n",
    "\n",
    "        # update all the momentums\n",
    "        def _update_one_momentum(m, g):\n",
    "          return m * parent.decay + (g * (1 - parent.decay))\n",
    "\n",
    "        next_moms = jax.tree_util.tree_map(_update_one_momentum, opt_state.momentums,\n",
    "                                 grads)\n",
    "\n",
    "        # Update all the params\n",
    "        def _update_one(g, m, p):\n",
    "          step = parent.net.apply(theta[\"nn\"], g, m, p)\n",
    "          return p - step\n",
    "\n",
    "        next_params = jax.tree_util.tree_map(_update_one, opt_state.params, grads,\n",
    "                                   next_moms)\n",
    "\n",
    "        # Pack the new parameters back up\n",
    "        return PerParamState(\n",
    "            params=next_params,\n",
    "            model_state=model_state,\n",
    "            iteration=opt_state.iteration + 1,\n",
    "            momentums=next_moms)\n",
    "    return _Opt()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e10f7be6",
   "metadata": {
    "id": "EjiNPZSnQ4Ab"
   },
   "source": [
    "Now let's look at what these meta-parameters look like."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8936a5e7",
   "metadata": {
    "executionInfo": {
     "elapsed": 2,
     "status": "ok",
     "timestamp": 1644473615597,
     "user": {
      "displayName": "Luke Metz",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gif9m36RuSe53tMVslYQLofCkRX0_Y47HVoDh3u=s64",
      "userId": "07706439306199750899"
     },
     "user_tz": 480
    },
    "id": "0FFWo-xjQFJu",
    "outputId": "7f736eb0-d3a5-4d25-c6e0-4dcde27f406c"
   },
   "outputs": [],
   "source": [
    "lopt = PerParamMLP()\n",
    "key = jax.random.PRNGKey(0)\n",
    "theta = lopt.init(key)\n",
    "jax.tree_util.tree_map(lambda x: (x.shape, x.dtype), theta)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e8a3cc3d",
   "metadata": {
    "id": "NSywvQWJRCIU"
   },
   "source": [
    "We have a 2 layer MLP. The first layer has 3 input channels (for grads, momentum, parameters), into 64 (hidden size), into 2 for output.\n",
    "\n",
    "We can again apply our optimizer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "054ece76",
   "metadata": {
    "executionInfo": {
     "elapsed": 1,
     "status": "ok",
     "timestamp": 1644473678355,
     "user": {
      "displayName": "Luke Metz",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gif9m36RuSe53tMVslYQLofCkRX0_Y47HVoDh3u=s64",
      "userId": "07706439306199750899"
     },
     "user_tz": 480
    },
    "id": "9oNudcVgQ72x"
   },
   "outputs": [],
   "source": [
    "opt = lopt.opt_fn(theta)\n",
    "fake_params = {\"a\": jnp.ones([2, 3]), \"b\": jnp.ones([1])}\n",
    "opt_state = opt.init(fake_params)\n",
    "fake_grads = {\"a\": -jnp.ones([2, 3]), \"b\": -jnp.ones([1])}\n",
    "new_opt_state = opt.update(opt_state, fake_grads)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "559869f8",
   "metadata": {
    "id": "wKCBWC6gRQPM"
   },
   "source": [
    "We can see both params, and momentum was updated."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6bd412f9",
   "metadata": {
    "executionInfo": {
     "elapsed": 3,
     "status": "ok",
     "timestamp": 1644473698434,
     "user": {
      "displayName": "Luke Metz",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gif9m36RuSe53tMVslYQLofCkRX0_Y47HVoDh3u=s64",
      "userId": "07706439306199750899"
     },
     "user_tz": 480
    },
    "id": "NBVFeyJRRP3X",
    "outputId": "3712a0fd-d3f2-49a7-da6f-47f5bdc33f35"
   },
   "outputs": [],
   "source": [
    "print(opt.get_params(new_opt_state))\n",
    "print(new_opt_state.momentums)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c932a6bd",
   "metadata": {
    "id": "WB7LuabVRim_"
   },
   "source": [
    "## Meta-learned RNN Controllers\n",
    "\n",
    "Another kind of learned optimizer architecture consists of a recurrent \"controller\" which modifies and sets the hyper parameters of some base model.\n",
    "These optimizers often have low overhead as computing hparams to use is often much cheaper than computing the underlying gradients. These optimizers also don't require complex computations to be done at each parameter like the per parameter optimizers above.\n",
    "\n",
    "To demonstrate this family, we will implement an adaptive learning rate optimizer.\n",
    "\n",
    "The RNN we will use needs to operate on some set of features and outputs. For simplicity our learned optimizer will just use the loss as a feature, and produces a learning rate.\n",
    "Because it is a recurrent model, we must also take in the previous and next RNN state. This loss is NOT provided into all optimizers and thus some care should be taken -- anything using this optimizer must know about the loss.\n",
    "\n",
    "\n",
    "For this RNN, we use haiku for no particularly strong reason (Flax, or any other neural network library which allows for creating purely functional NN would work.)\n",
    "\n",
    "This optimizer will additionally have a meta-learnable initial RNN State. We desire this state to be meta-learned and thus it must be constructed by `LearnedOptimizer.init`. This state needs to be updated while applying the optimizer, so when we construct the inner-optimizer state."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "630aa988",
   "metadata": {
    "executionInfo": {
     "elapsed": 2,
     "status": "ok",
     "timestamp": 1644474159670,
     "user": {
      "displayName": "Luke Metz",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gif9m36RuSe53tMVslYQLofCkRX0_Y47HVoDh3u=s64",
      "userId": "07706439306199750899"
     },
     "user_tz": 480
    },
    "id": "Y5qEGiPGTCAP"
   },
   "outputs": [],
   "source": [
    "@flax.struct.dataclass\n",
    "class HParamControllerInnerOptState:\n",
    "  params: Any\n",
    "  model_state: Any\n",
    "  iteration: Any\n",
    "  rnn_hidden_state: Any"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d607cb22",
   "metadata": {
    "id": "A81lpZa7TXdc"
   },
   "source": [
    "First we will define some helper functions which perform the compute of the learned optimizer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1bdd127",
   "metadata": {
    "id": "VWjvA9AETb94"
   },
   "outputs": [],
   "source": [
    "import haiku as hk\n",
    "\n",
    "def rnn_mod():\n",
    "  return hk.LSTM(128)\n",
    "\n",
    "@hk.transform\n",
    "def initial_state_fn():\n",
    "  rnn_hidden_state = rnn_mod().initial_state(batch_size=1)\n",
    "  return rnn_hidden_state\n",
    "\n",
    "@hk.transform\n",
    "def forward_fn(hidden_state, input):\n",
    "  mod = rnn_mod()\n",
    "  output, next_state = mod(input, hidden_state)\n",
    "  log_lr = hk.Linear(1)(output)\n",
    "  return next_state, jnp.exp(log_lr) * 0.01"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3c592f74",
   "metadata": {
    "id": "sKVF8_UhTgH6"
   },
   "source": [
    "Now for the full optimizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9b860ac",
   "metadata": {
    "executionInfo": {
     "elapsed": 2,
     "status": "ok",
     "timestamp": 1644474352955,
     "user": {
      "displayName": "Luke Metz",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gif9m36RuSe53tMVslYQLofCkRX0_Y47HVoDh3u=s64",
      "userId": "07706439306199750899"
     },
     "user_tz": 480
    },
    "id": "CNNjLi7Dm7Wz"
   },
   "outputs": [],
   "source": [
    "class HParamControllerLOPT(lopt_base.LearnedOptimizer):\n",
    "  def init(self, key):\n",
    "    \"\"\"Initialize weights of learned optimizer.\"\"\"\n",
    "    # Only one input -- just the loss.\n",
    "    n_input_features = 1\n",
    "    # This takes no input parameters -- hence the {}.\n",
    "    initial_state = initial_state_fn.apply({}, key)\n",
    "\n",
    "    fake_input_data = jnp.zeros([1, n_input_features])\n",
    "    rnn_params = forward_fn.init(key, initial_state, fake_input_data)\n",
    "    return {\"rnn_params\": rnn_params, \"initial_rnn_hidden_state\": initial_state}\n",
    "\n",
    "  def opt_fn(self, theta):\n",
    "    class _Opt(opt_base.Optimizer):\n",
    "      def init(self, params, model_state=None, **kwargs):\n",
    "        # Copy the initial, meta-learned rnn state into the inner-parameters\n",
    "        # so that it can be updated by the RNN.\n",
    "        return HParamControllerInnerOptState(\n",
    "            params=params,\n",
    "            model_state=model_state,\n",
    "            iteration=jnp.asarray(0, dtype=jnp.int32),\n",
    "            rnn_hidden_state=theta[\"initial_rnn_hidden_state\"])\n",
    "\n",
    "      def update(self, opt_state, grads, loss=None, model_state=None, **kwargs):\n",
    "        # As this loss is not part of the default Optimizer definition, we assert\n",
    "        # that it is non None\n",
    "        assert loss is not None\n",
    "\n",
    "        # Add a batch dimension to the loss\n",
    "        batched_loss = jnp.reshape(loss, [1, 1])\n",
    "\n",
    "        # run the RNN\n",
    "        rnn_forward = hk.without_apply_rng(forward_fn).apply\n",
    "        next_rnn_state, lr = rnn_forward(theta[\"rnn_params\"],\n",
    "                                         opt_state.rnn_hidden_state,\n",
    "                                         batched_loss)\n",
    "\n",
    "        # use the results of the RNN to update the parameters.\n",
    "        def update_one(p, g):\n",
    "          return p - g * lr\n",
    "\n",
    "        next_params = jax.tree_util.tree_map(update_one, opt_state.params, grads)\n",
    "\n",
    "        return HParamControllerInnerOptState(\n",
    "            params=next_params,\n",
    "            model_state=model_state,\n",
    "            iteration=opt_state.iteration + 1,\n",
    "            rnn_hidden_state=next_rnn_state)\n",
    "\n",
    "    return _Opt()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e6633371",
   "metadata": {
    "id": "mVjY1-cHT01N"
   },
   "source": [
    "We can apply this optimizer on some fake parameters. If we look at the state, we will see the parameter values, as well as the rnn hidden state."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8d86e19",
   "metadata": {
    "executionInfo": {
     "elapsed": 3,
     "status": "ok",
     "timestamp": 1644474625301,
     "user": {
      "displayName": "Luke Metz",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14Gif9m36RuSe53tMVslYQLofCkRX0_Y47HVoDh3u=s64",
      "userId": "07706439306199750899"
     },
     "user_tz": 480
    },
    "id": "Bei372EpT1MY",
    "outputId": "6807d2ba-07b1-4ad9-d53c-cf5a4cbc49e1"
   },
   "outputs": [],
   "source": [
    "lopt = HParamControllerLOPT()\n",
    "theta = lopt.init(key)\n",
    "opt = lopt.opt_fn(theta)\n",
    "\n",
    "params = {\"a\": jnp.ones([3, 2]), \"b\": jnp.ones([2, 1])}\n",
    "opt_state = opt.init(params)\n",
    "fake_grads = {\"a\": -jnp.ones([3, 2]), \"b\": -jnp.ones([2, 1])}\n",
    "opt_state = opt.update(opt_state, fake_grads, loss=1.0)\n",
    "jax.tree_util.tree_map(lambda x: x.shape, opt_state)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "48c60fd7",
   "metadata": {
    "id": "RLAJP-MdU_5I"
   },
   "source": [
    "## More LearnedOptimizer architectures\n",
    "\n",
    "Many more learned optimizer architectures are implemented inside the [learned_optimization/learned_optimizers](https://github.com/google/learned_optimization/tree/main/learned_optimization/learned_optimizers) folder. These include:\n",
    "\n",
    "* `nn_adam`: which implements a more sophisticated hyper parameter controller which controls Adam hparams.\n",
    "\n",
    "* `mlp_lopt` and `adafac_mlp_lopt`: which implement more sophisticated per-parameter learned optimizers.\n",
    "\n",
    "* `rnn_mlp_opt`: Implements a hierarchical learned optimizer. A per tensor RNN is used to compute hidden state which is passed to a per-parameter MLP which does the actual weight updates."
   ]
  }
 ],
 "metadata": {
  "jupytext": {
   "formats": "ipynb,md:myst,py",
   "main_language": "python"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
