{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "20970d65",
   "metadata": {
    "id": "20970d65"
   },
   "source": [
    "# Part 2: Custom Tasks, Task Families, and Performance Improvements\n",
    "\n",
    "In this part, we will look at how to define custom tasks and datasets. We will also consider _families_ of tasks, which are common specifications of meta-learning problems. Finally, we will look at how to efficiently parallelize over tasks during training."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ef075664",
   "metadata": {
    "id": "ef075664"
   },
   "source": [
    "## Prerequisites\n",
    "\n",
    "This document assumes knowledge of JAX which is covered in depth at the [JAX Docs](https://jax.readthedocs.io/en/latest/index.html).\n",
    "In particular, we would recomend making your way through [JAX tutorial 101](https://jax.readthedocs.io/en/latest/jax-101/index.html). We also recommend that you have worked your way through Part 1."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f560fa24",
   "metadata": {
    "id": "f560fa24"
   },
   "outputs": [],
   "source": [
    "!pip install git+https://github.com/google/learned_optimization.git"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "04db154b",
   "metadata": {
    "executionInfo": {
     "elapsed": 24640,
     "status": "ok",
     "timestamp": 1643173374165,
     "user": {
      "displayName": "",
      "photoUrl": "",
      "userId": ""
     },
     "user_tz": 480
    },
    "id": "04db154b"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import jax.numpy as jnp\n",
    "import jax\n",
    "from matplotlib import pylab as plt\n",
    "\n",
    "from learned_optimization.outer_trainers import full_es\n",
    "from learned_optimization.outer_trainers import truncated_pes\n",
    "from learned_optimization.outer_trainers import gradient_learner\n",
    "from learned_optimization.outer_trainers import truncation_schedule\n",
    "\n",
    "from learned_optimization.tasks import quadratics\n",
    "from learned_optimization.tasks.fixed import image_mlp\n",
    "from learned_optimization.tasks import base as tasks_base\n",
    "from learned_optimization.tasks.datasets import base as datasets_base\n",
    "\n",
    "from learned_optimization.learned_optimizers import base as lopt_base\n",
    "from learned_optimization.learned_optimizers import mlp_lopt\n",
    "from learned_optimization.optimizers import base as opt_base\n",
    "\n",
    "from learned_optimization import optimizers\n",
    "from learned_optimization import eval_training\n",
    "\n",
    "import haiku as hk\n",
    "import tqdm"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "707298d0",
   "metadata": {
    "id": "707298d0"
   },
   "source": [
    "## Defining a custom Dataset\n",
    "\n",
    "The dataset's in this library consists of iterators which yield batches of the corresponding data. For the provided tasks, these dataset have 4 splits of data rather than the traditional 3. We have \"train\" which is data used by the task to train a model, \"inner_valid\" which contains validation data for use when inner training (training an instance of a task). This could be use for, say, picking hparams. \"outer_valid\" which is used to meta-train with -- this is unseen in inner training and thus serves as a basis to train learned optimizers against. \"test\" which can be used to test the learned optimizer with.\n",
    "\n",
    "To make a dataset, simply write 4 iterators with these splits.\n",
    "\n",
    "For performance reasons, creating these iterators cannot be slow.\n",
    "The existing dataset's make extensive use of caching to share iterators across tasks which use the same data iterators.\n",
    "To account for this reuse, it is expected that these iterators are always randomly sampling data and have a large shuffle buffer so as to not run into any sampling issues."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "df73c83b",
   "metadata": {
    "executionInfo": {
     "elapsed": 3,
     "status": "ok",
     "timestamp": 1643173374354,
     "user": {
      "displayName": "",
      "photoUrl": "",
      "userId": ""
     },
     "user_tz": 480
    },
    "id": "df73c83b",
    "outputId": "435d2986-d008-412e-bd71-bb7d9c404f3d"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'data': array([[0., 0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0., 0.]])}"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import numpy as np\n",
    "\n",
    "\n",
    "def data_iterator():\n",
    "  bs = 3\n",
    "  while True:\n",
    "    batch = {\"data\": np.zeros([bs, 5])}\n",
    "    yield batch\n",
    "\n",
    "\n",
    "@datasets_base.dataset_lru_cache\n",
    "def get_datasets():\n",
    "  return datasets_base.Datasets(\n",
    "      train=data_iterator(),\n",
    "      inner_valid=data_iterator(),\n",
    "      outer_valid=data_iterator(),\n",
    "      test=data_iterator())\n",
    "\n",
    "\n",
    "ds = get_datasets()\n",
    "next(ds.train)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "410f2024",
   "metadata": {
    "id": "410f2024"
   },
   "source": [
    "## Defining a custom `Task`\n",
    "\n",
    "To define a custom class, one simply needs to write a base class of `Task`. Let's look at a simple task consisting of a quadratic task with noisy targets."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "27dbabeb",
   "metadata": {
    "executionInfo": {
     "elapsed": 799,
     "status": "ok",
     "timestamp": 1643173375359,
     "user": {
      "displayName": "",
      "photoUrl": "",
      "userId": ""
     },
     "user_tz": 480
    },
    "id": "27dbabeb",
    "outputId": "394bb22e-4481-4ee6-8d3b-490d1b77f35c"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DeviceArray(10.503748, dtype=float32)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# First we construct data iterators.\n",
    "def noise_datasets():\n",
    "\n",
    "  def _fn():\n",
    "    while True:\n",
    "      yield np.random.normal(size=[4, 2]).astype(dtype=np.float32)\n",
    "\n",
    "  return datasets_base.Datasets(\n",
    "      train=_fn(), inner_valid=_fn(), outer_valid=_fn(), test=_fn())\n",
    "\n",
    "\n",
    "class MyTask(tasks_base.Task):\n",
    "  datasets = noise_datasets()\n",
    "\n",
    "  def loss(self, params, rng, data):\n",
    "    return jnp.sum(jnp.square(params - data))\n",
    "\n",
    "  def init(self, key):\n",
    "    return jax.random.normal(key, shape=(4, 2))\n",
    "\n",
    "\n",
    "task = MyTask()\n",
    "key = jax.random.PRNGKey(0)\n",
    "key1, key = jax.random.split(key)\n",
    "params = task.init(key)\n",
    "\n",
    "task.loss(params, key1, next(task.datasets.train))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a16e5e3b",
   "metadata": {
    "id": "a16e5e3b"
   },
   "source": [
    "## Meta-training on multiple tasks: `TaskFamily`\n",
    "\n",
    "What we have shown previously was meta-training on a single task instance.\n",
    "While sometimes this is sufficient for a given situation, in many situations we seek to meta-train a meta-learning algorithm such as a learned optimizer on a mixture of different tasks.\n",
    "\n",
    "One path to do this is to simply run more than one meta-gradient computation, each with different tasks, average the gradients, and perform one meta-update.\n",
    "This works great when the tasks are quite different -- e.g. meta-gradients when training a convnet vs a MLP.\n",
    "A big negative to this is that these meta-gradient calculations are happening sequentially, and thus making poor use of hardware accelerators like GPU or TPU.\n",
    "\n",
    "As a solution to this problem, we have an abstraction of a `TaskFamily` to enable better use of hardware. A `TaskFamily` represents a distribution over a set of tasks and specifies particular samples from this distribution as a pytree of jax types.\n",
    "\n",
    "The function to sample these configurations is called `sample`, and the function to get a task from the sampled config is `task_fn`. `TaskFamily` also optionally contain datasets which are shared for all the `Task` it creates.\n",
    "\n",
    "As a simple example, let's consider a family of quadratics parameterized by meansquared error to some point which itself is sampled."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "f1c7d7f8",
   "metadata": {
    "executionInfo": {
     "elapsed": 64,
     "status": "ok",
     "timestamp": 1643173375565,
     "user": {
      "displayName": "",
      "photoUrl": "",
      "userId": ""
     },
     "user_tz": 480
    },
    "id": "f1c7d7f8"
   },
   "outputs": [],
   "source": [
    "PRNGKey = jnp.ndarray\n",
    "TaskParams = jnp.ndarray\n",
    "\n",
    "\n",
    "class FixedDimQuadraticFamily(tasks_base.TaskFamily):\n",
    "  \"\"\"A simple TaskFamily with a fixed dimensionality but sampled target.\"\"\"\n",
    "\n",
    "  def __init__(self, dim: int):\n",
    "    super().__init__()\n",
    "    self._dim = dim\n",
    "    self.datasets = None\n",
    "\n",
    "  def sample(self, key: PRNGKey) -> TaskParams:\n",
    "    # Sample the target for the quadratic task.\n",
    "    return jax.random.normal(key, shape=(self._dim,))\n",
    "\n",
    "  def task_fn(self, task_params: TaskParams) -> tasks_base.Task:\n",
    "    dim = self._dim\n",
    "\n",
    "    class _Task(tasks_base.Task):\n",
    "\n",
    "      def loss(self, params, rng, _):\n",
    "        # Compute MSE to the target task.\n",
    "        return jnp.sum(jnp.square(task_params - params))\n",
    "\n",
    "      def init(self, key):\n",
    "        return jax.random.normal(key, shape=(dim,))\n",
    "\n",
    "    return _Task()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "37652293",
   "metadata": {
    "id": "37652293"
   },
   "source": [
    "*With* this task family defined, we can create instances by sampling a configuration and creating a task. This task acts like any other task in that it has an `init` and a `loss` function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "fba3b113",
   "metadata": {
    "executionInfo": {
     "elapsed": 334,
     "status": "ok",
     "timestamp": 1643173376069,
     "user": {
      "displayName": "",
      "photoUrl": "",
      "userId": ""
     },
     "user_tz": 480
    },
    "id": "fba3b113",
    "outputId": "1f62a1e6-8c99-4991-b2d7-380c5adee83a"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DeviceArray(13.190405, dtype=float32)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "task_family = FixedDimQuadraticFamily(10)\n",
    "key = jax.random.PRNGKey(0)\n",
    "task_cfg = task_family.sample(key)\n",
    "task = task_family.task_fn(task_cfg)\n",
    "\n",
    "key1, key = jax.random.split(key)\n",
    "params = task.init(key)\n",
    "batch = None\n",
    "task.loss(params, key, batch)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8b25914f",
   "metadata": {
    "id": "8b25914f"
   },
   "source": [
    "To achive speedups, we can now leverage `jax.vmap` to train *multiple* task instances in parallel! Depending on the task, this can be considerably faster than serially executing them."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "7dded1ea",
   "metadata": {
    "executionInfo": {
     "elapsed": 1508,
     "status": "ok",
     "timestamp": 1643173377718,
     "user": {
      "displayName": "",
      "photoUrl": "",
      "userId": ""
     },
     "user_tz": 480
    },
    "id": "7dded1ea",
    "outputId": "d75a21c5-0210-4482-f088-1b5a0ce92c17"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "single loss 10.973224\n",
      "multiple losses [28.484756  15.884144  10.12129   17.281586  18.210754  17.650654\n",
      " 31.202633  20.745605  21.301374  36.30536   22.189842  21.358437\n",
      " 13.802605  16.462059  13.092703  25.175426  23.442476  13.078012\n",
      " 20.773136  15.165912  23.114235  24.486801  31.850758  11.04059\n",
      "  5.795575  26.002295  31.550493   2.9317625 10.598424  18.45548\n",
      " 24.402779  20.770353 ]\n"
     ]
    }
   ],
   "source": [
    "def train_task(cfg, key):\n",
    "  task = task_family.task_fn(cfg)\n",
    "  key1, key = jax.random.split(key)\n",
    "  params = task.init(key1)\n",
    "  opt = opt_base.Adam()\n",
    "\n",
    "  opt_state = opt.init(params)\n",
    "\n",
    "  for i in range(4):\n",
    "    params = opt.get_params(opt_state)\n",
    "    loss, grad = jax.value_and_grad(task.loss)(params, key, None)\n",
    "    opt_state = opt.update(opt_state, grad, loss=loss)\n",
    "  loss = task.loss(params, key, None)\n",
    "  return loss\n",
    "\n",
    "\n",
    "task_cfg = task_family.sample(key)\n",
    "print(\"single loss\", train_task(task_cfg, key))\n",
    "\n",
    "keys = jax.random.split(key, 32)\n",
    "task_cfgs = jax.vmap(task_family.sample)(keys)\n",
    "losses = jax.vmap(train_task)(task_cfgs, keys)\n",
    "print(\"multiple losses\", losses)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "79f74adc",
   "metadata": {
    "id": "79f74adc"
   },
   "source": [
    "Because of this ability to apply vmap over task families, this is the main building block for a number of the high level libraries in this package. Single tasks can always be converted to a task family with:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "6cd2f682",
   "metadata": {
    "executionInfo": {
     "elapsed": 3041,
     "status": "ok",
     "timestamp": 1643173380925,
     "user": {
      "displayName": "",
      "photoUrl": "",
      "userId": ""
     },
     "user_tz": 480
    },
    "id": "6cd2f682"
   },
   "outputs": [],
   "source": [
    "single_task = image_mlp.ImageMLP_FashionMnist8_Relu32()\n",
    "task_family = tasks_base.single_task_to_family(single_task)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "905293c1",
   "metadata": {
    "id": "905293c1"
   },
   "source": [
    "This wrapper task family has no configuable value and always returns the base task."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "cb049afb",
   "metadata": {
    "executionInfo": {
     "elapsed": 3,
     "status": "ok",
     "timestamp": 1643173381121,
     "user": {
      "displayName": "",
      "photoUrl": "",
      "userId": ""
     },
     "user_tz": 480
    },
    "id": "cb049afb",
    "outputId": "47250a96-577d-4d74-de88-b21d17f27fa3"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "config only contains a dummy value: 0\n"
     ]
    }
   ],
   "source": [
    "cfg = task_family.sample(key)\n",
    "print(\"config only contains a dummy value:\", cfg)\n",
    "task = task_family.task_fn(cfg)\n",
    "# Tasks are the same\n",
    "assert task == single_task"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "760f8e76",
   "metadata": {
    "id": "760f8e76"
   },
   "source": [
    "## Limitations of `TaskFamily`\n",
    "Task families are designed for, and only work for variation that results in a static computation graph. This is required for `jax.vmap` to work.\n",
    "\n",
    "This means things like naively changing hidden sizes, or number of layers, activation functions is off the table.\n",
    "\n",
    "In some cases, one can leverage other jax control flow such as `jax.lax.cond` to select between implementations. For example, one could make a `TaskFamily` that used one of 2 activation functions. While this works, the resulting vectorized computation could be slow and thus profiling is required to determine if this is a good idea or not.\n",
    "\n",
    "In this code base, we use `TaskFamily` to mainly parameterize over different kinds of initializations."
   ]
  }
 ],
 "metadata": {
  "colab": {
   "last_runtime": {
    "build_target": "//learning/deepmind/public/tools/ml_python:ml_notebook",
    "kind": "private"
   },
   "name": "Part2_CustomTasks.ipynb",
   "provenance": []
  },
  "jupytext": {
   "formats": "ipynb,md:myst,py",
   "main_language": "python"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
